diff --git a/src/driftpy/types.py b/src/driftpy/types.py index 3d79f007..87c49900 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -1,3 +1,4 @@ +import zlib import inspect from dataclasses import dataclass, field @@ -48,6 +49,14 @@ def stack_trace(): return f"{file_name}:{line_number}" +def compress(data: bytes) -> bytes: + return zlib.compress(data, level=9) + + +def decompress(data: bytes) -> bytes: + return zlib.decompress(data) + + @_rust_enum class SwapDirection: Add = constructor() @@ -845,6 +854,12 @@ class UserAccount: padding: list[int] = field(default_factory=lambda: [0] * 21) +@dataclass +class PickledUser: + pubkey: Pubkey + data: bytes + + @dataclass class UserFees: total_fee_paid: int diff --git a/src/driftpy/user_map/user_map.py b/src/driftpy/user_map/user_map.py index b6873539..8c08edac 100644 --- a/src/driftpy/user_map/user_map.py +++ b/src/driftpy/user_map/user_map.py @@ -1,6 +1,6 @@ import asyncio import jsonrpcclient -import traceback +import pickle import base64 from typing import Any, Container, Optional, Dict @@ -14,7 +14,7 @@ from driftpy.drift_user import DriftUser from driftpy.account_subscription_config import AccountSubscriptionConfig -from driftpy.types import OrderRecord, UserAccount +from driftpy.types import OrderRecord, PickledUser, UserAccount, compress, decompress from driftpy.user_map.user_map_config import UserMapConfig, PollingConfig from driftpy.user_map.websocket_sub import WebsocketSubscription @@ -31,11 +31,13 @@ class UserMap(UserMapInterface, DLOBSource): def __init__(self, config: UserMapConfig): self.user_map: Dict[str, DriftUser] = {} + self.raw: Dict[str, bytes] = {} self.last_number_of_sub_accounts = None self.sync_lock = asyncio.Lock() self.drift_client: DriftClient = config.drift_client self.latest_slot: int = 0 self.is_subscribed = False + self.last_dumped_slot = 0 if config.connection: self.connection = config.connection else: @@ -95,6 +97,9 @@ def size(self) -> int: def values(self): return iter(self.user_map.values()) + def clear(self): + self.user_map.clear() + def get_user_authority(self, user_account_public_key: str) -> Optional[Pubkey]: user = self.user_map.get(user_account_public_key) if not user: @@ -162,14 +167,17 @@ async def sync(self) -> None: rpc_response_values = parsed_resp.result["value"] program_account_buffer_map: Dict[str, Container[Any]] = {} + raw: Dict[str, bytes] = {} # parse the gPA data before inserting for program_account in rpc_response_values: pubkey = program_account["pubkey"] - data = decode_user( - base64.b64decode(program_account["account"]["data"][0]) - ) + raw_bytes = base64.b64decode(program_account["account"]["data"][0]) + data = decode_user(raw_bytes) program_account_buffer_map[str(pubkey)] = data + raw[str(pubkey)] = raw_bytes + + self.raw = raw # "idempotent" insert into usermap for pubkey in program_account_buffer_map.keys(): @@ -217,3 +225,27 @@ async def get_DLOB(self, slot: int): def get_slot(self) -> int: return self.latest_slot + + def get_last_dump_filepath(self) -> str: + return f"usermap_{self.last_dumped_slot}.pkl" + + async def load(self, filename: Optional[str] = None): + if not filename: + filename = self.get_last_dump_filepath() + start = filename.index("_") + 1 + end = filename.index(".") + slot = int(filename[start:end]) + with open(filename, "rb") as f: + users: list[PickledUser] = pickle.load(f) + for user in users: + data = decode_user(decompress(user.data)) + await self.add_pubkey(user.pubkey, DataAndSlot(slot, data)) + + def dump(self): + users = [] + for pubkey, user in self.raw.items(): + users.append(PickledUser(pubkey=pubkey, data=compress(user))) + self.last_dumped_slot = self.get_slot() + filename = f"usermap_{self.last_dumped_slot}.pkl" + with open(filename, "wb") as f: + pickle.dump(users, f, pickle.HIGHEST_PROTOCOL)