Skip to content

Commit

Permalink
bigz/add-perp--orders-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbigz committed Nov 15, 2023
1 parent 3562a80 commit 4aff04b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 34 deletions.
64 changes: 33 additions & 31 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ async def place_spot_order(
return await self.send_ixs(
[
self.get_increase_compute_ix(),
await self.get_place_spot_order_ix(order_params, maker_info, user_id),
await self.get_place_spot_order_ix(order_params, user_id),
]
)

Expand Down Expand Up @@ -713,62 +713,64 @@ async def place_perp_order(
user_id: int = 0,
):
return await self.send_ixs(
[
[
self.get_increase_compute_ix(),
await self.get_place_perp_order_ix(order_params, maker_info, user_id),
(await self.get_place_perp_order_ix(order_params, user_id))[-1]
]

)

async def get_place_perp_order_ix(
self,
order_params: OrderParams,
user_id: int = 0,
):
) -> TransactionInstruction:
user_account_public_key = self.get_user_account_public_key(user_id)
remaining_accounts = await self.get_remaining_accounts(
writable_market_index=order_params.market_index, user_id=user_id
)

ix = self.program.instruction["place_perp_order"](
order_params,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
)
order_params,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
)

return ix

async def get_place_perp_orders_ix(
self,
order_params: List[OrderParams],
user_id: int = 0,
cancel_all=True
):
user_account_public_key = self.get_user_account_public_key(user_id)
writeable_market_indexes = list(set([x.market_index for x in order_params]))
remaining_accounts = await self.get_remaining_accounts(
writable_market_index=writeable_market_indexes, user_id=user_id
)

ixs = [
self.program.instruction["cancel_orders"](
None,
None,
None,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(user_id),
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
)
]
ixs = []
if cancel_all:
ixs.append(
self.program.instruction["cancel_orders"](
None,
None,
None,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(user_id),
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
))
for order_param in order_params:
ix = self.program.instruction["place_perp_order"](
order_param,
Expand Down Expand Up @@ -950,7 +952,7 @@ def default_order_params(
price=0,
market_index=market_index,
reduce_only=False,
post_only=PostOnlyParam.NONE(),
post_only=PostOnlyParams.NONE(),
immediate_or_cancel=False,
trigger_price=0,
trigger_condition=OrderTriggerCondition.ABOVE(),
Expand Down
10 changes: 10 additions & 0 deletions src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ async def get_user(self):
self.program, self.authority, self.subaccount_id
)


async def get_open_orders(self,
# market_type: MarketType,
# market_index: int,
# position_direction: PositionDirection
):
user: User = await self.get_user()
return user.orders


async def get_spot_market_liability(
self,
market_index=None,
Expand Down
6 changes: 3 additions & 3 deletions src/driftpy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class OracleSource:
PYTH_STABLE_COIN = constructor()

@_rust_enum
class PostOnlyParam:
class PostOnlyParams:
NONE = constructor()
MUST_POST_ONLY = constructor()
TRY_POST_ONLY = constructor()
Expand Down Expand Up @@ -298,7 +298,7 @@ class OrderParams:
price: int
market_index: int
reduce_only: bool
post_only: PostOnlyParam
post_only: PostOnlyParams
immediate_or_cancel: bool
max_ts: Optional[int]
trigger_price: Optional[int]
Expand All @@ -314,7 +314,7 @@ class ModifyOrderParams:
base_asset_amount: Optional[int]
price: Optional[int]
reduce_only: Optional[bool]
post_only: Optional[PostOnlyParam]
post_only: Optional[PostOnlyParams]
immediate_or_cancel: Optional[bool]
max_ts: Optional[int]
trigger_price: Optional[int]
Expand Down
38 changes: 38 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from driftpy.constants.numeric_constants import (
PRICE_PRECISION,
AMM_RESERVE_PRECISION,
BASE_PRECISION,
QUOTE_PRECISION,
SPOT_BALANCE_PRECISION,
SPOT_WEIGHT_PRECISION,
)
from math import sqrt

from driftpy.drift_user import User as DriftUser
from driftpy.drift_client import DriftClient
from driftpy.setup.helpers import (
_create_mint,
Expand All @@ -28,6 +30,8 @@
PositionDirection,
OracleSource,
PerpMarket,
OrderType,
OrderParams
# SwapDirection,
)
from driftpy.accounts import (
Expand Down Expand Up @@ -201,6 +205,40 @@ async def test_usdc_deposit(
== USDC_AMOUNT / QUOTE_PRECISION * SPOT_BALANCE_PRECISION
)

@mark.asyncio
async def test_open_orders(
drift_client: Admin,
):

drift_user = DriftUser(drift_client)
user_account = await get_user_account(
drift_client.program, drift_client.authority
)

assert(len(user_account.orders)==32)
assert(user_account.orders[0].market_index == 0)

open_orders = await drift_user.get_open_orders()
assert(len(open_orders)==32)
assert(open_orders==user_account.orders)

order_params: OrderParams = drift_client.default_order_params(
OrderType.MARKET(), 0, int(1 * BASE_PRECISION), PositionDirection.LONG()
)
order_params.user_order_id = 169
ixs = await drift_client.get_place_perp_orders_ix([order_params])
await drift_client.send_ixs(ixs)
open_orders_after = await drift_user.get_open_orders()
assert(open_orders_after[0].base_asset_amount == BASE_PRECISION)
assert(open_orders_after[0].order_id == 1)
assert(open_orders_after[0].user_order_id == 169)

await drift_client.cancel_order(1, 0)
open_orders_after2 = await drift_user.get_open_orders()
assert(open_orders_after2[0].base_asset_amount == 0)




@mark.asyncio
async def test_update_curve(
Expand Down

0 comments on commit 4aff04b

Please sign in to comment.