Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add account subscription interface #36

Merged
merged 9 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
.tmp
.DS_Store
node.txt
accounts/
keypairs/
test-ledger/

Expand Down Expand Up @@ -157,4 +156,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

scratch
6 changes: 3 additions & 3 deletions examples/limit_order_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from driftpy.types import *
#MarketType, OrderType, OrderParams, PositionDirection, OrderTriggerCondition
from driftpy.accounts import get_perp_market_account, get_spot_market_account
from driftpy.math.oracle import get_oracle_data
from driftpy.accounts.oracle import get_oracle_price_data_and_slot
from driftpy.math.spot_market import get_signed_token_amount, get_token_amount
from driftpy.drift_client import DriftClient
from driftpy.drift_user import DriftUser
Expand Down Expand Up @@ -118,7 +118,7 @@ async def main(
drift_acct.program, market_index
)
try:
oracle_data = await get_oracle_data(connection, market.amm.oracle)
oracle_data = await get_oracle_price_data_and_slot(connection, market.amm.oracle)
current_price = oracle_data.price/PRICE_PRECISION
except:
current_price = market.amm.historical_oracle_data.last_oracle_price/PRICE_PRECISION
Expand All @@ -132,7 +132,7 @@ async def main(
else:
market = await get_spot_market_account( drift_acct.program, market_index)
try:
oracle_data = await get_oracle_data(connection, market.oracle)
oracle_data = await get_oracle_price_data_and_slot(connection, market.oracle)
current_price = oracle_data.price/PRICE_PRECISION
except:
current_price = market.historical_oracle_data.last_oracle_price/PRICE_PRECISION
Expand Down
70 changes: 0 additions & 70 deletions src/driftpy/accounts.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/driftpy/accounts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .get_accounts import *
from .types import *
2 changes: 2 additions & 0 deletions src/driftpy/accounts/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .drift_client import *
from .user import *
78 changes: 78 additions & 0 deletions src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from anchorpy import Program
from solana.publickey import PublicKey
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_state_account_and_slot, get_spot_market_account_and_slot, \
get_perp_market_account_and_slot
from driftpy.accounts.oracle import get_oracle_price_data_and_slot
from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot
from typing import Optional

from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State


class CachedDriftClientAccountSubscriber(DriftClientAccountSubscriber):
def __init__(self, program: Program, commitment: Commitment = "confirmed"):
self.program = program
self.commitment = commitment
self.cache = None

async def update_cache(self):
if self.cache is None:
self.cache = {}

state_and_slot = await get_state_account_and_slot(self.program)
self.cache["state"] = state_and_slot

oracle_data = {}

spot_markets = []
for i in range(state_and_slot.data.number_of_spot_markets):
spot_market_and_slot = await get_spot_market_account_and_slot(self.program, i)
spot_markets.append(spot_market_and_slot)

oracle_price_data_and_slot = await get_oracle_price_data_and_slot(
self.program.provider.connection,
spot_market_and_slot.data.oracle,
spot_market_and_slot.data.oracle_source

)
oracle_data[str(spot_market_and_slot.data.oracle)] = oracle_price_data_and_slot

self.cache["spot_markets"] = spot_markets

perp_markets = []
for i in range(state_and_slot.data.number_of_markets):
perp_market_and_slot = await get_perp_market_account_and_slot(self.program, i)
perp_markets.append(perp_market_and_slot)

oracle_price_data_and_slot = await get_oracle_price_data_and_slot(
self.program.provider.connection,
perp_market_and_slot.data.amm.oracle,
perp_market_and_slot.data.amm.oracle_source
)
oracle_data[str(perp_market_and_slot.data.amm.oracle)] = oracle_price_data_and_slot

self.cache["perp_markets"] = perp_markets

self.cache["oracle_price_data"] = oracle_data

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
await self.cache_if_needed()
return self.cache["state"]

async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]:
await self.cache_if_needed()
return self.cache["perp_markets"][market_index]

async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]:
await self.cache_if_needed()
return self.cache["spot_markets"][market_index]

async def get_oracle_data_and_slot(self, oracle: PublicKey) -> Optional[DataAndSlot[OraclePriceData]]:
await self.cache_if_needed()
return self.cache["oracle_price_data"][str(oracle)]

async def cache_if_needed(self):
if self.cache is None:
await self.update_cache()
29 changes: 29 additions & 0 deletions src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional

from anchorpy import Program
from solana.publickey import PublicKey
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_user_account_and_slot
from driftpy.accounts import UserAccountSubscriber, DataAndSlot
from driftpy.types import User


class CachedUserAccountSubscriber(UserAccountSubscriber):
def __init__(self, user_pubkey: PublicKey, program: Program, commitment: Commitment = "confirmed"):
self.program = program
self.commitment = commitment
self.user_pubkey = user_pubkey
self.user_and_slot = None

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
await self.cache_if_needed()
return self.user_and_slot

async def cache_if_needed(self):
if self.user_and_slot is None:
await self.update_cache()
103 changes: 103 additions & 0 deletions src/driftpy/accounts/get_accounts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import base64
from typing import cast
from solana.publickey import PublicKey
from anchorpy import Program, ProgramAccount
from solana.rpc.commitment import Commitment

from driftpy.types import *
from driftpy.addresses import *
from .types import DataAndSlot, T


async def get_account_data_and_slot(address: PublicKey, program: Program, commitment: Commitment = "processed") -> Optional[
DataAndSlot[T]]:
account_info = await program.provider.connection.get_account_info(
address,
encoding="base64",
commitment=commitment,
)

if not account_info["result"]["value"]:
return None

slot = account_info["result"]["context"]["slot"]
data = base64.b64decode(account_info["result"]["value"]["data"][0])

decoded_data = program.coder.accounts.decode(data)

return DataAndSlot(slot, decoded_data)


async def get_state_account_and_slot(program: Program) -> DataAndSlot[State]:
state_public_key = get_state_public_key(program.program_id)
return await get_account_data_and_slot(state_public_key, program)


async def get_state_account(program: Program) -> State:
return (await get_state_account_and_slot(program)).data


async def get_if_stake_account(
program: Program, authority: PublicKey, spot_market_index: int
) -> InsuranceFundStake:
if_stake_pk = get_insurance_fund_stake_public_key(
program.program_id, authority, spot_market_index
)
response = await program.account["InsuranceFundStake"].fetch(if_stake_pk)
return cast(InsuranceFundStake, response)


async def get_user_stats_account(
program: Program,
authority: PublicKey,
) -> UserStats:
user_stats_public_key = get_user_stats_account_public_key(
program.program_id,
authority,
)
response = await program.account["UserStats"].fetch(user_stats_public_key)
return cast(UserStats, response)

async def get_user_account_and_slot(
program: Program,
user_public_key: PublicKey,
) -> DataAndSlot[User]:
return await get_account_data_and_slot(user_public_key, program)

async def get_user_account(
program: Program,
user_public_key: PublicKey,
) -> User:
return (await get_user_account_and_slot(program, user_public_key)).data


async def get_perp_market_account_and_slot(program: Program, market_index: int) -> Optional[DataAndSlot[PerpMarket]]:
perp_market_public_key = get_perp_market_public_key(program.program_id, market_index)
return await get_account_data_and_slot(perp_market_public_key, program)


async def get_perp_market_account(program: Program, market_index: int) -> PerpMarket:
return (await get_perp_market_account_and_slot(program, market_index)).data


async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount]:
return await program.account["PerpMarket"].all()


async def get_spot_market_account_and_slot(
program: Program, spot_market_index: int
) -> DataAndSlot[SpotMarket]:
spot_market_public_key = get_spot_market_public_key(
program.program_id, spot_market_index
)
return await get_account_data_and_slot(spot_market_public_key, program)


async def get_spot_market_account(
program: Program, spot_market_index: int
) -> SpotMarket:
return (await get_spot_market_account_and_slot(program, spot_market_index)).data


async def get_all_spot_market_accounts(program: Program) -> list[ProgramAccount]:
return await program.account["SpotMarket"].all()
Loading
Loading