From 7f65d1df9fbc680b3968e24debdbf1efc52389ea Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Sat, 9 Nov 2024 02:59:52 +0000 Subject: [PATCH] feat: implement evmspec --- dank_mids/_method.py | 133 ++++++++++++++ dank_mids/_requests.py | 90 ++++++---- dank_mids/controller.py | 3 +- dank_mids/eth.py | 239 +++++++++++++++++++++++++ dank_mids/types.py | 377 ++++++++++++++++------------------------ docs/conf.py | 122 ++++++++----- poetry.lock | 38 +++- pyproject.toml | 4 +- 8 files changed, 686 insertions(+), 320 deletions(-) create mode 100644 dank_mids/_method.py create mode 100644 dank_mids/eth.py diff --git a/dank_mids/_method.py b/dank_mids/_method.py new file mode 100644 index 00000000..c9c9db9d --- /dev/null +++ b/dank_mids/_method.py @@ -0,0 +1,133 @@ +from importlib.metadata import version +from typing import Tuple, Type + +from typing_extensions import Self +from web3.eth import BaseEth +from web3._utils.method_formatters import ERROR_FORMATTERS, NULL_RESULT_FORMATTERS +from web3._utils.blocks import select_method_for_block_identifier +from web3._utils.rpc_abi import RPC +from web3.method import Method, TFunc, _apply_request_formatters, default_root_munger +from web3.types import BlockIdentifier + + +WEB3_MAJOR_VERSION = int(version("web3").split(".")[0]) + +return_as_is = lambda x: x + + +class MethodNoFormat(Method[TFunc]): + """Bypasses web3py default result formatters.""" + + def process_params(self, module, *args, **kwargs): + params = self.input_munger(module, args, kwargs) + + if self.method_choice_depends_on_args: + # If the method choice depends on the args that get passed in, + # the first parameter determines which method needs to be called + self.json_rpc_method = self.method_choice_depends_on_args(value=params[0]) + + pending_or_latest_filter_methods = [ + RPC.eth_newPendingTransactionFilter, + RPC.eth_newBlockFilter, + ] + if self.json_rpc_method in pending_or_latest_filter_methods: + # For pending or latest filter methods, use params to determine + # which method to call, but don't pass them through with the request + params = [] + + method = self.method_selector_fn() + request = (method, _apply_request_formatters(params, self.request_formatters(method))) + response_formatters = ( + return_as_is, + ERROR_FORMATTERS.get(self.json_rpc_method, return_as_is), # type: ignore [arg-type] + NULL_RESULT_FORMATTERS.get(self.json_rpc_method, return_as_is), # type: ignore [arg-type] + ) + return request, response_formatters + + @classmethod + def default(cls, method: RPC) -> Self: + return cls(method, [default_root_munger]) + + +def bypass_chainid_formatter(eth: Type[BaseEth]) -> None: + eth._chain_id = MethodNoFormat(RPC.eth_chainId) + + +def bypass_getbalance_formatter(eth: Type[BaseEth]) -> None: + eth._get_balance = MethodNoFormat(RPC.eth_getBalance, mungers=[eth.block_id_munger]) + + +def bypass_blocknumber_formatter(eth: Type[BaseEth]) -> None: + eth.get_block_number = MethodNoFormat(RPC.eth_blockNumber) + + +def bypass_transaction_count_formatter(eth: Type[BaseEth]) -> None: + eth._get_transaction_count = MethodNoFormat( + RPC.eth_getTransactionCount, mungers=[eth.block_id_munger] + ) + + +def bypass_log_formatter(eth: Type[BaseEth]) -> None: + eth._get_logs = MethodNoFormat.default(RPC.eth_getLogs) + eth._get_logs_raw = MethodNoFormat.default(f"{RPC.eth_getLogs}_raw") + eth.get_filter_logs = MethodNoFormat.default(RPC.eth_getFilterLogs) + eth.get_filter_changes = MethodNoFormat.default(RPC.eth_getFilterChanges) + + +def bypass_transaction_receipt_formatter(eth: Type[BaseEth]) -> None: + attr_name = "_transaction_receipt" if WEB3_MAJOR_VERSION >= 6 else "_get_transaction_receipt" + setattr(eth, attr_name, MethodNoFormat.default(RPC.eth_getTransactionReceipt)) + + +def bypass_transaction_formatter(eth: Type[BaseEth]) -> None: + eth._get_transaction = MethodNoFormat.default(RPC.eth_getTransactionByHash) + + +_block_selectors = dict( + if_predefined=RPC.eth_getBlockByNumber, + if_hash=RPC.eth_getBlockByHash, + if_number=RPC.eth_getBlockByNumber, +) + + +def bypass_block_formatters(eth: Type[BaseEth]) -> None: + if WEB3_MAJOR_VERSION >= 6: + get_block_munger = eth.get_block_munger + else: + + def get_block_munger( + self, block_identifier: BlockIdentifier, full_transactions: bool = False + ) -> Tuple[BlockIdentifier, bool]: + return (block_identifier, full_transactions) + + eth._get_block = MethodNoFormat( + method_choice_depends_on_args=select_method_for_block_identifier(**_block_selectors), + mungers=[get_block_munger], + ) + + +def bypass_eth_call_formatter(eth: Type[BaseEth]) -> None: + eth._call = MethodNoFormat(RPC.eth_call, mungers=[eth.call_munger]) + + +def bypass_get_code_formatter(eth: Type[BaseEth]) -> None: + eth._get_code = MethodNoFormat(RPC.eth_getCode, mungers=[eth.block_id_munger]) + + +skip_formatters = ( + bypass_chainid_formatter, + bypass_getbalance_formatter, + bypass_blocknumber_formatter, + bypass_transaction_count_formatter, + bypass_eth_call_formatter, + bypass_get_code_formatter, + bypass_log_formatter, + bypass_transaction_receipt_formatter, + bypass_transaction_formatter, + bypass_block_formatters, +) + + +def bypass_formatters(eth): + for bypass in skip_formatters: + bypass(eth) diff --git a/dank_mids/_requests.py b/dank_mids/_requests.py index cfbaba76..1922077e 100644 --- a/dank_mids/_requests.py +++ b/dank_mids/_requests.py @@ -154,7 +154,7 @@ def _should_batch_method(method: str) -> bool: class RPCRequest(_RequestMeta[RawResponse]): - __slots__ = "method", "params", "should_batch", "_started", "_retry", "_daemon" + __slots__ = "method", "params", "should_batch", "raw", "_started", "_retry", "_daemon" def __init__( self, @@ -165,7 +165,12 @@ def __init__( ): self.controller = controller """The DankMiddlewareController that created this request.""" - self.method = method + if method[-4:] == "_raw": + self.method = method[:-4] + self.raw = True + else: + self.method = method + self.raw = False """The rpc method for this request.""" self.params = params """The parameters to send with this request, if any.""" @@ -245,33 +250,48 @@ async def get_response(self) -> RPCResponse: # type: ignore [override] # JIT json decoding if isinstance(self.response, RawResponse): - response = self.response.decode(partial=True).to_dict(self.method) - error: Optional[RPCError] - if error := response.get("error"): # type: ignore [assignment] - if error["message"].lower() in ["invalid request", "parse error"]: - if self.controller._time_of_request_type_change == 0: - self.controller.request_type = Request - self.controller._time_of_request_type_change = time.time() - if time.time() - self.controller._time_of_request_type_change <= 600: - logger.debug( - "your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead" - ) - return await self.controller(self.method, self.params) - response["error"] = dict(error) - response["error"]["dankmids_added_context"] = self.request - # I'm 99.99999% sure that any errd call has no result and we only get this field from mscspec object defs - # But I'll check it anyway to be safe - if result := response.pop("result", None): - response["result"] = result - logger.debug("error response for %s: %s", self, response) + response = self.response.decode(partial=True) + if response.error is None: + if self.raw: + return {"result": response.result} + response_dict = response.to_dict(self.method) + assert "result" in response_dict or "error" in response_dict, ( + response_dict, + type(response_dict), + ) + return response_dict + + if response.error.message.lower() in ["invalid request", "parse error"]: + if self.controller._time_of_request_type_change == 0: + self.controller.request_type = Request + self.controller._time_of_request_type_change = time.time() + if time.time() - self.controller._time_of_request_type_change <= 600: + logger.debug( + "your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead" + ) + method = self.method + if self.raw: + method += "_raw" + return await self.controller(method, self.params) + + error = dict(response.error.items()) + error["dankmids_added_context"] = self.request + + response = response.to_dict(self.method) + response["error"] = error + logger.debug("error response for %s: %s", self, response) return response # If we have an Exception here it came from the goofy sync_call thing I need to get rid of. # We raise it here so it traces back up to the caller if isinstance(self.response, Exception): - __raise_more_detailed_exc(self.request, self.response) + _raise_more_detailed_exc(self.request, self.response) # Less optimal decoding # TODO: refactor this out + assert "result" in self.response or "error" in self.response, ( + self.response, + type(self.response), + ) return self.response @set_done @@ -289,7 +309,10 @@ async def get_response_unbatched(self) -> RPCResponse: # type: ignore [override t.cancel() for task in done: return await task - return self.response.decode(partial=True).to_dict(self.method) + response = self.response.decode(partial=True) + retval = {"result": response.result} if self.raw else response.to_dict(self.method) + assert "result" in retval or "error" in retval, (retval, type(retval)) + return retval @set_done async def spoof_response(self, data: Union[RawResponse, bytes, Exception]) -> None: @@ -352,8 +375,12 @@ async def create_duplicate( # Creating the task before awaiting the new call ensures the new call will grab the semaphore immediately # and then the task will try to acquire at the very next event loop _run_once cycle logger.warning("%s got stuck, we're creating a new one", self) - retval = await self.controller(self.method, self.params) + method = self.method + if self.raw: + method += "_raw" + retval = await self.controller(method, self.params) await self.semaphore.acquire() + assert "result" in retval or "error" in retval, (retval, type(retval)) return retval @@ -581,18 +608,11 @@ def mcall_encode(data: List[Tuple[bool, bytes]]) -> bytes: def mcall_decode(data: PartialResponse) -> Union[List[Tuple[bool, bytes]], Exception]: try: - decoded = data.decode_result("eth_call")[2:] # type: ignore [arg-type] - decoded = bytes.fromhex(decoded) - return mcall_decoder(decoding.ContextFramesBytesIO(decoded))[2] + return mcall_decoder(decoding.ContextFramesBytesIO(data.decode_result("eth_call")))[2] except Exception as e: # NOTE: We need to safely bring any Exceptions back out of the ProcessPool - try: - # We do this goofy thing since we can't `return Exc() from e` - raise e.__class__( - *e.args, data.decode_result() if isinstance(data, PartialResponse) else data - ) from e - except Exception as new_e: - return new_e + e.args = (*e.args, data.decode_result() if isinstance(data, PartialResponse) else data) + return e class Multicall(_Batch[RPCResponse, eth_call]): @@ -1208,7 +1228,7 @@ def __format_error(request: PartialRequest, response: PartialResponse) -> Attrib return AttributeDict.recursive(error) -def __raise_more_detailed_exc(request: PartialRequest, exc: Exception) -> NoReturn: +def _raise_more_detailed_exc(request: PartialRequest, exc: Exception) -> NoReturn: if isinstance(exc, ClientResponseError): raise DankMidsClientResponseError(exc, request) from exc try: diff --git a/dank_mids/controller.py b/dank_mids/controller.py index 4cb81359..da183ea2 100644 --- a/dank_mids/controller.py +++ b/dank_mids/controller.py @@ -6,6 +6,7 @@ import eth_retry from eth_typing import BlockNumber, ChecksumAddress from eth_utils import to_checksum_address +from evmspec._ids import ChainId from msgspec import Struct from multicall.constants import MULTICALL2_ADDRESSES, MULTICALL_ADDRESSES from multicall.multicall import NotSoBrightBatcher @@ -21,7 +22,7 @@ from dank_mids._uid import UIDGenerator, _AlertingRLock from dank_mids.helpers import _codec, _helpers, _session from dank_mids.semaphores import _MethodQueues, _MethodSemaphores, BlockSemaphore -from dank_mids.types import BlockId, ChainId, PartialRequest, RawResponse, Request +from dank_mids.types import BlockId, PartialRequest, RawResponse, Request try: from multicall.constants import MULTICALL3_ADDRESSES diff --git a/dank_mids/eth.py b/dank_mids/eth.py new file mode 100644 index 00000000..fbbcc6de --- /dev/null +++ b/dank_mids/eth.py @@ -0,0 +1,239 @@ +from typing import ( + Awaitable, + Callable, + List, + Literal, + Sequence, + Tuple, + Type, + TypedDict, + Union, + overload, +) + +from async_lru import alru_cache +from async_property import async_cached_property +from eth_typing import BlockNumber +from evmspec import AnyTransaction, FilterTrace, Transaction, TransactionRLP, TransactionReceipt +from evmspec.block import TinyBlock +from evmspec.data import TransactionHash, UnixTimestamp, _DecodeHook, _decode_hook +from evmspec.log import Log +from evmspec.receipt import Status +from msgspec import Raw, Struct, ValidationError, json +from web3._utils.blocks import select_method_for_block_identifier +from web3._utils.rpc_abi import RPC +from web3.eth import AsyncEth +from web3.method import default_root_munger +from web3.types import Address, BlockIdentifier, ChecksumAddress, ENS, HexStr + +from dank_mids._method import ( + WEB3_MAJOR_VERSION, + MethodNoFormat, + bypass_formatters, + _block_selectors, +) +from dank_mids.types import T + + +class TraceFilterParams(TypedDict, total=False): # type: ignore [call-arg] + after: int + count: int + fromAddress: Sequence[Union[Address, ChecksumAddress, ENS]] + fromBlock: BlockIdentifier + toAddress: Sequence[Union[Address, ChecksumAddress, ENS]] + toBlock: BlockIdentifier + + +class DankEth(AsyncEth): + _block_cache_ttl = 0.1 + + # eth_chainId + + @async_cached_property + async def chain_id(self) -> int: + return await self._chain_id() + + _chain_id: MethodNoFormat[Callable[[], Awaitable[int]]] + + try: + _chain_id = MethodNoFormat(RPC.eth_chainId, is_property=True) + except TypeError as e: # NOTE: older web3.py versions cant use `is_property` kwarg + if str(e) != "__init__() got an unexpected keyword argument 'is_property'": + raise + _chain_id = MethodNoFormat(RPC.eth_chainId, mungers=None) + + # eth_blockNumber + + @property + @alru_cache(maxsize=None, ttl=_block_cache_ttl) + async def block_number(self) -> BlockNumber: # type: ignore [override] + return await self.get_block_number() + + async def get_block_number(self) -> BlockNumber: # type: ignore [override] + return await self._get_block_number() # type: ignore [misc] + + _get_block_number: MethodNoFormat[Callable[[], Awaitable[BlockNumber]]] + + try: + _get_block_number = MethodNoFormat(RPC.eth_blockNumber, is_property=True) + except TypeError as e: # NOTE: older web3.py versions cant use `is_property` kwarg + if str(e) != "__init__() got an unexpected keyword argument 'is_property'": + raise + _get_block_number = MethodNoFormat(RPC.eth_blockNumber, mungers=None) + + async def get_block_timestamp(self, block_identifier: int) -> UnixTimestamp: + """ + Retrieves only the timestamp from a specific block. + + This method skips decoding the rest of the Block response data. + + Args: + block_identifier: The block number from which to retrieve the transactions. + hashes_only: If True, only transaction hashes will be returned. + + Returns: + A list of :class:`~Transaction` data objects from the block, or a list of transaction hashes. + + Example: + >>> [print(tx.hash) for tx in await dank_mids.eth.get_transactions(12345678)] + """ + try: # TypeError: 'str' object cannot be interpreted as an integer + block_identifier = hex(block_identifier) + finally: + block_bytes = await self._get_block_raw(block_identifier, False) + return json.decode( + block_bytes, type=_Timestamped, dec_hook=UnixTimestamp._decode_hook + ).timestamp + + @overload + async def get_transactions(self, block_identifier: Union[int, HexStr]) -> List[Transaction]: ... + @overload + async def get_transactions( + self, block_identifier: Union[int, HexStr], hashes_only: Literal[True] + ) -> List[TransactionHash]: ... + @overload + async def get_transactions( + self, block_identifier: Union[int, HexStr], hashes_only: Literal[False] + ) -> List[Transaction]: ... + async def get_transactions( + self, block_identifier: Union[int, HexStr], hashes_only: bool = False + ) -> Union[List[Transaction], List[TransactionHash]]: + """ + Retrieves only the transactions from a specific block. + + This method skips decoding the rest of the Block response data. + + Args: + block_identifier: The block number from which to retrieve the transactions. + hashes_only: If True, only transaction hashes will be returned. + + Returns: + A list of :class:`~Transaction` data objects from the block, or a list of transaction hashes. + + Example: + >>> [print(tx.hash) for tx in await dank_mids.eth.get_transactions(12345678)] + """ + try: # TypeError: 'str' object cannot be interpreted as an integer + block_identifier = hex(block_identifier) # type: ignore [arg-type, assignment] + finally: + block_bytes = await self._get_block_raw(block_identifier, not hashes_only) + return json.decode(block_bytes, type=TinyBlock, dec_hook=_decode_hook).transactions + + async def get_transaction_receipt( + self, + *args, + decode_to: Type[T] = TransactionReceipt, + decode_hook: _DecodeHook[T] = _decode_hook, + **kwargs, + ) -> T: + receipt_bytes = await self._get_transaction_receipt_raw(*args, **kwargs) + return json.decode(receipt_bytes, type=decode_to, dec_hook=decode_hook) + + async def get_transaction_status(self, transaction_hash: str) -> Status: + tiny_receipt = await self.get_transaction_receipt( + transaction_hash, + decode_to=_Statusable, + decode_hook=lambda enum_cls, data: enum_cls(data), + ) + return tiny_receipt.status + + async def trace_filter( + self, + filter_params: TraceFilterParams, + decode_to: Type[T] = List[FilterTrace], + decode_hook: _DecodeHook[T] = _decode_hook, + ) -> T: + traces_bytes = await self._trace_filter(filter_params) + try: + return json.decode(traces_bytes, type=decode_to, dec_hook=decode_hook) + except ValidationError: + if decode_to.__origin__ is not list: + raise + + traces_raw = json.decode(traces_bytes, type=List[Raw]) + traces = [] + trace_cls = decode_to.__args__[0] + for raw in traces_raw: + try: + traces.append(json.decode(raw, type=trace_cls, dec_hook=decode_hook)) + except ValidationError as e: + e.args = *e.args, json.decode(raw) + raise + + async def trace_transaction(self, transaction_hash: str) -> List[FilterTrace]: + return await self._trace_transaction(transaction_hash) + + _get_transaction_raw: MethodNoFormat[Callable[[HexStr], Awaitable[Raw]]] = MethodNoFormat(f"{RPC.eth_getTransactionByHash}_raw", mungers=[default_root_munger]) # type: ignore [arg-type,var-annotated] + + async def get_transaction(self, transaction_hash: str) -> AnyTransaction: # type: ignore [override] + transaction_bytes = await self._get_transaction_raw(transaction_hash) + try: + return json.decode(transaction_bytes, type=Transaction, dec_hook=_decode_hook) + except ValidationError: + try: + return json.decode(transaction_bytes, type=TransactionRLP, dec_hook=_decode_hook) + except ValidationError as e: + e.args = *e.args, json.decode(transaction_bytes) + raise + + async def get_logs( + self, + *args, + decode_to: Type[T] = Tuple[Log, ...], # type: ignore [assignment] + decode_hook: _DecodeHook[T] = _decode_hook, + **kwargs, + ) -> T: + logs_bytes = await self._get_logs_raw(*args, **kwargs) # type: ignore [attr-defined] + return json.decode(logs_bytes, type=decode_to, dec_hook=decode_hook) + + meth = MethodNoFormat.default(RPC.eth_getTransactionReceipt) # type: ignore [arg-type, var-annotated] + if WEB3_MAJOR_VERSION >= 6: + _transaction_receipt = meth + else: + _get_transaction_receipt = meth + + _get_transaction_receipt_raw = MethodNoFormat.default(f"{RPC.eth_getTransactionReceipt}_raw") + + _get_block_raw: MethodNoFormat[Callable[..., Awaitable[Raw]]] = MethodNoFormat( + method_choice_depends_on_args=select_method_for_block_identifier( + **{k: f"{v}_raw" for k, v in _block_selectors.items()} + ), + mungers=[AsyncEth.get_block_munger], + ) + _trace_filter = MethodNoFormat.default(f"{RPC.trace_filter}_raw") + _trace_transaction = MethodNoFormat.default(RPC.trace_transaction) + + +# TODO: this is super hacky, make it not. +bypass_formatters(DankEth) + + +class _Statusable(Struct, frozen=True): + + status: Status + + +class _Timestamped(Struct, frozen=True): # type: ignore [call-arg] + + timestamp: UnixTimestamp + """The Unix timestamp for when the block was collated.""" diff --git a/dank_mids/types.py b/dank_mids/types.py index 1a199fb3..611b9742 100644 --- a/dank_mids/types.py +++ b/dank_mids/types.py @@ -8,22 +8,32 @@ Coroutine, DefaultDict, Dict, + Iterable, Iterator, List, Literal, Mapping, NewType, + NoReturn, Optional, Set, Tuple, + Type, TypedDict, TypeVar, Union, overload, ) -import msgspec -from eth_typing import ChecksumAddress +import evmspec +from dictstruct import DictStruct +from eth_typing import ChecksumAddress, HexStr +from evmspec._ids import ChainId +from evmspec.block import BaseBlock, Block, MinedBlock, ShanghaiCapellaBlock +from evmspec.data import Address, BlockNumber, Wei, uint, _decode_hook +from evmspec.log import Log +from hexbytes import HexBytes +from msgspec import UNSET, Raw, ValidationError, json from web3.datastructures import AttributeDict from web3.types import RPCEndpoint, RPCResponse @@ -38,8 +48,9 @@ if TYPE_CHECKING: from dank_mids._requests import Multicall -ChainId = NewType("ChainId", int) -"""A type representing the unique integer identifier for a blockchain network.""" + +T = TypeVar("T") + BlockId = NewType("BlockId", str) """A type representing the identifier for a specific block in the blockchain.""" @@ -84,85 +95,7 @@ """A type alias for a nested dictionary structure.""" -class _DictStruct(msgspec.Struct): - """A base class enhancing :class:`~msgspec.Struct` with additional dictionary-like functionality.""" - - def __bool__(self) -> bool: - """A Struct will always exist.""" - return True - - def __getitem__(self, attr: str) -> Any: - """ - Allow dictionary-style access to attributes. - - Args: - attr: The name of the attribute to access. - - Returns: - The value of the attribute. - """ - try: - return getattr(self, attr) - except AttributeError: - raise KeyError(attr) from None - - def __getattr__(self, attr: str) -> Any: - """ - Get the value of an attribute, raising AttributeError if the value is :obj:`msgspec.UNSET`. - - Parameters: - attr: The name of the attribute to fetch. - - Raises: - AttributeError: If the value is :obj:`~msgspec.UNSET`. - - Returns: - The value of the attribute. - """ - attr = super().__getattr__(attr) - if attr is msgspec.UNSET: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") - return attr - - def __iter__(self) -> Iterator[str]: - """ - Iterate thru the keys of the Struct. - - Yields: - Struct key. - """ - for field in self.__struct_fields__: - if getattr(self, field, msgspec.UNSET) is not msgspec.UNSET: - yield field - - def __len__(self) -> int: - """ - The number of keys in the Struct. - - Returns: - The number of keys. - """ - return len(list(self)) - - def keys(self) -> Iterator[str]: - yield from self - - def items(self) -> Iterator[Tuple[str, Any]]: - for key in self.__struct_fields__: - try: - yield key, getattr(self, key) - except AttributeError: - continue - - def values(self) -> Iterator[Any]: - for key in self.__struct_fields__: - try: - yield getattr(self, key) - except AttributeError: - continue - - -class PartialRequest(_DictStruct, frozen=True): # type: ignore [call-arg] +class PartialRequest(DictStruct, frozen=True, omit_defaults=True, repr_omit_defaults=True): # type: ignore [call-arg] """ Represents a partial JSON-RPC request. @@ -184,7 +117,7 @@ class PartialRequest(_DictStruct, frozen=True): # type: ignore [call-arg] @property def data(self) -> bytes: - return msgspec.json.encode(self) + return json.encode(self) class Request(PartialRequest): @@ -198,7 +131,7 @@ class Request(PartialRequest): """The JSON-RPC version, always set to "2.0".""" -class Error(_DictStruct, frozen=True): # type: ignore [call-arg] +class Error(DictStruct, frozen=True, omit_defaults=True, repr_omit_defaults=True): # type: ignore [call-arg] """ Represents an error in a JSON-RPC response. """ @@ -209,7 +142,7 @@ class Error(_DictStruct, frozen=True): # type: ignore [call-arg] message: str """The error message.""" - data: Optional[Any] = msgspec.UNSET + data: Optional[Any] = UNSET """Additional error data, if any.""" @@ -218,110 +151,21 @@ class Error(_DictStruct, frozen=True): # type: ignore [call-arg] _str_responses: Set[str] = set() -class Log(_DictStruct, frozen=True): # type: ignore [call-arg] - removed: Optional[bool] - logIndex: Optional[str] - transactionIndex: Optional[str] - transactionHash: str - blockHash: Optional[str] - blockNumber: Optional[str] - address: Optional[str] - data: Optional[str] - topics: Optional[List[str]] - - -class AccessListEntry(_DictStruct, frozen=True): # type: ignore [call-arg] - address: str - storageKeys: List[str] - - -AccessList = List[AccessListEntry] - -# TODO: use the types from snek -Transaction = Dict[str, Union[str, None, AccessList]] - - -class FeeStats(_DictStruct, frozen=True): # type: ignore [call-arg] - """Arbitrum includes this in the `feeStats` field of a tx receipt.""" - - l1Calldata: str - l2Storage: str - l1Transaction: str - l2Computation: str - - -class ArbitrumFeeStats(_DictStruct, frozen=True): # type: ignore [call-arg] - """Arbitrum includes these with a tx receipt.""" - - paid: FeeStats - """ - The breakdown of gas paid for the transaction. - - (price * unitsUsed) - """ - # These 2 attributes do not always exist - unitsUsed: FeeStats = msgspec.UNSET - """The breakdown of units of gas used for the transaction.""" - prices: FeeStats = msgspec.UNSET - """The breakdown of gas prices for the transaction.""" - - -class TransactionReceipt(_DictStruct, frozen=True, omit_defaults=True): # type: ignore [call-arg] - transactionHash: str - blockHash: str - blockNumber: str - logsBloom: str - contractAddress: Optional[str] - transactionIndex: Optional[str] - returnCode: str - effectiveGasPrice: str - gasUsed: str - cumulativeGasUsed: str - returnData: str - logs: List[Log] - - # These fields are only present on Arbitrum. - l1BlockNumber: str = msgspec.UNSET - """This field is only present on Arbitrum.""" - l1InboxBatchInfo: Optional[str] = msgspec.UNSET - """This field is only present on Arbitrum.""" - feeStats: ArbitrumFeeStats = msgspec.UNSET - """This field is only present on Arbitrum.""" - - -class Block(_DictStruct, frozen=True): - parentHash: str - sha3Uncles: str - miner: str - stateRoot: str - transactionsRoot: str - receiptsRoot: str - logsBloom: str - number: str - gasLimit: str - gasUsed: str - timestamp: str - extraData: str - mixHash: str - nonce: str - size: str - uncles: List[str] - transactions: List[Union[str, Transaction]] - - _RETURN_TYPES = { - "eth_call": str, - "eth_chainId": str, - "eth_getCode": str, + "eth_call": HexBytes, + "eth_chainId": ChainId, + "eth_getCode": HexBytes, "eth_getLogs": List[Log], - "eth_getBalance": str, - "eth_blockNumber": str, # TODO: see if we can decode this straight to an int - "eth_accounts": List[str], + "eth_getBalance": Wei, + "eth_blockNumber": BlockNumber, + "eth_accounts": List[Address], "eth_getBlockByNumber": Block, - "eth_getTransactionCount": str, - "eth_getTransactionByHash": Transaction, - "eth_getTransactionReceipt": TransactionReceipt, - "erigon_getHeaderByNumber": Dict[str, Union[str, int, bool, None]], + "eth_getTransactionCount": uint, + "eth_getTransactionByHash": evmspec.Transaction, + "eth_getTransactionReceipt": evmspec.FullTransactionReceipt, + "erigon_getHeaderByNumber": evmspec.ErigonBlockHeader, + "trace_filter": List[evmspec.FilterTrace], + "trace_transaction": List[evmspec.FilterTrace], } """ A dictionary mapping RPC method names to their expected return types. @@ -333,14 +177,14 @@ class Block(_DictStruct, frozen=True): _chainstack_429_msg = "You've exceeded the RPS limit available on the current plan." -class PartialResponse(_DictStruct, frozen=True): +class PartialResponse(DictStruct, frozen=True, omit_defaults=True, repr_omit_defaults=True): """ Represents a partial JSON-RPC response. We use these to more efficiently decode responses from the node. """ - result: msgspec.Raw = None # type: ignore + result: Raw = None # type: ignore """The result of the RPC call, if successful.""" error: Optional[Error] = None @@ -379,18 +223,16 @@ def payload_too_large(self) -> bool: def to_dict(self, method: Optional[RPCEndpoint] = None) -> RPCResponse: # type: ignore [override] """Returns a complete dictionary representation of this response ``Struct``.""" - data: RPCResponse = {} - for field, attr in self.items(): - if attr is None: - continue - if field == "result": - attr = self.decode_result(method=method, _caller=self) - data[field] = AttributeDict(attr) if isinstance(attr, Mapping) else attr # type: ignore [literal-required] + data: RPCResponse = { + key: self.decode_result(method=method, caller=self) if key == "result" else value + for key, value in self.items() + if value is not None + } return data def decode_result( - self, method: Optional[RPCEndpoint] = None, _caller=None - ) -> Union[str, AttributeDict]: + self, method: Optional[RPCEndpoint] = None, *, caller=None + ) -> Union[HexBytes, Wei, uint, ChainId, BlockNumber, AttributeDict]: # NOTE: These must be added to the `_RETURN_TYPES` constant above manually if method and (typ := _RETURN_TYPES.get(method)): if method in [ @@ -405,39 +247,70 @@ def decode_result( "erigon_getHeaderByNumber", ]: try: - return msgspec.json.decode(self.result, type=typ) - except (msgspec.ValidationError, TypeError) as e: - raise ValueError( - e, - f"method: {method} result: {msgspec.json.decode(self.result)}", - ).with_traceback(e.__traceback__) from e + return better_decode( + self.result, type=typ, dec_hook=_decode_hook, method=method + ) + except Exception as e: + if typ is not Block: + raise + + if e.args[0] == "Object contains unknown field `totalDifficulty`": + try: + # NOTE should we do this?? + # _RETURN_TYPES[method] = MinedBlock + return better_decode( + self.result, type=MinedBlock, dec_hook=_decode_hook, method=method + ) + except ValidationError as e2: + if e2.args[0] != "Object contains unknown field `baseFeePerGas`": + raise + result = better_decode( + self.result, type=BaseBlock, dec_hook=_decode_hook, method=method + ) + _RETURN_TYPES[method] = BaseBlock # all blocks on base are BaseBlocks + return result + + elif e.args[0] == "Object contains unknown field `withdrawals`": + return better_decode( + self.result, + type=ShanghaiCapellaBlock, + dec_hook=_decode_hook, + method=method, + ) + else: + raise + + return better_decode( + self.result, type=typ, dec_hook=_decode_hook, method=method + ) + + start = time() try: - start = time() - decoded = msgspec.json.decode(self.result, type=typ) - if _caller: - stats.log_duration(f"decoding {type(_caller)} {method}", start) - return AttributeDict(decoded) if isinstance(decoded, dict) else decoded - except (msgspec.ValidationError, TypeError) as e: + decoded = better_decode(self.result, type=typ, dec_hook=_decode_hook, method=method) + except (ValidationError, TypeError) as e: stats.logger.log_validation_error(self, e) + raise + + if caller: + stats.log_duration(f"decoding {type(caller)} {method}", start) + return decoded # We have some semi-smart logic for providing decoder hints even if method not in `_RETURN_TYPES` if method: try: if method in _dict_responses: - decoded = AttributeDict( - msgspec.json.decode(self.result, type=_nested_dict_of_stuff) - ) + decoded = AttributeDict(json.decode(self.result, type=_nested_dict_of_stuff)) stats.logger.log_types(method, decoded) return decoded elif method in _str_responses: # TODO: finish adding methods and get rid of this - stats.logger.devhint(f"Must add `{method}: str` to `_RETURN_TYPES`") - return msgspec.json.decode(self.result, type=str) - except (msgspec.ValidationError, TypeError) as e: + stats.logger.devhint("Must add `%s: str` to `_RETURN_TYPES`", method) + return json.decode(self.result, type=str) + except (ValidationError, TypeError) as e: stats.logger.log_validation_error(method, e) # In this case we can provide no hints, let's let the decoder figure it out - decoded = msgspec.json.decode(self.result) + decoded = json.decode(self.result) if isinstance(decoded, str): if method: _str_responses.add(method) @@ -445,7 +318,7 @@ def decode_result( elif isinstance(decoded, dict): if method: _dict_responses.add(method) - return AttributeDict.recursive(decoded) + return AttributeDict(decoded) elif isinstance(decoded, list): if method is None: return decoded @@ -454,7 +327,7 @@ def decode_result( ) -class Response(PartialResponse): +class Response(PartialResponse, omit_defaults=True, repr_omit_defaults=True): # type: ignore [call-arg] """ Represents a complete JSON-RPC response. @@ -475,31 +348,33 @@ class RawResponse: They represent either a successful or a failed response, stored as pre-decoded bytes. """ - def __init__(self, raw: msgspec.Raw) -> None: + def __init__(self, raw: Raw) -> None: self._raw = raw - """The `msgspec.Raw` object wrapped by this wrapper.""" + """The :class:`Raw` object wrapped by this wrapper.""" @overload def decode(self, partial: Literal[True]) -> PartialResponse: ... @overload def decode(self, partial: Literal[False] = False) -> Response: ... def decode(self, partial: bool = False) -> Union[Response, PartialResponse]: - """Decode the wrapped `msgspec.Raw` object into a `Response` or a `PartialResponse`.""" - try: - return msgspec.json.decode(self._raw, type=PartialResponse if partial else Response) - except (msgspec.ValidationError, TypeError) as e: - e.args = (*e.args, f"decoded: {msgspec.json.decode(self._raw)}") - raise + """Decode the wrapped :class:`Raw` object into a :class:`Response` or a :class:`PartialResponse`.""" + return better_decode(self._raw, type=PartialResponse if partial else Response) JSONRPCBatchRequest = List[Request] # NOTE: A PartialResponse result implies a failure response from the rpc. JSONRPCBatchResponse = Union[List[RawResponse], PartialResponse] # We need this for proper decoding. -JSONRPCBatchResponseRaw = Union[List[msgspec.Raw], PartialResponse] +JSONRPCBatchResponseRaw = Union[List[Raw], PartialResponse] + + +StrEncodable = Union[ChecksumAddress, HexStr] +Encodable = Union[int, StrEncodable, HexBytes, bytes] +RpcThing = Union[HexStr, List[HexStr], Dict[str, HexStr]] -def _encode_hook(obj: Any) -> Any: + +def _encode_hook(obj: Encodable) -> RpcThing: """ A hook function for encoding objects during JSON serialization. @@ -512,6 +387,44 @@ def _encode_hook(obj: Any) -> Any: Raises: NotImplementedError: If the object type is not supported for encoding. """ - if isinstance(obj, AttributeDict): - return dict(obj) - raise NotImplementedError(type(obj)) + try: + # We just assume `obj` is an int subclass instead of performing if checks because it usually is. + return hex(int(obj)) # type: ignore [return-value] + except TypeError as e: + # I put this here for AttributeDicts which come from eth_getLogs params + # but I check for mapping so it can work with user custom classes + if not isinstance(obj, Mapping): + raise TypeError(obj, type(obj)) from e + return dict({k: _rudimentary_encode_dict_value(v) for k, v in obj.items()}) + except ValueError as e: + # NOTE: The error is probably this if `obj` is a string: + # ValueError: invalid literal for int() with base 10:""" + if not isinstance(obj, HexBytes): + e.args = *e.args, obj, type(obj) + raise ValueError(obj, type(obj)) from e + return obj.hex() # type: ignore [return-value] + + +def _rudimentary_encode_dict_value(value): + # I dont think this needs to be robust, time will tell + return hex(value) if isinstance(value, int) else value + + +def better_decode( + data: Raw, + *, + type: Optional[Type[T]] = None, + dec_hook: Optional[Callable[[Type, object], T]] = None, + method: Optional[str] = None, +) -> T: + try: + return json.decode(data, type=type, dec_hook=dec_hook) + except (ValidationError, TypeError) as e: + extra_args = [ + f"type: {type.__module__}.{type.__qualname__}", + f"result: {json.decode(data)}", + ] + if method: + extra_args.insert(0, f"method: {method}") + e.args = (*e.args, *extra_args) + raise diff --git a/docs/conf.py b/docs/conf.py index d1743c87..81a84894 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,61 +14,67 @@ logger = logging.getLogger(__name__) -network.connect('mainnet') +network.connect("mainnet") -project = 'dank_mids' -copyright = '2024, BobTheBuidler' -author = 'BobTheBuidler' +project = "dank_mids" +copyright = "2024, BobTheBuidler" +author = "BobTheBuidler" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.intersphinx', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", ] -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] intersphinx_mapping = { - 'a_sync': ('https://bobthebuidler.github.io/ez-a-sync', None), - 'brownie': ('https://eth-brownie.readthedocs.io/en/stable/', None), - 'hexbytes': ('https://hexbytes.readthedocs.io/en/stable/', None), - 'python': ('https://docs.python.org/3', None), - 'typing_extensions': ('https://typing-extensions.readthedocs.io/en/latest/', None), - 'web3': ('https://web3py.readthedocs.io/en/stable/', None), + "a_sync": ("https://bobthebuidler.github.io/ez-a-sync", None), + "brownie": ("https://eth-brownie.readthedocs.io/en/stable/", None), + "dictstruct": ("https://bobthebuidler.github.io/dictstruct", None), + "evmspec": ("https://bobthebuidler.github.io/evmspec", None), + "hexbytes": ("https://hexbytes.readthedocs.io/en/stable/", None), + "python": ("https://docs.python.org/3", None), + "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest/", None), + "web3": ("https://web3py.readthedocs.io/en/stable/", None), } # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' -html_static_path = ['_static'] +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] autodoc_default_options = { - 'special-members': ','.join([ - '__init__', - '__call__', - '__getitem__', - '__iter__', - '__aiter__', - '__next__', - '__anext__', - ]), - 'inherited-members': True, - 'member-order': 'groupwise', + "special-members": ",".join( + [ + "__init__", + "__call__", + "__getitem__", + "__iter__", + "__aiter__", + "__next__", + "__anext__", + ] + ), + "inherited-members": True, + "member-order": "groupwise", # hide private methods that aren't relevant to us here - 'exclude-members': ','.join([ - '__new__', - '_abc_impl', - '_fget', - '_fset', - '_fdel', - '_ASyncSingletonMeta__instances', - '_ASyncSingletonMeta__lock', - ]), + "exclude-members": ",".join( + [ + "__new__", + "_abc_impl", + "_fget", + "_fset", + "_fdel", + "_ASyncSingletonMeta__instances", + "_ASyncSingletonMeta__lock", + ] + ), } autodoc_typehints = "description" # Don't show class signature with the class' name. @@ -76,7 +82,7 @@ automodule_generate_module_stub = True -sys.path.insert(0, os.path.abspath('./dank_mids')) +sys.path.insert(0, os.path.abspath("./dank_mids")) def skip_specific_members(app, what, name, obj, skip, options): @@ -84,23 +90,47 @@ def skip_specific_members(app, what, name, obj, skip, options): Function to exclude specific members for a particular module. """ exclusions = { - 'dank_mids.types': {'__iter__', 'get', 'update', 'clear', 'copy', 'keys', 'values', 'items', 'fromkeys', 'pop', 'popitem', 'setdefault'}, + "dank_mids.types": { + "__iter__", + "get", + "update", + "clear", + "copy", + "keys", + "values", + "items", + "fromkeys", + "pop", + "popitem", + "setdefault", + }, } - - current_module = getattr(obj, '__module__', None) + + current_module = getattr(obj, "__module__", None) logger.info(f"module: {current_module} name: {name} obj: {obj}") if current_module in exclusions and name in exclusions[current_module]: return True # Skip the __init__ and __call__ members of any NewType objects we defined. - if current_module == "typing" and hasattr(obj, "__self__") and type(obj.__self__).__name__ == "NewType" and name in ["__init__", "__call__"]: + if ( + current_module == "typing" + and hasattr(obj, "__self__") + and type(obj.__self__).__name__ == "NewType" + and name in ["__init__", "__call__"] + ): return True - + # Skip the __init__, args, and with_traceback members of all Exceptions - if current_module is None and hasattr(obj, '__objclass__') and issubclass(obj.__objclass__, BaseException) and name in ["__init__", "args", "with_traceback"]: + if ( + current_module is None + and hasattr(obj, "__objclass__") + and issubclass(obj.__objclass__, BaseException) + and name in ["__init__", "args", "with_traceback"] + ): return True - + return skip + def setup(app): - app.connect('autodoc-skip-member', skip_specific_members) + app.connect("autodoc-skip-member", skip_specific_members) diff --git a/poetry.lock b/poetry.lock index 7c14ec59..8356df8c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1307,6 +1307,20 @@ files = [ {file = "dataclassy-0.11.1.tar.gz", hash = "sha256:ad6622cb91e644d13f68768558983fbc22c90a8ff7e355638485d18b9baf1198"}, ] +[[package]] +name = "dictstruct" +version = "0.0.1" +description = "A msgspec.Struct implementation compatible with the standard dictionary API" +optional = false +python-versions = "<3.13,>=3.8" +files = [ + {file = "dictstruct-0.0.1-py3-none-any.whl", hash = "sha256:efa89f4cf077f0b8e7eec325dd833aaa577d92d95d8a9fca18f08cea40bc3cfa"}, + {file = "dictstruct-0.0.1.tar.gz", hash = "sha256:423fae56834bb02b6eb071b474de93c386e310331a8baa7e6adee983a4d5e37b"}, +] + +[package.dependencies] +msgspec = ">=0.18.5" + [[package]] name = "eip712" version = "0.1.0" @@ -1944,6 +1958,22 @@ docs = ["sphinx (>=5.0.0)", "sphinx-rtd-theme (>=1.0.0)", "towncrier (>=21,<22)" lint = ["black (>=23)", "flake8 (==3.8.3)", "isort (>=5.11.0)", "mypy (==0.971)", "pydocstyle (>=5.0.0)", "types-setuptools"] test = ["hypothesis (>=4.43.0)", "mypy (==0.971)", "pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)", "types-setuptools"] +[[package]] +name = "evmspec" +version = "0.0.1" +description = "A collection of msgspec.Struct definitions for use with the Ethereum Virtual Machine" +optional = false +python-versions = "<3.13,>=3.8" +files = [ + {file = "evmspec-0.0.1-py3-none-any.whl", hash = "sha256:6b973d4bf5e7f9cfd99c5725a0ac2c6e5c180b240a7baf6a1f9c3c8b4e867010"}, + {file = "evmspec-0.0.1.tar.gz", hash = "sha256:c66483d14db0314e5ce24583e6cc56ffc50c52b1b3972d30e26c457f6629c520"}, +] + +[package.dependencies] +dictstruct = ">=0.0.1" +hexbytes = "*" +typing_extensions = ">=4.0.0" + [[package]] name = "execnet" version = "1.9.0" @@ -4210,13 +4240,13 @@ telegram = ["requests"] [[package]] name = "typed-envs" -version = "0.0.4" +version = "0.0.2" description = "typed_envs is used to create specialized EnvironmentVariable objects that behave exactly the same as any other instance of the `typ` used to create them." optional = false python-versions = "*" files = [ - {file = "typed-envs-0.0.4.tar.gz", hash = "sha256:5a71ec50cc1b274c39455885d276a97ae0000eecf0a1d1023bdc5f698fccb6fe"}, - {file = "typed_envs-0.0.4-py3-none-any.whl", hash = "sha256:f995ecbccff283ed6579f33187ea253d9d37fdd652866ec312b8ccc508d4c07b"}, + {file = "typed-envs-0.0.2.tar.gz", hash = "sha256:7113e60f489936344f03c2e65fdbbbb1115cc8e215b823627d446396efcf4327"}, + {file = "typed_envs-0.0.2-py3-none-any.whl", hash = "sha256:8bd4c31e52011cfc7616c09d5b2db239c30190cdff938653c8cd5a05581ca6b8"}, ] [[package]] @@ -4978,4 +5008,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.13" -content-hash = "e4a181d7d21dea0548e8a59e67732d8616946c96217376ab038d462a7c146ca1" +content-hash = "ac5b33e9bfb86d90ea13df856dd077095e65b07e9bae73200176fbbf8d8f74ff" diff --git a/pyproject.toml b/pyproject.toml index 46628c5e..232e30e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,10 @@ readme = "README.md" python = ">=3.8,<3.13" aiofiles = "*" eth-retry = "^0.1.15" +evmspec = "^0.0.1" ez-a-sync = ">=0.20.7,<0.24" -msgspec = "*" multicall = ">=0.6.2,<1" -typed-envs = ">=0.0.2" +typed-envs = "^0.0.2" web3 = ">=5.27,!=5.29.*,!=5.30.*,!=5.31.1,!=5.31.2,<8" [tool.poetry.group.dev.dependencies]