Skip to content

Commit

Permalink
Merge pull request #212 from drift-labs/crispheaney/oracle-source-test
Browse files Browse the repository at this point in the history
tests for oracle sources
  • Loading branch information
SinaKhalili authored Dec 16, 2024
2 parents 975fad0 + 4fc0e35 commit 4c98403
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 48 deletions.
6 changes: 6 additions & 0 deletions src/driftpy/accounts/bulk_account_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def add_account(
if existing_size == 0:
self._start_loading()

# If the account is already loaded, call the callback immediately
if existing_account_to_load is not None:
buffer_and_slot = self.buffer_and_slot_map.get(pubkey_str)
if buffer_and_slot is not None and buffer_and_slot.buffer is not None:
self.handle_callbacks(existing_account_to_load, buffer_and_slot.buffer, buffer_and_slot.slot)

return callback_id

def get_callback_id(self) -> int:
Expand Down
67 changes: 37 additions & 30 deletions src/driftpy/accounts/polling/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_state_public_key,
)
from driftpy.constants.config import find_all_market_and_oracles
from driftpy.oracles.oracle_id import get_oracle_id
from driftpy.types import (
OracleInfo,
OraclePriceData,
Expand All @@ -38,6 +39,7 @@ def __init__(
self.program = program
self.is_subscribed = False
self.callbacks: dict[str, int] = {}
self.oracle_callbacks: dict[str, int] = {}

self.perp_market_indexes = perp_market_indexes
self.spot_market_indexes = spot_market_indexes
Expand All @@ -49,7 +51,9 @@ def __init__(
self.spot_markets = {}
self.oracle = {}
self.perp_oracle_map: dict[int, Pubkey] = {}
self.perp_oracle_strings_map: dict[int, str] = {}
self.spot_oracle_map: dict[int, Pubkey] = {}
self.spot_oracle_strings_map: dict[int, str] = {}

async def subscribe(self):
if len(self.callbacks) != 0:
Expand Down Expand Up @@ -138,41 +142,38 @@ def cb(buffer: bytes, slot: int):
return cb

async def add_oracle(self, oracle: Pubkey, oracle_source: OracleSource):
if oracle == Pubkey.default() or oracle in self.oracle:
return True

oracle_str = str(oracle)
if oracle_str in self.callbacks:
oracle_id = get_oracle_id(oracle, oracle_source)
if oracle == Pubkey.default() or oracle_id in self.oracle:
return True

callback_id = self.bulk_account_loader.add_account(
oracle, self._get_oracle_callback(oracle_str, oracle_source)
oracle, self._get_oracle_callback(oracle_id, oracle_source)
)
self.callbacks[oracle_str] = callback_id
self.oracle_callbacks[oracle_id] = callback_id

await self._wait_for_oracle(3, oracle_str)
await self._wait_for_oracle(3, oracle_id)

return True

async def _wait_for_oracle(self, tries: int, oracle: str):
async def _wait_for_oracle(self, tries: int, oracle_id: str):
while tries > 0:
await asyncio.sleep(self.bulk_account_loader.frequency)
if oracle in self.bulk_account_loader.buffer_and_slot_map:
if oracle_id in self.oracle:
return
tries -= 1
print(
f"WARNING: Oracle: {oracle} not found after {tries * self.bulk_account_loader.frequency} seconds, Location: {stack_trace()}"
f"WARNING: Oracle: {oracle_id} not found after {tries * self.bulk_account_loader.frequency} seconds, Location: {stack_trace()}"
)

def _get_oracle_callback(self, oracle_str: str, oracle_source: OracleSource):
def _get_oracle_callback(self, oracle_id: str, oracle_source: OracleSource):
decode = get_oracle_decode_fn(oracle_source)

def cb(buffer: bytes, slot: int):
if buffer is None:
return

decoded_data = decode(buffer)
self.oracle[oracle_str] = DataAndSlot(slot, decoded_data)
self.oracle[oracle_id] = DataAndSlot(slot, decoded_data)

return cb

Expand All @@ -181,7 +182,14 @@ async def unsubscribe(self):
self.bulk_account_loader.remove_account(
Pubkey.from_string(pubkey_str), callback_id
)

for oracle_id, callback_id in self.oracle_callbacks.items():
self.bulk_account_loader.remove_account(
Pubkey.from_string(oracle_id.split("-")[0]), callback_id
)

self.callbacks.clear()
self.oracle_callbacks.clear()

def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.state
Expand All @@ -197,9 +205,9 @@ def get_spot_market_and_slot(
return self.spot_markets.get(market_index)

def get_oracle_price_data_and_slot(
self, oracle: Pubkey
self, oracle_id: str
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.oracle.get(str(oracle))
return self.oracle.get(oracle_id)

def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]:
return [
Expand All @@ -221,53 +229,52 @@ async def _set_perp_oracle_map(self):
market_account = market.data
market_index = market_account.market_index
oracle = market_account.amm.oracle
if oracle not in self.oracle:
await self.add_oracle(oracle, market_account.amm.oracle_source)
oracle_source = market_account.amm.oracle_source
oracle_id = get_oracle_id(oracle, oracle_source)
if oracle_id not in self.oracle:
await self.add_oracle(oracle, oracle_source)
self.perp_oracle_map[market_index] = oracle
self.perp_oracle_strings_map[market_index] = oracle_id

async def _set_spot_oracle_map(self):
spot_markets = self.get_spot_market_accounts_and_slots()
for market in spot_markets:
market_account = market.data
market_index = market_account.market_index
oracle = market_account.oracle
if oracle not in self.oracle:
await self.add_oracle(oracle, market_account.oracle_source)
oracle_source = market_account.oracle_source
oracle_id = get_oracle_id(oracle, oracle_source)
if oracle_id not in self.oracle:
await self.add_oracle(oracle, oracle_source)
self.spot_oracle_map[market_index] = oracle
self.spot_oracle_strings_map[market_index] = oracle_id

def get_oracle_price_data_and_slot_for_perp_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
print(
"==> PollingDriftClientAccountSubscriber: Getting oracle price data for perp market",
market_index,
)
print(self.perp_markets)
print(self.spot_markets)
perp_market_account = self.get_perp_market_and_slot(market_index)
oracle = self.perp_oracle_map.get(market_index)

print("Perp market account: ", perp_market_account)
print("Oracle: ", oracle)
oracle_id = self.perp_oracle_strings_map.get(market_index)

if not perp_market_account or not oracle:
return None

if str(perp_market_account.data.amm.oracle) != str(oracle):
asyncio.create_task(self._set_perp_oracle_map())

return self.get_oracle_price_data_and_slot(oracle)
return self.get_oracle_price_data_and_slot(oracle_id)

def get_oracle_price_data_and_slot_for_spot_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
spot_market_account = self.get_spot_market_and_slot(market_index)
oracle = self.spot_oracle_map.get(market_index)
oracle_id = self.spot_oracle_strings_map.get(market_index)

if not spot_market_account or not oracle:
return None

if str(spot_market_account.data.oracle) != str(oracle):
asyncio.create_task(self._set_spot_oracle_map())

return self.get_oracle_price_data_and_slot(oracle)
return self.get_oracle_price_data_and_slot(oracle_id)
13 changes: 9 additions & 4 deletions src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(
self.spot_market_map = None
self.perp_market_map = None
self.spot_market_oracle_map: dict[int, Pubkey] = {}
self.spot_market_oracle_strings_map: dict[int, str] = {}
self.perp_market_oracle_map: dict[int, Pubkey] = {}
self.perp_market_oracle_strings_map: dict[int, str] = {}

async def subscribe(self):
if self.is_subscribed():
Expand Down Expand Up @@ -178,7 +180,7 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper):

async def subscribe_to_oracle_info(self, oracle_info: OracleInfo):
oracle_id = get_oracle_id(oracle_info.pubkey, oracle_info.source)
if oracle_id == Pubkey.default():
if oracle_info.pubkey == Pubkey.default():
return

if oracle_id in self.oracle_subscribers:
Expand Down Expand Up @@ -221,7 +223,7 @@ async def add_oracle(self, oracle_info: OracleInfo):
if oracle_id in self.oracle_subscribers:
return True

if oracle_id == Pubkey.default():
if oracle_info.pubkey == Pubkey.default():
return True

return await self.subscribe_to_oracle_info(oracle_info)
Expand Down Expand Up @@ -299,6 +301,7 @@ async def _set_perp_oracle_map(self):
OracleInfo(oracle, perp_market_account.amm.oracle_source)
)
self.perp_market_oracle_map[market_index] = oracle
self.perp_market_oracle_strings_map[market_index] = oracle_id

async def _set_spot_oracle_map(self):
spot_markets = self.get_spot_market_accounts_and_slots()
Expand All @@ -315,33 +318,35 @@ async def _set_spot_oracle_map(self):
OracleInfo(oracle, spot_market_account.oracle_source)
)
self.spot_market_oracle_map[market_index] = oracle
self.spot_market_oracle_strings_map[market_index] = oracle_id

def get_oracle_price_data_and_slot_for_perp_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
perp_market_account = self.get_perp_market_and_slot(market_index)
oracle = self.perp_market_oracle_map.get(market_index)
oracle_id = self.perp_market_oracle_strings_map.get(market_index)

if not perp_market_account or not oracle:
return None

if perp_market_account.data.amm.oracle != oracle:
asyncio.create_task(self._set_perp_oracle_map())

oracle_id = get_oracle_id(oracle, perp_market_account.data.amm.oracle_source)
return self.get_oracle_price_data_and_slot(oracle_id)

def get_oracle_price_data_and_slot_for_spot_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
spot_market_account = self.get_spot_market_and_slot(market_index)
oracle = self.spot_market_oracle_map.get(market_index)
oracle_id = self.spot_market_oracle_strings_map.get(market_index)

if not spot_market_account or not oracle:
return None

if spot_market_account.data.oracle != oracle:
asyncio.create_task(self._set_spot_oracle_map())

oracle_id = get_oracle_id(oracle, spot_market_account.data.oracle_source)
return self.get_oracle_price_data_and_slot(oracle_id)

54 changes: 48 additions & 6 deletions src/driftpy/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from driftpy.drift_client import (
DriftClient,
)
from driftpy.constants.numeric_constants import PEG_PRECISION
from driftpy.types import OracleGuardRails, OracleSource, PrelaunchOracleParams
from driftpy.constants.numeric_constants import BASE_PRECISION, PEG_PRECISION, PRICE_PRECISION
from driftpy.types import AssetTier, ContractTier, OracleGuardRails, OracleSource, PrelaunchOracleParams
from driftpy.addresses import *
from driftpy.accounts import get_state_account
from driftpy.constants.numeric_constants import (
Expand Down Expand Up @@ -61,10 +61,24 @@ async def initialize_perp_market(
periodicity: int,
peg_multiplier: int = PEG_PRECISION,
oracle_source: OracleSource = OracleSource.Pyth(),
contract_tier: ContractTier = ContractTier.Speculative(),
margin_ratio_initial: int = 2000,
margin_ratio_maintenance: int = 500,
liquidation_fee: int = 0,
liquidator_fee: int = 0,
if_liquidator_fee: int = 10000,
imf_factor: int = 0,
active_status: bool = True,
base_spread: int = 0,
max_spread: int = 142500,
max_open_interest: int = 0,
max_revenue_withdraw_per_period: int = 0,
quote_max_insurance: int = 0,
order_step_size: int = BASE_PRECISION // 10000,
order_tick_size: int = PRICE_PRECISION // 100000,
min_order_size: int = BASE_PRECISION // 10000,
concentration_coef_scale: int = 1,
curve_update_intensity: int = 0,
amm_jit_intensity: int = 0,
name: list = [0] * 32,
) -> Signature:
state_public_key = get_state_public_key(self.program.program_id)
Expand All @@ -81,10 +95,24 @@ async def initialize_perp_market(
periodicity,
peg_multiplier,
oracle_source,
contract_tier,
margin_ratio_initial,
margin_ratio_maintenance,
liquidation_fee,
liquidator_fee,
if_liquidator_fee,
imf_factor,
active_status,
base_spread,
max_spread,
max_open_interest,
max_revenue_withdraw_per_period,
quote_max_insurance,
order_step_size,
order_tick_size,
min_order_size,
concentration_coef_scale,
curve_update_intensity,
amm_jit_intensity,
name,
ctx=Context(
accounts={
Expand All @@ -111,7 +139,14 @@ async def initialize_spot_market(
initial_liability_weight: int = SPOT_WEIGHT_PRECISION,
maintenance_liability_weight: int = SPOT_WEIGHT_PRECISION,
imf_factor: int = 0,
liquidation_fee: int = 0,
liquidator_fee: int = 0,
if_liquidation_fee: int = 0,
scale_initial_asset_weight_start: int = 0,
withdraw_guard_threshold: int = 0,
order_tick_size: int = 1,
order_step_size: int = 1,
if_total_factor: int = 0,
asset_tier: AssetTier = AssetTier.COLLATERAL(),
active_status: bool = True,
name: list = [0] * 32,
):
Expand All @@ -137,8 +172,15 @@ async def initialize_spot_market(
initial_liability_weight,
maintenance_liability_weight,
imf_factor,
liquidation_fee,
liquidator_fee,
if_liquidation_fee,
active_status,
asset_tier,
scale_initial_asset_weight_start,
withdraw_guard_threshold,
order_tick_size,
order_step_size,
if_total_factor,
name,
ctx=Context(
accounts={
Expand Down
4 changes: 2 additions & 2 deletions src/driftpy/setup/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ async def mock_oracle(


async def initialize_sol_spot_market(
admin: Admin, sol_oracle: Pubkey, sol_mint: Pubkey = NATIVE_MINT
admin: Admin, sol_oracle: Pubkey, sol_mint: Pubkey = NATIVE_MINT, oracle_source: OracleSource = OracleSource.Pyth()
):
optimal_utilization = SPOT_RATE_PRECISION // 2
optimal_rate = SPOT_RATE_PRECISION * 20
Expand All @@ -352,7 +352,7 @@ async def initialize_sol_spot_market(
optimal_rate,
max_rate,
sol_oracle,
OracleSource.Pyth(),
oracle_source,
initial_asset_weight,
maintenance_asset_weight,
initial_liability_weight,
Expand Down
Loading

0 comments on commit 4c98403

Please sign in to comment.