diff --git a/core/src/banking_stage/forwarder.rs b/core/src/banking_stage/forwarder.rs index 41e8e09fb372d2..d48c1556fa206e 100644 --- a/core/src/banking_stage/forwarder.rs +++ b/core/src/banking_stage/forwarder.rs @@ -311,6 +311,7 @@ mod tests { unprocessed_packet_batches::{DeserializedPacket, UnprocessedPacketBatches}, unprocessed_transaction_storage::ThreadType, }, + solana_client::rpc_client::SerializableTransaction, solana_gossip::cluster_info::{ClusterInfo, Node}, solana_ledger::{blockstore::Blockstore, genesis_utils::GenesisConfigInfo}, solana_perf::packet::PacketFlags, @@ -320,13 +321,25 @@ mod tests { hash::Hash, poh_config::PohConfig, signature::Keypair, signer::Signer, system_transaction, transaction::VersionedTransaction, }, - solana_streamer::recvmmsg::recv_mmsg, - std::sync::atomic::AtomicBool, + solana_streamer::{ + nonblocking::testing_utilities::{ + setup_quic_server_with_sockets, SpawnTestServerResult, TestServerConfig, + }, + quic::rt, + }, + std::{ + sync::atomic::AtomicBool, + time::{Duration, Instant}, + }, tempfile::TempDir, + tokio::time::sleep, }; struct TestSetup { _ledger_dir: TempDir, + blockhash: Hash, + rent_min_balance: u64, + bank_forks: Arc>, poh_recorder: Arc>, exit: Arc, @@ -363,6 +376,9 @@ mod tests { TestSetup { _ledger_dir: ledger_path, + blockhash: genesis_config.hash(), + rent_min_balance: genesis_config.rent.minimum_balance(0), + bank_forks, poh_recorder, exit, @@ -372,11 +388,52 @@ mod tests { } } + async fn check_all_received( + socket: UdpSocket, + expected_num_packets: usize, + expected_packet_size: usize, + expected_blockhash: &Hash, + ) { + let SpawnTestServerResult { + join_handle, + exit, + receiver, + server_address: _, + stats: _, + } = setup_quic_server_with_sockets(vec![socket], None, TestServerConfig::default()); + + let now = Instant::now(); + let mut total_packets = 0; + while now.elapsed().as_secs() < 5 { + if let Ok(packets) = receiver.try_recv() { + total_packets += packets.len(); + for packet in packets.iter() { + assert_eq!(packet.meta().size, expected_packet_size); + let tx: VersionedTransaction = packet.deserialize_slice(..).unwrap(); + assert_eq!( + tx.get_recent_blockhash(), + expected_blockhash, + "Unexpected blockhash, tx: {tx:?}, expected blockhash: {expected_blockhash}." + ); + } + } else { + sleep(Duration::from_millis(100)).await; + } + if total_packets >= expected_num_packets { + break; + } + } + assert_eq!(total_packets, expected_num_packets); + + exit.store(true, Ordering::Relaxed); + join_handle.await.unwrap(); + } + #[test] - #[ignore] fn test_forwarder_budget() { - solana_logger::setup(); let TestSetup { + blockhash, + rent_min_balance, bank_forks, poh_recorder, exit, @@ -390,17 +447,21 @@ mod tests { let tx = system_transaction::transfer( &Keypair::new(), &solana_sdk::pubkey::new_rand(), - 1, - Hash::new_unique(), + rent_min_balance, + blockhash, ); - let packet = Packet::from_data(None, tx).unwrap(); + let mut packet = Packet::from_data(None, tx).unwrap(); + // unstaked transactions will not be forwarded + packet.meta_mut().set_from_staked_node(true); + let expected_packet_size = packet.meta().size; let deserialized_packet = DeserializedPacket::new(packet).unwrap(); let test_cases = vec![ ("budget-restricted", DataBudget::restricted(), 0), ("budget-available", DataBudget::default(), 1), ]; - for (name, data_budget, expected_num_forwarded) in test_cases { + let runtime = rt("solQuicTestRt".to_string()); + for (_name, data_budget, expected_num_forwarded) in test_cases { let mut forwarder = Forwarder::new( poh_recorder.clone(), bank_forks.clone(), @@ -425,14 +486,13 @@ mod tests { &mut TracerPacketStats::new(0), ); - let recv_socket = &local_node.sockets.tpu_forwards[0]; - recv_socket - .set_nonblocking(expected_num_forwarded == 0) - .unwrap(); - - let mut packets = vec![Packet::default(); 2]; - let num_received = recv_mmsg(recv_socket, &mut packets[..]).unwrap_or_default(); - assert_eq!(num_received, expected_num_forwarded, "{name}"); + let recv_socket = &local_node.sockets.tpu_forwards_quic[0]; + runtime.block_on(check_all_received( + (*recv_socket).try_clone().unwrap(), + expected_num_forwarded, + expected_packet_size, + &blockhash, + )); } exit.store(true, Ordering::Relaxed); @@ -440,10 +500,10 @@ mod tests { } #[test] - #[ignore] fn test_handle_forwarding() { - solana_logger::setup(); let TestSetup { + blockhash, + rent_min_balance, bank_forks, poh_recorder, exit, @@ -453,36 +513,58 @@ mod tests { .. } = setup(); - // packets are deserialized upon receiving, failed packets will not be - // forwarded; Therefore need to create real packets here. let keypair = Keypair::new(); let pubkey = solana_sdk::pubkey::new_rand(); - let fwd_block_hash = Hash::new_unique(); + // forwarded packets will not be forwarded again let forwarded_packet = { - let transaction = system_transaction::transfer(&keypair, &pubkey, 1, fwd_block_hash); + let transaction = + system_transaction::transfer(&keypair, &pubkey, rent_min_balance, blockhash); let mut packet = Packet::from_data(None, transaction).unwrap(); packet.meta_mut().flags |= PacketFlags::FORWARDED; DeserializedPacket::new(packet).unwrap() }; - - let normal_block_hash = Hash::new_unique(); - let normal_packet = { - let transaction = system_transaction::transfer(&keypair, &pubkey, 1, normal_block_hash); + // packets from unstaked nodes will not be forwarded + let unstaked_packet = { + let transaction = + system_transaction::transfer(&keypair, &pubkey, rent_min_balance, blockhash); + let packet = Packet::from_data(None, transaction).unwrap(); + DeserializedPacket::new(packet).unwrap() + }; + // packets with incorrect blockhash will be filtered out + let incorrect_blockhash_packet = { + let transaction = + system_transaction::transfer(&keypair, &pubkey, rent_min_balance, Hash::default()); let packet = Packet::from_data(None, transaction).unwrap(); DeserializedPacket::new(packet).unwrap() }; + // maybe also add packet without stake and packet with incorrect blockhash? + let (expected_packet_size, normal_packet) = { + let transaction = system_transaction::transfer(&keypair, &pubkey, 1, blockhash); + let mut packet = Packet::from_data(None, transaction).unwrap(); + packet.meta_mut().set_from_staked_node(true); + (packet.meta().size, DeserializedPacket::new(packet).unwrap()) + }; + let mut unprocessed_packet_batches = UnprocessedTransactionStorage::new_transaction_storage( - UnprocessedPacketBatches::from_iter(vec![forwarded_packet, normal_packet], 2), + UnprocessedPacketBatches::from_iter( + vec![ + forwarded_packet, + unstaked_packet, + incorrect_blockhash_packet, + normal_packet, + ], + 4, + ), ThreadType::Transactions, ); let connection_cache = ConnectionCache::new("connection_cache_test"); let test_cases = vec![ - ("fwd-normal", true, vec![normal_block_hash], 2), - ("fwd-no-op", true, vec![], 2), - ("fwd-no-hold", false, vec![], 0), + ("fwd-normal", true, 2, 1), + ("fwd-no-op", true, 2, 0), + ("fwd-no-hold", false, 0, 0), ]; let mut forwarder = Forwarder::new( @@ -492,7 +574,8 @@ mod tests { Arc::new(connection_cache), Arc::new(DataBudget::default()), ); - for (name, hold, expected_ids, expected_num_unprocessed) in test_cases { + let runtime = rt("solQuicTestRt".to_string()); + for (name, hold, expected_num_unprocessed, expected_num_processed) in test_cases { let stats = BankingStageStats::default(); forwarder.handle_forwarding( &mut unprocessed_packet_batches, @@ -502,24 +585,14 @@ mod tests { &mut TracerPacketStats::new(0), ); - let recv_socket = &local_node.sockets.tpu_forwards[0]; - recv_socket - .set_nonblocking(expected_ids.is_empty()) - .unwrap(); - - let mut packets = vec![Packet::default(); 2]; - let num_received = recv_mmsg(recv_socket, &mut packets[..]).unwrap_or_default(); - assert_eq!(num_received, expected_ids.len(), "{name}"); - for (i, expected_id) in expected_ids.iter().enumerate() { - assert_eq!(packets[i].meta().size, 215); - let recv_transaction: VersionedTransaction = - packets[i].deserialize_slice(..).unwrap(); - assert_eq!( - recv_transaction.message.recent_blockhash(), - expected_id, - "{name}" - ); - } + let recv_socket = &local_node.sockets.tpu_forwards_quic[0]; + + runtime.block_on(check_all_received( + (*recv_socket).try_clone().unwrap(), + expected_num_processed, + expected_packet_size, + &blockhash, + )); let num_unprocessed_packets: usize = unprocessed_packet_batches.len(); assert_eq!(num_unprocessed_packets, expected_num_unprocessed, "{name}"); diff --git a/streamer/src/nonblocking/testing_utilities.rs b/streamer/src/nonblocking/testing_utilities.rs index 4a63458e7c6d74..ab87334c7cc4c9 100644 --- a/streamer/src/nonblocking/testing_utilities.rs +++ b/streamer/src/nonblocking/testing_utilities.rs @@ -136,13 +136,7 @@ pub struct SpawnTestServerResult { pub fn setup_quic_server( option_staked_nodes: Option, - TestServerConfig { - max_connections_per_peer, - max_staked_connections, - max_unstaked_connections, - max_streams_per_ms, - max_connections_per_ipaddr_per_minute, - }: TestServerConfig, + config: TestServerConfig, ) -> SpawnTestServerResult { let sockets = { #[cfg(not(target_os = "windows"))] @@ -171,7 +165,20 @@ pub fn setup_quic_server( vec![UdpSocket::bind("127.0.0.1:0").unwrap()] } }; + setup_quic_server_with_sockets(sockets, option_staked_nodes, config) +} +pub fn setup_quic_server_with_sockets( + sockets: Vec, + option_staked_nodes: Option, + TestServerConfig { + max_connections_per_peer, + max_staked_connections, + max_unstaked_connections, + max_streams_per_ms, + max_connections_per_ipaddr_per_minute, + }: TestServerConfig, +) -> SpawnTestServerResult { let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); let keypair = Keypair::new(); diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index b5f78c753da92c..3d15a42bfdad05 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -153,7 +153,7 @@ pub(crate) fn configure_server( Ok((server_config, cert_chain_pem)) } -fn rt(name: String) -> Runtime { +pub fn rt(name: String) -> Runtime { tokio::runtime::Builder::new_multi_thread() .thread_name(name) .enable_all()