Skip to content

Commit

Permalink
add remove_account to bulk account loader
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 19, 2023
1 parent a0d3815 commit c74ced4
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/driftpy/accounts/bulk_account_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@dataclass
class AccountToLoad:
pubkey: Pubkey
callback: Mapping[int, Callable[[bytes, int], None]]
callbacks: Mapping[int, Callable[[bytes, int], None]]


@dataclass
Expand Down Expand Up @@ -44,12 +44,12 @@ def add_account(
) -> int:
existing_size = len(self.accounts_to_load)

callback_id = self.callback_id
callback_id = self.get_callback_id()

pubkey_str = str(pubkey)
existing_account_to_load = self.accounts_to_load.get(pubkey_str)
if existing_account_to_load is not None:
existing_account_to_load.callback[callback_id] = callback
existing_account_to_load.callbacks[callback_id] = callback
else:
callbacks = {}
callbacks[callback_id] = callback
Expand All @@ -65,7 +65,24 @@ def get_callback_id(self) -> int:
return self.callback_id

def _start_loading(self):
self.task = asyncio.create_task(self.load())
if self.task is None:
self.task = asyncio.create_task(self.load())

def remove_account(self, pubkey: Pubkey, callback_id: int):
pubkey_str = str(pubkey)
existing_account_to_load = self.accounts_to_load.get(pubkey_str)
if existing_account_to_load is not None:
del existing_account_to_load.callbacks[callback_id]
if len(existing_account_to_load.callbacks) == 0:
del self.accounts_to_load[pubkey_str]

if len(self.accounts_to_load) == 0:
self._stop_loading()

def _stop_loading(self):
if self.task is not None:
self.task.cancel()
self.task = None

def chunks(self, array: List, size: int) -> List[List]:
return [array[i : i + size] for i in range(0, len(array), size)]
Expand Down Expand Up @@ -144,11 +161,6 @@ async def load_chunk(self, chunk: List[List[AccountToLoad]]):
def handle_callbacks(
self, account_to_load: AccountToLoad, buffer: Optional[bytes], slot: int
):
for cb in account_to_load.callback.values():
for cb in account_to_load.callbacks.values():
if bytes is not None:
cb(buffer, slot)

def unsubscribe(self):
if self.task is not None:
self.task.cancel()
self.task = None

0 comments on commit c74ced4

Please sign in to comment.