Skip to content

Commit

Permalink
add account suffix to getters on drift client and drift user
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 26, 2023
1 parent 649d4b7 commit c6037eb
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 51 deletions.
32 changes: 24 additions & 8 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_user(self, sub_account_id=0) -> DriftUser:
return self.users[sub_account_id]

async def get_user_account(self, sub_account_id=0) -> UserAccount:
return await self.get_user(sub_account_id).get_user()
return await self.get_user(sub_account_id).get_user_account()

def switch_active_user(self, sub_account_id: int):
self.active_sub_account_id = sub_account_id
Expand All @@ -164,17 +164,21 @@ def get_state_public_key(self):
def get_user_stats_public_key(self):
return get_user_stats_account_public_key(self.program_id, self.authority)

async def get_state(self) -> Optional[StateAccount]:
async def get_state_account(self) -> Optional[StateAccount]:
state_and_slot = await self.account_subscriber.get_state_account_and_slot()
return getattr(state_and_slot, "data", None)

async def get_perp_market(self, market_index: int) -> Optional[PerpMarketAccount]:
async def get_perp_market_account(
self, market_index: int
) -> Optional[PerpMarketAccount]:
perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot(
market_index
)
return getattr(perp_market_and_slot, "data", None)

async def get_spot_market(self, market_index: int) -> Optional[SpotMarketAccount]:
async def get_spot_market_account(
self, market_index: int
) -> Optional[SpotMarketAccount]:
spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot(
market_index
)
Expand All @@ -186,6 +190,18 @@ async def get_oracle_price_data(self, oracle: Pubkey) -> Optional[OraclePriceDat
)
return getattr(oracle_price_data_and_slot, "data", None)

async def get_oracle_price_data_for_perp_market(
self, market_index: int
) -> Optional[OraclePriceData]:
oracle = (await self.get_perp_market_account(market_index)).amm.oracle
return await self.get_oracle_price_data(oracle)

async def get_oracle_price_data_for_spot_market(
self, market_index: int
) -> Optional[OraclePriceData]:
oracle = (await self.get_spot_market_account(market_index)).oracle
return await self.get_oracle_price_data(oracle)

async def fetch_market_lookup_table(self) -> AddressLookupTableAccount:
if self.market_lookup_table_account is not None:
return self.market_lookup_table_account
Expand Down Expand Up @@ -367,7 +383,7 @@ async def add_perp_market_to_remaining_account_maps(
spot_market_account_map: dict[int, AccountMeta],
perp_market_account_map: dict[int, AccountMeta],
) -> None:
perp_market_account = await self.get_perp_market(market_index)
perp_market_account = await self.get_perp_market_account(market_index)

perp_market_account_map[market_index] = AccountMeta(
pubkey=perp_market_account.pubkey, is_signer=False, is_writable=writable
Expand All @@ -391,7 +407,7 @@ async def add_spot_market_to_remaining_account_maps(
oracle_account_map: dict[str, AccountMeta],
spot_market_account_map: dict[int, AccountMeta],
) -> None:
spot_market_account = await self.get_spot_market(market_index)
spot_market_account = await self.get_spot_market_account(market_index)

spot_market_account_map[market_index] = AccountMeta(
pubkey=spot_market_account.pubkey, is_signer=False, is_writable=writable
Expand Down Expand Up @@ -473,7 +489,7 @@ async def get_withdraw_collateral_ix(
reduce_only: bool = False,
sub_account_id: int = 0,
):
spot_market = await self.get_spot_market(spot_market_index)
spot_market = await self.get_spot_market_account(spot_market_index)
remaining_accounts = await self.get_remaining_accounts(
user_accounts=[await self.get_user_account(sub_account_id)],
writable_spot_market_indexes=[spot_market_index],
Expand Down Expand Up @@ -1061,7 +1077,7 @@ async def get_user_position(
market_index: int,
sub_account_id: int = 0,
) -> Optional[PerpPosition]:
user = await self.get_user(sub_account_id).get_user()
user = await self.get_user(sub_account_id).get_user_account()

found = False
for position in user.perp_positions:
Expand Down
92 changes: 56 additions & 36 deletions src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,27 @@ async def subscribe(self):
def unsubscribe(self):
self.account_subscriber.unsubscribe()

async def get_spot_oracle_data(
self, spot_market: SpotMarketAccount
async def get_oracle_data_for_spot_market(
self, market_index: int
) -> Optional[OraclePriceData]:
return await self.drift_client.get_oracle_price_data(spot_market.oracle)
return await self.drift_client.get_oracle_price_data_for_spot_market(
market_index
)

async def get_perp_oracle_data(
self, perp_market: PerpMarketAccount
async def get_oracle_data_for_perp_market(
self, market_index: int
) -> Optional[OraclePriceData]:
return await self.drift_client.get_oracle_price_data(perp_market.amm.oracle)

async def get_state(self) -> StateAccount:
return await self.drift_client.get_state()
return await self.drift_client.get_oracle_price_data_for_perp_market(
market_index
)

async def get_spot_market(self, market_index: int) -> SpotMarketAccount:
return await self.drift_client.get_spot_market(market_index)
async def get_perp_market_account(self, market_index: int) -> PerpMarketAccount:
return await self.drift_client.get_perp_market_account(market_index)

async def get_perp_market(self, market_index: int) -> PerpMarketAccount:
return await self.drift_client.get_perp_market(market_index)
async def get_spot_market_account(self, market_index: int) -> SpotMarketAccount:
return await self.drift_client.get_spot_market_account(market_index)

async def get_user(self) -> UserAccount:
async def get_user_account(self) -> UserAccount:
return (await self.account_subscriber.get_user_account_and_slot()).data

async def get_open_orders(
Expand All @@ -79,7 +80,7 @@ async def get_open_orders(
# market_index: int,
# position_direction: PositionDirection
):
user: UserAccount = await self.get_user()
user: UserAccount = await self.get_user_account()
return user.orders

async def get_spot_market_liability(
Expand All @@ -89,15 +90,17 @@ async def get_spot_market_liability(
liquidation_buffer=None,
include_open_orders=None,
):
user = await self.get_user()
user = await self.get_user_account()
total_liability = 0
for position in user.spot_positions:
if is_spot_position_available(position) or (
market_index is not None and position.market_index != market_index
):
continue

spot_market = await self.get_spot_market(position.market_index)
spot_market = await self.drift_client.get_spot_market_account(
position.market_index
)

if position.market_index == QUOTE_SPOT_MARKET_INDEX:
if str(position.balance_type) == "SpotBalanceType.Borrow()":
Expand All @@ -114,7 +117,9 @@ async def get_spot_market_liability(
else:
continue

oracle_data = await self.get_spot_oracle_data(spot_market)
oracle_data = await self.drift_client.get_oracle_price_data(
spot_market.oracle
)
if not include_open_orders:
if str(position.balance_type) == "SpotBalanceType.Borrow()":
token_amount = get_token_amount(
Expand Down Expand Up @@ -166,16 +171,20 @@ async def get_total_perp_liability(
liquidation_buffer: Optional[int] = 0,
include_open_orders: bool = False,
):
user = await self.get_user()
user = await self.get_user_account()

unrealized_pnl = 0
for position in user.perp_positions:
market = await self.get_perp_market(position.market_index)
market = await self.drift_client.get_perp_market_account(
position.market_index
)

if position.lp_shares > 0:
pass

price = (await self.get_perp_oracle_data(market)).price
price = (
await self.drift_client.get_oracle_price_data(market.amm.oracle)
).price
base_asset_amount = (
calculate_worst_case_base_asset_amount(position)
if include_open_orders
Expand Down Expand Up @@ -206,11 +215,11 @@ async def get_total_perp_liability(
async def can_be_liquidated(self) -> bool:
total_collateral = await self.get_total_collateral()

user = await self.get_user()
user = await self.get_user_account()
liquidation_buffer = None
if user.being_liquidated:
liquidation_buffer = (
await self.get_state()
await self.drift_client.get_state_account()
).liquidation_margin_buffer_ratio

maintenance_req = await self.get_margin_requirement(
Expand Down Expand Up @@ -265,7 +274,7 @@ async def get_user_spot_position(
self,
market_index: int,
) -> Optional[SpotPosition]:
user = await self.get_user()
user = await self.get_user_account()

found = False
for position in user.spot_positions:
Expand All @@ -285,7 +294,7 @@ async def get_user_position(
self,
market_index: int,
) -> Optional[PerpPosition]:
user = await self.get_user()
user = await self.get_user_account()

found = False
for position in user.perp_positions:
Expand All @@ -304,18 +313,24 @@ async def get_unrealized_pnl(
market_index: int = None,
with_weight_margin_category: Optional[MarginCategory] = None,
):
user = await self.get_user()
quote_spot_market = await self.get_spot_market(QUOTE_SPOT_MARKET_INDEX)
user = await self.get_user_account()
quote_spot_market = await self.drift_client.get_spot_market_account(
QUOTE_SPOT_MARKET_INDEX
)

unrealized_pnl = 0
position: PerpPosition
for position in user.perp_positions:
if market_index is not None and position.market_index != market_index:
continue

market = await self.get_perp_market(position.market_index)
market = await self.drift_client.get_perp_market_account(
position.market_index
)

oracle_data = await self.get_perp_oracle_data(market)
oracle_data = await self.drift_client.get_oracle_price_data(
market.amm.oracle
)
position_unrealized_pnl = calculate_position_pnl_with_oracle(
market, position, oracle_data, with_funding
)
Expand Down Expand Up @@ -345,15 +360,17 @@ async def get_spot_market_asset_value(
include_open_orders=True,
market_index: Optional[int] = None,
):
user = await self.get_user()
user = await self.get_user_account()
total_value = 0
for position in user.spot_positions:
if is_spot_position_available(position) or (
market_index is not None and position.market_index != market_index
):
continue

spot_market = await self.get_spot_market(position.market_index)
spot_market = await self.drift_client.get_spot_market_account(
position.market_index
)

if position.market_index == QUOTE_SPOT_MARKET_INDEX:
spot_token_value = get_token_amount(
Expand All @@ -373,7 +390,9 @@ async def get_spot_market_asset_value(
total_value += spot_token_value
continue

oracle_data = await self.get_spot_oracle_data(spot_market)
oracle_data = await self.drift_client.get_oracle_price_data(
spot_market.oracle
)

if not include_open_orders:
token_amount = get_token_amount(
Expand Down Expand Up @@ -435,11 +454,10 @@ async def get_perp_liq_price(
margin_req = await self.get_margin_requirement(MarginCategory.MAINTENANCE)
delta_liq = total_collateral - margin_req

perp_market = await self.get_perp_market(perp_market_index)
delta_per_baa = delta_liq / (position.base_asset_amount / AMM_RESERVE_PRECISION)

oracle_price = (
await self.get_perp_oracle_data(perp_market)
await self.get_oracle_data_for_perp_market(perp_market_index)
).price / PRICE_PRECISION

liq_price = oracle_price - (delta_per_baa / QUOTE_PRECISION)
Expand All @@ -462,7 +480,7 @@ async def get_spot_liq_price(
)
delta_liq = total_collateral - margin_req

spot_market = await self.get_spot_market(spot_market_index)
spot_market = await self.drift_client.get_spot_market_account(spot_market_index)
token_amount = get_token_amount(
position.scaled_balance, spot_market, position.balance_type
)
Expand Down Expand Up @@ -491,7 +509,9 @@ async def get_spot_liq_price(
case _:
raise Exception(f"Invalid balance type: {position.balance_type}")

price = (await self.get_spot_oracle_data(spot_market)).price
price = (
await self.drift_client.get_oracle_price_data(spot_market.oracle)
).price
liq_price = price + liq_price_delta
liq_price /= PRICE_PRECISION

Expand Down
14 changes: 7 additions & 7 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ async def test_usdc_deposit(
USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True
)
await drift_client.get_user(0).account_subscriber.update_cache()
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()
assert (
user_account.spot_positions[0].scaled_balance
== USDC_AMOUNT / QUOTE_PRECISION * SPOT_BALANCE_PRECISION
Expand All @@ -226,7 +226,7 @@ async def test_open_orders(
account_subscription=AccountSubscriptionConfig("cached"),
)
await drift_user.subscribe()
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()

assert len(user_account.orders) == 32
assert user_account.orders[0].market_index == 0
Expand Down Expand Up @@ -290,14 +290,14 @@ async def test_add_remove_liquidity(

await drift_client.add_liquidity(n_shares, 0)
await drift_client.get_user(0).account_subscriber.update_cache()
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()
assert user_account.perp_positions[0].lp_shares == n_shares

await drift_client.settle_lp(drift_client.authority, 0)

await drift_client.remove_liquidity(n_shares, 0)
await drift_client.get_user(0).account_subscriber.update_cache()
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()
assert user_account.perp_positions[0].lp_shares == 0


Expand Down Expand Up @@ -344,15 +344,15 @@ async def test_open_close_position(
# print(tx)

await drift_client.get_user(0).account_subscriber.update_cache()
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()

assert user_account.perp_positions[0].base_asset_amount == baa
assert user_account.perp_positions[0].quote_asset_amount < 0

await drift_client.close_position(0)

await drift_client.get_user(0).account_subscriber.update_cache()
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()
assert user_account.perp_positions[0].base_asset_amount == 0
assert user_account.perp_positions[0].quote_asset_amount < 0

Expand Down Expand Up @@ -394,7 +394,7 @@ async def test_liq_perp(
drift_client: Admin, usdc_mint: Keypair, workspace: WorkspaceType
):
market = await get_perp_market_account(drift_client.program, 0)
user_account = await drift_client.get_user(0).get_user()
user_account = await drift_client.get_user(0).get_user_account()

liq, _ = await _airdrop_user(drift_client.program.provider)
liq_drift_client = DriftClient(
Expand Down

0 comments on commit c6037eb

Please sign in to comment.