diff --git a/src/driftpy/accounts/types.py b/src/driftpy/accounts/types.py index d4a11b7c..8c460c1f 100644 --- a/src/driftpy/accounts/types.py +++ b/src/driftpy/accounts/types.py @@ -12,6 +12,7 @@ UserAccount, OraclePriceData, StateAccount, + UserStatsAccount, ) T = TypeVar("T") @@ -115,3 +116,23 @@ async def update_data(self, data: Optional[DataAndSlot[UserAccount]]): @abstractmethod def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: pass + + +class UserStatsAccountSubscriber: + @abstractmethod + async def subscribe(self): + pass + + @abstractmethod + def unsubscribe(self): + pass + + @abstractmethod + async def fetch(self): + pass + + @abstractmethod + def get_user_stats_account_and_slot( + self, + ) -> Optional[DataAndSlot[UserStatsAccount]]: + pass diff --git a/src/driftpy/accounts/ws/__init__.py b/src/driftpy/accounts/ws/__init__.py index 58e298b4..aec82557 100644 --- a/src/driftpy/accounts/ws/__init__.py +++ b/src/driftpy/accounts/ws/__init__.py @@ -1,2 +1,3 @@ from .drift_client import * from .user import * +from .user_stats import * diff --git a/src/driftpy/accounts/ws/account_subscriber.py b/src/driftpy/accounts/ws/account_subscriber.py index be3804e9..4238da44 100644 --- a/src/driftpy/accounts/ws/account_subscriber.py +++ b/src/driftpy/accounts/ws/account_subscriber.py @@ -6,7 +6,11 @@ from solana.rpc.commitment import Commitment from driftpy.accounts import get_account_data_and_slot -from driftpy.accounts import UserAccountSubscriber, DataAndSlot +from driftpy.accounts import ( + UserAccountSubscriber, + DataAndSlot, + UserStatsAccountSubscriber, +) import websockets import websockets.exceptions # force eager imports @@ -19,7 +23,9 @@ T = TypeVar("T") -class WebsocketAccountSubscriber(UserAccountSubscriber, Generic[T]): +class WebsocketAccountSubscriber( + UserAccountSubscriber, UserStatsAccountSubscriber, Generic[T] +): def __init__( self, pubkey: Pubkey, diff --git a/src/driftpy/accounts/ws/user_stats.py b/src/driftpy/accounts/ws/user_stats.py new file mode 100644 index 00000000..2e569377 --- /dev/null +++ b/src/driftpy/accounts/ws/user_stats.py @@ -0,0 +1,16 @@ +from typing import Optional + +from driftpy.accounts import DataAndSlot +from driftpy.types import UserStatsAccount + +from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber +from driftpy.accounts.types import UserStatsAccountSubscriber + + +class WebsocketUserStatsAccountSubscriber( + WebsocketAccountSubscriber[UserStatsAccount], UserStatsAccountSubscriber +): + def get_user_stats_account_and_slot( + self, + ) -> Optional[DataAndSlot[UserStatsAccount]]: + return self.data_and_slot diff --git a/src/driftpy/drift_user_stats.py b/src/driftpy/drift_user_stats.py new file mode 100644 index 00000000..a1cf83f9 --- /dev/null +++ b/src/driftpy/drift_user_stats.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from typing import Optional + +from solders.pubkey import Pubkey +from solana.rpc.commitment import Commitment, Confirmed + +from driftpy.accounts.types import DataAndSlot +from driftpy.types import ReferrerInfo, UserStatsAccount +from driftpy.accounts.ws.user_stats import WebsocketUserStatsAccountSubscriber +from driftpy.addresses import ( + get_user_account_public_key, + get_user_stats_account_public_key, +) +from driftpy.drift_client import DriftClient + + +@dataclass +class UserStatsSubscriptionConfig: + commitment: Commitment = Confirmed + resub_timeout_ms: Optional[int] = None + initial_data: Optional[DataAndSlot[UserStatsAccount]] = None + + +class DriftUserStats: + def __init__( + self, + drift_client: DriftClient, + user_stats_account_pubkey: Pubkey, + config: UserStatsSubscriptionConfig, + ): + self.drift_client = drift_client + self.user_stats_account_pubkey = user_stats_account_pubkey + self.account_subscriber = WebsocketUserStatsAccountSubscriber( + user_stats_account_pubkey, + drift_client.program, + config.commitment, + initial_data=config.initial_data, + ) + self.subscribed = False + + async def subscribe(self) -> bool: + if self.subscribed: + return + + await self.account_subscriber.subscribe() + self.subscribed = True + + return self.subscribed + + async def fetch_accounts(self): + await self.account_subscriber.fetch() + + def unsubscribe(self): + self.account_subscriber.unsubscribe() + + def get_account_and_slot(self) -> DataAndSlot[UserStatsAccount]: + return self.account_subscriber.get_user_stats_account_and_slot() + + def get_account(self) -> UserStatsAccount: + return self.account_subscriber.get_user_stats_account_and_slot().data + + def get_referrer_info(self) -> Optional[ReferrerInfo]: + if self.get_account().referrer == Pubkey.default(): + return None + else: + return ReferrerInfo( + get_user_account_public_key( + self.drift_client.program_id, self.get_account().referrer, 0 + ), + get_user_stats_account_public_key( + self.drift_client.program_id, self.get_account().referrer + ), + ) diff --git a/src/driftpy/memcmp.py b/src/driftpy/memcmp.py index e49c7982..c0fee737 100644 --- a/src/driftpy/memcmp.py +++ b/src/driftpy/memcmp.py @@ -26,3 +26,7 @@ def get_market_type_filter(market_type: MarketType) -> MemcmpOpts: return MemcmpOpts( 0, base58.b58encode(_account_discriminator("SpotMarket")).decode() ) + + +def get_user_stats_filter() -> MemcmpOpts: + return MemcmpOpts(0, base58.b58encode(_account_discriminator("UserStats")).decode()) diff --git a/src/driftpy/user_map/user_map.py b/src/driftpy/user_map/user_map.py index 6e16e3c7..b6873539 100644 --- a/src/driftpy/user_map/user_map.py +++ b/src/driftpy/user_map/user_map.py @@ -14,7 +14,7 @@ from driftpy.drift_user import DriftUser from driftpy.account_subscription_config import AccountSubscriptionConfig -from driftpy.types import UserAccount +from driftpy.types import OrderRecord, UserAccount from driftpy.user_map.user_map_config import UserMapConfig, PollingConfig from driftpy.user_map.websocket_sub import WebsocketSubscription @@ -127,6 +127,9 @@ async def add_pubkey( self.user_map[str(user_account_public_key)] = user + async def update_with_order_record(self, record: OrderRecord): + self.must_get(str(record.user)) + async def sync(self) -> None: async with self.sync_lock: try: @@ -199,7 +202,6 @@ async def sync(self) -> None: except Exception as e: print(f"Error in UserMap.sync(): {e}") - traceback.print_exc() # this is used as a callback for ws subscriptions to update data as its streamed async def update_user_account(self, key: str, data: DataAndSlot[UserAccount]): diff --git a/src/driftpy/user_map/user_map_config.py b/src/driftpy/user_map/user_map_config.py index 638f8101..58e22fdb 100644 --- a/src/driftpy/user_map/user_map_config.py +++ b/src/driftpy/user_map/user_map_config.py @@ -4,30 +4,40 @@ from typing import Optional, Union from driftpy.drift_client import DriftClient + @dataclass class UserAccountFilterCriteria: # only return users that have open orders has_open_orders: bool + @dataclass class PollingConfig: frequency: int commitment: Optional[Commitment] = None + @dataclass class WebsocketConfig: resub_timeout_ms: Optional[int] = None commitment: Optional[Commitment] = None + @dataclass class UserMapConfig: drift_client: DriftClient subscription_config: Union[PollingConfig, WebsocketConfig] - # connection object to use specifically for the UserMap. + # connection object to use specifically for the UserMap. # If None, will use the drift_client's connection connection: Optional[AsyncClient] = None # True to skip the initial load of user_accounts via gPA - skip_initial_load: Optional[bool] = None + skip_initial_load: Optional[bool] = False # True to include idle users when loading. # Defaults to false to decrease # of accounts subscribed to include_idle: Optional[bool] = None + + +@dataclass +class UserStatsMapConfig: + drift_client: DriftClient + connection: Optional[AsyncClient] = None diff --git a/src/driftpy/user_map/userstats_map.py b/src/driftpy/user_map/userstats_map.py new file mode 100644 index 00000000..c3424794 --- /dev/null +++ b/src/driftpy/user_map/userstats_map.py @@ -0,0 +1,233 @@ +import asyncio +import base64 +import traceback + +from typing import Dict, Optional +import jsonrpcclient + +from solders.pubkey import Pubkey +from driftpy.accounts.types import DataAndSlot + +from driftpy.addresses import get_user_stats_account_public_key +from driftpy.drift_user_stats import DriftUserStats, UserStatsSubscriptionConfig +from driftpy.memcmp import get_user_stats_filter +from driftpy.types import ( + NewUserRecord, + DepositRecord, + InsuranceFundStakeRecord, + LPRecord, + FundingPaymentRecord, + LiquidationRecord, + SettlePnlRecord, + OrderRecord, + OrderActionRecord, + UserStatsAccount, +) +from driftpy.events.types import WrappedEvent +from driftpy.user_map.user_map_config import UserStatsMapConfig +from driftpy.user_map.user_map import UserMap + + +class UserStatsMap: + def __init__(self, config: UserStatsMapConfig): + self.user_stats_map: Dict[str, DriftUserStats] = {} + + self.sync_lock = asyncio.Lock() + self.drift_client = config.drift_client + self.latest_slot: int = 0 + self.connection = config.connection or config.drift_client.connection + + async def subscribe(self): + if self.size() > 0: + return + + await self.drift_client.subscribe() + + await self.sync() + + async def sync(self): + async with self.sync_lock: + try: + filters = [ + { + "memcmp": { + "offset": 0, + "bytes": f"{get_user_stats_filter().bytes}", + } + } + ] + + rpc_request = jsonrpcclient.request( + "getProgramAccounts", + [ + str(self.drift_client.program_id), + {"filters": filters, "encoding": "base64", "withContext": True}, + ], + ) + + post = self.connection._provider.session.post( + self.connection._provider.endpoint_uri, + json=rpc_request, + headers={"content-encoding": "gzip"}, + ) + + resp = await asyncio.wait_for(post, timeout=10) + + parsed_resp = jsonrpcclient.parse(resp.json()) + + slot = int(parsed_resp.result["context"]["slot"]) + + self.latest_slot = slot + + rpc_response_values = parsed_resp.result["value"] + + program_account_buffer_map: Dict[str, UserStatsAccount] = {} + + for program_account in rpc_response_values: + pubkey = program_account["pubkey"] + buffer = base64.b64decode(program_account["account"]["data"][0]) + data = self.drift_client.program.coder.accounts.decode(buffer) + program_account_buffer_map[str(pubkey)] = data + + for pubkey in program_account_buffer_map.keys(): + data = program_account_buffer_map.get(pubkey) + if not self.has(pubkey): + await self.add_user_stat( + Pubkey.from_string(pubkey), DataAndSlot(slot, data) + ) + else: + self.update_user_stat(pubkey, DataAndSlot(slot, data)) + + await asyncio.sleep(0) + + keys_to_delete = [] + for key in list(self.user_stats_map.keys()): + if key not in program_account_buffer_map: + keys_to_delete.append(key) + await asyncio.sleep(0) + + for key in keys_to_delete: + del self.user_stats_map[key] + + except Exception as e: + print(f"Error in UserStatsMap.sync(): {e}") + traceback.print_exc() + + def unsubscribe(self): + keys = list(self.user_stats_map.keys()) + for key in keys: + self.user_stats_map[key].unsubscribe() + del self.user_stats_map[key] + + async def add_user_stat( + self, + authority: Pubkey, + user_stats: Optional[DataAndSlot[UserStatsAccount]] = None, + ): + user_stat = DriftUserStats( + self.drift_client, + get_user_stats_account_public_key(self.drift_client.program_id, authority), + UserStatsSubscriptionConfig(initial_data=user_stats), + ) + self.user_stats_map[str(authority)] = user_stat + + async def update_user_stat( + self, authority: Pubkey, user_stats: DataAndSlot[UserStatsAccount] + ): + await self.must_get(str(authority), user_stats) + self.user_stats_map[str(authority)] = user_stats + + async def update_with_order_record(self, record: OrderRecord, user_map: UserMap): + user = await user_map.must_get(str(record.user)) + await self.must_get(str(user.get_user_account().authority)) + + async def update_with_event_record( + self, record: WrappedEvent, user_map: Optional[UserMap] = None + ): + if record.event_type == "DepositRecord": + deposit_record: DepositRecord = record + await self.must_get(str(deposit_record.user_authority)) + + elif record.event_type == "FundingPaymentRecord": + funding_payment_record: FundingPaymentRecord = record + await self.must_get(str(funding_payment_record.user_authority)) + + elif record.event_type == "LiquidationRecord": + if not user_map: + return + + liq_record: LiquidationRecord = record + + user = await user_map.must_get(str(liq_record.user)) + await self.must_get(str(user.get_user_account().authority)) + + liquidator = await user_map.must_get(str(liq_record.liquidator)) + await self.must_get(str(liquidator.get_user_account().authority)) + + elif record.event_type == "OrderRecord": + if not user_map: + return + + order_record: OrderRecord = record + await user_map.update_with_order_record(order_record) + + elif record.event_type == "OrderActionRecord": + if not user_map: + return + + action_record: OrderActionRecord = record + + if action_record.taker: + taker = await user_map.must_get(str(action_record.taker)) + await self.must_get(str(taker.get_user_account().authority)) + + if action_record.maker: + maker = await user_map.must_get(str(action_record.maker)) + await self.must_get(str(maker.get_user_account().authority)) + + elif record.event_type == "SettlePnlRecord": + if not user_map: + return + + settle_record: SettlePnlRecord = record + + user = await user_map.must_get(str(settle_record.user)) + await self.must_get(str(user.get_user_account().authority)) + + elif record.event_type == "NewUserRecord": + new_user_record: NewUserRecord = record + + await self.must_get(str(new_user_record.user_authority)) + + elif record.event_type == "LPRecord": + if not user_map: + return + + lp_record: LPRecord = record + + user = await user_map.must_get(str(lp_record.user)) + await self.must_get(str(user.get_user_account().authority)) + + elif record.event_type == "InsuranceFundStakeRecord": + stake_record: InsuranceFundStakeRecord = record + + await self.must_get(str(stake_record.authority)) + + def values(self): + return self.user_stats_map.values() + + def size(self): + return len(self.user_stats_map) + + def has(self, pubkey: str) -> bool: + return pubkey in self.user_stats_map + + def get(self, pubkey: str): + return self.user_stats_map.get(pubkey) + + async def must_get( + self, pubkey: str, user_stats: Optional[DataAndSlot[UserStatsAccount]] = None + ): + if not self.has(pubkey): + await self.add_user_stat(Pubkey.from_string(pubkey), user_stats) + return self.get(pubkey) diff --git a/src/driftpy/user_map/websocket_sub.py b/src/driftpy/user_map/websocket_sub.py index 4b8485b1..9a9a904a 100644 --- a/src/driftpy/user_map/websocket_sub.py +++ b/src/driftpy/user_map/websocket_sub.py @@ -1,11 +1,9 @@ from typing import Callable, Optional, TypeVar -from driftpy.account_subscription_config import AccountSubscriptionConfig from driftpy.accounts.ws.program_account_subscriber import ( WebSocketProgramAccountSubscriber, ) from driftpy.memcmp import get_user_filter, get_non_idle_user_filter from driftpy.accounts.types import UpdateCallback, WebsocketProgramAccountOptions -from driftpy.user_map.types import ConfigType T = TypeVar("T")