Skip to content

Commit

Permalink
account subscriber getters are sync
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 26, 2023
1 parent c6037eb commit 193f2ee
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 402 deletions.
294 changes: 175 additions & 119 deletions examples/limit_order_grid.py

Large diffs are not rendered by default.

145 changes: 75 additions & 70 deletions examples/start_lp.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,50 @@
import sys
sys.path.append('../src/')

sys.path.append("../src/")

from driftpy.constants.config import configs
from anchorpy import Provider
import json
import json
from anchorpy import Wallet
from solana.rpc.async_api import AsyncClient
from driftpy.drift_client import DriftClient
from driftpy.accounts import *
from solana.keypair import Keypair

# todo: airdrop udsc + init account for any kp
# rn do it through UI
# rn do it through UI
from driftpy.drift_user import DriftUser
from driftpy.constants.numeric_constants import AMM_RESERVE_PRECISION
from solana.rpc import commitment
import pprint
from driftpy.constants.numeric_constants import QUOTE_PRECISION

async def view_logs(
sig: str,
connection: AsyncClient
):
connection._commitment = commitment.Confirmed
logs = ''
try:

async def view_logs(sig: str, connection: AsyncClient):
connection._commitment = commitment.Confirmed
logs = ""
try:
await connection.confirm_transaction(sig, commitment.Confirmed)
logs = (await connection.get_transaction(sig))["result"]["meta"]["logMessages"]
finally:
connection._commitment = commitment.Processed
connection._commitment = commitment.Processed
pprint.pprint(logs)


async def main(
keypath,
env,
url,
keypath,
env,
url,
market_index,
liquidity_amount,
operation,
):
with open(keypath, 'r') as f: secret = json.load(f)
with open(keypath, "r") as f:
secret = json.load(f)
kp = Keypair.from_secret_key(bytes(secret))
print('using public key:', kp.public_key)
print('market:', market_index)
print("using public key:", kp.public_key)
print("market:", market_index)

config = configs[env]
wallet = Wallet(kp)
connection = AsyncClient(url)
Expand All @@ -53,98 +54,102 @@ async def main(
drift_user = User(dc)

total_collateral = await drift_user.get_total_collateral()
print('total collateral:', total_collateral/QUOTE_PRECISION)
print("total collateral:", total_collateral / QUOTE_PRECISION)

if total_collateral == 0:
print('cannot lp with 0 collateral')
print("cannot lp with 0 collateral")
return

market = await get_perp_market_account(
dc.program,
market_index
)
market = await get_perp_market_account(dc.program, market_index)
lp_amount = liquidity_amount * AMM_RESERVE_PRECISION
lp_amount -= lp_amount % market.amm.order_step_size
lp_amount = int(lp_amount)
print('standardized lp amount:', lp_amount / AMM_RESERVE_PRECISION)
print("standardized lp amount:", lp_amount / AMM_RESERVE_PRECISION)

if lp_amount < market.amm.order_step_size:
print('lp amount too small - exiting...')


print(f'{operation}ing {lp_amount} lp shares...')
print("lp amount too small - exiting...")

print(f"{operation}ing {lp_amount} lp shares...")

sig = None
if operation == 'add':
resp = input('confirm adding liquidity: Y?')
if resp != 'Y':
print('confirmation failed exiting...')
if operation == "add":
resp = input("confirm adding liquidity: Y?")
if resp != "Y":
print("confirmation failed exiting...")
return
sig = await dc.add_liquidity(lp_amount, market_index)
print(sig)

elif operation == 'remove':
resp = input('confirm removing liquidity: Y?')
if resp != 'Y':
print('confirmation failed exiting...')
elif operation == "remove":
resp = input("confirm removing liquidity: Y?")
if resp != "Y":
print("confirmation failed exiting...")
return
sig = await dc.remove_liquidity(lp_amount, market_index)
print(sig)

elif operation == 'view':
elif operation == "view":
pass

elif operation == 'settle':
resp = input('confirm settling revenue to if stake: Y?')
if resp != 'Y':
print('confirmation failed exiting...')
elif operation == "settle":
resp = input("confirm settling revenue to if stake: Y?")
if resp != "Y":
print("confirmation failed exiting...")
return
sig = await dc.settle_lp(dc.authority, market_index)
print(sig)
else:

else:
return

if sig:
print('confirming tx...')
print("confirming tx...")
await connection.confirm_transaction(sig)

position = await dc.get_user_position(market_index)
position = dc.get_user_position(market_index)
market = await get_perp_market_account(dc.program, market_index)
percent_provided = (position.lp_shares / market.amm.sqrt_k) * 100
percent_provided = (position.lp_shares / market.amm.sqrt_k) * 100
print(f"lp shares: {position.lp_shares}")
print(f"providing {percent_provided}% of total market liquidity")
print('done! :)')
print("done! :)")

if __name__ == '__main__':

if __name__ == "__main__":
import argparse
import os
import os

parser = argparse.ArgumentParser()
parser.add_argument('--keypath', type=str, required=False, default=os.environ.get('ANCHOR_WALLET'))
parser.add_argument('--env', type=str, default='devnet')
parser.add_argument('--amount', type=float, required=False)
parser.add_argument('--market', type=int, required=True)
parser.add_argument('--operation', choices=['remove', 'add', 'view', 'settle'], required=True)
parser.add_argument(
"--keypath", type=str, required=False, default=os.environ.get("ANCHOR_WALLET")
)
parser.add_argument("--env", type=str, default="devnet")
parser.add_argument("--amount", type=float, required=False)
parser.add_argument("--market", type=int, required=True)
parser.add_argument(
"--operation", choices=["remove", "add", "view", "settle"], required=True
)
args = parser.parse_args()

if args.keypath is None:
raise NotImplementedError("need to provide keypath or set ANCHOR_WALLET")

match args.env:
case 'devnet':
url = 'https://api.devnet.solana.com'
case 'mainnet':
url = 'https://api.mainnet-beta.solana.com'
case "devnet":
url = "https://api.devnet.solana.com"
case "mainnet":
url = "https://api.mainnet-beta.solana.com"
case _:
raise NotImplementedError('only devnet/mainnet env supported')
raise NotImplementedError("only devnet/mainnet env supported")

import asyncio
asyncio.run(main(
args.keypath,
args.env,
url,
args.market,
args.amount,
args.operation,
))

asyncio.run(
main(
args.keypath,
args.env,
url,
args.market,
args.amount,
args.operation,
)
)
18 changes: 5 additions & 13 deletions src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"):
self.cache = None

async def subscribe(self):
await self.cache_if_needed()
await self.update_cache()

async def update_cache(self):
if self.cache is None:
Expand Down Expand Up @@ -75,31 +75,23 @@ async def update_cache(self):

self.cache["oracle_price_data"] = oracle_data

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
await self.cache_if_needed()
def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.cache["state"]

async def get_perp_market_and_slot(
def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarketAccount]]:
await self.cache_if_needed()
return self.cache["perp_markets"][market_index]

async def get_spot_market_and_slot(
def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarketAccount]]:
await self.cache_if_needed()
return self.cache["spot_markets"][market_index]

async def get_oracle_price_data_and_slot(
def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
await self.cache_if_needed()
return self.cache["oracle_price_data"][str(oracle)]

async def cache_if_needed(self):
if self.cache is None:
await self.update_cache()

def unsubscribe(self):
self.cache = None
9 changes: 2 additions & 7 deletions src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ def __init__(
self.user_and_slot = None

async def subscribe(self):
await self.cache_if_needed()
await self.update_cache()

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
await self.cache_if_needed()
def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
return self.user_and_slot

async def cache_if_needed(self):
if self.user_and_slot is None:
await self.update_cache()

def unsubscribe(self):
self.user_and_slot = None
8 changes: 4 additions & 4 deletions src/driftpy/accounts/polling/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,20 @@ def unsubscribe(self):
)
self.callbacks.clear()

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.state

async def get_perp_market_and_slot(
def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarketAccount]]:
return self.perp_markets.get(market_index)

async def get_spot_market_and_slot(
def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarketAccount]]:
return self.spot_markets.get(market_index)

async def get_oracle_price_data_and_slot(
def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.oracle.get(str(oracle))
2 changes: 1 addition & 1 deletion src/driftpy/accounts/polling/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def unsubscribe(self):

self.callback_id = None

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
return self.data_and_slot
10 changes: 5 additions & 5 deletions src/driftpy/accounts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@ def unsubscribe(self):
pass

@abstractmethod
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
pass

@abstractmethod
async def get_perp_market_and_slot(
def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarketAccount]]:
pass

@abstractmethod
async def get_spot_market_and_slot(
def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarketAccount]]:
pass

@abstractmethod
async def get_oracle_price_data_and_slot(
def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
pass
Expand All @@ -64,5 +64,5 @@ def unsubscribe(self):
pass

@abstractmethod
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
pass
8 changes: 4 additions & 4 deletions src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,20 @@ async def subscribe_to_oracle(self, oracle: Pubkey, oracle_source: OracleSource)
await oracle_subscriber.subscribe()
self.oracle_subscribers[str(oracle)] = oracle_subscriber

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.state_subscriber.data_and_slot

async def get_perp_market_and_slot(
def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarketAccount]]:
return self.perp_market_subscribers[market_index].data_and_slot

async def get_spot_market_and_slot(
def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarketAccount]]:
return self.spot_market_subscribers[market_index].data_and_slot

async def get_oracle_price_data_and_slot(
def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.oracle_subscribers[str(oracle)].data_and_slot
Expand Down
2 changes: 1 addition & 1 deletion src/driftpy/accounts/ws/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
class WebsocketUserAccountSubscriber(
WebsocketAccountSubscriber[UserAccount], UserAccountSubscriber
):
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
return self.data_and_slot
Loading

0 comments on commit 193f2ee

Please sign in to comment.