Skip to content

Commit

Permalink
Merge pull request #358 from lidofinance/feat/health-checker
Browse files Browse the repository at this point in the history
feat: add endpoints health checker
  • Loading branch information
F4ever authored Apr 13, 2023
2 parents 2cebd83 + d162e95 commit 8842852
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 18 deletions.
32 changes: 16 additions & 16 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import cast

from prometheus_client import start_http_server
from web3_multi_provider import FallbackProvider
from web3.middleware import simple_cache_middleware

from src import variables
Expand All @@ -20,6 +19,7 @@
ConsensusClientModule,
KeysAPIClientModule,
LidoValidatorsProvider,
FallbackProviderModule
)
from src.web3py.middleware import metrics_collector
from src.web3py.typings import Web3
Expand Down Expand Up @@ -57,7 +57,7 @@ def main(module: OracleModule):
start_http_server(variables.PROMETHEUS_PORT)

logger.info({'msg': 'Initialize multi web3 provider.'})
web3 = Web3(FallbackProvider(variables.EXECUTION_CLIENT_URI))
web3 = Web3(FallbackProviderModule(variables.EXECUTION_CLIENT_URI))

logger.info({'msg': 'Modify web3 with custom contract function call.'})
tweak_w3_contracts(web3)
Expand All @@ -68,6 +68,8 @@ def main(module: OracleModule):
logger.info({'msg': 'Initialize keys api client.'})
kac = KeysAPIClientModule(variables.KEYS_API_URI, web3)

check_providers_chain_ids(web3, cc, kac)

web3.attach_modules({
'lido_contracts': LidoContracts,
'lido_validators': LidoValidatorsProvider,
Expand All @@ -81,7 +83,6 @@ def main(module: OracleModule):
web3.middleware_onion.add(simple_cache_middleware)

logger.info({'msg': 'Sanity checks.'})
check_providers_chain_ids(web3)

if module == OracleModule.ACCOUNTING:
logger.info({'msg': 'Initialize Accounting module.'})
Expand All @@ -101,19 +102,18 @@ def check():
return ChecksModule().execute_module()


def check_providers_chain_ids(web3: Web3):
execution_chain_id = web3.eth.chain_id
consensus_chain_id = int(web3.cc.get_config_spec().DEPOSIT_CHAIN_ID)
chain_ids = [
Web3.to_int(hexstr=provider.make_request("eth_chainId", []).get('result'))
for provider in cast(FallbackProvider, web3.provider)._providers # type: ignore[attr-defined] # pylint: disable=protected-access
]
keys_api_chain_id = web3.kac.get_status().chainId
if any(execution_chain_id != chain_id for chain_id in [*chain_ids, consensus_chain_id, keys_api_chain_id]):
raise ValueError('Different chain ids detected:\n'
f'Execution chain ids: {", ".join(map(str, chain_ids))}\n'
f'Consensus chain id: {consensus_chain_id}\n'
f'Keys API chain id: {keys_api_chain_id}\n')
def check_providers_chain_ids(web3: Web3, cc: ConsensusClientModule, kac: KeysAPIClientModule):
keys_api_chain_id = kac.check_providers_consistency()
consensus_chain_id = cc.check_providers_consistency()
execution_chain_id = cast(FallbackProviderModule, web3.provider).check_providers_consistency()

if execution_chain_id == consensus_chain_id == keys_api_chain_id:
return

raise ValueError('Different chain ids detected:\n'
f'Execution chain id: {execution_chain_id}\n'
f'Consensus chain id: {consensus_chain_id}\n'
f'Keys API chain id: {keys_api_chain_id}\n')


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion src/modules/checks/suites/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ejector(web3, skip_locator):

def check_providers_chain_ids(web3):
"""Make sure all providers are on the same chain"""
chain_ids_check(web3)
chain_ids_check(web3, web3.cc, web3.kac)


def check_accounting_contract_configs(accounting):
Expand Down
6 changes: 6 additions & 0 deletions src/providers/consensus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,9 @@ def __raise_last_missed_slot_error(self, errors: list[Exception]) -> Exception |
return error

return None

def _get_chain_id_with_provider(self, provider_index: int) -> int:
data, _ = self._get_without_fallbacks(self.hosts[provider_index], self.API_GET_SPEC)
if not isinstance(data, dict):
raise ValueError("Expected mapping response from getSpec")
return int(BeaconSpecResponse.from_response(**data).DEPOSIT_CHAIN_ID)
11 changes: 10 additions & 1 deletion src/providers/http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from requests.exceptions import ConnectionError as RequestsConnectionError
from urllib3 import Retry

from src.web3py.extensions.consistency import ProviderConsistencyModule


logger = logging.getLogger(__name__)


Expand All @@ -27,7 +30,7 @@ def __init__(self, *args, status: int, text: str):
super().__init__(*args)


class HTTPProvider(ABC):
class HTTPProvider(ProviderConsistencyModule, ABC):
"""
Base HTTP Provider with metrics and retry strategy integrated inside.
"""
Expand Down Expand Up @@ -159,3 +162,9 @@ def _get_without_fallbacks(
meta = {}

return data, meta

def get_all_providers(self) -> list[str]:
return self.hosts

def _get_chain_id_with_provider(self, provider_index: int) -> int:
raise NotImplementedError("_chain_id should be implemented")
4 changes: 4 additions & 0 deletions src/providers/keys/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ def get_status(self) -> KeysApiStatus:
"""Docs: https://keys-api.lido.fi/api/static/index.html#/status/StatusController_get"""
data, _ = self._get(self.STATUS)
return KeysApiStatus.from_response(**cast(dict, data))

def _get_chain_id_with_provider(self, provider_index: int) -> int:
data, _ = self._get_without_fallbacks(self.hosts[provider_index], self.STATUS)
return KeysApiStatus.from_response(**cast(dict, data)).chainId
2 changes: 2 additions & 0 deletions src/web3py/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from src.web3py.extensions.consensus import ConsensusClientModule
from src.web3py.extensions.contracts import LidoContracts
from src.web3py.extensions.lido_validators import LidoValidatorsProvider
from src.web3py.extensions.fallback import FallbackProviderModule
from src.web3py.extensions.consistency import ProviderConsistencyModule
47 changes: 47 additions & 0 deletions src/web3py/extensions/consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any, Optional
from abc import abstractmethod, ABC


class InconsistentProviders(Exception):
pass


class NotHealthyProvider(Exception):
pass


class ProviderConsistencyModule(ABC):
"""
A class that provides HTTP provider ability to check that
provided hosts are alive and chain ids are same.
Methods must be implemented:
def get_all_providers(self) -> [any]:
def _get_chain_id_with_provider(self, int) -> int:
"""
def check_providers_consistency(self) -> Optional[int]:
chain_id = None

for provider_index in range(len(self.get_all_providers())):
try:
curr_chain_id = self._get_chain_id_with_provider(provider_index)
except Exception as error:
raise NotHealthyProvider(f'Provider [{provider_index}] does not responding.') from error

if chain_id is None:
chain_id = curr_chain_id
elif chain_id != curr_chain_id:
raise InconsistentProviders(f'Different chain ids detected for {provider_index=}. '
f'Expected {curr_chain_id=}, got {chain_id=}.')

return chain_id

@abstractmethod
def get_all_providers(self) -> list[Any]:
"""Returns list of hosts or providers."""
raise NotImplementedError("get_all_providers should be implemented")

@abstractmethod
def _get_chain_id_with_provider(self, provider_index: int) -> int:
"""Does a health check call and returns chain_id for current host"""
raise NotImplementedError("get_chain_id should be implemented")
13 changes: 13 additions & 0 deletions src/web3py/extensions/fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any

from web3_multi_provider import FallbackProvider
from src.web3py.extensions.consistency import ProviderConsistencyModule
from web3 import Web3


class FallbackProviderModule(ProviderConsistencyModule, FallbackProvider):
def get_all_providers(self) -> list[Any]:
return self._providers # type: ignore[attr-defined]

def _get_chain_id_with_provider(self, provider_index: int) -> int:
return Web3.to_int(hexstr=self._providers[provider_index].make_request("eth_chainId", []).get('result')) # type: ignore[attr-defined]
6 changes: 6 additions & 0 deletions tests/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def _get(
return response["response"]
raise NoMockException('There is no mock for response')

def get_all_hosts(self) -> list:
return []

def get_chain_id(self, host) -> int:
return 0


class UpdateResponsesHTTPProvider(HTTPProvider, Module, UpdateResponses):
def __init__(self, mock_path: Path, host: str, w3: Web3):
Expand Down

0 comments on commit 8842852

Please sign in to comment.