Skip to content

Commit

Permalink
Add type hints to node_rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaStebaev committed Apr 24, 2024
1 parent 48eabb1 commit 15a528c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 49 deletions.
80 changes: 42 additions & 38 deletions skale/contracts/manager/node_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,24 @@
# along with SKALE.py. If not, see <https://www.gnu.org/licenses/>.
""" NodeRotation.sol functions """

from __future__ import annotations
import logging
import functools
import warnings
from typing import TYPE_CHECKING, List, TypedDict, cast
from dataclasses import dataclass

from eth_typing import ChecksumAddress

from skale.contracts.base_contract import BaseContract, transaction_method
from skale.transactions.result import TxRes
from web3.contract.contract import ContractFunction
from web3.exceptions import ContractLogicError

from skale.types.node import NodeId
from skale.types.schain import SchainHash, SchainName

if TYPE_CHECKING:
from skale.contracts.manager.schains import SChains


logger = logging.getLogger(__name__)

Expand All @@ -42,67 +51,62 @@ class Rotation:
rotation_counter: int


class RotationSwap(TypedDict):
schain_id: SchainHash
finished_rotation: int


class NodeRotation(BaseContract):
"""Wrapper for NodeRotation.sol functions"""

@property
@functools.lru_cache()
def schains(self):
return self.skale.schains
def schains(self) -> SChains:
from skale.contracts.manager.schains import SChains
return cast(SChains, self.skale.schains)

def get_rotation_obj(self, schain_name) -> Rotation:
def get_rotation(self, schain_name: SchainName) -> Rotation:
schain_id = self.schains.name_to_id(schain_name)
rotation_data = self.contract.functions.getRotation(schain_id).call()
return Rotation(*rotation_data)

def get_rotation(self, schain_name):
warnings.warn('Deprecated, will be removed in v6', DeprecationWarning)
schain_id = self.schains.name_to_id(schain_name)
rotation_data = self.contract.functions.getRotation(schain_id).call()
return {
'leaving_node': rotation_data[0],
'new_node': rotation_data[1],
'freeze_until': rotation_data[2],
'rotation_id': rotation_data[3]
}

def get_leaving_history(self, node_id):
def get_leaving_history(self, node_id: NodeId) -> List[RotationSwap]:
raw_history = self.contract.functions.getLeavingHistory(node_id).call()
history = [
{
'schain_id': schain[0],
'finished_rotation': schain[1]
}
RotationSwap({
'schain_id': SchainHash(schain[0]),
'finished_rotation': int(schain[1])
})
for schain in raw_history
]
return history

def get_schain_finish_ts(self, node_id: int, schain_name: str) -> int | None:
def get_schain_finish_ts(self, node_id: NodeId, schain_name: SchainName) -> int | None:
raw_history = self.contract.functions.getLeavingHistory(node_id).call()
schain_id = self.skale.schains.name_to_id(schain_name)
finish_ts = next(
(schain[1] for schain in raw_history if '0x' + schain[0].hex() == schain_id), None)
if not finish_ts:
return None
return finish_ts
return int(finish_ts)

def is_rotation_in_progress(self, schain_name) -> bool:
def is_rotation_in_progress(self, schain_name: SchainName) -> bool:
schain_id = self.schains.name_to_id(schain_name)
return self.contract.functions.isRotationInProgress(schain_id).call()
return bool(self.contract.functions.isRotationInProgress(schain_id).call())

def is_new_node_found(self, schain_name) -> bool:
def is_new_node_found(self, schain_name: SchainName) -> bool:
schain_id = self.schains.name_to_id(schain_name)
return self.contract.functions.isNewNodeFound(schain_id).call()
return bool(self.contract.functions.isNewNodeFound(schain_id).call())

def is_rotation_active(self, schain_name) -> bool:
def is_rotation_active(self, schain_name: SchainName) -> bool:
"""
The public function that tells whether rotation is in the active phase - the new group is
already generated
"""
finish_ts_reached = self.is_finish_ts_reached(schain_name)
return self.is_rotation_in_progress(schain_name) and not finish_ts_reached

def is_finish_ts_reached(self, schain_name) -> bool:
def is_finish_ts_reached(self, schain_name: SchainName) -> bool:
rotation = self.skale.node_rotation.get_rotation_obj(schain_name)
schain_finish_ts = self.get_schain_finish_ts(rotation.leaving_node_id, schain_name)

Expand All @@ -115,24 +119,24 @@ def is_finish_ts_reached(self, schain_name) -> bool:
logger.info(f'current_ts: {current_ts}, schain_finish_ts: {schain_finish_ts}')
return current_ts > schain_finish_ts

def wait_for_new_node(self, schain_name):
def wait_for_new_node(self, schain_name: SchainName) -> bool:
schain_id = self.schains.name_to_id(schain_name)
return self.contract.functions.waitForNewNode(schain_id).call()
return bool(self.contract.functions.waitForNewNode(schain_id).call())

@transaction_method
def grant_role(self, role: bytes, owner: str) -> TxRes:
def grant_role(self, role: bytes, owner: str) -> ContractFunction:
return self.contract.functions.grantRole(role, owner)

def has_role(self, role: bytes, address: str) -> bool:
return self.contract.functions.hasRole(role, address).call()
def has_role(self, role: bytes, address: ChecksumAddress) -> bool:
return bool(self.contract.functions.hasRole(role, address).call())

def debugger_role(self):
return self.contract.functions.DEBUGGER_ROLE().call()
def debugger_role(self) -> bytes:
return bytes(self.contract.functions.DEBUGGER_ROLE().call())

def get_previous_node(self, schain_name: str, node_id: int) -> int | None:
def get_previous_node(self, schain_name: SchainName, node_id: NodeId) -> NodeId | None:
schain_id = self.schains.name_to_id(schain_name)
try:
return self.contract.functions.getPreviousNode(schain_id, node_id).call()
return NodeId(self.contract.functions.getPreviousNode(schain_id, node_id).call())
except (ContractLogicError, ValueError) as e:
if NO_PREVIOUS_NODE_EXCEPTION_TEXT in str(e):
return None
Expand Down
2 changes: 1 addition & 1 deletion skale/contracts/manager/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_active_node_ids(self) -> List[NodeId]:
if self.get_node_status(NodeId(node_id)) == NodeStatus.ACTIVE
]

def get_active_node_ips(self) -> List[str]:
def get_active_node_ips(self) -> List[bytes]:
nodes_number = self.get_nodes_number()
return [
self.get(NodeId(node_id))['ip']
Expand Down
11 changes: 1 addition & 10 deletions tests/manager/node_rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,7 @@


def test_get_rotation(skale):
assert skale.node_rotation.get_rotation(DEFAULT_SCHAIN_NAME) == {
'leaving_node': 0,
'new_node': 0,
'freeze_until': 0,
'rotation_id': 0
}


def test_get_rotation_obj(skale):
assert skale.node_rotation.get_rotation_obj(DEFAULT_SCHAIN_NAME) == Rotation(
assert skale.node_rotation.get_rotation(DEFAULT_SCHAIN_NAME) == Rotation(
leaving_node_id=0,
new_node_id=0,
freeze_until=0,
Expand Down

0 comments on commit 15a528c

Please sign in to comment.