diff --git a/dank_mids/_batch.py b/dank_mids/_batch.py index e6fd85f9..6d433d9a 100644 --- a/dank_mids/_batch.py +++ b/dank_mids/_batch.py @@ -125,11 +125,11 @@ def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None, check_len = min(CHECK, self.controller.batcher.step) # Go thru the multicalls and add calls to the batch for mcall in self.multicalls.values(): - # NOTE: If a multicall has less than `CHECK` calls, we should just throw the calls into a jsonrpc batch individually. - try: # NOTE: This should be faster than using len(). - mcall[check_len] + if len(mcall) >= check_len: working_batch.append(mcall, skip_check=True) - except IndexError: + else: + # NOTE: If a multicall has less than `check_len` calls, we should + # just throw the calls into a jsonrpc batch individually. working_batch.extend(mcall, skip_check=True) if working_batch.is_full: yield working_batch diff --git a/dank_mids/_requests.py b/dank_mids/_requests.py index bf154114..a6089040 100644 --- a/dank_mids/_requests.py +++ b/dank_mids/_requests.py @@ -495,40 +495,73 @@ def semaphore(self) -> a_sync.Semaphore: _Request = TypeVar("_Request", bound=_RequestMeta) +class WeakRequestList(Generic[_Request]): + def __init__(self, data=None): + self._refs = {} # Mapping from object ID to weak reference + if data is not None: + for item in data: + self.append(item) + + def _gc_callback(self, item: _Request) -> None: + # Callback when a weakly-referenced object is garbage collected + self._refs.pop(id(item), None) # Safely remove the item if it exists + + def append(self, item: _Request) -> None: + # Keep a weak reference with a callback for when the item is collected + ref = weakref.ref(item, self._gc_callback) + self._refs[id(item)] = ref + + def extend(self, items: Iterable[_Request]) -> None: + for item in items: + self.append(item) + + def __len__(self) -> int: + return len(self._refs) + + def __bool__(self) -> bool: + return bool(self._refs) + + def remove(self, item: _Request) -> None: + obj_id = id(item) + ref = self._refs.get(obj_id) + if ref is not None and ref() is item: + del self._refs[obj_id] + else: + raise ValueError("list.remove(x): x not in list") + + def __contains__(self, item: _Request) -> bool: + ref = self._refs.get(id(item)) + return ref is not None and ref() is item + + def __iter__(self) -> Iterator[_Request]: + for ref in self._refs.values(): + item = ref() + if item is not None: + yield item + + def __repr__(self): + # Use list comprehension syntax within the repr function for clarity + return f"WeakList([{', '.join(repr(item) for item in self)}])" + + class _Batch(_RequestMeta[List[_Response]], Iterable[_Request]): - __slots__ = "_calls", "_lock", "_daemon" - _calls: List["weakref.ref[_Request]"] + __slots__ = "calls", "_lock", "_daemon" + calls: WeakRequestList[_Request] def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]): self.controller = controller - self._calls = [weakref.ref(call) for call in calls] + self.calls = WeakRequestList(calls) self._lock = _AlertingRLock(name=self.__class__.__name__) super().__init__() def __bool__(self) -> bool: - try: - next(self.calls) - return True - except StopIteration: - return False - - @overload - def __getitem__(self, ix: int) -> _Request: ... - @overload - def __getitem__(self, ix: slice) -> Tuple[_Request, ...]: ... - def __getitem__(self, ix: Union[int, slice]) -> Union[_Request, Tuple[_Request, ...]]: - return tuple(self.calls)[ix] + return bool(self.calls) def __iter__(self) -> Iterator[_Request]: - return self.calls + return iter(self.calls) def __len__(self) -> int: - return sum(1 for _ in self.calls) - - @property - def calls(self) -> Iterator[_Request]: - "Returns a list of calls. Creates a temporary strong reference to each call in the batch, if it still exists." - return (call for ref in self._calls if (call := ref())) + return len(self.calls) @property def bisected(self) -> Generator[Tuple[_Request, ...], None, None]: @@ -544,7 +577,7 @@ def is_full(self) -> bool: def append(self, call: _Request, skip_check: bool = False) -> None: with self._lock: - self._calls.append(weakref.ref(call)) + self.calls.append(call) # self._len += 1 if not skip_check: if self.is_full: @@ -554,7 +587,7 @@ def append(self, call: _Request, skip_check: bool = False) -> None: def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None: with self._lock: - self._calls.extend(weakref.ref(call) for call in calls) + self.calls.extend(calls) if not skip_check: if self.is_full: self.start() @@ -642,7 +675,7 @@ def __repr__(self) -> str: @cached_property def block(self) -> BlockId: - return next(self.calls).block + return next(iter(self.calls)).block @property def calldata(self) -> str: @@ -759,20 +792,18 @@ async def spoof_response( async def decode(self, data: PartialResponse) -> List[Tuple[bool, bytes]]: start = time.time() - if ENVS.OPERATION_MODE.infura: # type: ignore [attr-defined] + if ENVS.OPERATION_MODE.infura or len(self) < 100: + # decode synchronously retval = mcall_decode(data) else: - try: # NOTE: Quickly check for length without counting each item with `len`. - if not ENVS.OPERATION_MODE.application: # type: ignore [attr-defined] - self[100] + try: retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data) # type: ignore [attr-defined] - except IndexError: - retval = mcall_decode(data) except BrokenProcessPool: # TODO: Move this somewhere else logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data) # type: ignore [attr-defined] ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined,assignment] retval = mcall_decode(data) + stats.log_duration(f"multicall decoding for {len(self)} calls", start) # Raise any Exceptions that may have come out of the process pool. if isinstance(retval, Exception): @@ -807,7 +838,7 @@ async def bisect_and_retry(self, e: Exception) -> List[RPCResponse]: @set_done async def _exec_single_call(self) -> None: - await next(self.calls).make_request() + await next(iter(self.calls)).make_request() def _post_future_cleanup(self) -> None: with suppress(KeyError):