From 774f0d8251d9a3805347a1c61dd8e533e7e2ba10 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Sun, 15 Dec 2024 11:32:06 -0600 Subject: [PATCH] removes redundant vector allocations before calling sendmmsg::batch_send streamer::sendmmsg::batch_send only requires an ExactSizeIterator: https://github.com/anza-xyz/agave/blob/566bb9565/streamer/src/sendmmsg.rs#L203-L204 https://github.com/anza-xyz/agave/blob/566bb9565/streamer/src/sendmmsg.rs#L166-L175 Collecting an iterator into a vector before calling batch_send is unnecessary and only adds overhead. In particular multi_target_send used in retransmitting shreds can be use without doing an additional vector allocation: https://github.com/anza-xyz/agave/blob/566bb9565/streamer/src/sendmmsg.rs#L219 --- core/src/banking_stage/forwarder.rs | 5 +- core/src/forwarding_stage.rs | 3 +- core/src/repair/repair_service.rs | 14 ++--- core/src/repair/serve_repair.rs | 9 ++- streamer/src/sendmmsg.rs | 56 ++++++++++++------- streamer/src/streamer.rs | 2 +- turbine/src/broadcast_stage.rs | 5 +- .../broadcast_duplicates_run.rs | 8 +-- udp-client/src/udp_client.rs | 16 ++---- 9 files changed, 62 insertions(+), 56 deletions(-) diff --git a/core/src/banking_stage/forwarder.rs b/core/src/banking_stage/forwarder.rs index 0527010c4b2a63..a2cea72b836bf3 100644 --- a/core/src/banking_stage/forwarder.rs +++ b/core/src/banking_stage/forwarder.rs @@ -23,7 +23,6 @@ use { solana_sdk::{pubkey::Pubkey, transport::TransportError}, solana_streamer::sendmmsg::batch_send, std::{ - iter::repeat, net::{SocketAddr, UdpSocket}, sync::{atomic::Ordering, Arc, RwLock}, }, @@ -281,8 +280,8 @@ impl Forwarder { match forward_option { ForwardOption::ForwardTpuVote => { // The vote must be forwarded using only UDP. - let pkts: Vec<_> = packet_vec.into_iter().zip(repeat(*addr)).collect(); - batch_send(&self.socket, &pkts).map_err(|err| err.into()) + let pkts = packet_vec.iter().map(|pkt| (pkt, addr)); + batch_send(&self.socket, pkts).map_err(TransportError::from) } ForwardOption::ForwardTransaction => { let conn = self.connection_cache.get_connection(addr); diff --git a/core/src/forwarding_stage.rs b/core/src/forwarding_stage.rs index 9075eefcdb3ccd..97ebd29d95f468 100644 --- a/core/src/forwarding_stage.rs +++ b/core/src/forwarding_stage.rs @@ -72,7 +72,8 @@ impl VoteClient { } fn send_batch(&self, batch: &mut Vec<(Vec, SocketAddr)>) { - let _res = batch_send(&self.udp_socket, batch); + let pkts = batch.iter().map(|(bytes, addr)| (bytes, addr)); + let _res = batch_send(&self.udp_socket, pkts); batch.clear(); } } diff --git a/core/src/repair/repair_service.rs b/core/src/repair/repair_service.rs index 37bdeacf9d59a6..2973186229519a 100644 --- a/core/src/repair/repair_service.rs +++ b/core/src/repair/repair_service.rs @@ -646,15 +646,13 @@ impl RepairService { let mut batch_send_repairs_elapsed = Measure::start("batch_send_repairs_elapsed"); if !batch.is_empty() { - match batch_send(repair_socket, &batch) { + let num_pkts = batch.len(); + let batch = batch.iter().map(|(bytes, addr)| (bytes, addr)); + match batch_send(repair_socket, batch) { Ok(()) => (), Err(SendPktsError::IoError(err, num_failed)) => { error!( - "{} batch_send failed to send {}/{} packets first error {:?}", - id, - num_failed, - batch.len(), - err + "{id} batch_send failed to send {num_failed}/{num_pkts} packets first error {err:?}" ); } } @@ -954,10 +952,10 @@ impl RepairService { ServeRepair::repair_proto_to_bytes(&request_proto, &identity_keypair).unwrap(); // Prepare packet batch to send - let reqs = [(packet_buf, address)]; + let reqs = [(&packet_buf, address)]; // Send packet batch - match batch_send(repair_socket, &reqs[..]) { + match batch_send(repair_socket, reqs) { Ok(()) => { debug!("successfully sent repair request to {pubkey} / {address}!"); } diff --git a/core/src/repair/serve_repair.rs b/core/src/repair/serve_repair.rs index 60b9c7a64ad14a..b37722d6db8364 100644 --- a/core/src/repair/serve_repair.rs +++ b/core/src/repair/serve_repair.rs @@ -1249,14 +1249,13 @@ impl ServeRepair { } } if !pending_pongs.is_empty() { - match batch_send(repair_socket, &pending_pongs) { + let num_pkts = pending_pongs.len(); + let pending_pongs = pending_pongs.iter().map(|(bytes, addr)| (bytes, addr)); + match batch_send(repair_socket, pending_pongs) { Ok(()) => (), Err(SendPktsError::IoError(err, num_failed)) => { warn!( - "batch_send failed to send {}/{} packets. First error: {:?}", - num_failed, - pending_pongs.len(), - err + "batch_send failed to send {num_failed}/{num_pkts} packets. First error: {err:?}" ); } } diff --git a/streamer/src/sendmmsg.rs b/streamer/src/sendmmsg.rs index 0afedcf628fd30..84a9d90ccbb707 100644 --- a/streamer/src/sendmmsg.rs +++ b/streamer/src/sendmmsg.rs @@ -16,7 +16,6 @@ use { std::{ borrow::Borrow, io, - iter::repeat, net::{SocketAddr, UdpSocket}, }, thiserror::Error, @@ -35,11 +34,15 @@ impl From for TransportError { } } +// The type and lifetime constraints are overspecified to match 'linux' code. #[cfg(not(target_os = "linux"))] -pub fn batch_send(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError> +pub fn batch_send<'a, S, T: 'a + ?Sized>( + sock: &UdpSocket, + packets: impl IntoIterator, +) -> Result<(), SendPktsError> where S: Borrow, - T: AsRef<[u8]>, + &'a T: AsRef<[u8]>, { let mut num_failed = 0; let mut erropt = None; @@ -158,12 +161,17 @@ fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPkts const MAX_IOV: usize = libc::UIO_MAXIOV as usize; #[cfg(target_os = "linux")] -pub fn batch_send_max_iov(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError> +fn batch_send_max_iov<'a, S, T: 'a + ?Sized>( + sock: &UdpSocket, + packets: impl IntoIterator, +) -> Result<(), SendPktsError> where S: Borrow, - T: AsRef<[u8]>, + &'a T: AsRef<[u8]>, { - assert!(packets.len() <= MAX_IOV); + let packets = packets.into_iter(); + let num_packets = packets.len(); + debug_assert!(num_packets <= MAX_IOV); let mut iovs = [MaybeUninit::uninit(); MAX_IOV]; let mut addrs = [MaybeUninit::uninit(); MAX_IOV]; @@ -177,13 +185,13 @@ where // SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are // guaranteed to be initialized by `mmsghdr_for_packet` before this loop. let hdrs_slice = - unsafe { std::slice::from_raw_parts_mut(hdrs.as_mut_ptr() as *mut mmsghdr, packets.len()) }; + unsafe { std::slice::from_raw_parts_mut(hdrs.as_mut_ptr() as *mut mmsghdr, num_packets) }; let result = sendmmsg_retry(sock, hdrs_slice); // SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are // guaranteed to be initialized by `mmsghdr_for_packet` before this loop. - for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(packets.len()) { + for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(num_packets) { unsafe { hdr.assume_init_drop(); iov.assume_init_drop(); @@ -194,13 +202,23 @@ where result } +// Need &'a to ensure that raw packet pointers obtained in mmsghdr_for_packet +// stay valid. #[cfg(target_os = "linux")] -pub fn batch_send(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError> +pub fn batch_send<'a, S, T: 'a + ?Sized>( + sock: &UdpSocket, + packets: impl IntoIterator, +) -> Result<(), SendPktsError> where S: Borrow, - T: AsRef<[u8]>, + &'a T: AsRef<[u8]>, { - for chunk in packets.chunks(MAX_IOV) { + let mut packets = packets.into_iter(); + loop { + let chunk = packets.by_ref().take(MAX_IOV); + if chunk.len() == 0 { + break; + } batch_send_max_iov(sock, chunk)?; } Ok(()) @@ -216,8 +234,8 @@ where T: AsRef<[u8]>, { let dests = dests.iter().map(Borrow::borrow); - let pkts: Vec<_> = repeat(&packet).zip(dests).collect(); - batch_send(sock, &pkts) + let pkts = dests.map(|addr| (&packet, addr)); + batch_send(sock, pkts) } #[cfg(test)] @@ -246,7 +264,7 @@ mod tests { let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect(); let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect(); - let sent = batch_send(&sender, &packet_refs[..]).ok(); + let sent = batch_send(&sender, packet_refs).ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; @@ -277,7 +295,7 @@ mod tests { }) .collect(); - let sent = batch_send(&sender, &packet_refs[..]).ok(); + let sent = batch_send(&sender, packet_refs).ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; @@ -345,7 +363,7 @@ mod tests { let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4]; let sender = bind_to_unspecified().expect("bind"); - let res = batch_send(&sender, &packet_refs[..]); + let res = batch_send(&sender, packet_refs); assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1))); let res = multi_target_send(&sender, &packets[0], &dest_refs); assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1))); @@ -366,7 +384,7 @@ mod tests { (&packets[3][..], &ipv4broadcast), (&packets[4][..], &ipv4local), ]; - match batch_send(&sender, &packet_refs[..]) { + match batch_send(&sender, packet_refs) { Ok(()) => panic!(), Err(SendPktsError::IoError(ioerror, num_failed)) => { assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied); @@ -382,7 +400,7 @@ mod tests { (&packets[3][..], &ipv4local), (&packets[4][..], &ipv4broadcast), ]; - match batch_send(&sender, &packet_refs[..]) { + match batch_send(&sender, packet_refs) { Ok(()) => panic!(), Err(SendPktsError::IoError(ioerror, num_failed)) => { assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied); @@ -398,7 +416,7 @@ mod tests { (&packets[3][..], &ipv4broadcast), (&packets[4][..], &ipv4local), ]; - match batch_send(&sender, &packet_refs[..]) { + match batch_send(&sender, packet_refs) { Ok(()) => panic!(), Err(SendPktsError::IoError(ioerror, num_failed)) => { assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied); diff --git a/streamer/src/streamer.rs b/streamer/src/streamer.rs index b7ce3f6e2756f3..b8dc0eb1699b30 100644 --- a/streamer/src/streamer.rs +++ b/streamer/src/streamer.rs @@ -362,7 +362,7 @@ fn recv_send( let data = pkt.data(..)?; socket_addr_space.check(&addr).then_some((data, addr)) }); - batch_send(sock, &packets.collect::>())?; + batch_send(sock, packets.collect::>())?; Ok(()) } diff --git a/turbine/src/broadcast_stage.rs b/turbine/src/broadcast_stage.rs index 1ecdae4b1c15bf..20d763bfdcc2fc 100644 --- a/turbine/src/broadcast_stage.rs +++ b/turbine/src/broadcast_stage.rs @@ -477,8 +477,9 @@ pub fn broadcast_shreds( shred_select.stop(); transmit_stats.shred_select += shred_select.as_us(); + let num_udp_packets = packets.len(); let mut send_mmsg_time = Measure::start("send_mmsg"); - match batch_send(s, &packets[..]) { + match batch_send(s, packets) { Ok(()) => (), Err(SendPktsError::IoError(ioerr, num_failed)) => { transmit_stats.dropped_packets_udp += num_failed; @@ -487,7 +488,7 @@ pub fn broadcast_shreds( } send_mmsg_time.stop(); transmit_stats.send_mmsg_elapsed += send_mmsg_time.as_us(); - transmit_stats.total_packets += packets.len() + quic_packets.len(); + transmit_stats.total_packets += num_udp_packets + quic_packets.len(); for (shred, addr) in quic_packets { let shred = Bytes::from(shred::Payload::unwrap_or_clone(shred.clone())); if let Err(err) = quic_endpoint_sender.blocking_send((addr, shred)) { diff --git a/turbine/src/broadcast_stage/broadcast_duplicates_run.rs b/turbine/src/broadcast_stage/broadcast_duplicates_run.rs index 6708bbf3f9055d..f5d5c0dce01604 100644 --- a/turbine/src/broadcast_stage/broadcast_duplicates_run.rs +++ b/turbine/src/broadcast_stage/broadcast_duplicates_run.rs @@ -392,13 +392,7 @@ impl BroadcastRun for BroadcastDuplicatesRun { .flatten() .collect(); - match batch_send(sock, &packets) { - Ok(()) => (), - Err(SendPktsError::IoError(ioerr, _)) => { - return Err(Error::Io(ioerr)); - } - } - Ok(()) + batch_send(sock, packets).map_err(|SendPktsError::IoError(err, _)| Error::Io(err)) } fn record(&mut self, receiver: &RecordReceiver, blockstore: &Blockstore) -> Result<()> { diff --git a/udp-client/src/udp_client.rs b/udp-client/src/udp_client.rs index 4a257a2d748092..39f6864d5c153d 100644 --- a/udp-client/src/udp_client.rs +++ b/udp-client/src/udp_client.rs @@ -2,7 +2,6 @@ //! an interface for sending data use { - core::iter::repeat, solana_connection_cache::client_connection::ClientConnection, solana_streamer::sendmmsg::batch_send, solana_transaction_error::TransportResult, @@ -37,18 +36,15 @@ impl ClientConnection for UdpClientConnection { } fn send_data_batch(&self, buffers: &[Vec]) -> TransportResult<()> { - let pkts: Vec<_> = buffers.iter().zip(repeat(self.server_addr())).collect(); - batch_send(&self.socket, &pkts)?; - Ok(()) + let addr = self.server_addr(); + let pkts = buffers.iter().map(|bytes| (bytes, addr)); + Ok(batch_send(&self.socket, pkts)?) } fn send_data_batch_async(&self, buffers: Vec>) -> TransportResult<()> { - let pkts: Vec<_> = buffers - .into_iter() - .zip(repeat(self.server_addr())) - .collect(); - batch_send(&self.socket, &pkts)?; - Ok(()) + let addr = self.server_addr(); + let pkts = buffers.iter().map(|bytes| (bytes, addr)); + Ok(batch_send(&self.socket, pkts)?) } fn send_data(&self, buffer: &[u8]) -> TransportResult<()> {