Skip to content

Commit

Permalink
cache oracle prices too
Browse files Browse the repository at this point in the history
  • Loading branch information
0xNineteen committed Nov 15, 2022
1 parent ed65e4a commit b9b52c0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 58 deletions.
67 changes: 20 additions & 47 deletions examples/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,84 +3,57 @@

from driftpy.constants.config import configs
from anchorpy import Provider
import json
from anchorpy import Wallet
from solana.rpc.async_api import AsyncClient
from driftpy.clearing_house import ClearingHouse, is_available
from driftpy.clearing_house import ClearingHouse
from driftpy.accounts import *
from solana.keypair import Keypair
from driftpy.math.positions import is_available

# todo: airdrop udsc + init account for any kp
# rn do it through UI
from driftpy.clearing_house_user import ClearingHouseUser
from driftpy.constants.numeric_constants import AMM_RESERVE_PRECISION
from solana.rpc import commitment
import pprint

async def view_logs(
sig: str,
connection: AsyncClient
):
connection._commitment = commitment.Confirmed
logs = ''
try:
await connection.confirm_transaction(sig, commitment.Confirmed)
logs = (await connection.get_transaction(sig))["result"]["meta"]["logMessages"]
finally:
connection._commitment = commitment.Processed
pprint.pprint(logs)

async def main(
keypath,
env,
url,
authority,
subaccount,
):
with open(keypath, 'r') as f: secret = json.load(f)
kp = Keypair.from_secret_key(bytes(secret))
print('using public key:', kp.public_key)
authority = PublicKey(authority)

import time
s = time.time()

env = 'mainnet'
config = configs[env]
wallet = Wallet(kp)
connection = AsyncClient(url)
wallet = Wallet(Keypair()) # throwaway
connection = AsyncClient(config.default_http)
provider = Provider(connection, wallet)

ch = ClearingHouse.from_config(config, provider)
chu = ClearingHouseUser(ch)
chu = ClearingHouseUser(ch, authority=authority, subaccount_id=subaccount, use_cache=True)
await chu.set_cache()

total_collateral = await chu.get_total_collateral()
print('total collateral:', total_collateral)
print('leverage:', await chu.get_leverage())

user = await ch.get_user()
user = await chu.get_user()
print('perp positions:')
for position in user.perp_positions:
if not is_available(position):
# market = await get_perp_market_account(ch.program, position.market_index)
print('>', position)


print(time.time() - s)
print('done! :)')

if __name__ == '__main__':
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--keypath', type=str, required=False, default=os.environ.get('ANCHOR_WALLET'))
parser.add_argument('--env', type=str, default='devnet')
parser.add_argument('--pubkey', type=str, required=True)
parser.add_argument('--subacc', type=int, required=False, default=0)
args = parser.parse_args()

if args.keypath is None:
raise NotImplementedError("need to provide keypath or set ANCHOR_WALLET")

match args.env:
case 'devnet':
url = 'https://api.devnet.solana.com'
case _:
raise NotImplementedError('only devnet env supported')

import asyncio
asyncio.run(main(
args.keypath,
args.env,
url,
args.pubkey,
args.subacc
))

54 changes: 43 additions & 11 deletions src/driftpy/clearing_house_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def __init__(
self.connection = self.program.provider.connection
self.subaccount_id = subaccount_id
self.use_cache = use_cache

if self.use_cache:
self.set_cache()
self.cache_is_set = False

# cache all state, perpmarket, oracle, etc. in single cache -- user calls reload
# when they want to update the data?
Expand All @@ -69,39 +67,74 @@ def __init__(
# if state = cache => get cached_market else get new market

async def set_cache(self):
self.cache_is_set = True

self.CACHE = {}
state = await get_state_account(self.program)
self.CACHE['state'] = state

spot_markets = []
spot_market_oracle_data = []
for i in range(state.number_of_spot_markets):
spot_market = await get_spot_market_account(
self.program, i
)
spot_markets.append(spot_market)

if i == 0:
spot_market_oracle_data.append(1)
else:
oracle_data = await get_oracle_data(self.connection, spot_market.oracle)
spot_market_oracle_data.append(oracle_data)

self.CACHE['spot_markets'] = spot_markets
self.CACHE['spot_market_oracles'] = spot_market_oracle_data

perp_markets = []
perp_market_oracle_data = []
for i in range(state.number_of_markets):
perp_market = await get_perp_market_account(
self.program, i
)
perp_markets.append(perp_market)

oracle_data = await get_oracle_data(self.connection, perp_market.amm.oracle)
perp_market_oracle_data.append(oracle_data)

self.CACHE['perp_markets'] = perp_markets
self.CACHE['perp_market_oracles'] = perp_market_oracle_data

user = await get_user_account(
self.program, self.authority, self.subaccount_id
)
self.CACHE['user'] = user

async def get_spot_oracle_data(self, spot_market: SpotMarket):
if self.use_cache:
assert self.cache_is_set, 'must call clearing_house_user.set_cache() first'
return self.CACHE['spot_market_oracles'][spot_market.market_index]
else:
oracle_data = await get_oracle_data(self.connection, spot_market.oracle)
return oracle_data

async def get_perp_oracle_data(self, perp_market: PerpMarket):
if self.use_cache:
assert self.cache_is_set, 'must call clearing_house_user.set_cache() first'
return self.CACHE['perp_market_oracles'][perp_market.market_index]
else:
oracle_data = await get_oracle_data(self.connection, perp_market.amm.oracle)
return oracle_data

async def get_state(self):
if self.use_cache:
assert self.cache_is_set, 'must call clearing_house_user.set_cache() first'
return self.CACHE['state']
else:
return await get_state_account(self.program)

async def get_spot_market(self, i):
if self.use_cache:
assert self.cache_is_set, 'must call clearing_house_user.set_cache() first'
return self.CACHE['spot_markets'][i]
else:
return await get_spot_market_account(
Expand All @@ -110,6 +143,7 @@ async def get_spot_market(self, i):

async def get_perp_market(self, i):
if self.use_cache:
assert self.cache_is_set, 'must call clearing_house_user.set_cache() first'
return self.CACHE['perp_markets'][i]
else:
return await get_perp_market_account(
Expand All @@ -118,6 +152,7 @@ async def get_perp_market(self, i):

async def get_user(self):
if self.use_cache:
assert self.cache_is_set, 'must call clearing_house_user.set_cache() first'
return self.CACHE['user']
else:
return await get_user_account(
Expand Down Expand Up @@ -156,7 +191,7 @@ async def get_spot_market_liability(
else:
continue

oracle_data = await get_oracle_data(self.connection, spot_market.oracle)
oracle_data = await self.get_spot_oracle_data(spot_market)
if not include_open_orders:
if str(position.balance_type) == "SpotBalanceType.Borrow()":
token_amount = get_token_amount(
Expand Down Expand Up @@ -217,7 +252,7 @@ async def get_total_perp_positon(
if position.lp_shares > 0:
pass

price = (await get_oracle_data(self.connection, market.amm.oracle)).price
price = (await self.get_perp_oracle_data(market)).price
base_asset_amount = (
calculate_worst_case_base_asset_amount(position)
if include_open_orders
Expand Down Expand Up @@ -346,9 +381,8 @@ async def get_unrealized_pnl(
continue

market = await self.get_perp_market(position.market_index)
oracle_data = await get_oracle_data(
self.program.provider.connection, market.amm.oracle
)

oracle_data = await self.get_perp_oracle_data(market)
position_unrealized_pnl = calculate_position_pnl(
market, position, oracle_data, with_funding
)
Expand Down Expand Up @@ -385,9 +419,7 @@ async def get_spot_market_asset_value(
total_value += spot_token_value
continue

oracle_data = await get_oracle_data(
self.program.provider.connection, spot_market.oracle
)
oracle_data = await self.get_spot_oracle_data(spot_market)

if not include_open_orders:
if str(position.balance_type) == "SpotBalanceType.Deposit()":
Expand Down

0 comments on commit b9b52c0

Please sign in to comment.