Skip to content

Commit

Permalink
vat (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
soundsonacid authored Jun 14, 2024
1 parent 84ad89b commit b46ac92
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 6 deletions.
81 changes: 79 additions & 2 deletions src/driftpy/market_map/market_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
import pickle
import traceback

from typing import Dict, Optional, Union
Expand All @@ -10,8 +11,15 @@

from driftpy.market_map.market_map_config import MarketMapConfig
from driftpy.market_map.websocket_sub import WebsocketSubscription
from driftpy.memcmp import get_market_type_filter
from driftpy.types import PerpMarketAccount, SpotMarketAccount, is_variant
from driftpy.types import (
PerpMarketAccount,
PickledData,
SpotMarketAccount,
is_variant,
compress,
decompress,
market_type_to_string,
)

GenericMarketType = Union[SpotMarketAccount, PerpMarketAccount]

Expand Down Expand Up @@ -87,3 +95,72 @@ async def update_market(
) -> None:
await self.must_get(data.data.market_index, data)
self.market_map[data.data.market_index] = data

def clear(self):
self.market_map.clear()

def get_last_dump_filepath(self) -> str:
return f"{market_type_to_string(self.market_type)}_{self.latest_slot}.pkl"

async def dump(self):
try:
filters = []
if is_variant(self.market_type, "Perp"):
filters.append({"memcmp": {"offset": 0, "bytes": "2pTyMkwXuti"}})
else:
filters.append({"memcmp": {"offset": 0, "bytes": "HqqNdyfVbzv"}})

rpc_request = jsonrpcclient.request(
"getProgramAccounts",
[
str(self.program.program_id),
{"filters": filters, "encoding": "base64", "withContext": True},
],
)

post = self.connection._provider.session.post(
self.connection._provider.endpoint_uri,
json=rpc_request,
headers={"content-encoding": "gzip"},
)

resp = await asyncio.wait_for(post, timeout=30)

parsed_resp = jsonrpcclient.parse(resp.json())

slot = int(parsed_resp.result["context"]["slot"])

self.latest_slot = slot

rpc_response_values = parsed_resp.result["value"]

raw: Dict[str, bytes] = {}

for market in rpc_response_values:
pubkey = market["pubkey"]
raw_bytes = base64.b64decode(market["account"]["data"][0])
raw[str(pubkey)] = raw_bytes

markets = []
for pubkey, market in raw.items():
markets.append(PickledData(pubkey=pubkey, data=compress(market)))
filename = (
f"{market_type_to_string(self.market_type)}_{self.latest_slot}.pkl"
)
with open(filename, "wb") as f:
pickle.dump(markets, f)

except Exception as e:
print(f"error in marketmap pickle: {e}")

async def load(self, filename: Optional[str] = None):
if not filename:
filename = self.get_last_dump_filepath()
start = filename.index("_") + 1
end = filename.index(".")
slot = int(filename[start:end])
with open(filename, "rb") as f:
markets: list[PickledData] = pickle.load(f)
for market in markets:
data = self.program.coder.accounts.decode(decompress(market.data))
await self.add_market(data.market_index, DataAndSlot(slot, data))
89 changes: 89 additions & 0 deletions src/driftpy/pickle/vat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pickle
from driftpy.drift_client import DriftClient
from driftpy.market_map.market_map import MarketMap
from driftpy.types import OraclePriceData, PickledData
from driftpy.user_map.user_map import UserMap
from driftpy.user_map.userstats_map import UserStatsMap


# generally this is not intended for use with websocket drift client subscriber
class Vat:
def __init__(
self,
drift_client: DriftClient,
users: UserMap,
user_stats: UserStatsMap,
spot_markets: MarketMap,
perp_markets: MarketMap,
):
self.drift_client = drift_client
self.users = users
self.user_stats = user_stats
self.spot_markets = spot_markets
self.perp_markets = perp_markets
self.last_oracle_slot = 0
self.market_index_to_perp_price = {}
self.market_index_to_spot_price = {}

async def pickle(self):
await self.users.sync()
self.users.dump()

await self.user_stats.sync()
self.user_stats.dump()

await self.spot_markets.dump()
await self.perp_markets.dump()

await self.dump_oracles()

async def unpickle(self):
self.users.clear()
self.user_stats.clear()
self.spot_markets.clear()
self.perp_markets.clear()

await self.users.load()
await self.user_stats.load()
await self.spot_markets.load()
await self.perp_markets.load()

self.load_oracles()

async def dump_oracles(self):
perp_oracles = []
for market in self.drift_client.get_perp_market_accounts():
oracle_price_data = self.drift_client.get_oracle_price_data_for_perp_market(
market.market_index
)
perp_oracles.append(
PickledData(pubkey=market.market_index, data=oracle_price_data)
)

spot_oracles = []
for market in self.drift_client.get_spot_market_accounts():
oracle_price = self.drift_client.get_oracle_price_data_for_spot_market(
market.market_index
)
spot_oracles.append(
PickledData(pubkey=market.market_index, data=oracle_price)
)

self.last_oracle_slot = await self.drift_client.connection.get_slot()

with open(f"perporacles_{self.last_oracle_slot}.pkl", "wb") as f:
pickle.dump(perp_oracles, f)

with open(f"spotoracles_{self.last_oracle_slot}.pkl", "wb") as f:
pickle.dump(spot_oracles, f)

def load_oracles(self):
with open(f"perporacles_{self.last_oracle_slot}.pkl", "rb") as f:
perp_oracles: list[PickledData] = pickle.load(f)
for oracle in perp_oracles:
self.market_index_to_perp_price[oracle.pubkey] = oracle.data

with open(f"spotoracles_{self.last_oracle_slot}.pkl", "rb") as f:
spot_oracles: list[PickledData] = pickle.load(f)
for oracle in spot_oracles:
self.market_index_to_spot_price[oracle.pubkey] = oracle.data
3 changes: 2 additions & 1 deletion src/driftpy/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import struct
import zlib
import inspect

Expand Down Expand Up @@ -855,7 +856,7 @@ class UserAccount:


@dataclass
class PickledUser:
class PickledData:
pubkey: Pubkey
data: bytes

Expand Down
6 changes: 3 additions & 3 deletions src/driftpy/user_map/user_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from driftpy.drift_user import DriftUser
from driftpy.account_subscription_config import AccountSubscriptionConfig

from driftpy.types import OrderRecord, PickledUser, UserAccount, compress, decompress
from driftpy.types import OrderRecord, PickledData, UserAccount, compress, decompress

from driftpy.user_map.user_map_config import UserMapConfig, PollingConfig
from driftpy.user_map.websocket_sub import WebsocketSubscription
Expand Down Expand Up @@ -236,15 +236,15 @@ async def load(self, filename: Optional[str] = None):
end = filename.index(".")
slot = int(filename[start:end])
with open(filename, "rb") as f:
users: list[PickledUser] = pickle.load(f)
users: list[PickledData] = pickle.load(f)
for user in users:
data = decode_user(decompress(user.data))
await self.add_pubkey(user.pubkey, DataAndSlot(slot, data))

def dump(self):
users = []
for pubkey, user in self.raw.items():
users.append(PickledUser(pubkey=pubkey, data=compress(user)))
users.append(PickledData(pubkey=pubkey, data=compress(user)))
self.last_dumped_slot = self.get_slot()
filename = f"usermap_{self.last_dumped_slot}.pkl"
with open(filename, "wb") as f:
Expand Down
45 changes: 45 additions & 0 deletions src/driftpy/user_map/userstats_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
import pickle
import traceback

from typing import Dict, Optional
Expand All @@ -18,10 +19,13 @@
LPRecord,
FundingPaymentRecord,
LiquidationRecord,
PickledData,
SettlePnlRecord,
OrderRecord,
OrderActionRecord,
UserStatsAccount,
compress,
decompress,
)
from driftpy.events.types import WrappedEvent
from driftpy.user_map.user_map_config import UserStatsMapConfig
Expand All @@ -33,8 +37,10 @@ def __init__(self, config: UserStatsMapConfig):
self.user_stats_map: Dict[str, DriftUserStats] = {}

self.sync_lock = asyncio.Lock()
self.raw: Dict[str, bytes] = {}
self.drift_client = config.drift_client
self.latest_slot: int = 0
self.last_dumped_slot: int = 0
self.connection = config.connection or config.drift_client.connection

async def subscribe(self):
Expand Down Expand Up @@ -82,12 +88,16 @@ async def sync(self):
rpc_response_values = parsed_resp.result["value"]

program_account_buffer_map: Dict[str, UserStatsAccount] = {}
raw: Dict[str, bytes] = {}

for program_account in rpc_response_values:
pubkey = program_account["pubkey"]
buffer = base64.b64decode(program_account["account"]["data"][0])
data = self.drift_client.program.coder.accounts.decode(buffer)
program_account_buffer_map[str(pubkey)] = data
raw[str(pubkey)] = buffer

self.raw = raw

for pubkey in program_account_buffer_map.keys():
data = program_account_buffer_map.get(pubkey)
Expand Down Expand Up @@ -231,3 +241,38 @@ async def must_get(
if not self.has(pubkey):
await self.add_user_stat(Pubkey.from_string(pubkey), user_stats)
return self.get(pubkey)

def get_last_dump_filepath(self) -> str:
return f"userstats_{self.last_dumped_slot}.pkl"

def clear(self):
self.user_stats_map.clear()

async def load(self, filename: Optional[str] = None):
if not filename:
filename = self.get_last_dump_filepath()
start = filename.index("_") + 1
end = filename.index(".")
slot = int(filename[start:end])
with open(filename, "rb") as f:
user_stats: list[PickledData] = pickle.load(f)
for user_stat in user_stats:
data = self.drift_client.program.coder.accounts.decode(
decompress(user_stat.data)
)
await self.add_user_stat(
Pubkey.from_string(str(user_stat.pubkey)), DataAndSlot(slot, data)
)

def dump(self):
user_stats = []
for _pubkey, user_stat in self.raw.items():
decoded: UserStatsAccount = self.drift_client.program.coder.accounts.decode(
user_stat
)
auth = decoded.authority
user_stats.append(PickledData(pubkey=auth, data=compress(user_stat)))
self.last_dumped_slot = self.latest_slot
filename = f"userstats_{self.last_dumped_slot}.pkl"
with open(filename, "wb") as f:
pickle.dump(user_stats, f, pickle.HIGHEST_PROTOCOL)

0 comments on commit b46ac92

Please sign in to comment.