Skip to content

Commit

Permalink
frank/user stats map (#89)
Browse files Browse the repository at this point in the history
* feat: user stats map

* chore: remove unnecessary async

* fix: default factory for order params defaults
  • Loading branch information
soundsonacid authored Apr 9, 2024
1 parent 7f24dea commit 45cd32e
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 8 deletions.
21 changes: 21 additions & 0 deletions src/driftpy/accounts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
UserAccount,
OraclePriceData,
StateAccount,
UserStatsAccount,
)

T = TypeVar("T")
Expand Down Expand Up @@ -115,3 +116,23 @@ async def update_data(self, data: Optional[DataAndSlot[UserAccount]]):
@abstractmethod
def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
pass


class UserStatsAccountSubscriber:
@abstractmethod
async def subscribe(self):
pass

@abstractmethod
def unsubscribe(self):
pass

@abstractmethod
async def fetch(self):
pass

@abstractmethod
def get_user_stats_account_and_slot(
self,
) -> Optional[DataAndSlot[UserStatsAccount]]:
pass
1 change: 1 addition & 0 deletions src/driftpy/accounts/ws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .drift_client import *
from .user import *
from .user_stats import *
10 changes: 8 additions & 2 deletions src/driftpy/accounts/ws/account_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_account_data_and_slot
from driftpy.accounts import UserAccountSubscriber, DataAndSlot
from driftpy.accounts import (
UserAccountSubscriber,
DataAndSlot,
UserStatsAccountSubscriber,
)

import websockets
import websockets.exceptions # force eager imports
Expand All @@ -19,7 +23,9 @@
T = TypeVar("T")


class WebsocketAccountSubscriber(UserAccountSubscriber, Generic[T]):
class WebsocketAccountSubscriber(
UserAccountSubscriber, UserStatsAccountSubscriber, Generic[T]
):
def __init__(
self,
pubkey: Pubkey,
Expand Down
16 changes: 16 additions & 0 deletions src/driftpy/accounts/ws/user_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional

from driftpy.accounts import DataAndSlot
from driftpy.types import UserStatsAccount

from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.accounts.types import UserStatsAccountSubscriber


class WebsocketUserStatsAccountSubscriber(
WebsocketAccountSubscriber[UserStatsAccount], UserStatsAccountSubscriber
):
def get_user_stats_account_and_slot(
self,
) -> Optional[DataAndSlot[UserStatsAccount]]:
return self.data_and_slot
73 changes: 73 additions & 0 deletions src/driftpy/drift_user_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from dataclasses import dataclass
from typing import Optional

from solders.pubkey import Pubkey
from solana.rpc.commitment import Commitment, Confirmed

from driftpy.accounts.types import DataAndSlot
from driftpy.types import ReferrerInfo, UserStatsAccount
from driftpy.accounts.ws.user_stats import WebsocketUserStatsAccountSubscriber
from driftpy.addresses import (
get_user_account_public_key,
get_user_stats_account_public_key,
)
from driftpy.drift_client import DriftClient


@dataclass
class UserStatsSubscriptionConfig:
commitment: Commitment = Confirmed
resub_timeout_ms: Optional[int] = None
initial_data: Optional[DataAndSlot[UserStatsAccount]] = None


class DriftUserStats:
def __init__(
self,
drift_client: DriftClient,
user_stats_account_pubkey: Pubkey,
config: UserStatsSubscriptionConfig,
):
self.drift_client = drift_client
self.user_stats_account_pubkey = user_stats_account_pubkey
self.account_subscriber = WebsocketUserStatsAccountSubscriber(
user_stats_account_pubkey,
drift_client.program,
config.commitment,
initial_data=config.initial_data,
)
self.subscribed = False

async def subscribe(self) -> bool:
if self.subscribed:
return

await self.account_subscriber.subscribe()
self.subscribed = True

return self.subscribed

async def fetch_accounts(self):
await self.account_subscriber.fetch()

def unsubscribe(self):
self.account_subscriber.unsubscribe()

def get_account_and_slot(self) -> DataAndSlot[UserStatsAccount]:
return self.account_subscriber.get_user_stats_account_and_slot()

def get_account(self) -> UserStatsAccount:
return self.account_subscriber.get_user_stats_account_and_slot().data

def get_referrer_info(self) -> Optional[ReferrerInfo]:
if self.get_account().referrer == Pubkey.default():
return None
else:
return ReferrerInfo(
get_user_account_public_key(
self.drift_client.program_id, self.get_account().referrer, 0
),
get_user_stats_account_public_key(
self.drift_client.program_id, self.get_account().referrer
),
)
4 changes: 4 additions & 0 deletions src/driftpy/memcmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ def get_market_type_filter(market_type: MarketType) -> MemcmpOpts:
return MemcmpOpts(
0, base58.b58encode(_account_discriminator("SpotMarket")).decode()
)


def get_user_stats_filter() -> MemcmpOpts:
return MemcmpOpts(0, base58.b58encode(_account_discriminator("UserStats")).decode())
6 changes: 4 additions & 2 deletions src/driftpy/user_map/user_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from driftpy.drift_user import DriftUser
from driftpy.account_subscription_config import AccountSubscriptionConfig

from driftpy.types import UserAccount
from driftpy.types import OrderRecord, UserAccount

from driftpy.user_map.user_map_config import UserMapConfig, PollingConfig
from driftpy.user_map.websocket_sub import WebsocketSubscription
Expand Down Expand Up @@ -127,6 +127,9 @@ async def add_pubkey(

self.user_map[str(user_account_public_key)] = user

async def update_with_order_record(self, record: OrderRecord):
self.must_get(str(record.user))

async def sync(self) -> None:
async with self.sync_lock:
try:
Expand Down Expand Up @@ -199,7 +202,6 @@ async def sync(self) -> None:

except Exception as e:
print(f"Error in UserMap.sync(): {e}")
traceback.print_exc()

# this is used as a callback for ws subscriptions to update data as its streamed
async def update_user_account(self, key: str, data: DataAndSlot[UserAccount]):
Expand Down
14 changes: 12 additions & 2 deletions src/driftpy/user_map/user_map_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,40 @@
from typing import Optional, Union
from driftpy.drift_client import DriftClient


@dataclass
class UserAccountFilterCriteria:
# only return users that have open orders
has_open_orders: bool


@dataclass
class PollingConfig:
frequency: int
commitment: Optional[Commitment] = None


@dataclass
class WebsocketConfig:
resub_timeout_ms: Optional[int] = None
commitment: Optional[Commitment] = None


@dataclass
class UserMapConfig:
drift_client: DriftClient
subscription_config: Union[PollingConfig, WebsocketConfig]
# connection object to use specifically for the UserMap.
# connection object to use specifically for the UserMap.
# If None, will use the drift_client's connection
connection: Optional[AsyncClient] = None
# True to skip the initial load of user_accounts via gPA
skip_initial_load: Optional[bool] = None
skip_initial_load: Optional[bool] = False
# True to include idle users when loading.
# Defaults to false to decrease # of accounts subscribed to
include_idle: Optional[bool] = None


@dataclass
class UserStatsMapConfig:
drift_client: DriftClient
connection: Optional[AsyncClient] = None
Loading

0 comments on commit 45cd32e

Please sign in to comment.