From b5789d380d15d3c3d625050c841017e7c646d6b0 Mon Sep 17 00:00:00 2001 From: sina <20732540+SinaKhalili@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:10:26 -0800 Subject: [PATCH] Ruff: apply formatting --- src/driftpy/drift_client.py | 13 ++++++--- src/driftpy/drift_user.py | 10 +++---- src/driftpy/drift_user_stats.py | 1 - src/driftpy/math/repeg.py | 23 +++++++++------ src/driftpy/tx/jito_subscriber.py | 18 ++++++++++-- tests/ci/devnet.py | 48 +++++++++++++++---------------- tests/ci/mainnet.py | 48 +++++++++++++++---------------- tests/integration/liq.py | 5 +++- tests/integration/oracle.py | 16 ++++++++--- tests/math/spreads.py | 14 +++++++-- 10 files changed, 119 insertions(+), 77 deletions(-) diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 9d069576..5b33b4f7 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -958,7 +958,6 @@ async def get_deposit_collateral_ix( reduce_only: Optional[bool] = False, user_initialized: Optional[bool] = True, ) -> List[Instruction]: - sub_account_id = self.get_sub_account_id_for_ix(sub_account_id) spot_market_account = self.get_spot_market_account(spot_market_index) if not spot_market_account: @@ -1624,7 +1623,9 @@ def get_place_and_take_perp_order_ix( maker_infos = ( maker_info if isinstance(maker_info, list) - else [maker_info] if maker_info else [] + else [maker_info] + if maker_info + else [] ) user_accounts = [self.get_user_account(sub_account_id)] @@ -1724,7 +1725,9 @@ def get_place_and_take_spot_order_ix( maker_infos = ( maker_info if isinstance(maker_info, list) - else [maker_info] if maker_info else [] + else [maker_info] + if maker_info + else [] ) for maker_info in maker_infos: user_accounts.append(maker_info.maker_user_account) @@ -2682,7 +2685,9 @@ async def get_fill_perp_order_ix( maker_info = ( maker_info if isinstance(maker_info, list) - else [maker_info] if maker_info else [] + else [maker_info] + if maker_info + else [] ) user_accounts = [user_account] diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index 7435c38b..6f36ca38 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -123,9 +123,8 @@ def get_perp_position(self, market_index: int) -> Optional[PerpPosition]: def get_spot_position(self, market_index: int) -> Optional[SpotPosition]: for position in self.get_user_account().spot_positions: - if ( - position.market_index == market_index - and not is_spot_position_available(position) + if position.market_index == market_index and not is_spot_position_available( + position ): return position @@ -358,9 +357,8 @@ def get_user_spot_position( found = False for position in user.spot_positions: - if ( - position.market_index == market_index - and not is_spot_position_available(position) + if position.market_index == market_index and not is_spot_position_available( + position ): found = True break diff --git a/src/driftpy/drift_user_stats.py b/src/driftpy/drift_user_stats.py index f199f88d..d567db7a 100644 --- a/src/driftpy/drift_user_stats.py +++ b/src/driftpy/drift_user_stats.py @@ -27,7 +27,6 @@ def __init__( user_stats_account_pubkey: Pubkey, config: UserStatsSubscriptionConfig, ): - self.drift_client = drift_client self.user_stats_account_pubkey = user_stats_account_pubkey self.account_subscriber = WebsocketUserStatsAccountSubscriber( diff --git a/src/driftpy/math/repeg.py b/src/driftpy/math/repeg.py index adcb216a..58db64e8 100644 --- a/src/driftpy/math/repeg.py +++ b/src/driftpy/math/repeg.py @@ -131,13 +131,20 @@ def calculate_adjust_k_cost(amm: AMM, numerator: int, denominator: int) -> int: p = numerator * PRICE_PRECISION // denominator - cost = (quote_scale * PERCENTAGE_PRECISION * PERCENTAGE_PRECISION // (x + d)) - ( - quote_scale - * p - * PERCENTAGE_PRECISION - * PERCENTAGE_PRECISION - // PRICE_PRECISION - // (x * p // PRICE_PRECISION + d) - ) // PERCENTAGE_PRECISION // PERCENTAGE_PRECISION // AMM_TO_QUOTE_PRECISION_RATIO // PEG_PRECISION + cost = ( + (quote_scale * PERCENTAGE_PRECISION * PERCENTAGE_PRECISION // (x + d)) + - ( + quote_scale + * p + * PERCENTAGE_PRECISION + * PERCENTAGE_PRECISION + // PRICE_PRECISION + // (x * p // PRICE_PRECISION + d) + ) + // PERCENTAGE_PRECISION + // PERCENTAGE_PRECISION + // AMM_TO_QUOTE_PRECISION_RATIO + // PEG_PRECISION + ) return cost * -1 diff --git a/src/driftpy/tx/jito_subscriber.py b/src/driftpy/tx/jito_subscriber.py index 578496e8..cf5789c3 100644 --- a/src/driftpy/tx/jito_subscriber.py +++ b/src/driftpy/tx/jito_subscriber.py @@ -13,7 +13,13 @@ from jito_searcher_client.async_searcher import get_async_searcher_client # type: ignore from jito_searcher_client.generated.searcher_pb2_grpc import SearcherServiceStub # type: ignore -from jito_searcher_client.generated.searcher_pb2 import ConnectedLeadersResponse, ConnectedLeadersRequest, GetTipAccountsRequest, GetTipAccountsResponse, SubscribeBundleResultsRequest # type: ignore +from jito_searcher_client.generated.searcher_pb2 import ( + ConnectedLeadersResponse, + ConnectedLeadersRequest, + GetTipAccountsRequest, + GetTipAccountsResponse, + SubscribeBundleResultsRequest, +) # type: ignore class JitoSubscriber: @@ -43,14 +49,20 @@ async def _subscribe(self): self.bundle_subscription = self.searcher_client.SubscribeBundleResults( SubscribeBundleResultsRequest() ) - tip_accounts: GetTipAccountsResponse = await self.searcher_client.GetTipAccounts(GetTipAccountsRequest()) # type: ignore + tip_accounts: GetTipAccountsResponse = ( + await self.searcher_client.GetTipAccounts(GetTipAccountsRequest()) + ) # type: ignore for account in tip_accounts.accounts: self.tip_accounts.append(Pubkey.from_string(account)) while True: try: self.cache.clear() current_slot = (await self.connection.get_slot(Confirmed)).value - leaders: ConnectedLeadersResponse = await self.searcher_client.GetConnectedLeaders(ConnectedLeadersRequest()) # type: ignore + leaders: ConnectedLeadersResponse = ( + await self.searcher_client.GetConnectedLeaders( + ConnectedLeadersRequest() + ) + ) # type: ignore for slot_list in leaders.connected_validators.values(): slots = slot_list.slots for slot in slots: diff --git a/tests/ci/devnet.py b/tests/ci/devnet.py index 17182323..7a4bfddf 100644 --- a/tests/ci/devnet.py +++ b/tests/ci/devnet.py @@ -72,12 +72,12 @@ async def test_devnet_constants(rpc_url: str): expected.market_index == received.market_index ), f"Devnet Perp: Expected market index {expected.market_index}, got {received.market_index} {market_info} for {expected.symbol}" - assert str(expected.oracle) == str( - received.amm.oracle + assert ( + str(expected.oracle) == str(received.amm.oracle) ), f"Devnet Perp: Expected oracle {expected.oracle}, got {received.amm.oracle} {market_info} for {expected.symbol}" - assert str(expected.oracle_source) == str( - received.amm.oracle_source + assert ( + str(expected.oracle_source) == str(received.amm.oracle_source) ), f"Devnet Perp: Expected oracle source {expected.oracle_source}, got {received.amm.oracle_source} {market_info} for {expected.symbol}" expected_spot_markets = sorted( @@ -103,12 +103,12 @@ async def test_devnet_constants(rpc_url: str): expected.market_index == received.market_index ), f"Devnet Spot: Expected market index {expected.market_index}, got {received.market_index} {market_info} for {expected.symbol}" - assert str(expected.oracle) == str( - received.oracle + assert ( + str(expected.oracle) == str(received.oracle) ), f"Devnet Spot: Expected oracle {expected.oracle}, got {received.oracle} {market_info} for {expected.symbol}" - assert str(expected.oracle_source) == str( - received.oracle_source + assert ( + str(expected.oracle_source) == str(received.oracle_source) ), f"Devnet Spot: Expected oracle source {expected.oracle_source}, got {received.oracle_source} {market_info} for {expected.symbol}" @@ -130,28 +130,28 @@ async def test_devnet_cached(rpc_url: str): perp_markets = drift_client.get_perp_market_accounts() print(f"1. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - devnet_perp_market_configs + assert ( + len(perp_markets) == len(devnet_perp_market_configs) ), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"1. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - devnet_spot_market_configs + assert ( + len(spot_markets) == len(devnet_spot_market_configs) ), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}" await drift_client.account_subscriber.update_cache() perp_markets = drift_client.get_perp_market_accounts() print(f"2. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - devnet_perp_market_configs + assert ( + len(perp_markets) == len(devnet_perp_market_configs) ), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"2. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - devnet_spot_market_configs + assert ( + len(spot_markets) == len(devnet_spot_market_configs) ), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}" print("Unsubscribing from Drift Client") @@ -177,14 +177,14 @@ async def test_devnet_ws(rpc_url: str): perp_markets = drift_client.get_perp_market_accounts() print(f"1. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - devnet_perp_market_configs + assert ( + len(perp_markets) == len(devnet_perp_market_configs) ), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"1. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - devnet_spot_market_configs + assert ( + len(spot_markets) == len(devnet_spot_market_configs) ), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}" # wait for some updates @@ -192,14 +192,14 @@ async def test_devnet_ws(rpc_url: str): perp_markets = drift_client.get_perp_market_accounts() print(f"2. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - devnet_perp_market_configs + assert ( + len(perp_markets) == len(devnet_perp_market_configs) ), f"Expected {len(devnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"2. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - devnet_spot_market_configs + assert ( + len(spot_markets) == len(devnet_spot_market_configs) ), f"Expected {len(devnet_spot_market_configs)} spot markets, got {len(spot_markets)}" print("Unsubscribing from Drift Client") diff --git a/tests/ci/mainnet.py b/tests/ci/mainnet.py index bb13251e..a143f4bc 100644 --- a/tests/ci/mainnet.py +++ b/tests/ci/mainnet.py @@ -59,11 +59,11 @@ async def test_mainnet_constants(rpc_url: str): assert ( expected.market_index == received.market_index ), f"Perp: Expected market index {expected.market_index}, got {received.market_index} Market: {received.pubkey}" - assert str(expected.oracle) == str( - received.amm.oracle + assert ( + str(expected.oracle) == str(received.amm.oracle) ), f"Perp: Expected oracle {expected.oracle}, got {received.amm.oracle} Market: {received.pubkey} Market Index: {received.market_index}" - assert str(expected.oracle_source) == str( - received.amm.oracle_source + assert ( + str(expected.oracle_source) == str(received.amm.oracle_source) ), f"Perp: Expected oracle source {expected.oracle_source}, got {received.amm.oracle_source} Market: {received.pubkey} Market Index: {received.market_index}" expected_spot_markets = sorted( @@ -77,11 +77,11 @@ async def test_mainnet_constants(rpc_url: str): assert ( expected.market_index == received.market_index ), f"Spot: Expected market index {expected.market_index}, got {received.market_index} Market: {received.pubkey}" - assert str(expected.oracle) == str( - received.oracle + assert ( + str(expected.oracle) == str(received.oracle) ), f"Spot: Expected oracle {expected.oracle}, got {received.oracle} Market: {received.pubkey} Market Index: {received.market_index}" - assert str(expected.oracle_source) == str( - received.oracle_source + assert ( + str(expected.oracle_source) == str(received.oracle_source) ), f"Spot: Expected oracle source {expected.oracle_source}, got {received.oracle_source} Market: {received.pubkey} Market Index: {received.market_index}" @@ -101,28 +101,28 @@ async def test_mainnet_cached(rpc_url: str): perp_markets = drift_client.get_perp_market_accounts() print(f"1. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - mainnet_perp_market_configs + assert ( + len(perp_markets) == len(mainnet_perp_market_configs) ), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"1. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - mainnet_spot_market_configs + assert ( + len(spot_markets) == len(mainnet_spot_market_configs) ), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}" await drift_client.account_subscriber.update_cache() perp_markets = drift_client.get_perp_market_accounts() print(f"2. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - mainnet_perp_market_configs + assert ( + len(perp_markets) == len(mainnet_perp_market_configs) ), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"2. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - mainnet_spot_market_configs + assert ( + len(spot_markets) == len(mainnet_spot_market_configs) ), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}" print("Unsubscribing from Drift Client") @@ -147,14 +147,14 @@ async def test_mainnet_ws(rpc_url: str): perp_markets = drift_client.get_perp_market_accounts() print(f"1. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - mainnet_perp_market_configs + assert ( + len(perp_markets) == len(mainnet_perp_market_configs) ), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"1. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - mainnet_spot_market_configs + assert ( + len(spot_markets) == len(mainnet_spot_market_configs) ), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}" # wait for some updates @@ -162,14 +162,14 @@ async def test_mainnet_ws(rpc_url: str): perp_markets = drift_client.get_perp_market_accounts() print(f"2. Got: {len(perp_markets)}") - assert len(perp_markets) == len( - mainnet_perp_market_configs + assert ( + len(perp_markets) == len(mainnet_perp_market_configs) ), f"Expected {len(mainnet_perp_market_configs)} perp markets, got {len(perp_markets)}" spot_markets = drift_client.get_spot_market_accounts() print(f"2. Got: {len(spot_markets)}") - assert len(spot_markets) == len( - mainnet_spot_market_configs + assert ( + len(spot_markets) == len(mainnet_spot_market_configs) ), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}" print("Unsubscribing from Drift Client") diff --git a/tests/integration/liq.py b/tests/integration/liq.py index a2b4888b..3d3d3f7f 100644 --- a/tests/integration/liq.py +++ b/tests/integration/liq.py @@ -151,7 +151,10 @@ async def test_perp_liq_price( await drift_client.get_user(0).account_subscriber.update_cache() await drift_client.deposit(USDC_AMOUNT, 0, usdc_acc.pubkey()) await drift_client.open_position( - PositionDirection.Long(), (175 * BASE_PRECISION) // 10, 0, 0 # 17.5 SOL + PositionDirection.Long(), + (175 * BASE_PRECISION) // 10, + 0, + 0, # 17.5 SOL ) lp_shares = drift_client.get_user_account().perp_positions[0].lp_shares diff --git a/tests/integration/oracle.py b/tests/integration/oracle.py index 79c1094b..c7d1d89e 100644 --- a/tests/integration/oracle.py +++ b/tests/integration/oracle.py @@ -155,13 +155,17 @@ async def test_polling( await asyncio.sleep(20) - perp_oracle_price_before = (polling_drift_client.get_oracle_price_data_for_perp_market(0)).price # type: ignore + perp_oracle_price_before = ( + polling_drift_client.get_oracle_price_data_for_perp_market(0) + ).price # type: ignore print(f"perp_oracle_price_before: {perp_oracle_price_before}") assert perp_oracle_price_before == 30 * PRICE_PRECISION await asyncio.sleep(10) - perp_oracle_price_after = (polling_drift_client.get_oracle_price_data_for_perp_market(0)).price # type: ignore + perp_oracle_price_after = ( + polling_drift_client.get_oracle_price_data_for_perp_market(0) + ).price # type: ignore print(f"perp_oracle_price_after: {perp_oracle_price_after}") assert perp_oracle_price_after == 100 * PRICE_PRECISION @@ -171,13 +175,17 @@ async def test_polling( await asyncio.sleep(20) - spot_oracle_price_before = (polling_drift_client.get_oracle_price_data_for_spot_market(1)).price # type: ignore + spot_oracle_price_before = ( + polling_drift_client.get_oracle_price_data_for_spot_market(1) + ).price # type: ignore print(f"spot_oracle_price_before: {spot_oracle_price_before}") assert spot_oracle_price_before == 30 * PRICE_PRECISION await asyncio.sleep(10) - spot_oracle_price_after = (polling_drift_client.get_oracle_price_data_for_spot_market(1)).price # type: ignore + spot_oracle_price_after = ( + polling_drift_client.get_oracle_price_data_for_spot_market(1) + ).price # type: ignore print(f"spot_oracle_price_after: {spot_oracle_price_after}") assert spot_oracle_price_after == 100 * PRICE_PRECISION diff --git a/tests/math/spreads.py b/tests/math/spreads.py index 683f0a71..e22ed6eb 100644 --- a/tests/math/spreads.py +++ b/tests/math/spreads.py @@ -370,7 +370,12 @@ async def test_spread_reserves_with_offset(): now = int(time.time()) oracle_price_data = OraclePriceData( - int(13.553 * PRICE_PRECISION), 69, 1, 0, 0, True # kek + int(13.553 * PRICE_PRECISION), + 69, + 1, + 0, + 0, + True, # kek ) bid_reserves, ask_reserves = calculate_spread_reserves(amm, oracle_price_data, now) @@ -519,7 +524,12 @@ async def test_spread_reserves_with_negative_offset(): now = int(time.time()) oracle_price_data = OraclePriceData( - int(13.553 * PRICE_PRECISION), 69, 1, 0, 0, True # kek + int(13.553 * PRICE_PRECISION), + 69, + 1, + 0, + 0, + True, # kek ) bid_reserves, ask_reserves = calculate_spread_reserves(amm, oracle_price_data, now)