diff --git a/src/driftpy/accounts/bulk_account_loader.py b/src/driftpy/accounts/bulk_account_loader.py index 2b705df1..66ec03cc 100644 --- a/src/driftpy/accounts/bulk_account_loader.py +++ b/src/driftpy/accounts/bulk_account_loader.py @@ -12,7 +12,7 @@ @dataclass class AccountToLoad: pubkey: Pubkey - callback: Mapping[int, Callable[[bytes, int], None]] + callbacks: Mapping[int, Callable[[bytes, int], None]] @dataclass @@ -44,12 +44,12 @@ def add_account( ) -> int: existing_size = len(self.accounts_to_load) - callback_id = self.callback_id + callback_id = self.get_callback_id() pubkey_str = str(pubkey) existing_account_to_load = self.accounts_to_load.get(pubkey_str) if existing_account_to_load is not None: - existing_account_to_load.callback[callback_id] = callback + existing_account_to_load.callbacks[callback_id] = callback else: callbacks = {} callbacks[callback_id] = callback @@ -65,7 +65,24 @@ def get_callback_id(self) -> int: return self.callback_id def _start_loading(self): - self.task = asyncio.create_task(self.load()) + if self.task is None: + self.task = asyncio.create_task(self.load()) + + def remove_account(self, pubkey: Pubkey, callback_id: int): + pubkey_str = str(pubkey) + existing_account_to_load = self.accounts_to_load.get(pubkey_str) + if existing_account_to_load is not None: + del existing_account_to_load.callbacks[callback_id] + if len(existing_account_to_load.callbacks) == 0: + del self.accounts_to_load[pubkey_str] + + if len(self.accounts_to_load) == 0: + self._stop_loading() + + def _stop_loading(self): + if self.task is not None: + self.task.cancel() + self.task = None def chunks(self, array: List, size: int) -> List[List]: return [array[i : i + size] for i in range(0, len(array), size)] @@ -144,11 +161,6 @@ async def load_chunk(self, chunk: List[List[AccountToLoad]]): def handle_callbacks( self, account_to_load: AccountToLoad, buffer: Optional[bytes], slot: int ): - for cb in account_to_load.callback.values(): + for cb in account_to_load.callbacks.values(): if bytes is not None: cb(buffer, slot) - - def unsubscribe(self): - if self.task is not None: - self.task.cancel() - self.task = None