Skip to content

Commit

Permalink
Ruff: apply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
SinaKhalili committed Dec 11, 2024
1 parent fb211d1 commit b5789d3
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 77 deletions.
13 changes: 9 additions & 4 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,6 @@ async def get_deposit_collateral_ix(
reduce_only: Optional[bool] = False,
user_initialized: Optional[bool] = True,
) -> List[Instruction]:

sub_account_id = self.get_sub_account_id_for_ix(sub_account_id)
spot_market_account = self.get_spot_market_account(spot_market_index)
if not spot_market_account:
Expand Down Expand Up @@ -1624,7 +1623,9 @@ def get_place_and_take_perp_order_ix(
maker_infos = (
maker_info
if isinstance(maker_info, list)
else [maker_info] if maker_info else []
else [maker_info]
if maker_info
else []
)

user_accounts = [self.get_user_account(sub_account_id)]
Expand Down Expand Up @@ -1724,7 +1725,9 @@ def get_place_and_take_spot_order_ix(
maker_infos = (
maker_info
if isinstance(maker_info, list)
else [maker_info] if maker_info else []
else [maker_info]
if maker_info
else []
)
for maker_info in maker_infos:
user_accounts.append(maker_info.maker_user_account)
Expand Down Expand Up @@ -2682,7 +2685,9 @@ async def get_fill_perp_order_ix(
maker_info = (
maker_info
if isinstance(maker_info, list)
else [maker_info] if maker_info else []
else [maker_info]
if maker_info
else []
)

user_accounts = [user_account]
Expand Down
10 changes: 4 additions & 6 deletions src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ def get_perp_position(self, market_index: int) -> Optional[PerpPosition]:

def get_spot_position(self, market_index: int) -> Optional[SpotPosition]:
for position in self.get_user_account().spot_positions:
if (
position.market_index == market_index
and not is_spot_position_available(position)
if position.market_index == market_index and not is_spot_position_available(
position
):
return position

Expand Down Expand Up @@ -358,9 +357,8 @@ def get_user_spot_position(

found = False
for position in user.spot_positions:
if (
position.market_index == market_index
and not is_spot_position_available(position)
if position.market_index == market_index and not is_spot_position_available(
position
):
found = True
break
Expand Down
1 change: 0 additions & 1 deletion src/driftpy/drift_user_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
user_stats_account_pubkey: Pubkey,
config: UserStatsSubscriptionConfig,
):

self.drift_client = drift_client
self.user_stats_account_pubkey = user_stats_account_pubkey
self.account_subscriber = WebsocketUserStatsAccountSubscriber(
Expand Down
23 changes: 15 additions & 8 deletions src/driftpy/math/repeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,20 @@ def calculate_adjust_k_cost(amm: AMM, numerator: int, denominator: int) -> int:

p = numerator * PRICE_PRECISION // denominator

cost = (quote_scale * PERCENTAGE_PRECISION * PERCENTAGE_PRECISION // (x + d)) - (
quote_scale
* p
* PERCENTAGE_PRECISION
* PERCENTAGE_PRECISION
// PRICE_PRECISION
// (x * p // PRICE_PRECISION + d)
) // PERCENTAGE_PRECISION // PERCENTAGE_PRECISION // AMM_TO_QUOTE_PRECISION_RATIO // PEG_PRECISION
cost = (
(quote_scale * PERCENTAGE_PRECISION * PERCENTAGE_PRECISION // (x + d))
- (
quote_scale
* p
* PERCENTAGE_PRECISION
* PERCENTAGE_PRECISION
// PRICE_PRECISION
// (x * p // PRICE_PRECISION + d)
)
// PERCENTAGE_PRECISION
// PERCENTAGE_PRECISION
// AMM_TO_QUOTE_PRECISION_RATIO
// PEG_PRECISION
)

return cost * -1
18 changes: 15 additions & 3 deletions src/driftpy/tx/jito_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from jito_searcher_client.async_searcher import get_async_searcher_client # type: ignore
from jito_searcher_client.generated.searcher_pb2_grpc import SearcherServiceStub # type: ignore
from jito_searcher_client.generated.searcher_pb2 import ConnectedLeadersResponse, ConnectedLeadersRequest, GetTipAccountsRequest, GetTipAccountsResponse, SubscribeBundleResultsRequest # type: ignore
from jito_searcher_client.generated.searcher_pb2 import (
ConnectedLeadersResponse,
ConnectedLeadersRequest,
GetTipAccountsRequest,
GetTipAccountsResponse,
SubscribeBundleResultsRequest,
) # type: ignore


class JitoSubscriber:
Expand Down Expand Up @@ -43,14 +49,20 @@ async def _subscribe(self):
self.bundle_subscription = self.searcher_client.SubscribeBundleResults(
SubscribeBundleResultsRequest()
)
tip_accounts: GetTipAccountsResponse = await self.searcher_client.GetTipAccounts(GetTipAccountsRequest()) # type: ignore
tip_accounts: GetTipAccountsResponse = (
await self.searcher_client.GetTipAccounts(GetTipAccountsRequest())
) # type: ignore
for account in tip_accounts.accounts:
self.tip_accounts.append(Pubkey.from_string(account))
while True:
try:
self.cache.clear()
current_slot = (await self.connection.get_slot(Confirmed)).value
leaders: ConnectedLeadersResponse = await self.searcher_client.GetConnectedLeaders(ConnectedLeadersRequest()) # type: ignore
leaders: ConnectedLeadersResponse = (
await self.searcher_client.GetConnectedLeaders(
ConnectedLeadersRequest()
)
) # type: ignore
for slot_list in leaders.connected_validators.values():
slots = slot_list.slots
for slot in slots:
Expand Down
48 changes: 24 additions & 24 deletions tests/ci/devnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ async def test_devnet_constants(rpc_url: str):
expected.market_index == received.market_index
), f"Devnet Perp: Expected market index {expected.market_index}, got {received.market_index} {market_info} for {expected.symbol}"

assert str(expected.oracle) == str(
received.amm.oracle
assert (
str(expected.oracle) == str(received.amm.oracle)
), f"Devnet Perp: Expected oracle {expected.oracle}, got {received.amm.oracle} {market_info} for {expected.symbol}"

assert str(expected.oracle_source) == str(
received.amm.oracle_source
assert (
str(expected.oracle_source) == str(received.amm.oracle_source)
), f"Devnet Perp: Expected oracle source {expected.oracle_source}, got {received.amm.oracle_source} {market_info} for {expected.symbol}"

expected_spot_markets = sorted(
Expand All @@ -103,12 +103,12 @@ async def test_devnet_constants(rpc_url: str):
expected.market_index == received.market_index
), f"Devnet Spot: Expected market index {expected.market_index}, got {received.market_index} {market_info} for {expected.symbol}"

assert str(expected.oracle) == str(
received.oracle
assert (
str(expected.oracle) == str(received.oracle)
), f"Devnet Spot: Expected oracle {expected.oracle}, got {received.oracle} {market_info} for {expected.symbol}"

assert str(expected.oracle_source) == str(
received.oracle_source
assert (
str(expected.oracle_source) == str(received.oracle_source)
), f"Devnet Spot: Expected oracle source {expected.oracle_source}, got {received.oracle_source} {market_info} for {expected.symbol}"


Expand All @@ -130,28 +130,28 @@ async def test_devnet_cached(rpc_url: str):

perp_markets = drift_client.get_perp_market_accounts()
print(f"1. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
devnet_perp_market_configs
assert (
len(perp_markets) == len(devnet_perp_market_configs)
), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"1. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
devnet_spot_market_configs
assert (
len(spot_markets) == len(devnet_spot_market_configs)
), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

await drift_client.account_subscriber.update_cache()

perp_markets = drift_client.get_perp_market_accounts()
print(f"2. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
devnet_perp_market_configs
assert (
len(perp_markets) == len(devnet_perp_market_configs)
), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"2. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
devnet_spot_market_configs
assert (
len(spot_markets) == len(devnet_spot_market_configs)
), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

print("Unsubscribing from Drift Client")
Expand All @@ -177,29 +177,29 @@ async def test_devnet_ws(rpc_url: str):

perp_markets = drift_client.get_perp_market_accounts()
print(f"1. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
devnet_perp_market_configs
assert (
len(perp_markets) == len(devnet_perp_market_configs)
), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"1. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
devnet_spot_market_configs
assert (
len(spot_markets) == len(devnet_spot_market_configs)
), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

# wait for some updates
await asyncio.sleep(10)

perp_markets = drift_client.get_perp_market_accounts()
print(f"2. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
devnet_perp_market_configs
assert (
len(perp_markets) == len(devnet_perp_market_configs)
), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"2. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
devnet_spot_market_configs
assert (
len(spot_markets) == len(devnet_spot_market_configs)
), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

print("Unsubscribing from Drift Client")
Expand Down
48 changes: 24 additions & 24 deletions tests/ci/mainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ async def test_mainnet_constants(rpc_url: str):
assert (
expected.market_index == received.market_index
), f"Perp: Expected market index {expected.market_index}, got {received.market_index} Market: {received.pubkey}"
assert str(expected.oracle) == str(
received.amm.oracle
assert (
str(expected.oracle) == str(received.amm.oracle)
), f"Perp: Expected oracle {expected.oracle}, got {received.amm.oracle} Market: {received.pubkey} Market Index: {received.market_index}"
assert str(expected.oracle_source) == str(
received.amm.oracle_source
assert (
str(expected.oracle_source) == str(received.amm.oracle_source)
), f"Perp: Expected oracle source {expected.oracle_source}, got {received.amm.oracle_source} Market: {received.pubkey} Market Index: {received.market_index}"

expected_spot_markets = sorted(
Expand All @@ -77,11 +77,11 @@ async def test_mainnet_constants(rpc_url: str):
assert (
expected.market_index == received.market_index
), f"Spot: Expected market index {expected.market_index}, got {received.market_index} Market: {received.pubkey}"
assert str(expected.oracle) == str(
received.oracle
assert (
str(expected.oracle) == str(received.oracle)
), f"Spot: Expected oracle {expected.oracle}, got {received.oracle} Market: {received.pubkey} Market Index: {received.market_index}"
assert str(expected.oracle_source) == str(
received.oracle_source
assert (
str(expected.oracle_source) == str(received.oracle_source)
), f"Spot: Expected oracle source {expected.oracle_source}, got {received.oracle_source} Market: {received.pubkey} Market Index: {received.market_index}"


Expand All @@ -101,28 +101,28 @@ async def test_mainnet_cached(rpc_url: str):

perp_markets = drift_client.get_perp_market_accounts()
print(f"1. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
mainnet_perp_market_configs
assert (
len(perp_markets) == len(mainnet_perp_market_configs)
), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"1. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
mainnet_spot_market_configs
assert (
len(spot_markets) == len(mainnet_spot_market_configs)
), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

await drift_client.account_subscriber.update_cache()

perp_markets = drift_client.get_perp_market_accounts()
print(f"2. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
mainnet_perp_market_configs
assert (
len(perp_markets) == len(mainnet_perp_market_configs)
), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"2. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
mainnet_spot_market_configs
assert (
len(spot_markets) == len(mainnet_spot_market_configs)
), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

print("Unsubscribing from Drift Client")
Expand All @@ -147,29 +147,29 @@ async def test_mainnet_ws(rpc_url: str):

perp_markets = drift_client.get_perp_market_accounts()
print(f"1. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
mainnet_perp_market_configs
assert (
len(perp_markets) == len(mainnet_perp_market_configs)
), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"1. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
mainnet_spot_market_configs
assert (
len(spot_markets) == len(mainnet_spot_market_configs)
), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

# wait for some updates
await asyncio.sleep(10)

perp_markets = drift_client.get_perp_market_accounts()
print(f"2. Got: {len(perp_markets)}")
assert len(perp_markets) == len(
mainnet_perp_market_configs
assert (
len(perp_markets) == len(mainnet_perp_market_configs)
), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}"

spot_markets = drift_client.get_spot_market_accounts()
print(f"2. Got: {len(spot_markets)}")
assert len(spot_markets) == len(
mainnet_spot_market_configs
assert (
len(spot_markets) == len(mainnet_spot_market_configs)
), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

print("Unsubscribing from Drift Client")
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/liq.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ async def test_perp_liq_price(
await drift_client.get_user(0).account_subscriber.update_cache()
await drift_client.deposit(USDC_AMOUNT, 0, usdc_acc.pubkey())
await drift_client.open_position(
PositionDirection.Long(), (175 * BASE_PRECISION) // 10, 0, 0 # 17.5 SOL
PositionDirection.Long(),
(175 * BASE_PRECISION) // 10,
0,
0, # 17.5 SOL
)

lp_shares = drift_client.get_user_account().perp_positions[0].lp_shares
Expand Down
Loading

0 comments on commit b5789d3

Please sign in to comment.