diff --git a/tests/bankrun/test_bankrun.py b/tests/bankrun/test_bankrun.py index 18ae5c95..34d4b97c 100644 --- a/tests/bankrun/test_bankrun.py +++ b/tests/bankrun/test_bankrun.py @@ -1,16 +1,18 @@ +from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple from pytest import mark, raises from solders.account import Account -from solders.bankrun import ProgramTestContext, start +from solders.bankrun import BanksClient, ProgramTestContext, start from solders.clock import Clock from solders.instruction import AccountMeta, Instruction +from solders.keypair import Keypair from solders.message import Message from solders.pubkey import Pubkey from solders.rent import Rent from solders.system_program import transfer -from solders.transaction import TransactionError, VersionedTransaction +from solders.transaction import Transaction, TransactionError, VersionedTransaction async def helloworld_program( @@ -56,8 +58,15 @@ async def helloworld_program_via_set_account( return context, program_id, greeted_pubkey -@mark.asyncio -async def test_helloworld() -> None: +@dataclass +class HelloworldSetup: # noqa: D101 + client: BanksClient + payer: Keypair + msg: Message + greeted_pubkey: Pubkey + + +async def helloworld_setup() -> HelloworldSetup: # https://github.com/solana-labs/example-helloworld/blob/36eb41d1290732786e13bd097668d8676254a139/src/program-rust/tests/lib.rs context, program_id, greeted_pubkey = await helloworld_program() ix = Instruction( @@ -72,7 +81,31 @@ async def test_helloworld() -> None: assert greeted_account_before is not None assert greeted_account_before.data == bytes([0, 0, 0, 0]) msg = Message.new_with_blockhash([ix], payer.pubkey(), blockhash) - tx = VersionedTransaction(msg, [payer]) + return HelloworldSetup(client, payer, msg, greeted_pubkey) + + +@mark.asyncio +async def test_helloworld() -> None: + setup = await helloworld_setup() + msg = setup.msg + payer = setup.payer + client = setup.client + greeted_pubkey = setup.greeted_pubkey + tx = Transaction([payer], msg, msg.recent_blockhash) + await client.process_transaction(tx) + greeted_account_after = await client.get_account(greeted_pubkey) + assert greeted_account_after is not None + assert greeted_account_after.data == bytes([1, 0, 0, 0]) + + +@mark.asyncio +async def test_helloworld_legacy_tx() -> None: + setup = await helloworld_setup() + msg = setup.msg + payer = setup.payer + client = setup.client + greeted_pubkey = setup.greeted_pubkey + tx = Transaction([payer], msg, msg.recent_blockhash) await client.process_transaction(tx) greeted_account_after = await client.get_account(greeted_pubkey) assert greeted_account_after is not None