Skip to content

Commit

Permalink
Refactor DriftClient, add test
Browse files Browse the repository at this point in the history
- some formatting
- use oracle id in `PollingDriftClientAccountSubscriber`
- update idl
- add test but it fails
  • Loading branch information
SinaKhalili committed Dec 11, 2024
1 parent b5789d3 commit a4cf18b
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 65 deletions.
44 changes: 28 additions & 16 deletions src/driftpy/accounts/polling/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import asyncio
from typing import Optional, Sequence, Union

from anchorpy import Program
from driftpy.accounts import DataAndSlot
from driftpy.accounts import DriftClientAccountSubscriber
from anchorpy.program.core import Program
from solders.pubkey import Pubkey

from driftpy.accounts import DataAndSlot, DriftClientAccountSubscriber
from driftpy.accounts.bulk_account_loader import BulkAccountLoader
from driftpy.accounts.oracle import get_oracle_decode_fn
from driftpy.addresses import get_perp_market_public_key
from driftpy.addresses import get_spot_market_public_key
from driftpy.addresses import get_state_public_key
from driftpy.addresses import (
get_perp_market_public_key,
get_spot_market_public_key,
get_state_public_key,
)
from driftpy.constants.config import find_all_market_and_oracles
from driftpy.oracles.oracle_id import get_oracle_id
from driftpy.types import OracleInfo
from driftpy.types import OraclePriceData
from driftpy.types import OracleSource
from driftpy.types import PerpMarketAccount
from driftpy.types import SpotMarketAccount
from driftpy.types import stack_trace
from driftpy.types import StateAccount
from solders.pubkey import Pubkey
from driftpy.types import (
OracleInfo,
OraclePriceData,
OracleSource,
PerpMarketAccount,
SpotMarketAccount,
StateAccount,
stack_trace,
)


class PollingDriftClientAccountSubscriber(DriftClientAccountSubscriber):
Expand Down Expand Up @@ -195,9 +199,8 @@ def get_spot_market_and_slot(
return self.spot_markets.get(market_index)

def get_oracle_price_data_and_slot(
self, oracle: Pubkey, oracle_source: OracleSource
self, oracle_id: str
) -> Optional[DataAndSlot[OraclePriceData]]:
oracle_id = get_oracle_id(oracle, oracle_source)
return self.oracle.get(oracle_id)

def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]:
Expand Down Expand Up @@ -241,9 +244,18 @@ async def _set_spot_oracle_map(self):
def get_oracle_price_data_and_slot_for_perp_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
print(
"==> PollingDriftClientAccountSubscriber: Getting oracle price data for perp market",
market_index,
)
print(self.perp_markets)
print(self.spot_markets)
perp_market_account = self.get_perp_market_and_slot(market_index)
oracle = self.perp_oracle_map.get(market_index)

print("Perp market account: ", perp_market_account)
print("Oracle: ", oracle)

if not perp_market_account or not oracle:
return None

Expand Down
10 changes: 9 additions & 1 deletion src/driftpy/constants/perp_markets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass

from driftpy.types import OracleSource
from solders.pubkey import Pubkey # type: ignore

from driftpy.types import OracleSource


@dataclass
class PerpMarketConfig:
Expand Down Expand Up @@ -648,4 +649,11 @@ class PerpMarketConfig:
oracle=Pubkey.from_string("AmjHowvVkVJApCPUiwV9CdHVFn29LiBYZQqtZQ3xMqdg"),
oracle_source=OracleSource.PythPull(),
),
PerpMarketConfig(
symbol="ME-PERP",
base_asset_symbol="ME",
market_index=61,
oracle=Pubkey.from_string("FLQjrmEPGwbCKRYZ1eYM5FPccHBrCv2cN4GBu3mWfmPH"),
oracle_source=OracleSource.PythPull(),
),
]
108 changes: 60 additions & 48 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,64 @@
import json
import os
from pathlib import Path
import random
import string
from pathlib import Path
from typing import List, Optional, Tuple, Union

import anchorpy
from anchorpy import Context
from anchorpy import Idl
from anchorpy import Program
from anchorpy import Provider
from anchorpy import Wallet
from deprecated import deprecated
import driftpy
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.accounts import *
from driftpy.address_lookup_table import get_address_lookup_table
from driftpy.addresses import get_sequencer_public_key_and_bump
from driftpy.constants import BASE_PRECISION
from driftpy.constants import PRICE_PRECISION
from driftpy.constants.config import configs
from driftpy.constants.config import DEVNET_SEQUENCER_PROGRAM_ID
from driftpy.constants.config import DRIFT_PROGRAM_ID
from driftpy.constants.config import DriftEnv
from driftpy.constants.config import SEQUENCER_PROGRAM_ID
from driftpy.constants.numeric_constants import QUOTE_SPOT_MARKET_INDEX
from driftpy.constants.spot_markets import WRAPPED_SOL_MINT
from driftpy.decode.utils import decode_name
from driftpy.drift_user import DriftUser
from driftpy.drift_user_stats import DriftUserStats
from driftpy.drift_user_stats import UserStatsSubscriptionConfig
from driftpy.math.perp_position import is_available
from driftpy.math.spot_market import cast_to_spot_precision
from driftpy.math.spot_position import is_spot_position_available
from driftpy.name import encode_name
from driftpy.tx.standard_tx_sender import StandardTxSender
from driftpy.tx.types import TxSender
from driftpy.tx.types import TxSigAndSlot
import requests
from anchorpy import Context, Idl, Program, Provider, Wallet
from deprecated import deprecated
from solana.rpc.async_api import AsyncClient
from solana.rpc.commitment import Processed
from solana.rpc.types import TxOpts
from solana.transaction import AccountMeta
from solders import system_program
from solders.address_lookup_table_account import AddressLookupTableAccount
from solders.compute_budget import set_compute_unit_limit
from solders.compute_budget import set_compute_unit_price
from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price
from solders.instruction import Instruction
from solders.keypair import Keypair
from solders.pubkey import Pubkey
from solders.signature import Signature
from solders.system_program import ID
from solders.system_program import ID as SYS_PROGRAM_ID
from solders.sysvar import RENT
from solders.transaction import Legacy
from solders.transaction import TransactionVersion
from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID
from spl.token.constants import TOKEN_PROGRAM_ID
from spl.token.instructions import close_account
from spl.token.instructions import CloseAccountParams
from spl.token.instructions import get_associated_token_address
from spl.token.instructions import initialize_account
from spl.token.instructions import InitializeAccountParams
from solders.transaction import Legacy, TransactionVersion
from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID
from spl.token.instructions import (
CloseAccountParams,
InitializeAccountParams,
close_account,
get_associated_token_address,
initialize_account,
)

import driftpy
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.accounts import *
from driftpy.accounts.cache import CachedDriftClientAccountSubscriber
from driftpy.accounts.demo import DemoDriftClientAccountSubscriber
from driftpy.address_lookup_table import get_address_lookup_table
from driftpy.addresses import get_sequencer_public_key_and_bump
from driftpy.constants import BASE_PRECISION, PRICE_PRECISION
from driftpy.constants.config import (
DEVNET_SEQUENCER_PROGRAM_ID,
DRIFT_PROGRAM_ID,
SEQUENCER_PROGRAM_ID,
DriftEnv,
configs,
)
from driftpy.constants.numeric_constants import QUOTE_SPOT_MARKET_INDEX
from driftpy.constants.spot_markets import WRAPPED_SOL_MINT
from driftpy.decode.utils import decode_name
from driftpy.drift_user import DriftUser
from driftpy.drift_user_stats import DriftUserStats, UserStatsSubscriptionConfig
from driftpy.math.perp_position import is_available
from driftpy.math.spot_market import cast_to_spot_precision
from driftpy.math.spot_position import is_spot_position_available
from driftpy.name import encode_name
from driftpy.tx.standard_tx_sender import StandardTxSender
from driftpy.tx.types import TxSender, TxSigAndSlot

DEFAULT_USER_NAME = "Main Account"

Expand Down Expand Up @@ -228,8 +224,6 @@ async def subscribe(self):
await self.add_user_stats(self.authority)

def resurrect(self, spot_markets, perp_markets, spot_oracles, perp_oracles):
from driftpy.accounts.cache import CachedDriftClientAccountSubscriber

if not isinstance(self.account_subscriber, CachedDriftClientAccountSubscriber):
raise ValueError(
'You can only resurrect a DriftClient that was initialized with AccountSubscriptionConfig("cached")'
Expand Down Expand Up @@ -346,14 +340,27 @@ def get_quote_spot_market_account(self) -> Optional[SpotMarketAccount]:
return getattr(spot_market_and_slot, "data", None)

def get_oracle_price_data(self, oracle_id: str) -> Optional[OraclePriceData]:
oracle_price_data_and_slot = (
self.account_subscriber.get_oracle_price_data_and_slot(oracle_id)
if self.account_subscriber is None:
return None

data_and_slot = self.account_subscriber.get_oracle_price_data_and_slot(
oracle_id
)
return getattr(oracle_price_data_and_slot, "data", None)

if data_and_slot is None:
return None

return getattr(data_and_slot, "data", None)

def get_oracle_price_data_for_perp_market(
self, market_index: int
) -> Optional[OraclePriceData]:
if self.account_subscriber is None:
raise ValueError("No account subscriber found")

if isinstance(self.account_subscriber, DemoDriftClientAccountSubscriber):
raise ValueError("Cannot get market for demo subscriber")

data = self.account_subscriber.get_oracle_price_data_and_slot_for_perp_market(
market_index
)
Expand All @@ -369,6 +376,11 @@ def get_oracle_price_data_for_perp_market(
def get_oracle_price_data_for_spot_market(
self, market_index: int
) -> Optional[OraclePriceData]:
if self.account_subscriber is None:
return None
if isinstance(self.account_subscriber, DemoDriftClientAccountSubscriber):
raise ValueError("Cannot get market for demo subscriber")

data = self.account_subscriber.get_oracle_price_data_and_slot_for_spot_market(
market_index
)
Expand Down
21 changes: 21 additions & 0 deletions src/driftpy/idl/drift.json
Original file line number Diff line number Diff line change
Expand Up @@ -1857,6 +1857,27 @@
],
"args": []
},
{
"name": "logUserBalances",
"accounts": [
{
"name": "state",
"isMut": false,
"isSigner": false
},
{
"name": "authority",
"isMut": false,
"isSigner": true
},
{
"name": "user",
"isMut": true,
"isSigner": false
}
],
"args": []
},
{
"name": "disableUserHighLeverageMode",
"accounts": [
Expand Down
Loading

0 comments on commit a4cf18b

Please sign in to comment.