Skip to content

Commit

Permalink
Calls inspect_account() on the fee payer (anza-xyz#2718)
Browse files Browse the repository at this point in the history
  • Loading branch information
brooksprumo authored and ray-kast committed Nov 27, 2024
1 parent b09ea22 commit 883a50a
Show file tree
Hide file tree
Showing 4 changed files with 462 additions and 5 deletions.
125 changes: 122 additions & 3 deletions svm/src/account_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,15 @@ mod tests {
transaction::{Result, SanitizedTransaction, Transaction, TransactionError},
transaction_context::{TransactionAccount, TransactionContext},
},
std::{borrow::Cow, collections::HashMap, sync::Arc},
std::{borrow::Cow, cell::RefCell, collections::HashMap, sync::Arc},
};

#[derive(Default)]
struct TestCallbacks {
accounts_map: HashMap<Pubkey, AccountSharedData>,
#[allow(clippy::type_complexity)]
inspected_accounts:
RefCell<HashMap<Pubkey, Vec<(Option<AccountSharedData>, /* is_writable */ bool)>>>,
}

impl TransactionProcessingCallback for TestCallbacks {
Expand All @@ -624,6 +627,23 @@ mod tests {
fn get_account_shared_data(&self, pubkey: &Pubkey) -> Option<AccountSharedData> {
self.accounts_map.get(pubkey).cloned()
}

fn inspect_account(
&self,
address: &Pubkey,
account_state: AccountState,
is_writable: bool,
) {
let account = match account_state {
AccountState::Dead => None,
AccountState::Alive(account) => Some(account.clone()),
};
self.inspected_accounts
.borrow_mut()
.entry(*address)
.or_default()
.push((account, is_writable));
}
}

fn load_accounts_with_features_and_rent(
Expand All @@ -640,7 +660,10 @@ mod tests {
for (pubkey, account) in accounts {
accounts_map.insert(*pubkey, account.clone());
}
let callbacks = TestCallbacks { accounts_map };
let callbacks = TestCallbacks {
accounts_map,
..Default::default()
};
load_accounts(
&callbacks,
&[sanitized_tx],
Expand Down Expand Up @@ -929,7 +952,10 @@ mod tests {
for (pubkey, account) in accounts {
accounts_map.insert(*pubkey, account.clone());
}
let callbacks = TestCallbacks { accounts_map };
let callbacks = TestCallbacks {
accounts_map,
..Default::default()
};
load_accounts(
&callbacks,
&[tx],
Expand Down Expand Up @@ -2108,4 +2134,97 @@ mod tests {
assert_eq!(account.rent_epoch(), 0);
assert_eq!(account.lamports(), 0);
}

// Ensure `TransactionProcessingCallback::inspect_account()` is called when
// loading accounts for transaction processing.
#[test]
fn test_inspect_account_non_fee_payer() {
let mut mock_bank = TestCallbacks::default();

let address0 = Pubkey::new_unique(); // <-- fee payer
let address1 = Pubkey::new_unique(); // <-- initially alive
let address2 = Pubkey::new_unique(); // <-- initially dead
let address3 = Pubkey::new_unique(); // <-- program

let mut account0 = AccountSharedData::default();
account0.set_lamports(1_000_000_000);
mock_bank.accounts_map.insert(address0, account0.clone());

let mut account1 = AccountSharedData::default();
account1.set_lamports(2_000_000_000);
mock_bank.accounts_map.insert(address1, account1.clone());

// account2 *not* added to the bank's accounts_map

let mut account3 = AccountSharedData::default();
account3.set_lamports(4_000_000_000);
account3.set_executable(true);
account3.set_owner(native_loader::id());
mock_bank.accounts_map.insert(address3, account3.clone());

let message = Message {
account_keys: vec![address0, address1, address2, address3],
header: MessageHeader::default(),
instructions: vec![
CompiledInstruction {
program_id_index: 3,
accounts: vec![0],
data: vec![],
},
CompiledInstruction {
program_id_index: 3,
accounts: vec![1, 2],
data: vec![],
},
CompiledInstruction {
program_id_index: 3,
accounts: vec![1],
data: vec![],
},
],
recent_blockhash: Hash::new_unique(),
};
let sanitized_message = new_unchecked_sanitized_message(message);
let sanitized_transaction = SanitizedTransaction::new_for_tests(
sanitized_message,
vec![Signature::new_unique()],
false,
);
let validation_result = Ok(ValidatedTransactionDetails {
loaded_fee_payer_account: LoadedTransactionAccount {
account: account0.clone(),
..LoadedTransactionAccount::default()
},
..ValidatedTransactionDetails::default()
});
let _load_results = load_accounts(
&mock_bank,
&[sanitized_transaction],
vec![validation_result],
&mut TransactionErrorMetrics::default(),
None,
&FeatureSet::default(),
&RentCollector::default(),
&ProgramCacheForTxBatch::default(),
);

// ensure the loaded accounts are inspected
let mut actual_inspected_accounts: Vec<_> = mock_bank
.inspected_accounts
.borrow()
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
actual_inspected_accounts.sort_unstable_by(|a, b| a.0.cmp(&b.0));

let mut expected_inspected_accounts = vec![
// *not* key0, since it is loaded during fee payer validation
(address1, vec![(Some(account1), true)]),
(address2, vec![(None, true)]),
(address3, vec![(Some(account3), false)]),
];
expected_inspected_accounts.sort_unstable_by(|a, b| a.0.cmp(&b.0));

assert_eq!(actual_inspected_accounts, expected_inspected_accounts,);
}
}
94 changes: 93 additions & 1 deletion svm/src/transaction_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use {
transaction_account_state_info::TransactionAccountStateInfo,
transaction_error_metrics::TransactionErrorMetrics,
transaction_execution_result::{ExecutedTransaction, TransactionExecutionDetails},
transaction_processing_callback::TransactionProcessingCallback,
transaction_processing_callback::{AccountState, TransactionProcessingCallback},
transaction_processing_result::{ProcessedTransaction, TransactionProcessingResult},
},
log::debug,
Expand Down Expand Up @@ -434,6 +434,12 @@ impl<FG: ForkGraph> TransactionBatchProcessor<FG> {
return Err(TransactionError::AccountNotFound);
};

callbacks.inspect_account(
fee_payer_address,
AccountState::Alive(&fee_payer_account),
true, // <-- is_writable
);

let fee_payer_loaded_rent_epoch = fee_payer_account.rent_epoch();
let fee_payer_rent_debit = collect_rent_from_account(
feature_set,
Expand Down Expand Up @@ -1034,6 +1040,9 @@ mod tests {
#[derive(Default, Clone)]
pub struct MockBankCallback {
pub account_shared_data: Arc<RwLock<HashMap<Pubkey, AccountSharedData>>>,
#[allow(clippy::type_complexity)]
pub inspected_accounts:
Arc<RwLock<HashMap<Pubkey, Vec<(Option<AccountSharedData>, /* is_writable */ bool)>>>>,
}

impl TransactionProcessingCallback for MockBankCallback {
Expand Down Expand Up @@ -1065,6 +1074,24 @@ mod tests {
.unwrap()
.insert(*program_id, account_data);
}

fn inspect_account(
&self,
address: &Pubkey,
account_state: AccountState,
is_writable: bool,
) {
let account = match account_state {
AccountState::Dead => None,
AccountState::Alive(account) => Some(account.clone()),
};
self.inspected_accounts
.write()
.unwrap()
.entry(*address)
.or_default()
.push((account, is_writable));
}
}

#[test]
Expand Down Expand Up @@ -1853,6 +1880,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -1930,6 +1958,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -2014,6 +2043,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -2051,6 +2081,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -2086,6 +2117,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -2177,6 +2209,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -2242,6 +2275,7 @@ mod tests {
mock_accounts.insert(*fee_payer_address, fee_payer_account.clone());
let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand Down Expand Up @@ -2294,6 +2328,7 @@ mod tests {

let mock_bank = MockBankCallback {
account_shared_data: Arc::new(RwLock::new(mock_accounts)),
..Default::default()
};

let mut error_counters = TransactionErrorMetrics::default();
Expand All @@ -2318,4 +2353,61 @@ mod tests {
result.err()
);
}

// Ensure `TransactionProcessingCallback::inspect_account()` is called when
// validating the fee payer, since that's when the fee payer account is loaded.
#[test]
fn test_inspect_account_fee_payer() {
let fee_payer_address = Pubkey::new_unique();
let fee_payer_account = AccountSharedData::new_rent_epoch(
123_000_000_000,
0,
&Pubkey::default(),
RENT_EXEMPT_RENT_EPOCH,
);
let mock_bank = MockBankCallback::default();
mock_bank
.account_shared_data
.write()
.unwrap()
.insert(fee_payer_address, fee_payer_account.clone());

let message = new_unchecked_sanitized_message(Message::new_with_blockhash(
&[
ComputeBudgetInstruction::set_compute_unit_limit(2000u32),
ComputeBudgetInstruction::set_compute_unit_price(1_000_000_000),
],
Some(&fee_payer_address),
&Hash::new_unique(),
));
let batch_processor = TransactionBatchProcessor::<TestForkGraph>::default();
batch_processor
.validate_transaction_fee_payer(
&mock_bank,
None,
&message,
CheckedTransactionDetails {
nonce: None,
lamports_per_signature: 5000,
},
&FeatureSet::default(),
&FeeStructure::default(),
&RentCollector::default(),
&mut TransactionErrorMetrics::default(),
)
.unwrap();

// ensure the fee payer is an inspected account
let actual_inspected_accounts: Vec<_> = mock_bank
.inspected_accounts
.read()
.unwrap()
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
assert_eq!(
actual_inspected_accounts.as_slice(),
&[(fee_payer_address, vec![(Some(fee_payer_account), true)])],
);
}
}
Loading

0 comments on commit 883a50a

Please sign in to comment.