From cebc9c618ce9ace63b3c69b79c061048cbaeddef Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Mon, 16 Dec 2024 17:12:09 -0500 Subject: [PATCH 1/3] ws test working --- src/driftpy/accounts/ws/drift_client.py | 13 +++-- src/driftpy/admin.py | 54 ++++++++++++++++--- src/driftpy/setup/helpers.py | 4 +- src/driftpy/types.py | 10 ++-- tests/integration/test_oracle_diff_sources.py | 2 +- 5 files changed, 65 insertions(+), 18 deletions(-) diff --git a/src/driftpy/accounts/ws/drift_client.py b/src/driftpy/accounts/ws/drift_client.py index d6f7bc57..1b866766 100644 --- a/src/driftpy/accounts/ws/drift_client.py +++ b/src/driftpy/accounts/ws/drift_client.py @@ -47,7 +47,9 @@ def __init__( self.spot_market_map = None self.perp_market_map = None self.spot_market_oracle_map: dict[int, Pubkey] = {} + self.spot_market_oracle_strings_map: dict[int, str] = {} self.perp_market_oracle_map: dict[int, Pubkey] = {} + self.perp_market_oracle_strings_map: dict[int, str] = {} async def subscribe(self): if self.is_subscribed(): @@ -178,7 +180,7 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper): async def subscribe_to_oracle_info(self, oracle_info: OracleInfo): oracle_id = get_oracle_id(oracle_info.pubkey, oracle_info.source) - if oracle_id == Pubkey.default(): + if oracle_info.pubkey == Pubkey.default(): return if oracle_id in self.oracle_subscribers: @@ -221,7 +223,7 @@ async def add_oracle(self, oracle_info: OracleInfo): if oracle_id in self.oracle_subscribers: return True - if oracle_id == Pubkey.default(): + if oracle_info.pubkey == Pubkey.default(): return True return await self.subscribe_to_oracle_info(oracle_info) @@ -299,6 +301,7 @@ async def _set_perp_oracle_map(self): OracleInfo(oracle, perp_market_account.amm.oracle_source) ) self.perp_market_oracle_map[market_index] = oracle + self.perp_market_oracle_strings_map[market_index] = oracle_id async def _set_spot_oracle_map(self): spot_markets = self.get_spot_market_accounts_and_slots() @@ -315,12 +318,14 @@ async def _set_spot_oracle_map(self): OracleInfo(oracle, spot_market_account.oracle_source) ) self.spot_market_oracle_map[market_index] = oracle + self.spot_market_oracle_strings_map[market_index] = oracle_id def get_oracle_price_data_and_slot_for_perp_market( self, market_index: int ) -> Union[DataAndSlot[OraclePriceData], None]: perp_market_account = self.get_perp_market_and_slot(market_index) oracle = self.perp_market_oracle_map.get(market_index) + oracle_id = self.perp_market_oracle_strings_map.get(market_index) if not perp_market_account or not oracle: return None @@ -328,7 +333,6 @@ def get_oracle_price_data_and_slot_for_perp_market( if perp_market_account.data.amm.oracle != oracle: asyncio.create_task(self._set_perp_oracle_map()) - oracle_id = get_oracle_id(oracle, perp_market_account.data.amm.oracle_source) return self.get_oracle_price_data_and_slot(oracle_id) def get_oracle_price_data_and_slot_for_spot_market( @@ -336,6 +340,7 @@ def get_oracle_price_data_and_slot_for_spot_market( ) -> Union[DataAndSlot[OraclePriceData], None]: spot_market_account = self.get_spot_market_and_slot(market_index) oracle = self.spot_market_oracle_map.get(market_index) + oracle_id = self.spot_market_oracle_strings_map.get(market_index) if not spot_market_account or not oracle: return None @@ -343,5 +348,5 @@ def get_oracle_price_data_and_slot_for_spot_market( if spot_market_account.data.oracle != oracle: asyncio.create_task(self._set_spot_oracle_map()) - oracle_id = get_oracle_id(oracle, spot_market_account.data.oracle_source) return self.get_oracle_price_data_and_slot(oracle_id) + diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index ef0fc9e1..892316fc 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -9,8 +9,8 @@ from driftpy.drift_client import ( DriftClient, ) -from driftpy.constants.numeric_constants import PEG_PRECISION -from driftpy.types import OracleGuardRails, OracleSource, PrelaunchOracleParams +from driftpy.constants.numeric_constants import BASE_PRECISION, PEG_PRECISION, PRICE_PRECISION +from driftpy.types import AssetTier, ContractTier, OracleGuardRails, OracleSource, PrelaunchOracleParams from driftpy.addresses import * from driftpy.accounts import get_state_account from driftpy.constants.numeric_constants import ( @@ -61,10 +61,24 @@ async def initialize_perp_market( periodicity: int, peg_multiplier: int = PEG_PRECISION, oracle_source: OracleSource = OracleSource.Pyth(), + contract_tier: ContractTier = ContractTier.Speculative(), margin_ratio_initial: int = 2000, margin_ratio_maintenance: int = 500, - liquidation_fee: int = 0, + liquidator_fee: int = 0, + if_liquidator_fee: int = 10000, + imf_factor: int = 0, active_status: bool = True, + base_spread: int = 0, + max_spread: int = 142500, + max_open_interest: int = 0, + max_revenue_withdraw_per_period: int = 0, + quote_max_insurance: int = 0, + order_step_size: int = BASE_PRECISION // 10000, + order_tick_size: int = PRICE_PRECISION // 100000, + min_order_size: int = BASE_PRECISION // 10000, + concentration_coef_scale: int = 1, + curve_update_intensity: int = 0, + amm_jit_intensity: int = 0, name: list = [0] * 32, ) -> Signature: state_public_key = get_state_public_key(self.program.program_id) @@ -81,10 +95,24 @@ async def initialize_perp_market( periodicity, peg_multiplier, oracle_source, + contract_tier, margin_ratio_initial, margin_ratio_maintenance, - liquidation_fee, + liquidator_fee, + if_liquidator_fee, + imf_factor, active_status, + base_spread, + max_spread, + max_open_interest, + max_revenue_withdraw_per_period, + quote_max_insurance, + order_step_size, + order_tick_size, + min_order_size, + concentration_coef_scale, + curve_update_intensity, + amm_jit_intensity, name, ctx=Context( accounts={ @@ -111,7 +139,14 @@ async def initialize_spot_market( initial_liability_weight: int = SPOT_WEIGHT_PRECISION, maintenance_liability_weight: int = SPOT_WEIGHT_PRECISION, imf_factor: int = 0, - liquidation_fee: int = 0, + liquidator_fee: int = 0, + if_liquidation_fee: int = 0, + scale_initial_asset_weight_start: int = 0, + withdraw_guard_threshold: int = 0, + order_tick_size: int = 1, + order_step_size: int = 1, + if_total_factor: int = 0, + asset_tier: AssetTier = AssetTier.COLLATERAL(), active_status: bool = True, name: list = [0] * 32, ): @@ -137,8 +172,15 @@ async def initialize_spot_market( initial_liability_weight, maintenance_liability_weight, imf_factor, - liquidation_fee, + liquidator_fee, + if_liquidation_fee, active_status, + asset_tier, + scale_initial_asset_weight_start, + withdraw_guard_threshold, + order_tick_size, + order_step_size, + if_total_factor, name, ctx=Context( accounts={ diff --git a/src/driftpy/setup/helpers.py b/src/driftpy/setup/helpers.py index 3d667802..083e6206 100644 --- a/src/driftpy/setup/helpers.py +++ b/src/driftpy/setup/helpers.py @@ -334,7 +334,7 @@ async def mock_oracle( async def initialize_sol_spot_market( - admin: Admin, sol_oracle: Pubkey, sol_mint: Pubkey = NATIVE_MINT + admin: Admin, sol_oracle: Pubkey, sol_mint: Pubkey = NATIVE_MINT, oracle_source: OracleSource = OracleSource.Pyth() ): optimal_utilization = SPOT_RATE_PRECISION // 2 optimal_rate = SPOT_RATE_PRECISION * 20 @@ -352,7 +352,7 @@ async def initialize_sol_spot_market( optimal_rate, max_rate, sol_oracle, - OracleSource.Pyth(), + oracle_source, initial_asset_weight, maintenance_asset_weight, initial_liability_weight, diff --git a/src/driftpy/types.py b/src/driftpy/types.py index 611de84a..ea10e876 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -240,7 +240,6 @@ class MarginCalculationMode: class OracleSource: Pyth = constructor() Switchboard = constructor() - SwitchboardOnDemand = constructor() QuoteAsset = constructor() Pyth1K = constructor() Pyth1M = constructor() @@ -250,6 +249,7 @@ class OracleSource: Pyth1KPull = constructor() Pyth1MPull = constructor() PythStableCoinPull = constructor() + SwitchboardOnDemand = constructor() @_rust_enum @@ -359,16 +359,16 @@ class OrderStatus: class OracleSourceNum: PYTH = 0 + SWITCHBOARD = 6 + QUOTE_ASSET = 7 PYTH_1K = 1 PYTH_1M = 2 + PYTH_STABLE_COIN = 8 + PRELAUNCH = 10 PYTH_PULL = 3 PYTH_1K_PULL = 4 PYTH_1M_PULL = 5 - SWITCHBOARD = 6 - QUOTE_ASSET = 7 - PYTH_STABLE_COIN = 8 PYTH_STABLE_COIN_PULL = 9 - PRELAUNCH = 10 SWITCHBOARD_ON_DEMAND = 11 diff --git a/tests/integration/test_oracle_diff_sources.py b/tests/integration/test_oracle_diff_sources.py index 5eee8715..18ce331f 100644 --- a/tests/integration/test_oracle_diff_sources.py +++ b/tests/integration/test_oracle_diff_sources.py @@ -77,7 +77,7 @@ async def admin_client(program: Program, usdc_mint: Keypair) -> Admin: admin = Admin( program.provider.connection, program.provider.wallet, - account_subscription=AccountSubscriptionConfig("polling"), + account_subscription=AccountSubscriptionConfig("websocket"), perp_market_indexes=market_indexes, spot_market_indexes=spot_market_indexes, ) From 1ff07f32c6c4b901006c3c73b4f11df770e8cde1 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Mon, 16 Dec 2024 18:15:08 -0500 Subject: [PATCH 2/3] fix polling test --- src/driftpy/accounts/bulk_account_loader.py | 6 ++ src/driftpy/accounts/polling/drift_client.py | 59 +++++++++---------- tests/integration/test_oracle_diff_sources.py | 1 + 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/driftpy/accounts/bulk_account_loader.py b/src/driftpy/accounts/bulk_account_loader.py index 2826c2c8..0db960af 100644 --- a/src/driftpy/accounts/bulk_account_loader.py +++ b/src/driftpy/accounts/bulk_account_loader.py @@ -59,6 +59,12 @@ def add_account( if existing_size == 0: self._start_loading() + # If the account is already loaded, call the callback immediately + if existing_account_to_load is not None: + buffer_and_slot = self.buffer_and_slot_map.get(pubkey_str) + if buffer_and_slot is not None and buffer_and_slot.buffer is not None: + self.handle_callbacks(existing_account_to_load, buffer_and_slot.buffer, buffer_and_slot.slot) + return callback_id def get_callback_id(self) -> int: diff --git a/src/driftpy/accounts/polling/drift_client.py b/src/driftpy/accounts/polling/drift_client.py index 4f17bf06..b27275bb 100644 --- a/src/driftpy/accounts/polling/drift_client.py +++ b/src/driftpy/accounts/polling/drift_client.py @@ -13,6 +13,7 @@ get_state_public_key, ) from driftpy.constants.config import find_all_market_and_oracles +from driftpy.oracles.oracle_id import get_oracle_id from driftpy.types import ( OracleInfo, OraclePriceData, @@ -49,7 +50,9 @@ def __init__( self.spot_markets = {} self.oracle = {} self.perp_oracle_map: dict[int, Pubkey] = {} + self.perp_oracle_strings_map: dict[int, str] = {} self.spot_oracle_map: dict[int, Pubkey] = {} + self.spot_oracle_strings_map: dict[int, str] = {} async def subscribe(self): if len(self.callbacks) != 0: @@ -138,33 +141,30 @@ def cb(buffer: bytes, slot: int): return cb async def add_oracle(self, oracle: Pubkey, oracle_source: OracleSource): - if oracle == Pubkey.default() or oracle in self.oracle: - return True - - oracle_str = str(oracle) - if oracle_str in self.callbacks: + oracle_id = get_oracle_id(oracle, oracle_source) + if oracle == Pubkey.default() or oracle_id in self.oracle: return True callback_id = self.bulk_account_loader.add_account( - oracle, self._get_oracle_callback(oracle_str, oracle_source) + oracle, self._get_oracle_callback(oracle_id, oracle_source) ) - self.callbacks[oracle_str] = callback_id + self.callbacks[oracle_id] = callback_id - await self._wait_for_oracle(3, oracle_str) + await self._wait_for_oracle(3, oracle_id) return True - async def _wait_for_oracle(self, tries: int, oracle: str): + async def _wait_for_oracle(self, tries: int, oracle_id: str): while tries > 0: await asyncio.sleep(self.bulk_account_loader.frequency) - if oracle in self.bulk_account_loader.buffer_and_slot_map: + if oracle_id in self.oracle: return tries -= 1 print( - f"WARNING: Oracle: {oracle} not found after {tries * self.bulk_account_loader.frequency} seconds, Location: {stack_trace()}" + f"WARNING: Oracle: {oracle_id} not found after {tries * self.bulk_account_loader.frequency} seconds, Location: {stack_trace()}" ) - def _get_oracle_callback(self, oracle_str: str, oracle_source: OracleSource): + def _get_oracle_callback(self, oracle_id: str, oracle_source: OracleSource): decode = get_oracle_decode_fn(oracle_source) def cb(buffer: bytes, slot: int): @@ -172,7 +172,7 @@ def cb(buffer: bytes, slot: int): return decoded_data = decode(buffer) - self.oracle[oracle_str] = DataAndSlot(slot, decoded_data) + self.oracle[oracle_id] = DataAndSlot(slot, decoded_data) return cb @@ -197,9 +197,9 @@ def get_spot_market_and_slot( return self.spot_markets.get(market_index) def get_oracle_price_data_and_slot( - self, oracle: Pubkey + self, oracle_id: str ) -> Optional[DataAndSlot[OraclePriceData]]: - return self.oracle.get(str(oracle)) + return self.oracle.get(oracle_id) def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]: return [ @@ -221,9 +221,12 @@ async def _set_perp_oracle_map(self): market_account = market.data market_index = market_account.market_index oracle = market_account.amm.oracle - if oracle not in self.oracle: - await self.add_oracle(oracle, market_account.amm.oracle_source) + oracle_source = market_account.amm.oracle_source + oracle_id = get_oracle_id(oracle, oracle_source) + if oracle_id not in self.oracle: + await self.add_oracle(oracle, oracle_source) self.perp_oracle_map[market_index] = oracle + self.perp_oracle_strings_map[market_index] = oracle_id async def _set_spot_oracle_map(self): spot_markets = self.get_spot_market_accounts_and_slots() @@ -231,24 +234,19 @@ async def _set_spot_oracle_map(self): market_account = market.data market_index = market_account.market_index oracle = market_account.oracle - if oracle not in self.oracle: - await self.add_oracle(oracle, market_account.oracle_source) + oracle_source = market_account.oracle_source + oracle_id = get_oracle_id(oracle, oracle_source) + if oracle_id not in self.oracle: + await self.add_oracle(oracle, oracle_source) self.spot_oracle_map[market_index] = oracle + self.spot_oracle_strings_map[market_index] = oracle_id def get_oracle_price_data_and_slot_for_perp_market( self, market_index: int ) -> Union[DataAndSlot[OraclePriceData], None]: - print( - "==> PollingDriftClientAccountSubscriber: Getting oracle price data for perp market", - market_index, - ) - print(self.perp_markets) - print(self.spot_markets) perp_market_account = self.get_perp_market_and_slot(market_index) oracle = self.perp_oracle_map.get(market_index) - - print("Perp market account: ", perp_market_account) - print("Oracle: ", oracle) + oracle_id = self.perp_oracle_strings_map.get(market_index) if not perp_market_account or not oracle: return None @@ -256,13 +254,14 @@ def get_oracle_price_data_and_slot_for_perp_market( if str(perp_market_account.data.amm.oracle) != str(oracle): asyncio.create_task(self._set_perp_oracle_map()) - return self.get_oracle_price_data_and_slot(oracle) + return self.get_oracle_price_data_and_slot(oracle_id) def get_oracle_price_data_and_slot_for_spot_market( self, market_index: int ) -> Union[DataAndSlot[OraclePriceData], None]: spot_market_account = self.get_spot_market_and_slot(market_index) oracle = self.spot_oracle_map.get(market_index) + oracle_id = self.spot_oracle_strings_map.get(market_index) if not spot_market_account or not oracle: return None @@ -270,4 +269,4 @@ def get_oracle_price_data_and_slot_for_spot_market( if str(spot_market_account.data.oracle) != str(oracle): asyncio.create_task(self._set_spot_oracle_map()) - return self.get_oracle_price_data_and_slot(oracle) + return self.get_oracle_price_data_and_slot(oracle_id) diff --git a/tests/integration/test_oracle_diff_sources.py b/tests/integration/test_oracle_diff_sources.py index 18ce331f..be0a5eb8 100644 --- a/tests/integration/test_oracle_diff_sources.py +++ b/tests/integration/test_oracle_diff_sources.py @@ -155,6 +155,7 @@ async def test_polling( oracle_infos=oracle_infos, ) await polling_client.subscribe() + print("sub done") # Verify spot market oracles oracle_data_for_spot_market_1 = ( From 4fc0e35a4d9d7df7402cc6ccecffcab4b1d9089c Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Mon, 16 Dec 2024 18:51:30 -0500 Subject: [PATCH 3/3] handle unsub for oracles --- src/driftpy/accounts/polling/drift_client.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/driftpy/accounts/polling/drift_client.py b/src/driftpy/accounts/polling/drift_client.py index b27275bb..9199a1b4 100644 --- a/src/driftpy/accounts/polling/drift_client.py +++ b/src/driftpy/accounts/polling/drift_client.py @@ -39,6 +39,7 @@ def __init__( self.program = program self.is_subscribed = False self.callbacks: dict[str, int] = {} + self.oracle_callbacks: dict[str, int] = {} self.perp_market_indexes = perp_market_indexes self.spot_market_indexes = spot_market_indexes @@ -148,7 +149,7 @@ async def add_oracle(self, oracle: Pubkey, oracle_source: OracleSource): callback_id = self.bulk_account_loader.add_account( oracle, self._get_oracle_callback(oracle_id, oracle_source) ) - self.callbacks[oracle_id] = callback_id + self.oracle_callbacks[oracle_id] = callback_id await self._wait_for_oracle(3, oracle_id) @@ -181,7 +182,14 @@ async def unsubscribe(self): self.bulk_account_loader.remove_account( Pubkey.from_string(pubkey_str), callback_id ) + + for oracle_id, callback_id in self.oracle_callbacks.items(): + self.bulk_account_loader.remove_account( + Pubkey.from_string(oracle_id.split("-")[0]), callback_id + ) + self.callbacks.clear() + self.oracle_callbacks.clear() def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: return self.state