Skip to content

Commit

Permalink
tweak naming for get_perp_position and get_spot_position
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 27, 2023
1 parent 5059ea1 commit 2232c2a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 21 deletions.
4 changes: 2 additions & 2 deletions examples/limit_order_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def main(
market.amm.historical_oracle_data.last_oracle_price / PRICE_PRECISION
)
# current_price = 20.00
current_pos_raw = drift_user.get_user_position(market_index)
current_pos_raw = drift_user.get_perp_position(market_index)
if current_pos_raw is not None:
current_pos = current_pos_raw.base_asset_amount / float(BASE_PRECISION)
else:
Expand All @@ -158,7 +158,7 @@ async def main(
market.historical_oracle_data.last_oracle_price / PRICE_PRECISION
)

spot_pos = await drift_user.get_user_spot_position(market_index)
spot_pos = await drift_user.get_spot_position(market_index)
tokens = get_token_amount(
spot_pos.scaled_balance, market, spot_pos.balance_type
)
Expand Down
2 changes: 1 addition & 1 deletion examples/start_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def main(
print("confirming tx...")
await connection.confirm_transaction(sig)

position = dc.get_user_position(market_index)
position = dc.get_perp_position(market_index)
market = await get_perp_market_account(dc.program, market_index)
percent_provided = (position.lp_shares / market.amm.sqrt_k) * 100
print(f"lp shares: {position.lp_shares}")
Expand Down
24 changes: 7 additions & 17 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,44 +1038,34 @@ async def get_settle_lp_ix(
),
)

def get_user_spot_position(
def get_spot_position(
self,
market_index: int,
sub_account_id: int = 0,
) -> Optional[SpotPosition]:
user = self.get_user_account(sub_account_id)

found = False
for position in user.spot_positions:
if (
position.market_index == market_index
and not is_spot_position_available(position)
):
found = True
break
return position

if not found:
return None
return None

return position

def get_user_position(
def get_perp_position(
self,
market_index: int,
sub_account_id: int = 0,
) -> Optional[PerpPosition]:
user = self.get_user(sub_account_id).get_user_account()

found = False
for position in user.perp_positions:
if position.market_index == market_index and not is_available(position):
found = True
break

if not found:
return None
return position

return position
return None

def default_order_params(
self, order_type, market_index, base_asset_amount, direction
Expand Down Expand Up @@ -1815,7 +1805,7 @@ async def close_position(
async def get_close_position_ix(
self, market_index: int, limit_price: int = 0, sub_account_id: int = 0
):
position = self.get_user_position(market_index, sub_account_id)
position = self.get_perp_position(market_index, sub_account_id)
if position is None or position.base_asset_amount == 0:
print("=> user has no position to close...")
return
Expand Down
2 changes: 1 addition & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,5 +446,5 @@ async def test_liq_perp(

# liq takes on position
await liq_drift_client.get_user(0).account_subscriber.update_cache()
position = liq_drift_client.get_user_position(0)
position = liq_drift_client.get_perp_position(0)
assert position.base_asset_amount != 0

0 comments on commit 2232c2a

Please sign in to comment.