Skip to content

Commit

Permalink
feat: implement evmspec (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Nov 9, 2024
1 parent 7c427d2 commit 50b3305
Show file tree
Hide file tree
Showing 8 changed files with 686 additions and 320 deletions.
133 changes: 133 additions & 0 deletions dank_mids/_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from importlib.metadata import version
from typing import Tuple, Type

from typing_extensions import Self
from web3.eth import BaseEth
from web3._utils.method_formatters import ERROR_FORMATTERS, NULL_RESULT_FORMATTERS
from web3._utils.blocks import select_method_for_block_identifier
from web3._utils.rpc_abi import RPC
from web3.method import Method, TFunc, _apply_request_formatters, default_root_munger
from web3.types import BlockIdentifier


WEB3_MAJOR_VERSION = int(version("web3").split(".")[0])

return_as_is = lambda x: x


class MethodNoFormat(Method[TFunc]):
"""Bypasses web3py default result formatters."""

def process_params(self, module, *args, **kwargs):
params = self.input_munger(module, args, kwargs)

if self.method_choice_depends_on_args:
# If the method choice depends on the args that get passed in,
# the first parameter determines which method needs to be called
self.json_rpc_method = self.method_choice_depends_on_args(value=params[0])

pending_or_latest_filter_methods = [
RPC.eth_newPendingTransactionFilter,
RPC.eth_newBlockFilter,
]
if self.json_rpc_method in pending_or_latest_filter_methods:
# For pending or latest filter methods, use params to determine
# which method to call, but don't pass them through with the request
params = []

method = self.method_selector_fn()
request = (method, _apply_request_formatters(params, self.request_formatters(method)))
response_formatters = (
return_as_is,
ERROR_FORMATTERS.get(self.json_rpc_method, return_as_is), # type: ignore [arg-type]
NULL_RESULT_FORMATTERS.get(self.json_rpc_method, return_as_is), # type: ignore [arg-type]
)
return request, response_formatters

@classmethod
def default(cls, method: RPC) -> Self:
return cls(method, [default_root_munger])


def bypass_chainid_formatter(eth: Type[BaseEth]) -> None:
eth._chain_id = MethodNoFormat(RPC.eth_chainId)


def bypass_getbalance_formatter(eth: Type[BaseEth]) -> None:
eth._get_balance = MethodNoFormat(RPC.eth_getBalance, mungers=[eth.block_id_munger])


def bypass_blocknumber_formatter(eth: Type[BaseEth]) -> None:
eth.get_block_number = MethodNoFormat(RPC.eth_blockNumber)


def bypass_transaction_count_formatter(eth: Type[BaseEth]) -> None:
eth._get_transaction_count = MethodNoFormat(
RPC.eth_getTransactionCount, mungers=[eth.block_id_munger]
)


def bypass_log_formatter(eth: Type[BaseEth]) -> None:
eth._get_logs = MethodNoFormat.default(RPC.eth_getLogs)
eth._get_logs_raw = MethodNoFormat.default(f"{RPC.eth_getLogs}_raw")
eth.get_filter_logs = MethodNoFormat.default(RPC.eth_getFilterLogs)
eth.get_filter_changes = MethodNoFormat.default(RPC.eth_getFilterChanges)


def bypass_transaction_receipt_formatter(eth: Type[BaseEth]) -> None:
attr_name = "_transaction_receipt" if WEB3_MAJOR_VERSION >= 6 else "_get_transaction_receipt"
setattr(eth, attr_name, MethodNoFormat.default(RPC.eth_getTransactionReceipt))


def bypass_transaction_formatter(eth: Type[BaseEth]) -> None:
eth._get_transaction = MethodNoFormat.default(RPC.eth_getTransactionByHash)


_block_selectors = dict(
if_predefined=RPC.eth_getBlockByNumber,
if_hash=RPC.eth_getBlockByHash,
if_number=RPC.eth_getBlockByNumber,
)


def bypass_block_formatters(eth: Type[BaseEth]) -> None:
if WEB3_MAJOR_VERSION >= 6:
get_block_munger = eth.get_block_munger
else:

def get_block_munger(
self, block_identifier: BlockIdentifier, full_transactions: bool = False
) -> Tuple[BlockIdentifier, bool]:
return (block_identifier, full_transactions)

eth._get_block = MethodNoFormat(
method_choice_depends_on_args=select_method_for_block_identifier(**_block_selectors),
mungers=[get_block_munger],
)


def bypass_eth_call_formatter(eth: Type[BaseEth]) -> None:
eth._call = MethodNoFormat(RPC.eth_call, mungers=[eth.call_munger])


def bypass_get_code_formatter(eth: Type[BaseEth]) -> None:
eth._get_code = MethodNoFormat(RPC.eth_getCode, mungers=[eth.block_id_munger])


skip_formatters = (
bypass_chainid_formatter,
bypass_getbalance_formatter,
bypass_blocknumber_formatter,
bypass_transaction_count_formatter,
bypass_eth_call_formatter,
bypass_get_code_formatter,
bypass_log_formatter,
bypass_transaction_receipt_formatter,
bypass_transaction_formatter,
bypass_block_formatters,
)


def bypass_formatters(eth):
for bypass in skip_formatters:
bypass(eth)
90 changes: 55 additions & 35 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _should_batch_method(method: str) -> bool:


class RPCRequest(_RequestMeta[RawResponse]):
__slots__ = "method", "params", "should_batch", "_started", "_retry", "_daemon"
__slots__ = "method", "params", "should_batch", "raw", "_started", "_retry", "_daemon"

def __init__(
self,
Expand All @@ -165,7 +165,12 @@ def __init__(
):
self.controller = controller
"""The DankMiddlewareController that created this request."""
self.method = method
if method[-4:] == "_raw":
self.method = method[:-4]
self.raw = True
else:
self.method = method
self.raw = False
"""The rpc method for this request."""
self.params = params
"""The parameters to send with this request, if any."""
Expand Down Expand Up @@ -245,33 +250,48 @@ async def get_response(self) -> RPCResponse: # type: ignore [override]

# JIT json decoding
if isinstance(self.response, RawResponse):
response = self.response.decode(partial=True).to_dict(self.method)
error: Optional[RPCError]
if error := response.get("error"): # type: ignore [assignment]
if error["message"].lower() in ["invalid request", "parse error"]:
if self.controller._time_of_request_type_change == 0:
self.controller.request_type = Request
self.controller._time_of_request_type_change = time.time()
if time.time() - self.controller._time_of_request_type_change <= 600:
logger.debug(
"your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead"
)
return await self.controller(self.method, self.params)
response["error"] = dict(error)
response["error"]["dankmids_added_context"] = self.request
# I'm 99.99999% sure that any errd call has no result and we only get this field from mscspec object defs
# But I'll check it anyway to be safe
if result := response.pop("result", None):
response["result"] = result
logger.debug("error response for %s: %s", self, response)
response = self.response.decode(partial=True)
if response.error is None:
if self.raw:
return {"result": response.result}
response_dict = response.to_dict(self.method)
assert "result" in response_dict or "error" in response_dict, (
response_dict,
type(response_dict),
)
return response_dict

if response.error.message.lower() in ["invalid request", "parse error"]:
if self.controller._time_of_request_type_change == 0:
self.controller.request_type = Request
self.controller._time_of_request_type_change = time.time()
if time.time() - self.controller._time_of_request_type_change <= 600:
logger.debug(
"your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead"
)
method = self.method
if self.raw:
method += "_raw"
return await self.controller(method, self.params)

error = dict(response.error.items())
error["dankmids_added_context"] = self.request

response = response.to_dict(self.method)
response["error"] = error
logger.debug("error response for %s: %s", self, response)
return response

# If we have an Exception here it came from the goofy sync_call thing I need to get rid of.
# We raise it here so it traces back up to the caller
if isinstance(self.response, Exception):
__raise_more_detailed_exc(self.request, self.response)
_raise_more_detailed_exc(self.request, self.response)
# Less optimal decoding
# TODO: refactor this out
assert "result" in self.response or "error" in self.response, (
self.response,
type(self.response),
)
return self.response

@set_done
Expand All @@ -289,7 +309,10 @@ async def get_response_unbatched(self) -> RPCResponse: # type: ignore [override
t.cancel()
for task in done:
return await task
return self.response.decode(partial=True).to_dict(self.method)
response = self.response.decode(partial=True)
retval = {"result": response.result} if self.raw else response.to_dict(self.method)
assert "result" in retval or "error" in retval, (retval, type(retval))
return retval

@set_done
async def spoof_response(self, data: Union[RawResponse, bytes, Exception]) -> None:
Expand Down Expand Up @@ -352,8 +375,12 @@ async def create_duplicate(
# Creating the task before awaiting the new call ensures the new call will grab the semaphore immediately
# and then the task will try to acquire at the very next event loop _run_once cycle
logger.warning("%s got stuck, we're creating a new one", self)
retval = await self.controller(self.method, self.params)
method = self.method
if self.raw:
method += "_raw"
retval = await self.controller(method, self.params)
await self.semaphore.acquire()
assert "result" in retval or "error" in retval, (retval, type(retval))
return retval


Expand Down Expand Up @@ -581,18 +608,11 @@ def mcall_encode(data: List[Tuple[bool, bytes]]) -> bytes:

def mcall_decode(data: PartialResponse) -> Union[List[Tuple[bool, bytes]], Exception]:
try:
decoded = data.decode_result("eth_call")[2:] # type: ignore [arg-type]
decoded = bytes.fromhex(decoded)
return mcall_decoder(decoding.ContextFramesBytesIO(decoded))[2]
return mcall_decoder(decoding.ContextFramesBytesIO(data.decode_result("eth_call")))[2]
except Exception as e:
# NOTE: We need to safely bring any Exceptions back out of the ProcessPool
try:
# We do this goofy thing since we can't `return Exc() from e`
raise e.__class__(
*e.args, data.decode_result() if isinstance(data, PartialResponse) else data
) from e
except Exception as new_e:
return new_e
e.args = (*e.args, data.decode_result() if isinstance(data, PartialResponse) else data)
return e


class Multicall(_Batch[RPCResponse, eth_call]):
Expand Down Expand Up @@ -1208,7 +1228,7 @@ def __format_error(request: PartialRequest, response: PartialResponse) -> Attrib
return AttributeDict.recursive(error)


def __raise_more_detailed_exc(request: PartialRequest, exc: Exception) -> NoReturn:
def _raise_more_detailed_exc(request: PartialRequest, exc: Exception) -> NoReturn:
if isinstance(exc, ClientResponseError):
raise DankMidsClientResponseError(exc, request) from exc
try:
Expand Down
3 changes: 2 additions & 1 deletion dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import eth_retry
from eth_typing import BlockNumber, ChecksumAddress
from eth_utils import to_checksum_address
from evmspec._ids import ChainId
from msgspec import Struct
from multicall.constants import MULTICALL2_ADDRESSES, MULTICALL_ADDRESSES
from multicall.multicall import NotSoBrightBatcher
Expand All @@ -21,7 +22,7 @@
from dank_mids._uid import UIDGenerator, _AlertingRLock
from dank_mids.helpers import _codec, _helpers, _session
from dank_mids.semaphores import _MethodQueues, _MethodSemaphores, BlockSemaphore
from dank_mids.types import BlockId, ChainId, PartialRequest, RawResponse, Request
from dank_mids.types import BlockId, PartialRequest, RawResponse, Request

try:
from multicall.constants import MULTICALL3_ADDRESSES
Expand Down
Loading

0 comments on commit 50b3305

Please sign in to comment.