From a4cf18ba80cfd5609345f4fa5ad604eb24b20d7b Mon Sep 17 00:00:00 2001 From: sina <20732540+SinaKhalili@users.noreply.github.com> Date: Wed, 11 Dec 2024 04:11:08 -0800 Subject: [PATCH] Refactor DriftClient, add test - some formatting - use oracle id in `PollingDriftClientAccountSubscriber` - update idl - add test but it fails --- src/driftpy/accounts/polling/drift_client.py | 44 ++-- src/driftpy/constants/perp_markets.py | 10 +- src/driftpy/drift_client.py | 108 ++++---- src/driftpy/idl/drift.json | 21 ++ tests/integration/test_oracle_diff_sources.py | 231 ++++++++++++++++++ 5 files changed, 349 insertions(+), 65 deletions(-) create mode 100644 tests/integration/test_oracle_diff_sources.py diff --git a/src/driftpy/accounts/polling/drift_client.py b/src/driftpy/accounts/polling/drift_client.py index 5098da03..14a7fb79 100644 --- a/src/driftpy/accounts/polling/drift_client.py +++ b/src/driftpy/accounts/polling/drift_client.py @@ -1,24 +1,28 @@ import asyncio from typing import Optional, Sequence, Union -from anchorpy import Program -from driftpy.accounts import DataAndSlot -from driftpy.accounts import DriftClientAccountSubscriber +from anchorpy.program.core import Program +from solders.pubkey import Pubkey + +from driftpy.accounts import DataAndSlot, DriftClientAccountSubscriber from driftpy.accounts.bulk_account_loader import BulkAccountLoader from driftpy.accounts.oracle import get_oracle_decode_fn -from driftpy.addresses import get_perp_market_public_key -from driftpy.addresses import get_spot_market_public_key -from driftpy.addresses import get_state_public_key +from driftpy.addresses import ( + get_perp_market_public_key, + get_spot_market_public_key, + get_state_public_key, +) from driftpy.constants.config import find_all_market_and_oracles from driftpy.oracles.oracle_id import get_oracle_id -from driftpy.types import OracleInfo -from driftpy.types import OraclePriceData -from driftpy.types import OracleSource -from driftpy.types import PerpMarketAccount -from driftpy.types import SpotMarketAccount -from driftpy.types import stack_trace -from driftpy.types import StateAccount -from solders.pubkey import Pubkey +from driftpy.types import ( + OracleInfo, + OraclePriceData, + OracleSource, + PerpMarketAccount, + SpotMarketAccount, + StateAccount, + stack_trace, +) class PollingDriftClientAccountSubscriber(DriftClientAccountSubscriber): @@ -195,9 +199,8 @@ def get_spot_market_and_slot( return self.spot_markets.get(market_index) def get_oracle_price_data_and_slot( - self, oracle: Pubkey, oracle_source: OracleSource + self, oracle_id: str ) -> Optional[DataAndSlot[OraclePriceData]]: - oracle_id = get_oracle_id(oracle, oracle_source) return self.oracle.get(oracle_id) def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]: @@ -241,9 +244,18 @@ async def _set_spot_oracle_map(self): def get_oracle_price_data_and_slot_for_perp_market( self, market_index: int ) -> Union[DataAndSlot[OraclePriceData], None]: + print( + "==> PollingDriftClientAccountSubscriber: Getting oracle price data for perp market", + market_index, + ) + print(self.perp_markets) + print(self.spot_markets) perp_market_account = self.get_perp_market_and_slot(market_index) oracle = self.perp_oracle_map.get(market_index) + print("Perp market account: ", perp_market_account) + print("Oracle: ", oracle) + if not perp_market_account or not oracle: return None diff --git a/src/driftpy/constants/perp_markets.py b/src/driftpy/constants/perp_markets.py index ed6670a2..96109c7d 100644 --- a/src/driftpy/constants/perp_markets.py +++ b/src/driftpy/constants/perp_markets.py @@ -1,8 +1,9 @@ from dataclasses import dataclass -from driftpy.types import OracleSource from solders.pubkey import Pubkey # type: ignore +from driftpy.types import OracleSource + @dataclass class PerpMarketConfig: @@ -648,4 +649,11 @@ class PerpMarketConfig: oracle=Pubkey.from_string("AmjHowvVkVJApCPUiwV9CdHVFn29LiBYZQqtZQ3xMqdg"), oracle_source=OracleSource.PythPull(), ), + PerpMarketConfig( + symbol="ME-PERP", + base_asset_symbol="ME", + market_index=61, + oracle=Pubkey.from_string("FLQjrmEPGwbCKRYZ1eYM5FPccHBrCv2cN4GBu3mWfmPH"), + oracle_source=OracleSource.PythPull(), + ), ] diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 5b33b4f7..248df5f9 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1,51 +1,21 @@ import json import os -from pathlib import Path import random import string +from pathlib import Path from typing import List, Optional, Tuple, Union import anchorpy -from anchorpy import Context -from anchorpy import Idl -from anchorpy import Program -from anchorpy import Provider -from anchorpy import Wallet -from deprecated import deprecated -import driftpy -from driftpy.account_subscription_config import AccountSubscriptionConfig -from driftpy.accounts import * -from driftpy.address_lookup_table import get_address_lookup_table -from driftpy.addresses import get_sequencer_public_key_and_bump -from driftpy.constants import BASE_PRECISION -from driftpy.constants import PRICE_PRECISION -from driftpy.constants.config import configs -from driftpy.constants.config import DEVNET_SEQUENCER_PROGRAM_ID -from driftpy.constants.config import DRIFT_PROGRAM_ID -from driftpy.constants.config import DriftEnv -from driftpy.constants.config import SEQUENCER_PROGRAM_ID -from driftpy.constants.numeric_constants import QUOTE_SPOT_MARKET_INDEX -from driftpy.constants.spot_markets import WRAPPED_SOL_MINT -from driftpy.decode.utils import decode_name -from driftpy.drift_user import DriftUser -from driftpy.drift_user_stats import DriftUserStats -from driftpy.drift_user_stats import UserStatsSubscriptionConfig -from driftpy.math.perp_position import is_available -from driftpy.math.spot_market import cast_to_spot_precision -from driftpy.math.spot_position import is_spot_position_available -from driftpy.name import encode_name -from driftpy.tx.standard_tx_sender import StandardTxSender -from driftpy.tx.types import TxSender -from driftpy.tx.types import TxSigAndSlot import requests +from anchorpy import Context, Idl, Program, Provider, Wallet +from deprecated import deprecated from solana.rpc.async_api import AsyncClient from solana.rpc.commitment import Processed from solana.rpc.types import TxOpts from solana.transaction import AccountMeta from solders import system_program from solders.address_lookup_table_account import AddressLookupTableAccount -from solders.compute_budget import set_compute_unit_limit -from solders.compute_budget import set_compute_unit_price +from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price from solders.instruction import Instruction from solders.keypair import Keypair from solders.pubkey import Pubkey @@ -53,16 +23,42 @@ from solders.system_program import ID from solders.system_program import ID as SYS_PROGRAM_ID from solders.sysvar import RENT -from solders.transaction import Legacy -from solders.transaction import TransactionVersion -from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID -from spl.token.constants import TOKEN_PROGRAM_ID -from spl.token.instructions import close_account -from spl.token.instructions import CloseAccountParams -from spl.token.instructions import get_associated_token_address -from spl.token.instructions import initialize_account -from spl.token.instructions import InitializeAccountParams +from solders.transaction import Legacy, TransactionVersion +from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID +from spl.token.instructions import ( + CloseAccountParams, + InitializeAccountParams, + close_account, + get_associated_token_address, + initialize_account, +) +import driftpy +from driftpy.account_subscription_config import AccountSubscriptionConfig +from driftpy.accounts import * +from driftpy.accounts.cache import CachedDriftClientAccountSubscriber +from driftpy.accounts.demo import DemoDriftClientAccountSubscriber +from driftpy.address_lookup_table import get_address_lookup_table +from driftpy.addresses import get_sequencer_public_key_and_bump +from driftpy.constants import BASE_PRECISION, PRICE_PRECISION +from driftpy.constants.config import ( + DEVNET_SEQUENCER_PROGRAM_ID, + DRIFT_PROGRAM_ID, + SEQUENCER_PROGRAM_ID, + DriftEnv, + configs, +) +from driftpy.constants.numeric_constants import QUOTE_SPOT_MARKET_INDEX +from driftpy.constants.spot_markets import WRAPPED_SOL_MINT +from driftpy.decode.utils import decode_name +from driftpy.drift_user import DriftUser +from driftpy.drift_user_stats import DriftUserStats, UserStatsSubscriptionConfig +from driftpy.math.perp_position import is_available +from driftpy.math.spot_market import cast_to_spot_precision +from driftpy.math.spot_position import is_spot_position_available +from driftpy.name import encode_name +from driftpy.tx.standard_tx_sender import StandardTxSender +from driftpy.tx.types import TxSender, TxSigAndSlot DEFAULT_USER_NAME = "Main Account" @@ -228,8 +224,6 @@ async def subscribe(self): await self.add_user_stats(self.authority) def resurrect(self, spot_markets, perp_markets, spot_oracles, perp_oracles): - from driftpy.accounts.cache import CachedDriftClientAccountSubscriber - if not isinstance(self.account_subscriber, CachedDriftClientAccountSubscriber): raise ValueError( 'You can only resurrect a DriftClient that was initialized with AccountSubscriptionConfig("cached")' @@ -346,14 +340,27 @@ def get_quote_spot_market_account(self) -> Optional[SpotMarketAccount]: return getattr(spot_market_and_slot, "data", None) def get_oracle_price_data(self, oracle_id: str) -> Optional[OraclePriceData]: - oracle_price_data_and_slot = ( - self.account_subscriber.get_oracle_price_data_and_slot(oracle_id) + if self.account_subscriber is None: + return None + + data_and_slot = self.account_subscriber.get_oracle_price_data_and_slot( + oracle_id ) - return getattr(oracle_price_data_and_slot, "data", None) + + if data_and_slot is None: + return None + + return getattr(data_and_slot, "data", None) def get_oracle_price_data_for_perp_market( self, market_index: int ) -> Optional[OraclePriceData]: + if self.account_subscriber is None: + raise ValueError("No account subscriber found") + + if isinstance(self.account_subscriber, DemoDriftClientAccountSubscriber): + raise ValueError("Cannot get market for demo subscriber") + data = self.account_subscriber.get_oracle_price_data_and_slot_for_perp_market( market_index ) @@ -369,6 +376,11 @@ def get_oracle_price_data_for_perp_market( def get_oracle_price_data_for_spot_market( self, market_index: int ) -> Optional[OraclePriceData]: + if self.account_subscriber is None: + return None + if isinstance(self.account_subscriber, DemoDriftClientAccountSubscriber): + raise ValueError("Cannot get market for demo subscriber") + data = self.account_subscriber.get_oracle_price_data_and_slot_for_spot_market( market_index ) diff --git a/src/driftpy/idl/drift.json b/src/driftpy/idl/drift.json index af373bb8..b676e46a 100644 --- a/src/driftpy/idl/drift.json +++ b/src/driftpy/idl/drift.json @@ -1857,6 +1857,27 @@ ], "args": [] }, + { + "name": "logUserBalances", + "accounts": [ + { + "name": "state", + "isMut": false, + "isSigner": false + }, + { + "name": "authority", + "isMut": false, + "isSigner": true + }, + { + "name": "user", + "isMut": true, + "isSigner": false + } + ], + "args": [] + }, { "name": "disableUserHighLeverageMode", "accounts": [ diff --git a/tests/integration/test_oracle_diff_sources.py b/tests/integration/test_oracle_diff_sources.py new file mode 100644 index 00000000..5eee8715 --- /dev/null +++ b/tests/integration/test_oracle_diff_sources.py @@ -0,0 +1,231 @@ +import asyncio +from math import sqrt + +from anchorpy.program.core import Program +from anchorpy.provider import Provider +from anchorpy.pytest_plugin import workspace_fixture +from anchorpy.workspace import WorkspaceType +from pytest import fixture, mark +from pytest_asyncio import fixture as async_fixture +from solana.rpc.commitment import Commitment +from solders.keypair import Keypair +from solders.pubkey import Pubkey + +from driftpy.account_subscription_config import AccountSubscriptionConfig +from driftpy.accounts.bulk_account_loader import BulkAccountLoader +from driftpy.admin import Admin +from driftpy.constants.numeric_constants import ( + PEG_PRECISION, + PRICE_PRECISION, + QUOTE_PRECISION, +) +from driftpy.drift_client import DriftClient +from driftpy.setup.helpers import ( + _create_and_mint_user_usdc, + _create_mint, + initialize_sol_spot_market, + mock_oracle, +) +from driftpy.types import OracleInfo, OracleSource + +workspace = workspace_fixture("protocol-v2", scope="function") + +MANTISSA_SQRT_SCALE = int(sqrt(PRICE_PRECISION)) +AMM_INITIAL_BAA = (5 * 10**13) * MANTISSA_SQRT_SCALE +AMM_INITIAL_QAA = (5 * 10**13) * MANTISSA_SQRT_SCALE +USDC_AMOUNT = 10 * QUOTE_PRECISION + + +@fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@async_fixture(scope="function") +async def usdc_mint(provider: Provider): + return await _create_mint(provider) + + +@async_fixture(scope="function") +async def user_usdc_account( + usdc_mint: Keypair, + provider: Provider, +): + return await _create_and_mint_user_usdc( + usdc_mint, provider, USDC_AMOUNT * 2, provider.wallet.public_key + ) + + +@fixture(scope="function") +def program(workspace: WorkspaceType) -> Program: + """Create a Program instance.""" + return workspace["drift"] + + +@fixture(scope="function") +def provider(program: Program) -> Provider: + return program.provider + + +@async_fixture(scope="function") +async def admin_client(program: Program, usdc_mint: Keypair) -> Admin: + market_indexes = [0, 1] + spot_market_indexes = [0, 1, 2] + + admin = Admin( + program.provider.connection, + program.provider.wallet, + account_subscription=AccountSubscriptionConfig("polling"), + perp_market_indexes=market_indexes, + spot_market_indexes=spot_market_indexes, + ) + await admin.initialize(usdc_mint.pubkey(), admin_controls_prices=True) + await admin.subscribe() + return admin + + +@async_fixture(scope="function") +async def sol_oracle(workspace: WorkspaceType): + oracle_program = workspace["pyth"] + oracle = await mock_oracle(oracle_program, 3, -7) + return oracle + + +@async_fixture(scope="function") +async def setup_markets(admin_client: Admin, usdc_mint: Keypair, sol_oracle: Pubkey): + # Initialize markets + await admin_client.initialize_spot_market(usdc_mint.pubkey()) + + # Initialize SOL spot markets with different oracle sources + mint = await _create_mint(admin_client.program.provider) + await initialize_sol_spot_market( + admin_client, + sol_oracle, + mint.pubkey(), + oracle_source=OracleSource.Pyth(), # type: ignore + ) + await initialize_sol_spot_market( + admin_client, + sol_oracle, + mint.pubkey(), + oracle_source=OracleSource.Pyth1K(), # type: ignore + ) + + # Initialize perp markets + await admin_client.initialize_perp_market( + 0, + sol_oracle, + AMM_INITIAL_BAA, + AMM_INITIAL_QAA, + 0, + 3 * PEG_PRECISION, + oracle_source=OracleSource.Pyth(), # type: ignore + ) + await admin_client.initialize_perp_market( + 1, + sol_oracle, + AMM_INITIAL_BAA, + AMM_INITIAL_QAA, + 0, + 3000 * PEG_PRECISION, + oracle_source=OracleSource.Pyth1K(), # type: ignore + ) + + +@mark.asyncio +async def test_polling( + program: Program, admin_client: Admin, sol_oracle: Pubkey, setup_markets +): + oracle_infos = [ + OracleInfo(sol_oracle, OracleSource.Pyth()), # type: ignore + OracleInfo(sol_oracle, OracleSource.Pyth1K()), # type: ignore + ] + + polling_client = DriftClient( + program.provider.connection, + program.provider.wallet, + account_subscription=AccountSubscriptionConfig( + "polling", + bulk_account_loader=BulkAccountLoader(program.provider.connection), + ), + spot_market_indexes=[0, 1, 2], + perp_market_indexes=[0, 1], + oracle_infos=oracle_infos, + ) + await polling_client.subscribe() + + # Verify spot market oracles + oracle_data_for_spot_market_1 = ( + polling_client.get_oracle_price_data_for_spot_market(1) + ) + assert oracle_data_for_spot_market_1 is not None + spot_price_1 = oracle_data_for_spot_market_1.price + assert spot_price_1 == 3 * PRICE_PRECISION + + oracle_data_for_spot_market_2 = ( + polling_client.get_oracle_price_data_for_spot_market(2) + ) + assert oracle_data_for_spot_market_2 is not None + spot_price_2 = oracle_data_for_spot_market_2.price + assert spot_price_2 == 3000 * PRICE_PRECISION + + # Verify perp market oracles + oracle_data_for_perp_market_0 = ( + polling_client.get_oracle_price_data_for_perp_market(0) + ) + assert oracle_data_for_perp_market_0 is not None + perp_price_0 = oracle_data_for_perp_market_0.price + assert perp_price_0 == 3 * PRICE_PRECISION + + oracle_data_for_perp_market_1 = ( + polling_client.get_oracle_price_data_for_perp_market(1) + ) + assert oracle_data_for_perp_market_1 is not None + perp_price_1 = oracle_data_for_perp_market_1.price + assert perp_price_1 == 3000 * PRICE_PRECISION + + +@mark.asyncio +async def test_ws( + program: Program, admin_client: Admin, sol_oracle: Pubkey, setup_markets +): + oracle_infos = [ + OracleInfo(sol_oracle, OracleSource.Pyth()), # type: ignore + OracleInfo(sol_oracle, OracleSource.Pyth1K()), # type: ignore + ] + + ws_client = DriftClient( + program.provider.connection, + program.provider.wallet, + account_subscription=AccountSubscriptionConfig( + "websocket", commitment=Commitment("processed") + ), + spot_market_indexes=[0, 1, 2], + perp_market_indexes=[0, 1], + oracle_infos=oracle_infos, + ) + await ws_client.subscribe() + + # Verify spot market oracles + oracle_data_for_spot_market_1 = ws_client.get_oracle_price_data_for_spot_market(1) + assert oracle_data_for_spot_market_1 is not None + spot_price_1 = oracle_data_for_spot_market_1.price + assert spot_price_1 == 3 * PRICE_PRECISION + + oracle_data_for_spot_market_2 = ws_client.get_oracle_price_data_for_spot_market(2) + assert oracle_data_for_spot_market_2 is not None + spot_price_2 = oracle_data_for_spot_market_2.price + assert spot_price_2 == 3000 * PRICE_PRECISION + + # Verify perp market oracles + oracle_data_for_perp_market_0 = ws_client.get_oracle_price_data_for_perp_market(0) + assert oracle_data_for_perp_market_0 is not None + perp_price_0 = oracle_data_for_perp_market_0.price + assert perp_price_0 == 3 * PRICE_PRECISION + + oracle_data_for_perp_market_1 = ws_client.get_oracle_price_data_for_perp_market(1) + assert oracle_data_for_perp_market_1 is not None + perp_price_1 = oracle_data_for_perp_market_1.price + assert perp_price_1 == 3000 * PRICE_PRECISION