Skip to content

Commit

Permalink
Add type hints to validator_service
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaStebaev committed Apr 24, 2024
1 parent 0e4f8c4 commit e296c9c
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 71 deletions.
3 changes: 1 addition & 2 deletions skale/contracts/base_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def wrapper(
method=method_name
)

should_wait = tx_hash is not None and wait_for
if should_wait:
if tx_hash is not None and wait_for:
receipt = self.skale.wallet.wait(tx_hash)

should_confirm = receipt is not None and confirmation_blocks > 0
Expand Down
145 changes: 94 additions & 51 deletions skale/contracts/manager/delegation/validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with SKALE.py. If not, see <https://www.gnu.org/licenses/>.

from typing import Any, Dict, List
from eth_typing import ChecksumAddress
from web3 import Web3
from web3.contract.contract import ContractFunction
from web3.types import Wei

from skale.contracts.base_contract import BaseContract, transaction_method
from skale.types.validator import Validator, ValidatorId, ValidatorWithId
from skale.utils.helper import format_fields

from skale.transactions.result import TxRes


FIELDS = [
'name', 'validator_address', 'requested_address', 'description', 'fee_rate',
Expand All @@ -35,16 +38,16 @@
class ValidatorService(BaseContract):
"""Wrapper for ValidatorService.sol functions"""

def __get_raw(self, _id) -> list:
def __get_raw(self, _id: ValidatorId) -> List[Any]:
"""Returns raw validator info.
:returns: Raw validator info
:rtype: list
"""
return self.contract.functions.validators(_id).call()
return list(self.contract.functions.validators(_id).call())

@format_fields(FIELDS)
def get(self, _id) -> list:
def untyped_get(self, _id: ValidatorId) -> List[Any]:
"""Returns validator info.
:returns: Validator info
Expand All @@ -55,25 +58,36 @@ def get(self, _id) -> list:
validator.append(trusted)
return validator

def get_with_id(self, _id) -> dict:
def get(self, _id: ValidatorId) -> Validator:
untyped_validator = self.untyped_get(_id)
if untyped_validator is None:
raise ValueError('Validator with id ', _id, ' is missing')
if isinstance(untyped_validator, dict):
return self._to_validator(untyped_validator)
if isinstance(untyped_validator, list):
return self._to_validator(untyped_validator[0])
raise TypeError(_id)

def get_with_id(self, _id: ValidatorId) -> ValidatorWithId:
"""Returns validator info with ID.
:returns: Validator info with ID
:rtype: dict
"""
validator = self.get(_id)
validator['id'] = _id
return validator
return ValidatorWithId({'id': _id, **validator})
# validator['id'] = _id
# return validator

def number_of_validators(self):
def number_of_validators(self) -> int:
"""Returns number of registered validators.
:returns: List of validators
:rtype: int
"""
return self.contract.functions.numberOfValidators().call()
return int(self.contract.functions.numberOfValidators().call())

def ls(self, trusted_only=False):
def ls(self, trusted_only: bool = False) -> List[ValidatorWithId]:
"""Returns list of registered validators.
:returns: List of validators
Expand All @@ -84,30 +98,42 @@ def ls(self, trusted_only=False):
self.get_with_id(val_id)
for val_id in self.get_trusted_validator_ids()
] if trusted_only else [
self.get_with_id(val_id)
self.get_with_id(ValidatorId(val_id))
for val_id in range(1, number_of_validators + 1)
]
return validators

def get_linked_addresses_by_validator_address(self, address: str) -> list:
def get_linked_addresses_by_validator_address(
self,
address: ChecksumAddress
) -> List[ChecksumAddress]:
"""Returns list of node addresses linked to the validator address.
:returns: List of node addresses
:rtype: list
"""
return self.contract.functions.getMyNodesAddresses().call({
'from': address
})
return [
Web3.to_checksum_address(address)
for address
in self.contract.functions.getMyNodesAddresses().call({'from': address})
]

def get_linked_addresses_by_validator_id(self, validator_id: int) -> list:
def get_linked_addresses_by_validator_id(
self,
validator_id: ValidatorId
) -> List[ChecksumAddress]:
"""Returns list of node addresses linked to the validator ID.
:returns: List of node addresses
:rtype: list
"""
return self.contract.functions.getNodeAddresses(validator_id).call()
return [
Web3.to_checksum_address(address)
for address
in self.contract.functions.getNodeAddresses(validator_id).call()
]

def is_main_address(self, validator_address: str) -> bool:
def is_main_address(self, validator_address: ChecksumAddress) -> bool:
"""Checks if provided address is the main validator address
:returns: True if provided address is the main validator address, otherwise False
Expand All @@ -125,59 +151,63 @@ def is_main_address(self, validator_address: str) -> bool:

return validator_address == validator['validator_address']

def validator_address_exists(self, validator_address: str) -> bool:
def validator_address_exists(self, validator_address: ChecksumAddress) -> bool:
"""Checks if there is a validator with provided address
:returns: True if validator exists, otherwise False
:rtype: bool
"""
return self.contract.functions.validatorAddressExists(validator_address).call()
return bool(self.contract.functions.validatorAddressExists(validator_address).call())

def validator_exists(self, validator_id: str) -> bool:
def validator_exists(self, validator_id: ValidatorId) -> bool:
"""Checks if there is a validator with provided ID
:returns: True if validator exists, otherwise False
:rtype: bool
"""
return self.contract.functions.validatorExists(validator_id).call()
return bool(self.contract.functions.validatorExists(validator_id).call())

def validator_id_by_address(self, validator_address: str) -> int:
def validator_id_by_address(self, validator_address: ChecksumAddress) -> ValidatorId:
"""Returns validator ID by validator address
:returns: Validator ID
:rtype: int
"""
return self.contract.functions.getValidatorId(validator_address).call()
return ValidatorId(self.contract.functions.getValidatorId(validator_address).call())

def get_trusted_validator_ids(self) -> list:
def get_trusted_validator_ids(self) -> List[ValidatorId]:
"""Returns list of trusted validators id.
:returns: List of trusted validators id
:rtype: list
"""
return self.contract.functions.getTrustedValidators().call()
return [
ValidatorId(id)
for id
in self.contract.functions.getTrustedValidators().call()
]

@transaction_method
def _enable_validator(self, validator_id: int) -> TxRes:
def _enable_validator(self, validator_id: ValidatorId) -> ContractFunction:
"""For internal usage only"""
return self.contract.functions.enableValidator(validator_id)

@transaction_method
def _disable_validator(self, validator_id: int) -> TxRes:
def _disable_validator(self, validator_id: ValidatorId) -> ContractFunction:
"""For internal usage only"""
return self.contract.functions.disableValidator(validator_id)

def _is_authorized_validator(self, validator_id: int) -> bool:
def _is_authorized_validator(self, validator_id: ValidatorId) -> bool:
"""For internal usage only"""
return self.contract.functions.isAuthorizedValidator(validator_id).call()
return bool(self.contract.functions.isAuthorizedValidator(validator_id).call())

def is_accepting_new_requests(self, validator_id: int) -> bool:
def is_accepting_new_requests(self, validator_id: ValidatorId) -> bool:
"""For internal usage only"""
return self.contract.functions.isAcceptingNewRequests(validator_id).call()
return bool(self.contract.functions.isAcceptingNewRequests(validator_id).call())

@transaction_method
def register_validator(self, name: str, description: str, fee_rate: int,
min_delegation_amount: int) -> TxRes:
min_delegation_amount: int) -> ContractFunction:
"""Registers a new validator in the SKALE Manager contracts.
:param name: Validator name
Expand All @@ -194,13 +224,13 @@ def register_validator(self, name: str, description: str, fee_rate: int,
return self.contract.functions.registerValidator(
name, description, fee_rate, min_delegation_amount)

def get_link_node_signature(self, validator_id: int) -> str:
def get_link_node_signature(self, validator_id: ValidatorId) -> str:
unsigned_hash = Web3.solidity_keccak(['uint256'], [validator_id])
signed_hash = self.skale.wallet.sign_hash(unsigned_hash.hex())
return signed_hash.signature.hex()

@transaction_method
def link_node_address(self, node_address: str, signature: str) -> TxRes:
def link_node_address(self, node_address: ChecksumAddress, signature: str) -> ContractFunction:
"""Link node address to your validator account.
:param node_address: Address of the node to link
Expand All @@ -213,7 +243,7 @@ def link_node_address(self, node_address: str, signature: str) -> TxRes:
return self.contract.functions.linkNodeAddress(node_address, signature)

@transaction_method
def unlink_node_address(self, node_address: str) -> TxRes:
def unlink_node_address(self, node_address: ChecksumAddress) -> ContractFunction:
"""Unlink node address from your validator account.
:param node_address: Address of the node to unlink
Expand All @@ -224,7 +254,7 @@ def unlink_node_address(self, node_address: str) -> TxRes:
return self.contract.functions.unlinkNodeAddress(node_address)

@transaction_method
def disable_whitelist(self) -> TxRes:
def disable_whitelist(self) -> ContractFunction:
""" Disable validator whitelist. Master key only transaction.
:returns: Transaction results
:rtype: TxRes
Expand All @@ -236,18 +266,18 @@ def get_use_whitelist(self) -> bool:
:returns: useWhitelist value
:rtype: bool
"""
return self.contract.functions.useWhitelist().call()
return bool(self.contract.functions.useWhitelist().call())

def get_and_update_bond_amount(self, validator_id: int) -> int:
def get_and_update_bond_amount(self, validator_id: ValidatorId) -> int:
"""Return amount of token that validator delegated to himself
:param validator_id: id of the validator
:returns:
:rtype: int
"""
return self.contract.functions.getAndUpdateBondAmount(validator_id).call()
return int(self.contract.functions.getAndUpdateBondAmount(validator_id).call())

@transaction_method
def set_validator_mda(self, minimum_delegation_amount: int) -> TxRes:
def set_validator_mda(self, minimum_delegation_amount: Wei) -> ContractFunction:
""" Allows a validator to set the minimum delegation amount.
:param new_minimum_delegation_amount: Minimum delegation amount
Expand All @@ -258,7 +288,7 @@ def set_validator_mda(self, minimum_delegation_amount: int) -> TxRes:
return self.contract.functions.setValidatorMDA(minimum_delegation_amount)

@transaction_method
def request_for_new_address(self, new_validator_address: str) -> TxRes:
def request_for_new_address(self, new_validator_address: ChecksumAddress) -> ContractFunction:
""" Allows a validator to request a new address.
:param new_validator_address: New validator address
Expand All @@ -269,7 +299,7 @@ def request_for_new_address(self, new_validator_address: str) -> TxRes:
return self.contract.functions.requestForNewAddress(new_validator_address)

@transaction_method
def confirm_new_address(self, validator_id: int) -> TxRes:
def confirm_new_address(self, validator_id: ValidatorId) -> ContractFunction:
""" Confirm change of the address.
:param validator_id: ID of the validator
Expand All @@ -280,7 +310,7 @@ def confirm_new_address(self, validator_id: int) -> TxRes:
return self.contract.functions.confirmNewAddress(validator_id)

@transaction_method
def set_validator_name(self, new_name: str) -> TxRes:
def set_validator_name(self, new_name: str) -> ContractFunction:
""" Allows a validator to change the name.
:param new_name: New validator name
Expand All @@ -291,7 +321,7 @@ def set_validator_name(self, new_name: str) -> TxRes:
return self.contract.functions.setValidatorName(new_name)

@transaction_method
def set_validator_description(self, new_description: str) -> TxRes:
def set_validator_description(self, new_description: str) -> ContractFunction:
""" Allows a validator to change the name.
:param new_description: New validator description
Expand All @@ -302,11 +332,24 @@ def set_validator_description(self, new_description: str) -> TxRes:
return self.contract.functions.setValidatorDescription(new_description)

@transaction_method
def grant_role(self, role: bytes, address: str) -> TxRes:
def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction:
return self.contract.functions.grantRole(role, address)

def validator_manager_role(self) -> bytes:
return self.contract.functions.VALIDATOR_MANAGER_ROLE().call()

def has_role(self, role: bytes, address: str) -> bool:
return self.contract.functions.hasRole(role, address).call()
return bytes(self.contract.functions.VALIDATOR_MANAGER_ROLE().call())

def has_role(self, role: bytes, address: ChecksumAddress) -> bool:
return bool(self.contract.functions.hasRole(role, address).call())

def _to_validator(self, untyped_validator: Dict[str, Any]) -> Validator:
return Validator({
'name': str(untyped_validator['name']),
'validator_address': ChecksumAddress(untyped_validator['validator_address']),
'requested_address': ChecksumAddress(untyped_validator['requested_address']),
'description': str(untyped_validator['description']),
'fee_rate': int(untyped_validator['fee_rate']),
'registration_time': int(untyped_validator['registration_time']),
'minimum_delegation_amount': Wei(untyped_validator['minimum_delegation_amount']),
'accept_new_requests': bool(untyped_validator['accept_new_requests']),
'trusted': bool(untyped_validator['trusted'])
})
4 changes: 2 additions & 2 deletions skale/skale_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from skale_contracts import skale_contracts

from skale.contracts.base_contract import BaseContract
from skale.wallets import BaseWallet
from skale.utils.exceptions import InvalidWalletError, EmptyWalletError
from skale.utils.web3_utils import default_gas_price, init_web3
from skale.wallets import BaseWallet

from skale.contracts.contract_manager import ContractManager

Expand Down Expand Up @@ -75,7 +75,7 @@ def gas_price(self):
return default_gas_price(self.web3)

@property
def wallet(self):
def wallet(self) -> BaseWallet:
if not self._wallet:
raise EmptyWalletError('No wallet provided')
return self._wallet
Expand Down
21 changes: 20 additions & 1 deletion skale/types/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,26 @@
# You should have received a copy of the GNU Affero General Public License
# along with SKALE.py. If not, see <https://www.gnu.org/licenses/>.

from typing import NewType
from typing import NewType, TypedDict

from eth_typing import ChecksumAddress
from web3.types import Wei


ValidatorId = NewType('ValidatorId', int)


class Validator(TypedDict):
name: str
validator_address: ChecksumAddress
requested_address: ChecksumAddress
description: str
fee_rate: int
registration_time: int
minimum_delegation_amount: Wei
accept_new_requests: bool
trusted: bool


class ValidatorWithId(Validator):
id: ValidatorId
Loading

0 comments on commit e296c9c

Please sign in to comment.