diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index 29c20eb9..75905249 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -20,6 +20,9 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"): self.commitment = commitment self.cache = None + async def subscribe(self): + await self.cache_if_needed() + async def update_cache(self): if self.cache is None: self.cache = {} @@ -83,7 +86,7 @@ async def get_spot_market_and_slot( await self.cache_if_needed() return self.cache["spot_markets"][market_index] - async def get_oracle_data_and_slot( + async def get_oracle_price_data_and_slot( self, oracle: Pubkey ) -> Optional[DataAndSlot[OraclePriceData]]: await self.cache_if_needed() @@ -92,3 +95,6 @@ async def get_oracle_data_and_slot( async def cache_if_needed(self): if self.cache is None: await self.update_cache() + + def unsubscribe(self): + self.cache = None diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py index 6c79f3ab..068cdbe7 100644 --- a/src/driftpy/accounts/cache/user.py +++ b/src/driftpy/accounts/cache/user.py @@ -21,6 +21,9 @@ def __init__( self.user_pubkey = user_pubkey self.user_and_slot = None + async def subscribe(self): + await self.cache_if_needed() + async def update_cache(self): user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey) self.user_and_slot = user_and_slot @@ -32,3 +35,6 @@ async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: async def cache_if_needed(self): if self.user_and_slot is None: await self.update_cache() + + def unsubscribe(self): + self.user_and_slot = None diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index 31fc8837..812fb687 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -1,5 +1,5 @@ import base64 -from typing import cast +from typing import cast, Optional, Callable from solders.pubkey import Pubkey from anchorpy import Program, ProgramAccount from solana.rpc.commitment import Commitment @@ -10,7 +10,10 @@ async def get_account_data_and_slot( - address: Pubkey, program: Program, commitment: Commitment = "processed" + address: Pubkey, + program: Program, + commitment: Commitment = "processed", + decode: Optional[Callable[[bytes], T]] = None, ) -> Optional[DataAndSlot[T]]: account_info = await program.provider.connection.get_account_info( address, @@ -24,7 +27,9 @@ async def get_account_data_and_slot( slot = account_info.context.slot data = account_info.value.data - decoded_data = program.coder.accounts.decode(data) + decoded_data = ( + decode(data) if decode is not None else program.coder.accounts.decode(data) + ) return DataAndSlot(slot, decoded_data) diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py index 5e2a28c8..1e980139 100644 --- a/src/driftpy/accounts/oracle.py +++ b/src/driftpy/accounts/oracle.py @@ -7,7 +7,6 @@ from solders.pubkey import Pubkey from pythclient.pythaccounts import PythPriceInfo, _ACCOUNT_HEADER_BYTES, EmaType from solana.rpc.async_api import AsyncClient -import base64 import struct @@ -21,26 +20,12 @@ async def get_oracle_price_data_and_slot( if "Pyth" in str(oracle_source): rpc_reponse = await connection.get_account_info(address) rpc_response_slot = rpc_reponse.context.slot - (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info( - rpc_reponse - ) - scale = 1 - if "1K" in str(oracle_source): - scale = 1e3 - elif "1M" in str(oracle_source): - scale = 1e6 - - oracle_data = OraclePriceData( - price=convert_pyth_price(pyth_price_info.price, scale), - slot=pyth_price_info.pub_slot, - confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale), - twap=convert_pyth_price(twap, scale), - twap_confidence=convert_pyth_price(twac, scale), - has_sufficient_number_of_datapoints=True, + oracle_price_data = decode_pyth_price_info( + rpc_reponse.value.data, oracle_source ) - return DataAndSlot(data=oracle_data, slot=rpc_response_slot) + return DataAndSlot(data=oracle_price_data, slot=rpc_response_slot) elif "Quote" in str(oracle_source): return DataAndSlot( data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0 @@ -49,11 +34,10 @@ async def get_oracle_price_data_and_slot( raise NotImplementedError("Unsupported Oracle Source", str(oracle_source)) -async def _parse_pyth_price_info( - resp: GetAccountInfoResp, -) -> (PythPriceInfo, int, int, int): - buffer = resp.value.data - +def decode_pyth_price_info( + buffer: bytes, + oracle_source=OracleSource.PYTH(), +) -> OraclePriceData: offset = _ACCOUNT_HEADER_BYTES _, exponent, _ = struct.unpack_from(" Optional[DataAndSlot[State]]: pass @@ -40,7 +48,7 @@ async def get_spot_market_and_slot( pass @abstractmethod - async def get_oracle_data_and_slot( + async def get_oracle_price_data_and_slot( self, oracle: Pubkey ) -> Optional[DataAndSlot[OraclePriceData]]: pass diff --git a/src/driftpy/accounts/ws/__init__.py b/src/driftpy/accounts/ws/__init__.py new file mode 100644 index 00000000..58e298b4 --- /dev/null +++ b/src/driftpy/accounts/ws/__init__.py @@ -0,0 +1,2 @@ +from .drift_client import * +from .user import * diff --git a/src/driftpy/accounts/ws/account_subscriber.py b/src/driftpy/accounts/ws/account_subscriber.py new file mode 100644 index 00000000..d8862f48 --- /dev/null +++ b/src/driftpy/accounts/ws/account_subscriber.py @@ -0,0 +1,92 @@ +import asyncio +from typing import Optional + +from anchorpy import Program +from solders.pubkey import Pubkey +from solana.rpc.commitment import Commitment + +from driftpy.accounts import get_account_data_and_slot +from driftpy.accounts import UserAccountSubscriber, DataAndSlot + +import websockets +import websockets.exceptions # force eager imports +from solana.rpc.websocket_api import connect + +from typing import cast, Generic, TypeVar, Callable + +T = TypeVar("T") + + +class WebsocketAccountSubscriber(UserAccountSubscriber, Generic[T]): + def __init__( + self, + pubkey: Pubkey, + program: Program, + commitment: Commitment = "confirmed", + decode: Optional[Callable[[bytes], T]] = None, + ): + self.program = program + self.commitment = commitment + self.pubkey = pubkey + self.data_and_slot = None + self.task = None + self.decode = ( + decode if decode is not None else self.program.coder.accounts.decode + ) + + async def subscribe(self): + if self.data_and_slot is None: + await self.fetch() + + self.task = asyncio.create_task(self.subscribe_ws()) + return self.task + + async def subscribe_ws(self): + ws_endpoint = self.program.provider.connection._provider.endpoint_uri.replace( + "https", "wss" + ).replace("http", "ws") + async for ws in connect(ws_endpoint): + try: + await ws.account_subscribe( # type: ignore + self.pubkey, + commitment=self.commitment, + encoding="base64", + ) + first_resp = await ws.recv() + subscription_id = cast(int, first_resp[0].result) # type: ignore + + async for msg in ws: + try: + slot = int(msg[0].result.context.slot) # type: ignore + + if msg[0].result.value is None: + continue + + account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore + decoded_data = self.decode(account_bytes) + self._update_data(DataAndSlot(slot, decoded_data)) + except Exception: + print(f"Error processing account data") + break + await ws.account_unsubscribe(subscription_id) # type: ignore + except websockets.exceptions.ConnectionClosed: + print("Websocket closed, reconnecting...") + continue + + async def fetch(self): + new_data = await get_account_data_and_slot( + self.pubkey, self.program, self.commitment, self.decode + ) + + self._update_data(new_data) + + def _update_data(self, new_data: Optional[DataAndSlot[T]]): + if new_data is None: + return + + if self.data_and_slot is None or new_data.slot > self.data_and_slot.slot: + self.data_and_slot = new_data + + def unsubscribe(self): + self.task.cancel() + self.task = None diff --git a/src/driftpy/accounts/ws/drift_client.py b/src/driftpy/accounts/ws/drift_client.py new file mode 100644 index 00000000..2eaa5572 --- /dev/null +++ b/src/driftpy/accounts/ws/drift_client.py @@ -0,0 +1,121 @@ +from anchorpy import Program +from solders.pubkey import Pubkey +from solana.rpc.commitment import Commitment + +from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot +from typing import Optional + +from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber +from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State + +from driftpy.addresses import * + +from driftpy.types import OracleSource + +from driftpy.accounts.oracle import decode_pyth_price_info + + +class WebsocketDriftClientAccountSubscriber(DriftClientAccountSubscriber): + def __init__(self, program: Program, commitment: Commitment = "confirmed"): + self.program = program + self.commitment = commitment + self.state_subscriber = None + self.spot_market_subscribers = {} + self.perp_market_subscribers = {} + self.oracle_subscribers = {} + + async def subscribe(self): + state_public_key = get_state_public_key(self.program.program_id) + self.state_subscriber = WebsocketAccountSubscriber[State]( + state_public_key, self.program, self.commitment + ) + await self.state_subscriber.subscribe() + + for i in range(self.state_subscriber.data_and_slot.data.number_of_spot_markets): + await self.subscribe_to_spot_market(i) + + for i in range(self.state_subscriber.data_and_slot.data.number_of_markets): + await self.subscribe_to_perp_market(i) + + async def subscribe_to_spot_market(self, market_index: int): + if market_index in self.spot_market_subscribers: + return + + spot_market_public_key = get_spot_market_public_key( + self.program.program_id, market_index + ) + spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket]( + spot_market_public_key, self.program, self.commitment + ) + await spot_market_subscriber.subscribe() + self.spot_market_subscribers[market_index] = spot_market_subscriber + + spot_market = spot_market_subscriber.data_and_slot.data + await self.subscribe_to_oracle(spot_market.oracle, spot_market.oracle_source) + + async def subscribe_to_perp_market(self, market_index: int): + if market_index in self.perp_market_subscribers: + return + + perp_market_public_key = get_perp_market_public_key( + self.program.program_id, market_index + ) + perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket]( + perp_market_public_key, self.program, self.commitment + ) + await perp_market_subscriber.subscribe() + self.perp_market_subscribers[market_index] = perp_market_subscriber + + perp_market = perp_market_subscriber.data_and_slot.data + await self.subscribe_to_oracle( + perp_market.amm.oracle, perp_market.amm.oracle_source + ) + + async def subscribe_to_oracle(self, oracle: Pubkey, oracle_source: OracleSource): + if oracle == Pubkey.default(): + return + + if str(oracle) in self.oracle_subscribers: + return + + oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( + oracle, + self.program, + self.commitment, + self._get_oracle_decode_fn(oracle_source), + ) + await oracle_subscriber.subscribe() + self.oracle_subscribers[str(oracle)] = oracle_subscriber + + def _get_oracle_decode_fn(self, oracle_source: OracleSource): + if "Pyth" in str(oracle_source): + return lambda data: decode_pyth_price_info(data, oracle_source) + else: + raise Exception("Unknown oracle source") + + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + return self.state_subscriber.data_and_slot + + async def get_perp_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[PerpMarket]]: + return self.perp_market_subscribers[market_index].data_and_slot + + async def get_spot_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[SpotMarket]]: + return self.spot_market_subscribers[market_index].data_and_slot + + async def get_oracle_price_data_and_slot( + self, oracle: Pubkey + ) -> Optional[DataAndSlot[OraclePriceData]]: + return self.oracle_subscribers[str(oracle)].data_and_slot + + def unsubscribe(self): + self.state_subscriber.unsubscribe() + for spot_market_subscriber in self.spot_market_subscribers.values(): + spot_market_subscriber.unsubscribe() + for perp_market_subscriber in self.perp_market_subscribers.values(): + perp_market_subscriber.unsubscribe() + for oracle_subscriber in self.oracle_subscribers.values(): + oracle_subscriber.unsubscribe() diff --git a/src/driftpy/accounts/ws/user.py b/src/driftpy/accounts/ws/user.py new file mode 100644 index 00000000..425958ed --- /dev/null +++ b/src/driftpy/accounts/ws/user.py @@ -0,0 +1,14 @@ +from typing import Optional + +from driftpy.accounts import DataAndSlot +from driftpy.types import User + +from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber +from driftpy.accounts.types import UserAccountSubscriber + + +class WebsocketUserAccountSubscriber( + WebsocketAccountSubscriber[User], UserAccountSubscriber +): + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + return self.data_and_slot diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index ddd4bb4e..0811bbe2 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -148,14 +148,14 @@ async def initialize_spot_market( ): state_public_key = get_state_public_key(self.program_id) state = await get_state_account(self.program) - spot_index = state.number_of_spot_markets + spot_market_index = state.number_of_spot_markets - spot_public_key = get_spot_market_public_key(self.program_id, spot_index) + spot_public_key = get_spot_market_public_key(self.program_id, spot_market_index) spot_vault_public_key = get_spot_market_vault_public_key( - self.program_id, spot_index + self.program_id, spot_market_index ) insurance_vault_public_key = get_insurance_fund_vault_public_key( - self.program_id, spot_index + self.program_id, spot_market_index ) return await self.program.rpc["initialize_spot_market"]( diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 1b00cf59..d9f6d432 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -29,7 +29,7 @@ from driftpy.math.positions import is_available, is_spot_position_available from driftpy.accounts import DriftClientAccountSubscriber -from driftpy.accounts.cache import CachedDriftClientAccountSubscriber +from driftpy.accounts.ws import WebsocketDriftClientAccountSubscriber DEFAULT_USER_NAME = "Main Account" @@ -72,7 +72,7 @@ def __init__( self.subaccounts = [0] if account_subscriber is None: - account_subscriber = CachedDriftClientAccountSubscriber(self.program) + account_subscriber = WebsocketDriftClientAccountSubscriber(self.program) self.account_subscriber = account_subscriber @@ -116,6 +116,12 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None): return drift_client + async def subscribe(self): + await self.account_subscriber.subscribe() + + def unsubscribe(self): + self.account_subscriber.unsubscribe() + def get_user_account_public_key(self, user_id=0) -> Pubkey: return get_user_account_public_key(self.program_id, self.authority, user_id) @@ -787,11 +793,10 @@ async def place_perp_order( user_id: int = 0, ): return await self.send_ixs( - [ + [ self.get_increase_compute_ix(), - (await self.get_place_perp_order_ix(order_params, user_id))[-1] + (await self.get_place_perp_order_ix(order_params, user_id))[-1], ] - ) async def get_place_perp_order_ix( @@ -805,24 +810,21 @@ async def get_place_perp_order_ix( ) ix = self.program.instruction["place_perp_order"]( - order_params, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": user_account_public_key, - "authority": self.signer.pubkey(), - }, - remaining_accounts=remaining_accounts, - ), - ) + order_params, + ctx=Context( + accounts={ + "state": self.get_state_public_key(), + "user": user_account_public_key, + "authority": self.signer.pubkey(), + }, + remaining_accounts=remaining_accounts, + ), + ) return ix async def get_place_perp_orders_ix( - self, - order_params: List[OrderParams], - user_id: int = 0, - cancel_all=True + self, order_params: List[OrderParams], user_id: int = 0, cancel_all=True ): user_account_public_key = self.get_user_account_public_key(user_id) writeable_market_indexes = list(set([x.market_index for x in order_params])) @@ -844,7 +846,8 @@ async def get_place_perp_orders_ix( }, remaining_accounts=remaining_accounts, ), - )) + ) + ) for order_param in order_params: ix = self.program.instruction["place_perp_order"]( order_param, diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index 1e7d4922..7a854ef3 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -1,5 +1,5 @@ from driftpy.accounts import UserAccountSubscriber -from driftpy.accounts.cache import CachedUserAccountSubscriber +from driftpy.accounts.ws import WebsocketUserAccountSubscriber from driftpy.drift_client import DriftClient from driftpy.math.positions import * from driftpy.math.margin import * @@ -40,12 +40,18 @@ def __init__( ) if account_subscriber is None: - account_subscriber = CachedUserAccountSubscriber( + account_subscriber = WebsocketUserAccountSubscriber( self.user_public_key, self.program ) self.account_subscriber = account_subscriber + async def subscribe(self): + await self.account_subscriber.subscribe() + + def unsubscribe(self): + self.account_subscriber.unsubscribe() + async def get_spot_oracle_data( self, spot_market: SpotMarket ) -> Optional[OraclePriceData]: @@ -68,16 +74,15 @@ async def get_perp_market(self, market_index: int) -> PerpMarket: async def get_user(self) -> User: return (await self.account_subscriber.get_user_account_and_slot()).data - - async def get_open_orders(self, - # market_type: MarketType, - # market_index: int, - # position_direction: PositionDirection - ): + async def get_open_orders( + self, + # market_type: MarketType, + # market_index: int, + # position_direction: PositionDirection + ): user: User = await self.get_user() return user.orders - async def get_spot_market_liability( self, market_index=None, diff --git a/tests/test.py b/tests/test.py index 4fea2c76..1a5c91ef 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,3 +1,5 @@ +import asyncio + from pytest import fixture, mark from pytest_asyncio import fixture as async_fixture from solders.keypair import Keypair @@ -12,6 +14,7 @@ SPOT_BALANCE_PRECISION, SPOT_WEIGHT_PRECISION, ) +from driftpy.accounts.cache import CachedUserAccountSubscriber, CachedDriftClientAccountSubscriber from math import sqrt from driftpy.drift_user import DriftUser @@ -83,8 +86,9 @@ def provider(program: Program) -> Provider: @async_fixture(scope="session") async def drift_client(program: Program, usdc_mint: Keypair) -> Admin: - admin = Admin(program) + admin = Admin(program, account_subscriber=CachedDriftClientAccountSubscriber(program)) await admin.initialize(usdc_mint.pubkey(), admin_controls_prices=True) + await admin.subscribe() return admin @@ -130,6 +134,8 @@ async def test_initialized_spot_market_2( maintenance_liability_weight=main_liab_weight, ) + await drift_client.account_subscriber.update_cache() + spot_market = await get_spot_market_account(admin_drift_client.program, 1) assert spot_market.market_index == 1 print(spot_market.market_index) @@ -148,6 +154,9 @@ async def initialized_market(drift_client: Admin, workspace: WorkspaceType) -> P PERIODICITY, ) + + await drift_client.account_subscriber.update_cache() + return sol_usd @@ -207,7 +216,8 @@ async def test_open_orders( drift_client: Admin, ): - drift_user = DriftUser(drift_client) + drift_user = DriftUser(drift_client, account_subscriber=CachedUserAccountSubscriber(drift_client.get_user_account_public_key(), drift_client.program)) + await drift_user.subscribe() user_account = await drift_client.get_user(0) assert(len(user_account.orders)==32) @@ -224,6 +234,7 @@ async def test_open_orders( ixs = await drift_client.get_place_perp_orders_ix([order_params]) await drift_client.send_ixs(ixs) await drift_user.account_subscriber.update_cache() + await asyncio.sleep(1) open_orders_after = await drift_user.get_open_orders() assert(open_orders_after[0].base_asset_amount == BASE_PRECISION) assert(open_orders_after[0].order_id == 1) @@ -376,7 +387,7 @@ async def test_liq_perp( user_account = await drift_client.get_user(0) liq, _ = await _airdrop_user(drift_client.program.provider) - liq_drift_client = DriftClient(drift_client.program, liq) + liq_drift_client = DriftClient(drift_client.program, liq, account_subscriber=CachedDriftClientAccountSubscriber(drift_client.program)) usdc_acc = await _create_and_mint_user_usdc( usdc_mint, drift_client.program.provider, USDC_AMOUNT, liq.pubkey() )