Skip to content

Commit

Permalink
feat: WeakList implementation to reduce weakrefs overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Nov 28, 2024
1 parent 5204df3 commit abb17b0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 36 deletions.
8 changes: 4 additions & 4 deletions dank_mids/_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None,
check_len = min(CHECK, self.controller.batcher.step)
# Go thru the multicalls and add calls to the batch
for mcall in self.multicalls.values():
# NOTE: If a multicall has less than `CHECK` calls, we should just throw the calls into a jsonrpc batch individually.
try: # NOTE: This should be faster than using len().
mcall[check_len]
if len(mcall) >= check_len:
working_batch.append(mcall, skip_check=True)
except IndexError:
else:
# NOTE: If a multicall has less than `check_len` calls, we should
# just throw the calls into a jsonrpc batch individually.
working_batch.extend(mcall, skip_check=True)
if working_batch.is_full:
yield working_batch
Expand Down
95 changes: 63 additions & 32 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,40 +495,73 @@ def semaphore(self) -> a_sync.Semaphore:
_Request = TypeVar("_Request", bound=_RequestMeta)


class WeakRequestList(Generic[_Request]):
def __init__(self, data=None):
self._refs = {} # Mapping from object ID to weak reference
if data is not None:
for item in data:
self.append(item)

def _gc_callback(self, item: _Request) -> None:
# Callback when a weakly-referenced object is garbage collected
self._refs.pop(id(item), None) # Safely remove the item if it exists

def append(self, item: _Request) -> None:
# Keep a weak reference with a callback for when the item is collected
ref = weakref.ref(item, self._gc_callback)
self._refs[id(item)] = ref

def extend(self, items: Iterable[_Request]) -> None:
for item in items:
self.append(item)

def __len__(self) -> int:
return len(self._refs)

def __bool__(self) -> bool:
return bool(self._refs)

def remove(self, item: _Request) -> None:
obj_id = id(item)
ref = self._refs.get(obj_id)
if ref is not None and ref() is item:
del self._refs[obj_id]
else:
raise ValueError("list.remove(x): x not in list")

def __contains__(self, item: _Request) -> bool:
ref = self._refs.get(id(item))
return ref is not None and ref() is item

def __iter__(self) -> Iterator[_Request]:
for ref in self._refs.values():
item = ref()
if item is not None:
yield item

def __repr__(self):
# Use list comprehension syntax within the repr function for clarity
return f"WeakList([{', '.join(repr(item) for item in self)}])"


class _Batch(_RequestMeta[List[_Response]], Iterable[_Request]):
__slots__ = "_calls", "_lock", "_daemon"
_calls: List["weakref.ref[_Request]"]
__slots__ = "calls", "_lock", "_daemon"
calls: WeakRequestList[_Request]

def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]):
self.controller = controller
self._calls = [weakref.ref(call) for call in calls]
self.calls = WeakRequestList(calls)
self._lock = _AlertingRLock(name=self.__class__.__name__)
super().__init__()

def __bool__(self) -> bool:
try:
next(self.calls)
return True
except StopIteration:
return False

@overload
def __getitem__(self, ix: int) -> _Request: ...
@overload
def __getitem__(self, ix: slice) -> Tuple[_Request, ...]: ...
def __getitem__(self, ix: Union[int, slice]) -> Union[_Request, Tuple[_Request, ...]]:
return tuple(self.calls)[ix]
return bool(self.calls)

def __iter__(self) -> Iterator[_Request]:
return self.calls
return iter(self.calls)

def __len__(self) -> int:
return sum(1 for _ in self.calls)

@property
def calls(self) -> Iterator[_Request]:
"Returns a list of calls. Creates a temporary strong reference to each call in the batch, if it still exists."
return (call for ref in self._calls if (call := ref()))
return len(self.calls)

@property
def bisected(self) -> Generator[Tuple[_Request, ...], None, None]:
Expand All @@ -544,7 +577,7 @@ def is_full(self) -> bool:

def append(self, call: _Request, skip_check: bool = False) -> None:
with self._lock:
self._calls.append(weakref.ref(call))
self.calls.append(call)
# self._len += 1
if not skip_check:
if self.is_full:
Expand All @@ -554,7 +587,7 @@ def append(self, call: _Request, skip_check: bool = False) -> None:

def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None:
with self._lock:
self._calls.extend(weakref.ref(call) for call in calls)
self.calls.extend(calls)
if not skip_check:
if self.is_full:
self.start()
Expand Down Expand Up @@ -642,7 +675,7 @@ def __repr__(self) -> str:

@cached_property
def block(self) -> BlockId:
return next(self.calls).block
return next(iter(self.calls)).block

@property
def calldata(self) -> str:
Expand Down Expand Up @@ -759,20 +792,18 @@ async def spoof_response(

async def decode(self, data: PartialResponse) -> List[Tuple[bool, bytes]]:
start = time.time()
if ENVS.OPERATION_MODE.infura: # type: ignore [attr-defined]
if ENVS.OPERATION_MODE.infura or len(self) < 100:
# decode synchronously
retval = mcall_decode(data)
else:
try: # NOTE: Quickly check for length without counting each item with `len`.
if not ENVS.OPERATION_MODE.application: # type: ignore [attr-defined]
self[100]
try:
retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data) # type: ignore [attr-defined]
except IndexError:
retval = mcall_decode(data)
except BrokenProcessPool:
# TODO: Move this somewhere else
logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data) # type: ignore [attr-defined]
ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined,assignment]
retval = mcall_decode(data)

stats.log_duration(f"multicall decoding for {len(self)} calls", start)
# Raise any Exceptions that may have come out of the process pool.
if isinstance(retval, Exception):
Expand Down Expand Up @@ -807,7 +838,7 @@ async def bisect_and_retry(self, e: Exception) -> List[RPCResponse]:

@set_done
async def _exec_single_call(self) -> None:
await next(self.calls).make_request()
await next(iter(self.calls)).make_request()

def _post_future_cleanup(self) -> None:
with suppress(KeyError):
Expand Down

0 comments on commit abb17b0

Please sign in to comment.