diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 6e2be3fe..0bc4288f 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -153,7 +153,7 @@ def get_user(self, sub_account_id=0) -> DriftUser: return self.users[sub_account_id] async def get_user_account(self, sub_account_id=0) -> UserAccount: - return await self.get_user(sub_account_id).get_user() + return await self.get_user(sub_account_id).get_user_account() def switch_active_user(self, sub_account_id: int): self.active_sub_account_id = sub_account_id @@ -164,17 +164,21 @@ def get_state_public_key(self): def get_user_stats_public_key(self): return get_user_stats_account_public_key(self.program_id, self.authority) - async def get_state(self) -> Optional[StateAccount]: + async def get_state_account(self) -> Optional[StateAccount]: state_and_slot = await self.account_subscriber.get_state_account_and_slot() return getattr(state_and_slot, "data", None) - async def get_perp_market(self, market_index: int) -> Optional[PerpMarketAccount]: + async def get_perp_market_account( + self, market_index: int + ) -> Optional[PerpMarketAccount]: perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot( market_index ) return getattr(perp_market_and_slot, "data", None) - async def get_spot_market(self, market_index: int) -> Optional[SpotMarketAccount]: + async def get_spot_market_account( + self, market_index: int + ) -> Optional[SpotMarketAccount]: spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot( market_index ) @@ -186,6 +190,18 @@ async def get_oracle_price_data(self, oracle: Pubkey) -> Optional[OraclePriceDat ) return getattr(oracle_price_data_and_slot, "data", None) + async def get_oracle_price_data_for_perp_market( + self, market_index: int + ) -> Optional[OraclePriceData]: + oracle = (await self.get_perp_market_account(market_index)).amm.oracle + return await self.get_oracle_price_data(oracle) + + async def get_oracle_price_data_for_spot_market( + self, market_index: int + ) -> Optional[OraclePriceData]: + oracle = (await self.get_spot_market_account(market_index)).oracle + return await self.get_oracle_price_data(oracle) + async def fetch_market_lookup_table(self) -> AddressLookupTableAccount: if self.market_lookup_table_account is not None: return self.market_lookup_table_account @@ -367,7 +383,7 @@ async def add_perp_market_to_remaining_account_maps( spot_market_account_map: dict[int, AccountMeta], perp_market_account_map: dict[int, AccountMeta], ) -> None: - perp_market_account = await self.get_perp_market(market_index) + perp_market_account = await self.get_perp_market_account(market_index) perp_market_account_map[market_index] = AccountMeta( pubkey=perp_market_account.pubkey, is_signer=False, is_writable=writable @@ -391,7 +407,7 @@ async def add_spot_market_to_remaining_account_maps( oracle_account_map: dict[str, AccountMeta], spot_market_account_map: dict[int, AccountMeta], ) -> None: - spot_market_account = await self.get_spot_market(market_index) + spot_market_account = await self.get_spot_market_account(market_index) spot_market_account_map[market_index] = AccountMeta( pubkey=spot_market_account.pubkey, is_signer=False, is_writable=writable @@ -473,7 +489,7 @@ async def get_withdraw_collateral_ix( reduce_only: bool = False, sub_account_id: int = 0, ): - spot_market = await self.get_spot_market(spot_market_index) + spot_market = await self.get_spot_market_account(spot_market_index) remaining_accounts = await self.get_remaining_accounts( user_accounts=[await self.get_user_account(sub_account_id)], writable_spot_market_indexes=[spot_market_index], @@ -1061,7 +1077,7 @@ async def get_user_position( market_index: int, sub_account_id: int = 0, ) -> Optional[PerpPosition]: - user = await self.get_user(sub_account_id).get_user() + user = await self.get_user(sub_account_id).get_user_account() found = False for position in user.perp_positions: diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index edb837e0..dc582030 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -51,26 +51,27 @@ async def subscribe(self): def unsubscribe(self): self.account_subscriber.unsubscribe() - async def get_spot_oracle_data( - self, spot_market: SpotMarketAccount + async def get_oracle_data_for_spot_market( + self, market_index: int ) -> Optional[OraclePriceData]: - return await self.drift_client.get_oracle_price_data(spot_market.oracle) + return await self.drift_client.get_oracle_price_data_for_spot_market( + market_index + ) - async def get_perp_oracle_data( - self, perp_market: PerpMarketAccount + async def get_oracle_data_for_perp_market( + self, market_index: int ) -> Optional[OraclePriceData]: - return await self.drift_client.get_oracle_price_data(perp_market.amm.oracle) - - async def get_state(self) -> StateAccount: - return await self.drift_client.get_state() + return await self.drift_client.get_oracle_price_data_for_perp_market( + market_index + ) - async def get_spot_market(self, market_index: int) -> SpotMarketAccount: - return await self.drift_client.get_spot_market(market_index) + async def get_perp_market_account(self, market_index: int) -> PerpMarketAccount: + return await self.drift_client.get_perp_market_account(market_index) - async def get_perp_market(self, market_index: int) -> PerpMarketAccount: - return await self.drift_client.get_perp_market(market_index) + async def get_spot_market_account(self, market_index: int) -> SpotMarketAccount: + return await self.drift_client.get_spot_market_account(market_index) - async def get_user(self) -> UserAccount: + async def get_user_account(self) -> UserAccount: return (await self.account_subscriber.get_user_account_and_slot()).data async def get_open_orders( @@ -79,7 +80,7 @@ async def get_open_orders( # market_index: int, # position_direction: PositionDirection ): - user: UserAccount = await self.get_user() + user: UserAccount = await self.get_user_account() return user.orders async def get_spot_market_liability( @@ -89,7 +90,7 @@ async def get_spot_market_liability( liquidation_buffer=None, include_open_orders=None, ): - user = await self.get_user() + user = await self.get_user_account() total_liability = 0 for position in user.spot_positions: if is_spot_position_available(position) or ( @@ -97,7 +98,9 @@ async def get_spot_market_liability( ): continue - spot_market = await self.get_spot_market(position.market_index) + spot_market = await self.drift_client.get_spot_market_account( + position.market_index + ) if position.market_index == QUOTE_SPOT_MARKET_INDEX: if str(position.balance_type) == "SpotBalanceType.Borrow()": @@ -114,7 +117,9 @@ async def get_spot_market_liability( else: continue - oracle_data = await self.get_spot_oracle_data(spot_market) + oracle_data = await self.drift_client.get_oracle_price_data( + spot_market.oracle + ) if not include_open_orders: if str(position.balance_type) == "SpotBalanceType.Borrow()": token_amount = get_token_amount( @@ -166,16 +171,20 @@ async def get_total_perp_liability( liquidation_buffer: Optional[int] = 0, include_open_orders: bool = False, ): - user = await self.get_user() + user = await self.get_user_account() unrealized_pnl = 0 for position in user.perp_positions: - market = await self.get_perp_market(position.market_index) + market = await self.drift_client.get_perp_market_account( + position.market_index + ) if position.lp_shares > 0: pass - price = (await self.get_perp_oracle_data(market)).price + price = ( + await self.drift_client.get_oracle_price_data(market.amm.oracle) + ).price base_asset_amount = ( calculate_worst_case_base_asset_amount(position) if include_open_orders @@ -206,11 +215,11 @@ async def get_total_perp_liability( async def can_be_liquidated(self) -> bool: total_collateral = await self.get_total_collateral() - user = await self.get_user() + user = await self.get_user_account() liquidation_buffer = None if user.being_liquidated: liquidation_buffer = ( - await self.get_state() + await self.drift_client.get_state_account() ).liquidation_margin_buffer_ratio maintenance_req = await self.get_margin_requirement( @@ -265,7 +274,7 @@ async def get_user_spot_position( self, market_index: int, ) -> Optional[SpotPosition]: - user = await self.get_user() + user = await self.get_user_account() found = False for position in user.spot_positions: @@ -285,7 +294,7 @@ async def get_user_position( self, market_index: int, ) -> Optional[PerpPosition]: - user = await self.get_user() + user = await self.get_user_account() found = False for position in user.perp_positions: @@ -304,8 +313,10 @@ async def get_unrealized_pnl( market_index: int = None, with_weight_margin_category: Optional[MarginCategory] = None, ): - user = await self.get_user() - quote_spot_market = await self.get_spot_market(QUOTE_SPOT_MARKET_INDEX) + user = await self.get_user_account() + quote_spot_market = await self.drift_client.get_spot_market_account( + QUOTE_SPOT_MARKET_INDEX + ) unrealized_pnl = 0 position: PerpPosition @@ -313,9 +324,13 @@ async def get_unrealized_pnl( if market_index is not None and position.market_index != market_index: continue - market = await self.get_perp_market(position.market_index) + market = await self.drift_client.get_perp_market_account( + position.market_index + ) - oracle_data = await self.get_perp_oracle_data(market) + oracle_data = await self.drift_client.get_oracle_price_data( + market.amm.oracle + ) position_unrealized_pnl = calculate_position_pnl_with_oracle( market, position, oracle_data, with_funding ) @@ -345,7 +360,7 @@ async def get_spot_market_asset_value( include_open_orders=True, market_index: Optional[int] = None, ): - user = await self.get_user() + user = await self.get_user_account() total_value = 0 for position in user.spot_positions: if is_spot_position_available(position) or ( @@ -353,7 +368,9 @@ async def get_spot_market_asset_value( ): continue - spot_market = await self.get_spot_market(position.market_index) + spot_market = await self.drift_client.get_spot_market_account( + position.market_index + ) if position.market_index == QUOTE_SPOT_MARKET_INDEX: spot_token_value = get_token_amount( @@ -373,7 +390,9 @@ async def get_spot_market_asset_value( total_value += spot_token_value continue - oracle_data = await self.get_spot_oracle_data(spot_market) + oracle_data = await self.drift_client.get_oracle_price_data( + spot_market.oracle + ) if not include_open_orders: token_amount = get_token_amount( @@ -435,11 +454,10 @@ async def get_perp_liq_price( margin_req = await self.get_margin_requirement(MarginCategory.MAINTENANCE) delta_liq = total_collateral - margin_req - perp_market = await self.get_perp_market(perp_market_index) delta_per_baa = delta_liq / (position.base_asset_amount / AMM_RESERVE_PRECISION) oracle_price = ( - await self.get_perp_oracle_data(perp_market) + await self.get_oracle_data_for_perp_market(perp_market_index) ).price / PRICE_PRECISION liq_price = oracle_price - (delta_per_baa / QUOTE_PRECISION) @@ -462,7 +480,7 @@ async def get_spot_liq_price( ) delta_liq = total_collateral - margin_req - spot_market = await self.get_spot_market(spot_market_index) + spot_market = await self.drift_client.get_spot_market_account(spot_market_index) token_amount = get_token_amount( position.scaled_balance, spot_market, position.balance_type ) @@ -491,7 +509,9 @@ async def get_spot_liq_price( case _: raise Exception(f"Invalid balance type: {position.balance_type}") - price = (await self.get_spot_oracle_data(spot_market)).price + price = ( + await self.drift_client.get_oracle_price_data(spot_market.oracle) + ).price liq_price = price + liq_price_delta liq_price /= PRICE_PRECISION diff --git a/tests/test.py b/tests/test.py index f45e8ca4..9abeb6c2 100644 --- a/tests/test.py +++ b/tests/test.py @@ -210,7 +210,7 @@ async def test_usdc_deposit( USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True ) await drift_client.get_user(0).account_subscriber.update_cache() - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() assert ( user_account.spot_positions[0].scaled_balance == USDC_AMOUNT / QUOTE_PRECISION * SPOT_BALANCE_PRECISION @@ -226,7 +226,7 @@ async def test_open_orders( account_subscription=AccountSubscriptionConfig("cached"), ) await drift_user.subscribe() - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() assert len(user_account.orders) == 32 assert user_account.orders[0].market_index == 0 @@ -290,14 +290,14 @@ async def test_add_remove_liquidity( await drift_client.add_liquidity(n_shares, 0) await drift_client.get_user(0).account_subscriber.update_cache() - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() assert user_account.perp_positions[0].lp_shares == n_shares await drift_client.settle_lp(drift_client.authority, 0) await drift_client.remove_liquidity(n_shares, 0) await drift_client.get_user(0).account_subscriber.update_cache() - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() assert user_account.perp_positions[0].lp_shares == 0 @@ -344,7 +344,7 @@ async def test_open_close_position( # print(tx) await drift_client.get_user(0).account_subscriber.update_cache() - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() assert user_account.perp_positions[0].base_asset_amount == baa assert user_account.perp_positions[0].quote_asset_amount < 0 @@ -352,7 +352,7 @@ async def test_open_close_position( await drift_client.close_position(0) await drift_client.get_user(0).account_subscriber.update_cache() - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() assert user_account.perp_positions[0].base_asset_amount == 0 assert user_account.perp_positions[0].quote_asset_amount < 0 @@ -394,7 +394,7 @@ async def test_liq_perp( drift_client: Admin, usdc_mint: Keypair, workspace: WorkspaceType ): market = await get_perp_market_account(drift_client.program, 0) - user_account = await drift_client.get_user(0).get_user() + user_account = await drift_client.get_user(0).get_user_account() liq, _ = await _airdrop_user(drift_client.program.provider) liq_drift_client = DriftClient(