Skip to content

Commit

Permalink
add ws and polling account subscribers
Browse files Browse the repository at this point in the history
bulk account loader
  • Loading branch information
crispheaney authored Nov 22, 2023
2 parents 971ac53 + ef800e1 commit 8c85414
Show file tree
Hide file tree
Showing 18 changed files with 780 additions and 83 deletions.
171 changes: 171 additions & 0 deletions src/driftpy/accounts/bulk_account_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import asyncio
from typing import Mapping, Callable, List, Optional
from dataclasses import dataclass
import jsonrpcclient
from base64 import b64decode

from solana.rpc.commitment import Commitment
from solana.rpc.async_api import AsyncClient
from solders.pubkey import Pubkey


@dataclass
class AccountToLoad:
pubkey: Pubkey
callbacks: dict[int, Callable[[bytes, int], None]]


@dataclass
class BufferAndSlot:
slot: int
buffer: Optional[bytes]


GET_MULTIPLE_ACCOUNTS_CHUNK_SIZE = 99


class BulkAccountLoader:
def __init__(
self,
connection: AsyncClient,
commitment: Commitment = "confirmed",
frequency: float = 1,
):
self.connection = connection
self.commitment = commitment
self.frequency = frequency
self.task = None
self.load_task = None
self.callback_id = 0
self.accounts_to_load: dict[str, AccountToLoad] = {}
self.buffer_and_slot_map: dict[str, BufferAndSlot] = {}

def add_account(
self, pubkey: Pubkey, callback: Callable[[bytes, int], None]
) -> int:
existing_size = len(self.accounts_to_load)

callback_id = self.get_callback_id()

pubkey_str = str(pubkey)
existing_account_to_load = self.accounts_to_load.get(pubkey_str)
if existing_account_to_load is not None:
existing_account_to_load.callbacks[callback_id] = callback
else:
callbacks = {}
callbacks[callback_id] = callback
self.accounts_to_load[pubkey_str] = AccountToLoad(pubkey, callbacks)

if existing_size == 0:
self._start_loading()

return callback_id

def get_callback_id(self) -> int:
self.callback_id += 1
return self.callback_id

def _start_loading(self):
if self.task is None:

async def loop():
while True:
await self.load()
await asyncio.sleep(self.frequency)

self.task = asyncio.create_task(loop())

def remove_account(self, pubkey: Pubkey, callback_id: int):
pubkey_str = str(pubkey)
existing_account_to_load = self.accounts_to_load.get(pubkey_str)
if existing_account_to_load is not None:
del existing_account_to_load.callbacks[callback_id]
if len(existing_account_to_load.callbacks) == 0:
del self.accounts_to_load[pubkey_str]

if len(self.accounts_to_load) == 0:
self._stop_loading()

def _stop_loading(self):
if self.task is not None:
self.task.cancel()
self.task = None

def chunks(self, array: List, size: int) -> List[List]:
return [array[i : i + size] for i in range(0, len(array), size)]

async def load(self):
chunks = self.chunks(
self.chunks(
list(self.accounts_to_load.values()),
GET_MULTIPLE_ACCOUNTS_CHUNK_SIZE,
),
10,
)

await asyncio.gather(*[self.load_chunk(chunk) for chunk in chunks])

async def load_chunk(self, chunk: List[List[AccountToLoad]]):
if len(chunk) == 0:
return

rpc_requests = []
for accounts_to_load in chunk:
pubkeys_to_send = [
str(accounts_to_load.pubkey) for accounts_to_load in accounts_to_load
]
rpc_request = jsonrpcclient.request(
"getMultipleAccounts",
params=[
pubkeys_to_send,
{"encoding": "base64", "commitment": self.commitment},
],
)
rpc_requests.append(rpc_request)

try:
post = self.connection._provider.session.post(
self.connection._provider.endpoint_uri,
json=rpc_requests,
headers={"content-encoding": "gzip"},
)
resp = await asyncio.wait_for(post, timeout=10)
except asyncio.TimeoutError:
print("request to rpc timed out")
return

parsed_resp = jsonrpcclient.parse(resp.json())

for rpc_result, chunk_accounts in zip(parsed_resp, chunk):
if isinstance(rpc_result, jsonrpcclient.Error):
print(f"Failed to get info about accounts: {rpc_result.message}")
continue

slot = rpc_result.result["context"]["slot"]

for i, account_to_load in enumerate(chunk_accounts):
pubkey_str = str(account_to_load.pubkey)
old_buffer_and_slot = self.buffer_and_slot_map.get(pubkey_str)

if old_buffer_and_slot is not None and slot <= old_buffer_and_slot.slot:
continue

new_buffer = None
if rpc_result.result["value"][i] is not None:
new_buffer = b64decode(rpc_result.result["value"][i]["data"][0])

if (
old_buffer_and_slot is None
or new_buffer != old_buffer_and_slot.buffer
):
self.handle_callbacks(account_to_load, new_buffer, slot)
self.buffer_and_slot_map[pubkey_str] = BufferAndSlot(
slot, new_buffer
)

def handle_callbacks(
self, account_to_load: AccountToLoad, buffer: Optional[bytes], slot: int
):
for cb in account_to_load.callbacks.values():
if bytes is not None:
cb(buffer, slot)
8 changes: 7 additions & 1 deletion src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"):
self.commitment = commitment
self.cache = None

async def subscribe(self):
await self.cache_if_needed()

async def update_cache(self):
if self.cache is None:
self.cache = {}
Expand Down Expand Up @@ -83,7 +86,7 @@ async def get_spot_market_and_slot(
await self.cache_if_needed()
return self.cache["spot_markets"][market_index]

async def get_oracle_data_and_slot(
async def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
await self.cache_if_needed()
Expand All @@ -92,3 +95,6 @@ async def get_oracle_data_and_slot(
async def cache_if_needed(self):
if self.cache is None:
await self.update_cache()

def unsubscribe(self):
self.cache = None
6 changes: 6 additions & 0 deletions src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(
self.user_pubkey = user_pubkey
self.user_and_slot = None

async def subscribe(self):
await self.cache_if_needed()

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot
Expand All @@ -32,3 +35,6 @@ async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
async def cache_if_needed(self):
if self.user_and_slot is None:
await self.update_cache()

def unsubscribe(self):
self.user_and_slot = None
11 changes: 8 additions & 3 deletions src/driftpy/accounts/get_accounts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import base64
from typing import cast
from typing import cast, Optional, Callable
from solders.pubkey import Pubkey
from anchorpy import Program, ProgramAccount
from solana.rpc.commitment import Commitment
Expand All @@ -10,7 +10,10 @@


async def get_account_data_and_slot(
address: Pubkey, program: Program, commitment: Commitment = "processed"
address: Pubkey,
program: Program,
commitment: Commitment = "processed",
decode: Optional[Callable[[bytes], T]] = None,
) -> Optional[DataAndSlot[T]]:
account_info = await program.provider.connection.get_account_info(
address,
Expand All @@ -24,7 +27,9 @@ async def get_account_data_and_slot(
slot = account_info.context.slot
data = account_info.value.data

decoded_data = program.coder.accounts.decode(data)
decoded_data = (
decode(data) if decode is not None else program.coder.accounts.decode(data)
)

return DataAndSlot(slot, decoded_data)

Expand Down
57 changes: 29 additions & 28 deletions src/driftpy/accounts/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from solders.pubkey import Pubkey
from pythclient.pythaccounts import PythPriceInfo, _ACCOUNT_HEADER_BYTES, EmaType
from solana.rpc.async_api import AsyncClient
import base64
import struct


Expand All @@ -21,26 +20,12 @@ async def get_oracle_price_data_and_slot(
if "Pyth" in str(oracle_source):
rpc_reponse = await connection.get_account_info(address)
rpc_response_slot = rpc_reponse.context.slot
(pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info(
rpc_reponse
)

scale = 1
if "1K" in str(oracle_source):
scale = 1e3
elif "1M" in str(oracle_source):
scale = 1e6

oracle_data = OraclePriceData(
price=convert_pyth_price(pyth_price_info.price, scale),
slot=pyth_price_info.pub_slot,
confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale),
twap=convert_pyth_price(twap, scale),
twap_confidence=convert_pyth_price(twac, scale),
has_sufficient_number_of_datapoints=True,
oracle_price_data = decode_pyth_price_info(
rpc_reponse.value.data, oracle_source
)

return DataAndSlot(data=oracle_data, slot=rpc_response_slot)
return DataAndSlot(data=oracle_price_data, slot=rpc_response_slot)
elif "Quote" in str(oracle_source):
return DataAndSlot(
data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0
Expand All @@ -49,11 +34,10 @@ async def get_oracle_price_data_and_slot(
raise NotImplementedError("Unsupported Oracle Source", str(oracle_source))


async def _parse_pyth_price_info(
resp: GetAccountInfoResp,
) -> (PythPriceInfo, int, int, int):
buffer = resp.value.data

def decode_pyth_price_info(
buffer: bytes,
oracle_source=OracleSource.PYTH(),
) -> OraclePriceData:
offset = _ACCOUNT_HEADER_BYTES
_, exponent, _ = struct.unpack_from("<IiI", buffer, offset)

Expand All @@ -73,9 +57,26 @@ async def _parse_pyth_price_info(

offset += 160

return (
PythPriceInfo.deserialise(buffer, offset, exponent=exponent),
last_slot,
twac,
twap,
pyth_price_info = PythPriceInfo.deserialise(buffer, offset, exponent=exponent)

scale = 1
if "1K" in str(oracle_source):
scale = 1e3
elif "1M" in str(oracle_source):
scale = 1e6

return OraclePriceData(
price=convert_pyth_price(pyth_price_info.price, scale),
slot=pyth_price_info.pub_slot,
confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale),
twap=convert_pyth_price(twap, scale),
twap_confidence=convert_pyth_price(twac, scale),
has_sufficient_number_of_datapoints=True,
)


def get_oracle_decode_fn(oracle_source: OracleSource):
if "Pyth" in str(oracle_source):
return lambda data: decode_pyth_price_info(data, oracle_source)
else:
raise Exception("Unknown oracle source")
2 changes: 2 additions & 0 deletions src/driftpy/accounts/polling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .drift_client import *
from .user import *
Loading

0 comments on commit 8c85414

Please sign in to comment.