diff --git a/examples/view.py b/examples/view.py index ef0d8716..e8d0eb49 100644 --- a/examples/view.py +++ b/examples/view.py @@ -3,84 +3,57 @@ from driftpy.constants.config import configs from anchorpy import Provider -import json from anchorpy import Wallet from solana.rpc.async_api import AsyncClient -from driftpy.clearing_house import ClearingHouse, is_available +from driftpy.clearing_house import ClearingHouse from driftpy.accounts import * from solana.keypair import Keypair +from driftpy.math.positions import is_available -# todo: airdrop udsc + init account for any kp -# rn do it through UI from driftpy.clearing_house_user import ClearingHouseUser -from driftpy.constants.numeric_constants import AMM_RESERVE_PRECISION -from solana.rpc import commitment -import pprint - -async def view_logs( - sig: str, - connection: AsyncClient -): - connection._commitment = commitment.Confirmed - logs = '' - try: - await connection.confirm_transaction(sig, commitment.Confirmed) - logs = (await connection.get_transaction(sig))["result"]["meta"]["logMessages"] - finally: - connection._commitment = commitment.Processed - pprint.pprint(logs) async def main( - keypath, - env, - url, + authority, + subaccount, ): - with open(keypath, 'r') as f: secret = json.load(f) - kp = Keypair.from_secret_key(bytes(secret)) - print('using public key:', kp.public_key) + authority = PublicKey(authority) + + import time + s = time.time() + env = 'mainnet' config = configs[env] - wallet = Wallet(kp) - connection = AsyncClient(url) + wallet = Wallet(Keypair()) # throwaway + connection = AsyncClient(config.default_http) provider = Provider(connection, wallet) ch = ClearingHouse.from_config(config, provider) - chu = ClearingHouseUser(ch) + chu = ClearingHouseUser(ch, authority=authority, subaccount_id=subaccount, use_cache=True) + await chu.set_cache() total_collateral = await chu.get_total_collateral() print('total collateral:', total_collateral) print('leverage:', await chu.get_leverage()) - user = await ch.get_user() + user = await chu.get_user() print('perp positions:') for position in user.perp_positions: if not is_available(position): - # market = await get_perp_market_account(ch.program, position.market_index) print('>', position) - + + print(time.time() - s) print('done! :)') if __name__ == '__main__': import argparse - import os parser = argparse.ArgumentParser() - parser.add_argument('--keypath', type=str, required=False, default=os.environ.get('ANCHOR_WALLET')) - parser.add_argument('--env', type=str, default='devnet') + parser.add_argument('--pubkey', type=str, required=True) + parser.add_argument('--subacc', type=int, required=False, default=0) args = parser.parse_args() - if args.keypath is None: - raise NotImplementedError("need to provide keypath or set ANCHOR_WALLET") - - match args.env: - case 'devnet': - url = 'https://api.devnet.solana.com' - case _: - raise NotImplementedError('only devnet env supported') - import asyncio asyncio.run(main( - args.keypath, - args.env, - url, + args.pubkey, + args.subacc )) diff --git a/src/driftpy/clearing_house_user.py b/src/driftpy/clearing_house_user.py index ea831482..a36856ba 100644 --- a/src/driftpy/clearing_house_user.py +++ b/src/driftpy/clearing_house_user.py @@ -57,9 +57,7 @@ def __init__( self.connection = self.program.provider.connection self.subaccount_id = subaccount_id self.use_cache = use_cache - - if self.use_cache: - self.set_cache() + self.cache_is_set = False # cache all state, perpmarket, oracle, etc. in single cache -- user calls reload # when they want to update the data? @@ -69,39 +67,74 @@ def __init__( # if state = cache => get cached_market else get new market async def set_cache(self): + self.cache_is_set = True + self.CACHE = {} state = await get_state_account(self.program) self.CACHE['state'] = state spot_markets = [] + spot_market_oracle_data = [] for i in range(state.number_of_spot_markets): spot_market = await get_spot_market_account( self.program, i ) spot_markets.append(spot_market) + + if i == 0: + spot_market_oracle_data.append(1) + else: + oracle_data = await get_oracle_data(self.connection, spot_market.oracle) + spot_market_oracle_data.append(oracle_data) + self.CACHE['spot_markets'] = spot_markets + self.CACHE['spot_market_oracles'] = spot_market_oracle_data perp_markets = [] + perp_market_oracle_data = [] for i in range(state.number_of_markets): perp_market = await get_perp_market_account( self.program, i ) perp_markets.append(perp_market) + + oracle_data = await get_oracle_data(self.connection, perp_market.amm.oracle) + perp_market_oracle_data.append(oracle_data) + self.CACHE['perp_markets'] = perp_markets + self.CACHE['perp_market_oracles'] = perp_market_oracle_data user = await get_user_account( self.program, self.authority, self.subaccount_id ) self.CACHE['user'] = user + + async def get_spot_oracle_data(self, spot_market: SpotMarket): + if self.use_cache: + assert self.cache_is_set, 'must call clearing_house_user.set_cache() first' + return self.CACHE['spot_market_oracles'][spot_market.market_index] + else: + oracle_data = await get_oracle_data(self.connection, spot_market.oracle) + return oracle_data + + async def get_perp_oracle_data(self, perp_market: PerpMarket): + if self.use_cache: + assert self.cache_is_set, 'must call clearing_house_user.set_cache() first' + return self.CACHE['perp_market_oracles'][perp_market.market_index] + else: + oracle_data = await get_oracle_data(self.connection, perp_market.amm.oracle) + return oracle_data async def get_state(self): if self.use_cache: + assert self.cache_is_set, 'must call clearing_house_user.set_cache() first' return self.CACHE['state'] else: return await get_state_account(self.program) async def get_spot_market(self, i): if self.use_cache: + assert self.cache_is_set, 'must call clearing_house_user.set_cache() first' return self.CACHE['spot_markets'][i] else: return await get_spot_market_account( @@ -110,6 +143,7 @@ async def get_spot_market(self, i): async def get_perp_market(self, i): if self.use_cache: + assert self.cache_is_set, 'must call clearing_house_user.set_cache() first' return self.CACHE['perp_markets'][i] else: return await get_perp_market_account( @@ -118,6 +152,7 @@ async def get_perp_market(self, i): async def get_user(self): if self.use_cache: + assert self.cache_is_set, 'must call clearing_house_user.set_cache() first' return self.CACHE['user'] else: return await get_user_account( @@ -156,7 +191,7 @@ async def get_spot_market_liability( else: continue - oracle_data = await get_oracle_data(self.connection, spot_market.oracle) + oracle_data = await self.get_spot_oracle_data(spot_market) if not include_open_orders: if str(position.balance_type) == "SpotBalanceType.Borrow()": token_amount = get_token_amount( @@ -217,7 +252,7 @@ async def get_total_perp_positon( if position.lp_shares > 0: pass - price = (await get_oracle_data(self.connection, market.amm.oracle)).price + price = (await self.get_perp_oracle_data(market)).price base_asset_amount = ( calculate_worst_case_base_asset_amount(position) if include_open_orders @@ -346,9 +381,8 @@ async def get_unrealized_pnl( continue market = await self.get_perp_market(position.market_index) - oracle_data = await get_oracle_data( - self.program.provider.connection, market.amm.oracle - ) + + oracle_data = await self.get_perp_oracle_data(market) position_unrealized_pnl = calculate_position_pnl( market, position, oracle_data, with_funding ) @@ -385,9 +419,7 @@ async def get_spot_market_asset_value( total_value += spot_token_value continue - oracle_data = await get_oracle_data( - self.program.provider.connection, spot_market.oracle - ) + oracle_data = await self.get_spot_oracle_data(spot_market) if not include_open_orders: if str(position.balance_type) == "SpotBalanceType.Deposit()":