diff --git a/.bumpversion.cfg b/.bumpversion.cfg index e3e14e89..23412676 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.7.58 +current_version = 0.7.59 commit = True tag = True tag_name = {new_version} diff --git a/pyproject.toml b/pyproject.toml index 9653ab9b..13c5b1bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "driftpy" -version = "0.7.58" +version = "0.7.59" description = "A Python client for the Drift DEX" authors = ["x19 ", "bigz ", "frank "] license = "MIT" diff --git a/src/driftpy/__init__.py b/src/driftpy/__init__.py index 4411885f..ad8b4551 100644 --- a/src/driftpy/__init__.py +++ b/src/driftpy/__init__.py @@ -1 +1 @@ -__version__ = "0.7.58" +__version__ = "0.7.59" diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index 028a1af8..41761f0d 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -169,12 +169,12 @@ def resurrect( for market_index, oracle_price_data in spot_oracles.items(): corresponding_market = self.cache["spot_markets"][market_index] - oracle_pubkey = corresponding_market.oracle + oracle_pubkey = corresponding_market.data.oracle self.cache["oracle_price_data"][str(oracle_pubkey)] = oracle_price_data for market_index, oracle_price_data in perp_oracles.items(): corresponding_market = self.cache["perp_markets"][market_index] - oracle_pubkey = corresponding_market.amm.oracle + oracle_pubkey = corresponding_market.data.amm.oracle self.cache["oracle_price_data"][str(oracle_pubkey)] = oracle_price_data def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index c24296a9..54624a2f 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -296,11 +296,14 @@ def get_oracle_price_data_for_perp_market( data = self.account_subscriber.get_oracle_price_data_and_slot_for_perp_market( market_index ) - return getattr( - data, - "data", - None, - ) + if isinstance(data, DataAndSlot): + return getattr( + data, + "data", + None, + ) + + return data def get_oracle_price_data_for_spot_market( self, market_index: int @@ -308,11 +311,14 @@ def get_oracle_price_data_for_spot_market( data = self.account_subscriber.get_oracle_price_data_and_slot_for_spot_market( market_index ) - return getattr( - data, - "data", - None, - ) + if isinstance(data, DataAndSlot): + return getattr( + data, + "data", + None, + ) + + return data def convert_to_spot_precision(self, amount: Union[int, float], market_index) -> int: spot_market = self.get_spot_market_account(market_index) diff --git a/src/driftpy/market_map/market_map.py b/src/driftpy/market_map/market_map.py index 4472b14d..fc5a4708 100644 --- a/src/driftpy/market_map/market_map.py +++ b/src/driftpy/market_map/market_map.py @@ -1,5 +1,6 @@ import asyncio import base64 +import os import pickle import traceback @@ -156,6 +157,8 @@ async def dump(self): async def load(self, filename: Optional[str] = None): if not filename: filename = self.get_last_dump_filepath() + if not os.path.exists(filename): + raise FileNotFoundError(f"File {filename} not found") start = filename.index("_") + 1 end = filename.index(".") slot = int(filename[start:end]) diff --git a/src/driftpy/pickle/vat.py b/src/driftpy/pickle/vat.py index 1e32aafb..d424c033 100644 --- a/src/driftpy/pickle/vat.py +++ b/src/driftpy/pickle/vat.py @@ -1,4 +1,6 @@ import pickle +import os +from typing import Optional from driftpy.drift_client import DriftClient from driftpy.market_map.market_map import MarketMap from driftpy.types import PickledData @@ -37,21 +39,30 @@ async def pickle(self): await self.dump_oracles() - async def unpickle(self): + async def unpickle( + self, + users_filename: Optional[str] = None, + user_stats_filename: Optional[str] = None, + spot_markets_filename: Optional[str] = None, + perp_markets_filename: Optional[str] = None, + spot_oracles_filename: Optional[str] = None, + perp_oracles_filename: Optional[str] = None, + ): self.users.clear() self.user_stats.clear() self.spot_markets.clear() self.perp_markets.clear() - await self.users.load() - await self.user_stats.load() - await self.spot_markets.load() - await self.perp_markets.load() + await self.users.load(users_filename) + await self.user_stats.load(user_stats_filename) + await self.spot_markets.load(spot_markets_filename) + await self.perp_markets.load(perp_markets_filename) + + self.load_oracles(spot_oracles_filename, perp_oracles_filename) self.drift_client.resurrect( self.spot_markets, self.perp_markets, self.spot_oracles, self.perp_oracles ) - self.load_oracles() async def dump_oracles(self): perp_oracles = [] @@ -80,13 +91,30 @@ async def dump_oracles(self): with open(f"spotoracles_{self.last_oracle_slot}.pkl", "wb") as f: pickle.dump(spot_oracles, f) - def load_oracles(self): - with open(f"perporacles_{self.last_oracle_slot}.pkl", "rb") as f: - perp_oracles: list[PickledData] = pickle.load(f) - for oracle in perp_oracles: - self.perp_oracles[oracle.pubkey] = oracle.data - - with open(f"spotoracles_{self.last_oracle_slot}.pkl", "rb") as f: - spot_oracles: list[PickledData] = pickle.load(f) - for oracle in spot_oracles: - self.spot_oracles[oracle.pubkey] = oracle.data + def load_oracles( + self, spot_filename: Optional[str] = None, perp_filename: Optional[str] = None + ): + if perp_filename is None: + perp_filename = f"perporacles_{self.last_oracle_slot}.pkl" + if spot_filename is None: + spot_filename = f"spotoracles_{self.last_oracle_slot}.pkl" + + if os.path.exists(perp_filename): + with open(perp_filename, "rb") as f: + perp_oracles: list[PickledData] = pickle.load(f) + for oracle in perp_oracles: + self.perp_oracles[ + oracle.pubkey + ] = oracle.data # oracle.pubkey is actually a market index + else: + raise FileNotFoundError(f"File {perp_filename} not found") + + if os.path.exists(spot_filename): + with open(spot_filename, "rb") as f: + spot_oracles: list[PickledData] = pickle.load(f) + for oracle in spot_oracles: + self.spot_oracles[ + oracle.pubkey + ] = oracle.data # oracle.pubkey is actually a market index + else: + raise FileNotFoundError(f"File {spot_filename} not found") diff --git a/src/driftpy/user_map/user_map.py b/src/driftpy/user_map/user_map.py index 55b12c76..2ec80c17 100644 --- a/src/driftpy/user_map/user_map.py +++ b/src/driftpy/user_map/user_map.py @@ -1,4 +1,5 @@ import asyncio +import os import jsonrpcclient import pickle import base64 @@ -232,6 +233,8 @@ def get_last_dump_filepath(self) -> str: async def load(self, filename: Optional[str] = None): if not filename: filename = self.get_last_dump_filepath() + if not os.path.exists(filename): + raise FileNotFoundError(f"File {filename} not found") start = filename.index("_") + 1 end = filename.index(".") slot = int(filename[start:end]) diff --git a/src/driftpy/user_map/userstats_map.py b/src/driftpy/user_map/userstats_map.py index 993cff67..56ef0320 100644 --- a/src/driftpy/user_map/userstats_map.py +++ b/src/driftpy/user_map/userstats_map.py @@ -1,5 +1,6 @@ import asyncio import base64 +import os import pickle import traceback @@ -252,6 +253,8 @@ def clear(self): async def load(self, filename: Optional[str] = None): if not filename: filename = self.get_last_dump_filepath() + if not os.path.exists(filename): + raise FileNotFoundError(f"File {filename} not found") start = filename.index("_") + 1 end = filename.index(".") slot = int(filename[start:end])