From 9ab81da6a000ef902b636b49b402b3f2ac8e8245 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Mon, 27 Nov 2023 21:15:33 -0500 Subject: [PATCH] add some missing user methods such as get_perp_position --- src/driftpy/drift_user.py | 44 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index bc1c4122..f2311d04 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -70,11 +70,32 @@ def get_spot_market_account(self, market_index: int) -> SpotMarketAccount: def get_user_account(self) -> UserAccount: return self.account_subscriber.get_user_account_and_slot().data + def get_token_amount(self, market_index: int) -> int: + spot_position = self.get_spot_position(market_index) + if spot_position is None: + return 0 + + spot_market = self.get_spot_market_account(market_index) + return get_token_amount( + spot_position.scaled_balance, spot_market, spot_position.balance_type + ) + + def get_order(self, order_id: int) -> Optional[Order]: + for order in self.get_user_account().orders: + if order.order_id == order_id: + return order + + return None + + def get_order_by_user_order_id(self, user_order_id: int): + for order in self.get_user_account().orders: + if order.user_order_id == user_order_id: + return order + + return None + def get_open_orders( self, - # market_type: MarketType, - # market_index: int, - # position_direction: PositionDirection ): return list( filter( @@ -83,6 +104,23 @@ def get_open_orders( ) ) + def get_perp_position(self, market_index: int) -> Optional[PerpPosition]: + for position in self.get_user_account().perp_positions: + if position.market_index == market_index and not is_available(position): + return position + + return None + + def get_spot_position(self, market_index: int) -> Optional[SpotPosition]: + for position in self.get_user_account().spot_positions: + if ( + position.market_index == market_index + and not is_spot_position_available(position) + ): + return position + + return None + def get_spot_market_liability( self, market_index=None,