diff --git a/examples/if_stake.py b/examples/if_stake.py index 8fab3288..919f2204 100644 --- a/examples/if_stake.py +++ b/examples/if_stake.py @@ -28,6 +28,14 @@ async def view_logs( connection._commitment = commitment.Processed pprint.pprint(logs) +async def does_account_exist( + connection, address +): + rpc_resp = await connection.get_account_info(address) + if rpc_resp["result"]["value"] is None: + return False + return True + async def main( keypath, env, @@ -63,10 +71,11 @@ async def main( print('ATA addr:', ata) if operation == 'add' or operation == 'remove' and spot_market_index == 1: - from spl.token.instructions import create_associated_token_account - ix = create_associated_token_account(ch.authority, ch.authority, spot_market.mint) - await ch.send_ixs(ix) ata = get_associated_token_address(ch.authority, spot_market.mint) + if not does_account_exist(connection, ata): + from spl.token.instructions import create_associated_token_account + ix = create_associated_token_account(ch.authority, ch.authority, spot_market.mint) + await ch.send_ixs(ix) ch.spot_market_atas[spot_market_index] = ata # send to WSOL and sync @@ -94,12 +103,10 @@ async def main( print('confirmation failed exiting...') return - rpc_resp = ( - await connection.get_account_info(get_insurance_fund_stake_public_key( - ch.program_id, kp.public_key, spot_market_index - )) + if_addr = get_insurance_fund_stake_public_key( + ch.program_id, kp.public_key, spot_market_index ) - if rpc_resp["result"]["value"] is None: + if not does_account_exist(connection, if_addr): print('initializing stake account...') sig = await ch.initialize_insurance_fund_stake(spot_market_index) print(sig) @@ -107,6 +114,12 @@ async def main( print('adding stake ....') sig = await ch.add_insurance_fund_stake(spot_market_index, if_amount) print(sig) + + elif operation == 'cancel': + print('canceling...') + sig = await ch.cancel_request_remove_insurance_fund_stake(spot_market_index) + print(sig) + elif operation == 'remove': resp = input('confirm removing stake: Y?') if resp != 'Y': @@ -136,10 +149,15 @@ async def main( await view_logs(ix, connection) print('removing if stake...') - ix = await ch.remove_insurance_fund_stake( - spot_market_index - ) - await view_logs(ix, connection) + try: + ix = await ch.remove_insurance_fund_stake( + spot_market_index + ) + await view_logs(ix, connection) + except Exception as e: + print('unable to unstake -- likely bc not enough time has passed since request') + print(e) + return elif operation == 'view': if_stake = await get_if_stake_account(ch.program, ch.authority, spot_market_index) @@ -182,7 +200,7 @@ async def main( parser.add_argument('--env', type=str, default='devnet') parser.add_argument('--amount', type=float, required=False) parser.add_argument('--market', type=int, required=True) - parser.add_argument('--operation', choices=['remove', 'add', 'view', 'settle'], required=True) + parser.add_argument('--operation', choices=['remove', 'add', 'view', 'settle', 'cancel'], required=True) args = parser.parse_args() diff --git a/src/driftpy/clearing_house.py b/src/driftpy/clearing_house.py index d184bd22..4f46dbfd 100644 --- a/src/driftpy/clearing_house.py +++ b/src/driftpy/clearing_house.py @@ -1297,6 +1297,44 @@ async def get_request_remove_insurance_fund_stake_ix( ), ) + async def cancel_request_remove_insurance_fund_stake(self, spot_market_index: int): + return await self.send_ixs( + await self.get_cancel_request_remove_insurance_fund_stake_ix(spot_market_index) + ) + + async def get_cancel_request_remove_insurance_fund_stake_ix(self, spot_market_index: int): + ra = await self.get_remaining_accounts( + writable_spot_market_index=spot_market_index + ) + + return self.program.instruction["cancel_request_remove_insurance_fund_stake"]( + spot_market_index, + ctx=Context( + accounts={ + "state": get_state_public_key(self.program_id), + "spot_market": get_spot_market_public_key( + self.program_id, spot_market_index + ), + "insurance_fund_stake": get_insurance_fund_stake_public_key( + self.program_id, self.authority, spot_market_index + ), + "user_stats": get_user_stats_account_public_key( + self.program_id, self.authority + ), + "authority": self.authority, + "insurance_fund_vault": get_insurance_fund_vault_public_key( + self.program_id, spot_market_index + ), + "drift_signer": get_clearing_house_signer_public_key( + self.program_id + ), + "user_token_account": self.spot_market_atas[spot_market_index], + "token_program": TOKEN_PROGRAM_ID, + }, + remaining_accounts=ra, + ), + ) + async def remove_insurance_fund_stake(self, spot_market_index: int): return await self.send_ixs( await self.get_remove_insurance_fund_stake_ix(spot_market_index)