Skip to content

Commit

Permalink
Add type hints to nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaStebaev committed Apr 24, 2024
1 parent 84d8059 commit 48eabb1
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 57 deletions.
143 changes: 98 additions & 45 deletions skale/contracts/manager/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

import socket
from enum import IntEnum
from typing import Any, Dict, List, Tuple, TypedDict, cast

from Crypto.Hash import keccak
from eth_typing import BlockNumber, ChecksumAddress
from web3.contract.contract import ContractFunction
from web3.exceptions import BadFunctionCallOutput, ContractLogicError

from skale.contracts.base_contract import BaseContract, transaction_method
from skale.transactions.result import TxRes

from skale.types.node import NodeId, Port
from skale.types.validator import ValidatorId
from skale.utils.exceptions import InvalidNodeIdError
from skale.utils.helper import format_fields

Expand All @@ -43,134 +47,183 @@ class NodeStatus(IntEnum):
IN_MAINTENANCE = 3


class Node(TypedDict):
name: str
ip: bytes
publicIP: bytes
port: Port
start_block: BlockNumber
last_reward_date: int
finish_time: int
status: NodeStatus
validator_id: ValidatorId
publicKey: str
domain_name: str


class Nodes(BaseContract):
def __get_raw(self, node_id):
def __get_raw(self, node_id: NodeId) -> List[Any]:
try:
return self.contract.functions.nodes(node_id).call()
return list(self.contract.functions.nodes(node_id).call())
except (ContractLogicError, ValueError, BadFunctionCallOutput):
raise InvalidNodeIdError(node_id)

def __get_raw_w_pk(self, node_id):
def __get_raw_w_pk(self, node_id: NodeId) -> List[Any]:
raw_node_struct = self.__get_raw(node_id)
raw_node_struct.append(self.get_node_public_key(node_id))
return raw_node_struct

def __get_raw_w_pk_w_domain(self, node_id):
def __get_raw_w_pk_w_domain(self, node_id: NodeId) -> List[Any]:
raw_node_struct_w_pk = self.__get_raw_w_pk(node_id)
raw_node_struct_w_pk.append(self.get_domain_name(node_id))
return raw_node_struct_w_pk

@format_fields(FIELDS)
def get(self, node_id):
def untyped_get(self, node_id: NodeId) -> List[Any]:
return self.__get_raw_w_pk_w_domain(node_id)

def get(self, node_id: NodeId) -> Node:
node = self.untyped_get(node_id)
if node is None:
raise ValueError('Node with id ', node_id, ' is not found')
if isinstance(node, dict):
return self._to_node(node)
if isinstance(node, list):
return self._to_node(node[0])
raise ValueError("Can't process returned node type")

@format_fields(FIELDS)
def get_by_name(self, name):
def get_by_name(self, name: str) -> List[Any]:
name_hash = self.name_to_id(name)
_id = self.contract.functions.nodesNameToIndex(name_hash).call()
return self.__get_raw_w_pk_w_domain(_id)

def get_nodes_number(self):
return self.contract.functions.getNumberOfNodes().call()
def get_nodes_number(self) -> int:
return int(self.contract.functions.getNumberOfNodes().call())

def get_active_node_ids(self):
def get_active_node_ids(self) -> List[NodeId]:
nodes_number = self.get_nodes_number()
return [
node_id
NodeId(node_id)
for node_id in range(0, nodes_number)
if self.get_node_status(node_id) == NodeStatus.ACTIVE
if self.get_node_status(NodeId(node_id)) == NodeStatus.ACTIVE
]

def get_active_node_ips(self):
def get_active_node_ips(self) -> List[str]:
nodes_number = self.get_nodes_number()
return [
self.get(node_id)['ip']
self.get(NodeId(node_id))['ip']
for node_id in range(0, nodes_number)
if self.get_node_status(node_id) == NodeStatus.ACTIVE
if self.get_node_status(NodeId(node_id)) == NodeStatus.ACTIVE
]

def name_to_id(self, name):
def name_to_id(self, name: str) -> bytes:
keccak_hash = keccak.new(data=name.encode("utf8"), digest_bits=256)
return keccak_hash.digest()

def is_node_name_available(self, name):
def is_node_name_available(self, name: str) -> bool:
node_id = self.name_to_id(name)
return not self.contract.functions.nodesNameCheck(node_id).call()

def is_node_ip_available(self, ip):
def is_node_ip_available(self, ip: str) -> bool:
ip_bytes = socket.inet_aton(ip)
return not self.contract.functions.nodesIPCheck(ip_bytes).call()

def node_name_to_index(self, name):
def node_name_to_index(self, name: str) -> int:
name_hash = self.name_to_id(name)
return self.contract.functions.nodesNameToIndex(name_hash).call()
return int(self.contract.functions.nodesNameToIndex(name_hash).call())

def get_node_status(self, node_id):
def get_node_status(self, node_id: NodeId) -> NodeStatus:
try:
return self.contract.functions.getNodeStatus(node_id).call()
return NodeStatus(self.contract.functions.getNodeStatus(node_id).call())
except (ContractLogicError, ValueError, BadFunctionCallOutput):
raise InvalidNodeIdError(node_id)

def get_node_finish_time(self, node_id):
def get_node_finish_time(self, node_id: NodeId) -> int:
try:
return self.contract.functions.getNodeFinishTime(node_id).call()
return int(self.contract.functions.getNodeFinishTime(node_id).call())
except (ContractLogicError, ValueError, BadFunctionCallOutput):
raise InvalidNodeIdError(node_id)

def __get_node_public_key_raw(self, node_id):
def __get_node_public_key_raw(self, node_id: NodeId) -> Tuple[bytes, bytes]:
try:
return self.contract.functions.getNodePublicKey(node_id).call()
return cast(
Tuple[bytes, bytes],
self.contract.functions.getNodePublicKey(node_id).call()
)
except (ContractLogicError, ValueError, BadFunctionCallOutput):
raise InvalidNodeIdError(node_id)

def get_node_public_key(self, node_id):
def get_node_public_key(self, node_id: NodeId) -> str:
raw_key = self.__get_node_public_key_raw(node_id)
key_bytes = raw_key[0] + raw_key[1]
return self.skale.web3.to_hex(key_bytes)

def get_validator_node_indices(self, validator_id: int) -> list:
def get_validator_node_indices(self, validator_id: int) -> list[NodeId]:
"""Returns list of node indices to the validator
:returns: List of trusted node indices
:rtype: list
"""
return self.contract.functions.getValidatorNodeIndexes(validator_id).call()
return [
NodeId(id)
for id
in self.contract.functions.getValidatorNodeIndexes(validator_id).call()
]

def get_last_change_ip_time(self, node_id: int) -> list:
return self.contract.functions.getLastChangeIpTime(node_id).call()
def get_last_change_ip_time(self, node_id: NodeId) -> int:
return int(self.contract.functions.getLastChangeIpTime(node_id).call())

@transaction_method
def set_node_in_maintenance(self, node_id):
def set_node_in_maintenance(self, node_id: NodeId) -> ContractFunction:
return self.contract.functions.setNodeInMaintenance(node_id)

@transaction_method
def remove_node_from_in_maintenance(self, node_id):
def remove_node_from_in_maintenance(self, node_id: NodeId) -> ContractFunction:
return self.contract.functions.removeNodeFromInMaintenance(node_id)

@transaction_method
def set_domain_name(self, node_id: int, domain_name: str):
def set_domain_name(self, node_id: NodeId, domain_name: str) -> ContractFunction:
return self.contract.functions.setDomainName(node_id, domain_name)

def get_domain_name(self, node_id: int):
return self.contract.functions.getNodeDomainName(node_id).call()
def get_domain_name(self, node_id: NodeId) -> str:
return str(self.contract.functions.getNodeDomainName(node_id).call())

@transaction_method
def grant_role(self, role: bytes, owner: str) -> TxRes:
def grant_role(self, role: bytes, owner: ChecksumAddress) -> ContractFunction:
return self.contract.functions.grantRole(role, owner)

def has_role(self, role: bytes, address: str) -> bool:
return self.contract.functions.hasRole(role, address).call()
def has_role(self, role: bytes, address: ChecksumAddress) -> bool:
return bool(self.contract.functions.hasRole(role, address).call())

def node_manager_role(self):
return self.contract.functions.NODE_MANAGER_ROLE().call()
def node_manager_role(self) -> bytes:
return bytes(self.contract.functions.NODE_MANAGER_ROLE().call())

def compliance_role(self):
return self.contract.functions.COMPLIANCE_ROLE().call()
def compliance_role(self) -> bytes:
return bytes(self.contract.functions.COMPLIANCE_ROLE().call())

@transaction_method
def init_exit(self, node_id: int) -> TxRes:
def init_exit(self, node_id: NodeId) -> ContractFunction:
return self.contract.functions.initExit(node_id)

@transaction_method
def change_ip(self, node_id: int, ip: bytes, public_ip: bytes) -> TxRes:
def change_ip(self, node_id: NodeId, ip: bytes, public_ip: bytes) -> ContractFunction:
return self.contract.functions.changeIP(node_id, ip, public_ip)

def _to_node(self, untyped_node: Dict[str, Any]) -> Node:
for key in Node.__annotations__:
if key not in untyped_node:
raise ValueError(f"Key: {key} is not available in node.")
return Node({
'name': str(untyped_node['name']),
'ip': bytes(untyped_node['ip']),
'publicIP': bytes(untyped_node['publicIP']),
'port': Port(untyped_node['port']),
'start_block': BlockNumber(untyped_node['start_block']),
'last_reward_date': int(untyped_node['last_reward_date']),
'finish_time': int(untyped_node['finish_time']),
'status': NodeStatus(untyped_node['status']),
'validator_id': ValidatorId(untyped_node['validator_id']),
'publicKey': str(untyped_node['publicKey']),
'domain_name': str(untyped_node['domain_name']),
})
23 changes: 11 additions & 12 deletions skale/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import sys
from logging import Formatter, StreamHandler
from random import randint
from typing import TYPE_CHECKING, Any, Callable, Generator, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, cast

from skale.config import ENV
from skale.types.node import Port
Expand All @@ -45,18 +45,20 @@ def decapitalize(s: str) -> str:
return s[:1].lower() + s[1:] if s else ''


WrapperReturnType = Dict[str, Any] | List[Dict[str, Any]] | None


def format_fields(
fields: list[str],
flist: bool = False
) -> Callable[
[Callable[
...,
[
Callable[
...,
dict[str, Any] | list[dict[str, Any]] | Any | None
List[Any]
]
]],
Callable[..., dict[str, Any] | list[dict[str, Any]] | Any | None]
],
Callable[..., WrapperReturnType]
]:
"""
Transform array to object with passed fields
Expand All @@ -71,16 +73,13 @@ def my_method()
def real_decorator(
function: Callable[
...,
Callable[
...,
dict[str, Any] | list[dict[str, Any]] | Any | None
]
List[Any]
]
) -> Callable[..., dict[str, Any] | list[dict[str, Any]] | Any | None]:
) -> Callable[..., WrapperReturnType]:
def wrapper(
*args: Any,
**kwargs: Any
) -> dict[str, Any] | list[dict[str, Any]] | Any | None:
) -> WrapperReturnType:
result = function(*args, **kwargs)

if result is None:
Expand Down

0 comments on commit 48eabb1

Please sign in to comment.