diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 24220682..1e5f93cf 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -635,7 +635,7 @@ async def place_spot_order( return await self.send_ixs( [ self.get_increase_compute_ix(), - await self.get_place_spot_order_ix(order_params, maker_info, user_id), + await self.get_place_spot_order_ix(order_params, user_id), ] ) @@ -713,33 +713,34 @@ async def place_perp_order( user_id: int = 0, ): return await self.send_ixs( - [ + [ self.get_increase_compute_ix(), - await self.get_place_perp_order_ix(order_params, maker_info, user_id), + (await self.get_place_perp_order_ix(order_params, user_id))[-1] ] + ) async def get_place_perp_order_ix( self, order_params: OrderParams, user_id: int = 0, - ): + ) -> TransactionInstruction: user_account_public_key = self.get_user_account_public_key(user_id) remaining_accounts = await self.get_remaining_accounts( writable_market_index=order_params.market_index, user_id=user_id ) ix = self.program.instruction["place_perp_order"]( - order_params, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": user_account_public_key, - "authority": self.signer.public_key, - }, - remaining_accounts=remaining_accounts, - ), - ) + order_params, + ctx=Context( + accounts={ + "state": self.get_state_public_key(), + "user": user_account_public_key, + "authority": self.signer.public_key, + }, + remaining_accounts=remaining_accounts, + ), + ) return ix @@ -747,28 +748,29 @@ async def get_place_perp_orders_ix( self, order_params: List[OrderParams], user_id: int = 0, + cancel_all=True ): user_account_public_key = self.get_user_account_public_key(user_id) writeable_market_indexes = list(set([x.market_index for x in order_params])) remaining_accounts = await self.get_remaining_accounts( writable_market_index=writeable_market_indexes, user_id=user_id ) - - ixs = [ - self.program.instruction["cancel_orders"]( - None, - None, - None, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": self.get_user_account_public_key(user_id), - "authority": self.signer.public_key, - }, - remaining_accounts=remaining_accounts, - ), - ) - ] + ixs = [] + if cancel_all: + ixs.append( + self.program.instruction["cancel_orders"]( + None, + None, + None, + ctx=Context( + accounts={ + "state": self.get_state_public_key(), + "user": self.get_user_account_public_key(user_id), + "authority": self.signer.public_key, + }, + remaining_accounts=remaining_accounts, + ), + )) for order_param in order_params: ix = self.program.instruction["place_perp_order"]( order_param, @@ -950,7 +952,7 @@ def default_order_params( price=0, market_index=market_index, reduce_only=False, - post_only=PostOnlyParam.NONE(), + post_only=PostOnlyParams.NONE(), immediate_or_cancel=False, trigger_price=0, trigger_condition=OrderTriggerCondition.ABOVE(), diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index bba280bd..95f17d6d 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -215,6 +215,16 @@ async def get_user(self): self.program, self.authority, self.subaccount_id ) + + async def get_open_orders(self, + # market_type: MarketType, + # market_index: int, + # position_direction: PositionDirection + ): + user: User = await self.get_user() + return user.orders + + async def get_spot_market_liability( self, market_index=None, diff --git a/src/driftpy/types.py b/src/driftpy/types.py index d38d57bc..22716281 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -175,7 +175,7 @@ class OracleSource: PYTH_STABLE_COIN = constructor() @_rust_enum -class PostOnlyParam: +class PostOnlyParams: NONE = constructor() MUST_POST_ONLY = constructor() TRY_POST_ONLY = constructor() @@ -298,7 +298,7 @@ class OrderParams: price: int market_index: int reduce_only: bool - post_only: PostOnlyParam + post_only: PostOnlyParams immediate_or_cancel: bool max_ts: Optional[int] trigger_price: Optional[int] @@ -314,7 +314,7 @@ class ModifyOrderParams: base_asset_amount: Optional[int] price: Optional[int] reduce_only: Optional[bool] - post_only: Optional[PostOnlyParam] + post_only: Optional[PostOnlyParams] immediate_or_cancel: Optional[bool] max_ts: Optional[int] trigger_price: Optional[int] diff --git a/tests/test.py b/tests/test.py index c43bed89..77e5c813 100644 --- a/tests/test.py +++ b/tests/test.py @@ -7,12 +7,14 @@ from driftpy.constants.numeric_constants import ( PRICE_PRECISION, AMM_RESERVE_PRECISION, + BASE_PRECISION, QUOTE_PRECISION, SPOT_BALANCE_PRECISION, SPOT_WEIGHT_PRECISION, ) from math import sqrt +from driftpy.drift_user import User as DriftUser from driftpy.drift_client import DriftClient from driftpy.setup.helpers import ( _create_mint, @@ -28,6 +30,8 @@ PositionDirection, OracleSource, PerpMarket, + OrderType, + OrderParams # SwapDirection, ) from driftpy.accounts import ( @@ -201,6 +205,40 @@ async def test_usdc_deposit( == USDC_AMOUNT / QUOTE_PRECISION * SPOT_BALANCE_PRECISION ) +@mark.asyncio +async def test_open_orders( + drift_client: Admin, +): + + drift_user = DriftUser(drift_client) + user_account = await get_user_account( + drift_client.program, drift_client.authority + ) + + assert(len(user_account.orders)==32) + assert(user_account.orders[0].market_index == 0) + + open_orders = await drift_user.get_open_orders() + assert(len(open_orders)==32) + assert(open_orders==user_account.orders) + + order_params: OrderParams = drift_client.default_order_params( + OrderType.MARKET(), 0, int(1 * BASE_PRECISION), PositionDirection.LONG() + ) + order_params.user_order_id = 169 + ixs = await drift_client.get_place_perp_orders_ix([order_params]) + await drift_client.send_ixs(ixs) + open_orders_after = await drift_user.get_open_orders() + assert(open_orders_after[0].base_asset_amount == BASE_PRECISION) + assert(open_orders_after[0].order_id == 1) + assert(open_orders_after[0].user_order_id == 169) + + await drift_client.cancel_order(1, 0) + open_orders_after2 = await drift_user.get_open_orders() + assert(open_orders_after2[0].base_asset_amount == 0) + + + @mark.asyncio async def test_update_curve(