diff --git a/core/src/banking_stage/forward_packet_batches_by_accounts.rs b/core/src/banking_stage/forward_packet_batches_by_accounts.rs index e01ca3b213b81e..04c75322eae13b 100644 --- a/core/src/banking_stage/forward_packet_batches_by_accounts.rs +++ b/core/src/banking_stage/forward_packet_batches_by_accounts.rs @@ -1,10 +1,10 @@ use { super::immutable_deserialized_packet::ImmutableDeserializedPacket, + core::num::NonZeroU64, solana_cost_model::{ block_cost_limits, cost_model::CostModel, cost_tracker::{CostTracker, UpdatedCosts}, - transaction_cost::TransactionCost, }, solana_perf::packet::Packet, solana_sdk::{feature_set::FeatureSet, transaction::SanitizedTransaction}, @@ -66,9 +66,8 @@ pub struct ForwardPacketBatchesByAccounts { cost_tracker: CostTracker, // Compute Unit limits for each batch - batch_vote_limit: u64, - batch_block_limit: u64, - batch_account_limit: u64, + batch_block_limit: NonZeroU64, + batch_account_limit: NonZeroU64, } impl ForwardPacketBatchesByAccounts { @@ -81,22 +80,32 @@ impl ForwardPacketBatchesByAccounts { .map(|_| ForwardBatch::default()) .collect(); - let batch_vote_limit = block_cost_limits::MAX_VOTE_UNITS.saturating_div(limit_ratio as u64); + let batch_vote_limit = + NonZeroU64::new(block_cost_limits::MAX_VOTE_UNITS.saturating_div(limit_ratio as u64)) + .expect("batch vote limit must not be zero"); let batch_block_limit = - block_cost_limits::MAX_BLOCK_UNITS.saturating_div(limit_ratio as u64); - let batch_account_limit = - block_cost_limits::MAX_WRITABLE_ACCOUNT_UNITS.saturating_div(limit_ratio as u64); + NonZeroU64::new(block_cost_limits::MAX_BLOCK_UNITS.saturating_div(limit_ratio as u64)) + .expect("batch block limit must not be zero"); + let batch_account_limit = NonZeroU64::new( + block_cost_limits::MAX_WRITABLE_ACCOUNT_UNITS.saturating_div(limit_ratio as u64), + ) + .expect("batch account limit must not be zero"); let mut cost_tracker = CostTracker::default(); cost_tracker.set_limits( - batch_account_limit.saturating_mul(number_of_batches as u64), - batch_block_limit.saturating_mul(number_of_batches as u64), - batch_vote_limit.saturating_mul(number_of_batches as u64), + batch_account_limit + .get() + .saturating_mul(number_of_batches as u64), + batch_block_limit + .get() + .saturating_mul(number_of_batches as u64), + batch_vote_limit + .get() + .saturating_mul(number_of_batches as u64), ); Self { forward_batches, cost_tracker, - batch_vote_limit, batch_block_limit, batch_account_limit, } @@ -111,7 +120,7 @@ impl ForwardPacketBatchesByAccounts { let tx_cost = CostModel::calculate_cost(sanitized_transaction, feature_set); if let Ok(updated_costs) = self.cost_tracker.try_add(&tx_cost) { - let batch_index = self.get_batch_index_by_updated_costs(&tx_cost, &updated_costs); + let batch_index = self.get_batch_index_by_updated_costs(&updated_costs); if let Some(forward_batch) = self.forward_batches.get_mut(batch_index) { forward_batch.forwardable_packets.push(immutable_packet); @@ -143,24 +152,12 @@ impl ForwardPacketBatchesByAccounts { // would be exceeded. Eg, if by block limit, it can be put into batch #1; by vote limit, it can // be put into batch #2; and by account limit, it can be put into batch #3; then it should be // put into batch #3 to satisfy all batch limits. - fn get_batch_index_by_updated_costs( - &self, - tx_cost: &TransactionCost, - updated_costs: &UpdatedCosts, - ) -> usize { - let Some(batch_index_by_block_limit) = - updated_costs.updated_block_cost.checked_div(match tx_cost { - TransactionCost::SimpleVote { .. } => self.batch_vote_limit, - TransactionCost::Transaction(_) => self.batch_block_limit, - }) - else { - unreachable!("batch vote limit or block limit must not be zero") - }; - + fn get_batch_index_by_updated_costs(&self, updated_costs: &UpdatedCosts) -> usize { + let batch_index_by_block_cost = + updated_costs.updated_block_cost / self.batch_block_limit.get(); let batch_index_by_account_limit = - updated_costs.updated_costliest_account_cost / self.batch_account_limit; - - batch_index_by_block_limit.max(batch_index_by_account_limit) as usize + updated_costs.updated_costliest_account_cost / self.batch_account_limit.get(); + batch_index_by_block_cost.max(batch_index_by_account_limit) as usize } } @@ -169,7 +166,6 @@ mod tests { use { super::*, crate::banking_stage::unprocessed_packet_batches::DeserializedPacket, - solana_cost_model::transaction_cost::UsageCostDetails, solana_sdk::{ compute_budget::ComputeBudgetInstruction, feature_set::FeatureSet, message::Message, pubkey::Pubkey, system_instruction, transaction::Transaction, @@ -343,48 +339,16 @@ mod tests { fn test_get_batch_index_by_updated_costs() { let test_cost = 99; - // check against vote limit only - { - let mut forward_packet_batches_by_accounts = - ForwardPacketBatchesByAccounts::new_with_default_batch_limits(); - forward_packet_batches_by_accounts.batch_vote_limit = test_cost + 1; - - let transaction_cost = TransactionCost::SimpleVote { - writable_accounts: vec![], - }; - assert_eq!( - 0, - forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, - &UpdatedCosts { - updated_block_cost: test_cost, - updated_costliest_account_cost: 0 - } - ) - ); - assert_eq!( - 1, - forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, - &UpdatedCosts { - updated_block_cost: test_cost + 1, - updated_costliest_account_cost: 0 - } - ) - ); - } - // check against block limit only { let mut forward_packet_batches_by_accounts = ForwardPacketBatchesByAccounts::new_with_default_batch_limits(); - forward_packet_batches_by_accounts.batch_block_limit = test_cost + 1; + forward_packet_batches_by_accounts.batch_block_limit = + NonZeroU64::new(test_cost + 1).unwrap(); - let transaction_cost = TransactionCost::Transaction(UsageCostDetails::default()); assert_eq!( 0, forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, &UpdatedCosts { updated_block_cost: test_cost, updated_costliest_account_cost: 0 @@ -394,7 +358,6 @@ mod tests { assert_eq!( 1, forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, &UpdatedCosts { updated_block_cost: test_cost + 1, updated_costliest_account_cost: 0 @@ -407,13 +370,12 @@ mod tests { { let mut forward_packet_batches_by_accounts = ForwardPacketBatchesByAccounts::new_with_default_batch_limits(); - forward_packet_batches_by_accounts.batch_account_limit = test_cost + 1; + forward_packet_batches_by_accounts.batch_account_limit = + NonZeroU64::new(test_cost + 1).unwrap(); - let transaction_cost = TransactionCost::Transaction(UsageCostDetails::default()); assert_eq!( 0, forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, &UpdatedCosts { updated_block_cost: 0, updated_costliest_account_cost: test_cost @@ -423,7 +385,6 @@ mod tests { assert_eq!( 1, forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, &UpdatedCosts { updated_block_cost: 0, updated_costliest_account_cost: test_cost + 1 @@ -439,15 +400,14 @@ mod tests { { let mut forward_packet_batches_by_accounts = ForwardPacketBatchesByAccounts::new_with_default_batch_limits(); - forward_packet_batches_by_accounts.batch_block_limit = test_cost + 1; - forward_packet_batches_by_accounts.batch_vote_limit = test_cost / 2 + 1; - forward_packet_batches_by_accounts.batch_account_limit = test_cost / 3 + 1; + forward_packet_batches_by_accounts.batch_block_limit = + NonZeroU64::new(test_cost + 1).unwrap(); + forward_packet_batches_by_accounts.batch_account_limit = + NonZeroU64::new(test_cost / 3 + 1).unwrap(); - let transaction_cost = TransactionCost::Transaction(UsageCostDetails::default()); assert_eq!( 2, forward_packet_batches_by_accounts.get_batch_index_by_updated_costs( - &transaction_cost, &UpdatedCosts { updated_block_cost: test_cost, updated_costliest_account_cost: test_cost