Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

account subscription ws #40

Merged
merged 8 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"):
self.commitment = commitment
self.cache = None

async def subscribe(self):
await self.cache_if_needed()

async def update_cache(self):
if self.cache is None:
self.cache = {}
Expand Down Expand Up @@ -83,7 +86,7 @@ async def get_spot_market_and_slot(
await self.cache_if_needed()
return self.cache["spot_markets"][market_index]

async def get_oracle_data_and_slot(
async def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
await self.cache_if_needed()
Expand All @@ -92,3 +95,6 @@ async def get_oracle_data_and_slot(
async def cache_if_needed(self):
if self.cache is None:
await self.update_cache()

def unsubscribe(self):
self.cache = None
6 changes: 6 additions & 0 deletions src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(
self.user_pubkey = user_pubkey
self.user_and_slot = None

async def subscribe(self):
await self.cache_if_needed()

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot
Expand All @@ -32,3 +35,6 @@ async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
async def cache_if_needed(self):
if self.user_and_slot is None:
await self.update_cache()

def unsubscribe(self):
self.user_and_slot = None
11 changes: 8 additions & 3 deletions src/driftpy/accounts/get_accounts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import base64
from typing import cast
from typing import cast, Optional, Callable
from solders.pubkey import Pubkey
from anchorpy import Program, ProgramAccount
from solana.rpc.commitment import Commitment
Expand All @@ -10,7 +10,10 @@


async def get_account_data_and_slot(
address: Pubkey, program: Program, commitment: Commitment = "processed"
address: Pubkey,
program: Program,
commitment: Commitment = "processed",
decode: Optional[Callable[[bytes], T]] = None,
) -> Optional[DataAndSlot[T]]:
account_info = await program.provider.connection.get_account_info(
address,
Expand All @@ -24,7 +27,9 @@ async def get_account_data_and_slot(
slot = account_info.context.slot
data = account_info.value.data

decoded_data = program.coder.accounts.decode(data)
decoded_data = (
decode(data) if decode is not None else program.coder.accounts.decode(data)
)

return DataAndSlot(slot, decoded_data)

Expand Down
50 changes: 22 additions & 28 deletions src/driftpy/accounts/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from solders.pubkey import Pubkey
from pythclient.pythaccounts import PythPriceInfo, _ACCOUNT_HEADER_BYTES, EmaType
from solana.rpc.async_api import AsyncClient
import base64
import struct


Expand All @@ -21,26 +20,12 @@ async def get_oracle_price_data_and_slot(
if "Pyth" in str(oracle_source):
rpc_reponse = await connection.get_account_info(address)
rpc_response_slot = rpc_reponse.context.slot
(pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info(
rpc_reponse
)

scale = 1
if "1K" in str(oracle_source):
scale = 1e3
elif "1M" in str(oracle_source):
scale = 1e6

oracle_data = OraclePriceData(
price=convert_pyth_price(pyth_price_info.price, scale),
slot=pyth_price_info.pub_slot,
confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale),
twap=convert_pyth_price(twap, scale),
twap_confidence=convert_pyth_price(twac, scale),
has_sufficient_number_of_datapoints=True,
oracle_price_data = decode_pyth_price_info(
rpc_reponse.value.data, oracle_source
)

return DataAndSlot(data=oracle_data, slot=rpc_response_slot)
return DataAndSlot(data=oracle_price_data, slot=rpc_response_slot)
elif "Quote" in str(oracle_source):
return DataAndSlot(
data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0
Expand All @@ -49,11 +34,10 @@ async def get_oracle_price_data_and_slot(
raise NotImplementedError("Unsupported Oracle Source", str(oracle_source))


async def _parse_pyth_price_info(
resp: GetAccountInfoResp,
) -> (PythPriceInfo, int, int, int):
buffer = resp.value.data

def decode_pyth_price_info(
buffer: bytes,
oracle_source=OracleSource.PYTH(),
) -> OraclePriceData:
offset = _ACCOUNT_HEADER_BYTES
_, exponent, _ = struct.unpack_from("<IiI", buffer, offset)

Expand All @@ -73,9 +57,19 @@ async def _parse_pyth_price_info(

offset += 160

return (
PythPriceInfo.deserialise(buffer, offset, exponent=exponent),
last_slot,
twac,
twap,
pyth_price_info = PythPriceInfo.deserialise(buffer, offset, exponent=exponent)

scale = 1
if "1K" in str(oracle_source):
scale = 1e3
elif "1M" in str(oracle_source):
scale = 1e6

return OraclePriceData(
price=convert_pyth_price(pyth_price_info.price, scale),
slot=pyth_price_info.pub_slot,
confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale),
twap=convert_pyth_price(twap, scale),
twap_confidence=convert_pyth_price(twac, scale),
has_sufficient_number_of_datapoints=True,
)
10 changes: 9 additions & 1 deletion src/driftpy/accounts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class DataAndSlot(Generic[T]):


class DriftClientAccountSubscriber:
@abstractmethod
async def subscribe(self):
pass

@abstractmethod
def unsubscribe(self):
pass

@abstractmethod
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
pass
Expand All @@ -40,7 +48,7 @@ async def get_spot_market_and_slot(
pass

@abstractmethod
async def get_oracle_data_and_slot(
async def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
pass
Expand Down
2 changes: 2 additions & 0 deletions src/driftpy/accounts/ws/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .drift_client import *
from .user import *
92 changes: 92 additions & 0 deletions src/driftpy/accounts/ws/account_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
from typing import Optional

from anchorpy import Program
from solders.pubkey import Pubkey
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_account_data_and_slot
from driftpy.accounts import UserAccountSubscriber, DataAndSlot

import websockets
import websockets.exceptions # force eager imports
from solana.rpc.websocket_api import connect

from typing import cast, Generic, TypeVar, Callable

T = TypeVar("T")


class WebsocketAccountSubscriber(UserAccountSubscriber, Generic[T]):
def __init__(
self,
pubkey: Pubkey,
program: Program,
commitment: Commitment = "confirmed",
decode: Optional[Callable[[bytes], T]] = None,
):
self.program = program
self.commitment = commitment
self.pubkey = pubkey
self.data_and_slot = None
self.task = None
self.decode = (
decode if decode is not None else self.program.coder.accounts.decode
)

async def subscribe(self):
if self.data_and_slot is None:
await self.fetch()

self.task = asyncio.create_task(self.subscribe_ws())
return self.task

async def subscribe_ws(self):
ws_endpoint = self.program.provider.connection._provider.endpoint_uri.replace(
"https", "wss"
).replace("http", "ws")
async for ws in connect(ws_endpoint):
try:
await ws.account_subscribe( # type: ignore
self.pubkey,
commitment=self.commitment,
encoding="base64",
)
first_resp = await ws.recv()
subscription_id = cast(int, first_resp[0].result) # type: ignore

async for msg in ws:
try:
slot = int(msg[0].result.context.slot) # type: ignore

if msg[0].result.value is None:
continue

account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore
decoded_data = self.decode(account_bytes)
self._update_data(DataAndSlot(slot, decoded_data))
except Exception:
print(f"Error processing account data")
break
await ws.account_unsubscribe(subscription_id) # type: ignore
except websockets.exceptions.ConnectionClosed:
print("Websocket closed, reconnecting...")
continue

async def fetch(self):
new_data = await get_account_data_and_slot(
self.pubkey, self.program, self.commitment, self.decode
)

self._update_data(new_data)

def _update_data(self, new_data: Optional[DataAndSlot[T]]):
if new_data is None:
return

if self.data_and_slot is None or new_data.slot > self.data_and_slot.slot:
self.data_and_slot = new_data

def unsubscribe(self):
self.task.cancel()
self.task = None
121 changes: 121 additions & 0 deletions src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from anchorpy import Program
from solders.pubkey import Pubkey
from solana.rpc.commitment import Commitment

from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot
from typing import Optional

from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State

from driftpy.addresses import *

from driftpy.types import OracleSource

from driftpy.accounts.oracle import decode_pyth_price_info


class WebsocketDriftClientAccountSubscriber(DriftClientAccountSubscriber):
def __init__(self, program: Program, commitment: Commitment = "confirmed"):
self.program = program
self.commitment = commitment
self.state_subscriber = None
self.spot_market_subscribers = {}
self.perp_market_subscribers = {}
self.oracle_subscribers = {}

async def subscribe(self):
state_public_key = get_state_public_key(self.program.program_id)
self.state_subscriber = WebsocketAccountSubscriber[State](
state_public_key, self.program, self.commitment
)
await self.state_subscriber.subscribe()

for i in range(self.state_subscriber.data_and_slot.data.number_of_spot_markets):
await self.subscribe_to_spot_market(i)

for i in range(self.state_subscriber.data_and_slot.data.number_of_markets):
await self.subscribe_to_perp_market(i)

async def subscribe_to_spot_market(self, market_index: int):
if market_index in self.spot_market_subscribers:
return

spot_market_public_key = get_spot_market_public_key(
self.program.program_id, market_index
)
spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket](
spot_market_public_key, self.program, self.commitment
)
await spot_market_subscriber.subscribe()
self.spot_market_subscribers[market_index] = spot_market_subscriber

spot_market = spot_market_subscriber.data_and_slot.data
await self.subscribe_to_oracle(spot_market.oracle, spot_market.oracle_source)

async def subscribe_to_perp_market(self, market_index: int):
if market_index in self.perp_market_subscribers:
return

perp_market_public_key = get_perp_market_public_key(
self.program.program_id, market_index
)
perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket](
perp_market_public_key, self.program, self.commitment
)
await perp_market_subscriber.subscribe()
self.perp_market_subscribers[market_index] = perp_market_subscriber

perp_market = perp_market_subscriber.data_and_slot.data
await self.subscribe_to_oracle(
perp_market.amm.oracle, perp_market.amm.oracle_source
)

async def subscribe_to_oracle(self, oracle: Pubkey, oracle_source: OracleSource):
if oracle == Pubkey.default():
return

if str(oracle) in self.oracle_subscribers:
return

oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData](
oracle,
self.program,
self.commitment,
self._get_oracle_decode_fn(oracle_source),
)
await oracle_subscriber.subscribe()
self.oracle_subscribers[str(oracle)] = oracle_subscriber

def _get_oracle_decode_fn(self, oracle_source: OracleSource):
if "Pyth" in str(oracle_source):
return lambda data: decode_pyth_price_info(data, oracle_source)
else:
raise Exception("Unknown oracle source")

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
return self.state_subscriber.data_and_slot

async def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarket]]:
return self.perp_market_subscribers[market_index].data_and_slot

async def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarket]]:
return self.spot_market_subscribers[market_index].data_and_slot

async def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.oracle_subscribers[str(oracle)].data_and_slot

def unsubscribe(self):
self.state_subscriber.unsubscribe()
for spot_market_subscriber in self.spot_market_subscribers.values():
spot_market_subscriber.unsubscribe()
for perp_market_subscriber in self.perp_market_subscribers.values():
perp_market_subscriber.unsubscribe()
for oracle_subscriber in self.oracle_subscribers.values():
oracle_subscriber.unsubscribe()
Loading
Loading