Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement evmspec #273

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions dank_mids/_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from importlib.metadata import version
from typing import Tuple, Type

from typing_extensions import Self
from web3.eth import BaseEth
from web3._utils.method_formatters import ERROR_FORMATTERS, NULL_RESULT_FORMATTERS
from web3._utils.blocks import select_method_for_block_identifier
from web3._utils.rpc_abi import RPC
from web3.method import Method, TFunc, _apply_request_formatters, default_root_munger
from web3.types import BlockIdentifier


WEB3_MAJOR_VERSION = int(version("web3").split(".")[0])

return_as_is = lambda x: x


class MethodNoFormat(Method[TFunc]):
"""Bypasses web3py default result formatters."""

def process_params(self, module, *args, **kwargs):
params = self.input_munger(module, args, kwargs)

if self.method_choice_depends_on_args:
# If the method choice depends on the args that get passed in,
# the first parameter determines which method needs to be called
self.json_rpc_method = self.method_choice_depends_on_args(value=params[0])

pending_or_latest_filter_methods = [
RPC.eth_newPendingTransactionFilter,
RPC.eth_newBlockFilter,
]
if self.json_rpc_method in pending_or_latest_filter_methods:
# For pending or latest filter methods, use params to determine
# which method to call, but don't pass them through with the request
params = []

method = self.method_selector_fn()
request = (method, _apply_request_formatters(params, self.request_formatters(method)))
response_formatters = (
return_as_is,
ERROR_FORMATTERS.get(self.json_rpc_method, return_as_is), # type: ignore [arg-type]
NULL_RESULT_FORMATTERS.get(self.json_rpc_method, return_as_is), # type: ignore [arg-type]
)
return request, response_formatters

@classmethod
def default(cls, method: RPC) -> Self:
return cls(method, [default_root_munger])


def bypass_chainid_formatter(eth: Type[BaseEth]) -> None:
eth._chain_id = MethodNoFormat(RPC.eth_chainId)


def bypass_getbalance_formatter(eth: Type[BaseEth]) -> None:
eth._get_balance = MethodNoFormat(RPC.eth_getBalance, mungers=[eth.block_id_munger])


def bypass_blocknumber_formatter(eth: Type[BaseEth]) -> None:
eth.get_block_number = MethodNoFormat(RPC.eth_blockNumber)


def bypass_transaction_count_formatter(eth: Type[BaseEth]) -> None:
eth._get_transaction_count = MethodNoFormat(
RPC.eth_getTransactionCount, mungers=[eth.block_id_munger]
)


def bypass_log_formatter(eth: Type[BaseEth]) -> None:
eth._get_logs = MethodNoFormat.default(RPC.eth_getLogs)
eth._get_logs_raw = MethodNoFormat.default(f"{RPC.eth_getLogs}_raw")
eth.get_filter_logs = MethodNoFormat.default(RPC.eth_getFilterLogs)
eth.get_filter_changes = MethodNoFormat.default(RPC.eth_getFilterChanges)


def bypass_transaction_receipt_formatter(eth: Type[BaseEth]) -> None:
attr_name = "_transaction_receipt" if WEB3_MAJOR_VERSION >= 6 else "_get_transaction_receipt"
setattr(eth, attr_name, MethodNoFormat.default(RPC.eth_getTransactionReceipt))


def bypass_transaction_formatter(eth: Type[BaseEth]) -> None:
eth._get_transaction = MethodNoFormat.default(RPC.eth_getTransactionByHash)


_block_selectors = dict(
if_predefined=RPC.eth_getBlockByNumber,
if_hash=RPC.eth_getBlockByHash,
if_number=RPC.eth_getBlockByNumber,
)


def bypass_block_formatters(eth: Type[BaseEth]) -> None:
if WEB3_MAJOR_VERSION >= 6:
get_block_munger = eth.get_block_munger
else:

def get_block_munger(
self, block_identifier: BlockIdentifier, full_transactions: bool = False
) -> Tuple[BlockIdentifier, bool]:
return (block_identifier, full_transactions)

eth._get_block = MethodNoFormat(
method_choice_depends_on_args=select_method_for_block_identifier(**_block_selectors),
mungers=[get_block_munger],
)


def bypass_eth_call_formatter(eth: Type[BaseEth]) -> None:
eth._call = MethodNoFormat(RPC.eth_call, mungers=[eth.call_munger])


def bypass_get_code_formatter(eth: Type[BaseEth]) -> None:
eth._get_code = MethodNoFormat(RPC.eth_getCode, mungers=[eth.block_id_munger])


skip_formatters = (
bypass_chainid_formatter,
bypass_getbalance_formatter,
bypass_blocknumber_formatter,
bypass_transaction_count_formatter,
bypass_eth_call_formatter,
bypass_get_code_formatter,
bypass_log_formatter,
bypass_transaction_receipt_formatter,
bypass_transaction_formatter,
bypass_block_formatters,
)


def bypass_formatters(eth):
for bypass in skip_formatters:
bypass(eth)
90 changes: 55 additions & 35 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _should_batch_method(method: str) -> bool:


class RPCRequest(_RequestMeta[RawResponse]):
__slots__ = "method", "params", "should_batch", "_started", "_retry", "_daemon"
__slots__ = "method", "params", "should_batch", "raw", "_started", "_retry", "_daemon"

def __init__(
self,
Expand All @@ -165,7 +165,12 @@ def __init__(
):
self.controller = controller
"""The DankMiddlewareController that created this request."""
self.method = method
if method[-4:] == "_raw":
self.method = method[:-4]
self.raw = True
else:
self.method = method
self.raw = False
"""The rpc method for this request."""
self.params = params
"""The parameters to send with this request, if any."""
Expand Down Expand Up @@ -245,33 +250,48 @@ async def get_response(self) -> RPCResponse: # type: ignore [override]

# JIT json decoding
if isinstance(self.response, RawResponse):
response = self.response.decode(partial=True).to_dict(self.method)
error: Optional[RPCError]
if error := response.get("error"): # type: ignore [assignment]
if error["message"].lower() in ["invalid request", "parse error"]:
if self.controller._time_of_request_type_change == 0:
self.controller.request_type = Request
self.controller._time_of_request_type_change = time.time()
if time.time() - self.controller._time_of_request_type_change <= 600:
logger.debug(
"your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead"
)
return await self.controller(self.method, self.params)
response["error"] = dict(error)
response["error"]["dankmids_added_context"] = self.request
# I'm 99.99999% sure that any errd call has no result and we only get this field from mscspec object defs
# But I'll check it anyway to be safe
if result := response.pop("result", None):
response["result"] = result
logger.debug("error response for %s: %s", self, response)
response = self.response.decode(partial=True)
if response.error is None:
if self.raw:
return {"result": response.result}
response_dict = response.to_dict(self.method)
assert "result" in response_dict or "error" in response_dict, (
response_dict,
type(response_dict),
)
return response_dict

if response.error.message.lower() in ["invalid request", "parse error"]:
if self.controller._time_of_request_type_change == 0:
self.controller.request_type = Request
self.controller._time_of_request_type_change = time.time()
if time.time() - self.controller._time_of_request_type_change <= 600:
logger.debug(
"your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead"
)
method = self.method
if self.raw:
method += "_raw"
return await self.controller(method, self.params)

error = dict(response.error.items())
error["dankmids_added_context"] = self.request

response = response.to_dict(self.method)
response["error"] = error
logger.debug("error response for %s: %s", self, response)
return response

# If we have an Exception here it came from the goofy sync_call thing I need to get rid of.
# We raise it here so it traces back up to the caller
if isinstance(self.response, Exception):
__raise_more_detailed_exc(self.request, self.response)
_raise_more_detailed_exc(self.request, self.response)
# Less optimal decoding
# TODO: refactor this out
assert "result" in self.response or "error" in self.response, (
self.response,
type(self.response),
)
return self.response

@set_done
Expand All @@ -289,7 +309,10 @@ async def get_response_unbatched(self) -> RPCResponse: # type: ignore [override
t.cancel()
for task in done:
return await task
return self.response.decode(partial=True).to_dict(self.method)
response = self.response.decode(partial=True)
retval = {"result": response.result} if self.raw else response.to_dict(self.method)
assert "result" in retval or "error" in retval, (retval, type(retval))
return retval

@set_done
async def spoof_response(self, data: Union[RawResponse, bytes, Exception]) -> None:
Expand Down Expand Up @@ -352,8 +375,12 @@ async def create_duplicate(
# Creating the task before awaiting the new call ensures the new call will grab the semaphore immediately
# and then the task will try to acquire at the very next event loop _run_once cycle
logger.warning("%s got stuck, we're creating a new one", self)
retval = await self.controller(self.method, self.params)
method = self.method
if self.raw:
method += "_raw"
retval = await self.controller(method, self.params)
await self.semaphore.acquire()
assert "result" in retval or "error" in retval, (retval, type(retval))
return retval


Expand Down Expand Up @@ -581,18 +608,11 @@ def mcall_encode(data: List[Tuple[bool, bytes]]) -> bytes:

def mcall_decode(data: PartialResponse) -> Union[List[Tuple[bool, bytes]], Exception]:
try:
decoded = data.decode_result("eth_call")[2:] # type: ignore [arg-type]
decoded = bytes.fromhex(decoded)
return mcall_decoder(decoding.ContextFramesBytesIO(decoded))[2]
return mcall_decoder(decoding.ContextFramesBytesIO(data.decode_result("eth_call")))[2]
except Exception as e:
# NOTE: We need to safely bring any Exceptions back out of the ProcessPool
try:
# We do this goofy thing since we can't `return Exc() from e`
raise e.__class__(
*e.args, data.decode_result() if isinstance(data, PartialResponse) else data
) from e
except Exception as new_e:
return new_e
e.args = (*e.args, data.decode_result() if isinstance(data, PartialResponse) else data)
return e


class Multicall(_Batch[RPCResponse, eth_call]):
Expand Down Expand Up @@ -1208,7 +1228,7 @@ def __format_error(request: PartialRequest, response: PartialResponse) -> Attrib
return AttributeDict.recursive(error)


def __raise_more_detailed_exc(request: PartialRequest, exc: Exception) -> NoReturn:
def _raise_more_detailed_exc(request: PartialRequest, exc: Exception) -> NoReturn:
if isinstance(exc, ClientResponseError):
raise DankMidsClientResponseError(exc, request) from exc
try:
Expand Down
3 changes: 2 additions & 1 deletion dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import eth_retry
from eth_typing import BlockNumber, ChecksumAddress
from eth_utils import to_checksum_address
from evmspec._ids import ChainId
from msgspec import Struct
from multicall.constants import MULTICALL2_ADDRESSES, MULTICALL_ADDRESSES
from multicall.multicall import NotSoBrightBatcher
Expand All @@ -21,7 +22,7 @@
from dank_mids._uid import UIDGenerator, _AlertingRLock
from dank_mids.helpers import _codec, _helpers, _session
from dank_mids.semaphores import _MethodQueues, _MethodSemaphores, BlockSemaphore
from dank_mids.types import BlockId, ChainId, PartialRequest, RawResponse, Request
from dank_mids.types import BlockId, PartialRequest, RawResponse, Request

try:
from multicall.constants import MULTICALL3_ADDRESSES
Expand Down
Loading
Loading