From 2232c2a814504f6ad5524616ead7418f5a308702 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Sun, 26 Nov 2023 20:14:51 -0500 Subject: [PATCH] tweak naming for get_perp_position and get_spot_position --- examples/limit_order_grid.py | 4 ++-- examples/start_lp.py | 2 +- src/driftpy/drift_client.py | 24 +++++++----------------- tests/test.py | 2 +- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/examples/limit_order_grid.py b/examples/limit_order_grid.py index 061da05f..a13b79e1 100644 --- a/examples/limit_order_grid.py +++ b/examples/limit_order_grid.py @@ -140,7 +140,7 @@ async def main( market.amm.historical_oracle_data.last_oracle_price / PRICE_PRECISION ) # current_price = 20.00 - current_pos_raw = drift_user.get_user_position(market_index) + current_pos_raw = drift_user.get_perp_position(market_index) if current_pos_raw is not None: current_pos = current_pos_raw.base_asset_amount / float(BASE_PRECISION) else: @@ -158,7 +158,7 @@ async def main( market.historical_oracle_data.last_oracle_price / PRICE_PRECISION ) - spot_pos = await drift_user.get_user_spot_position(market_index) + spot_pos = await drift_user.get_spot_position(market_index) tokens = get_token_amount( spot_pos.scaled_balance, market, spot_pos.balance_type ) diff --git a/examples/start_lp.py b/examples/start_lp.py index d9e85d67..13fbaadd 100644 --- a/examples/start_lp.py +++ b/examples/start_lp.py @@ -106,7 +106,7 @@ async def main( print("confirming tx...") await connection.confirm_transaction(sig) - position = dc.get_user_position(market_index) + position = dc.get_perp_position(market_index) market = await get_perp_market_account(dc.program, market_index) percent_provided = (position.lp_shares / market.amm.sqrt_k) * 100 print(f"lp shares: {position.lp_shares}") diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index b08499bb..63539f60 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1038,44 +1038,34 @@ async def get_settle_lp_ix( ), ) - def get_user_spot_position( + def get_spot_position( self, market_index: int, sub_account_id: int = 0, ) -> Optional[SpotPosition]: user = self.get_user_account(sub_account_id) - found = False for position in user.spot_positions: if ( position.market_index == market_index and not is_spot_position_available(position) ): - found = True - break + return position - if not found: - return None + return None - return position - - def get_user_position( + def get_perp_position( self, market_index: int, sub_account_id: int = 0, ) -> Optional[PerpPosition]: user = self.get_user(sub_account_id).get_user_account() - found = False for position in user.perp_positions: if position.market_index == market_index and not is_available(position): - found = True - break - - if not found: - return None + return position - return position + return None def default_order_params( self, order_type, market_index, base_asset_amount, direction @@ -1815,7 +1805,7 @@ async def close_position( async def get_close_position_ix( self, market_index: int, limit_price: int = 0, sub_account_id: int = 0 ): - position = self.get_user_position(market_index, sub_account_id) + position = self.get_perp_position(market_index, sub_account_id) if position is None or position.base_asset_amount == 0: print("=> user has no position to close...") return diff --git a/tests/test.py b/tests/test.py index 81632a7a..5f1b8165 100644 --- a/tests/test.py +++ b/tests/test.py @@ -446,5 +446,5 @@ async def test_liq_perp( # liq takes on position await liq_drift_client.get_user(0).account_subscriber.update_cache() - position = liq_drift_client.get_user_position(0) + position = liq_drift_client.get_perp_position(0) assert position.base_asset_amount != 0