From d62171fb72e8d3136910ab861256da9418c088a3 Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Fri, 10 Nov 2023 10:33:47 +0100 Subject: [PATCH 1/7] Checkpoint --- Cargo.toml | 13 +- benches/read_write.rs | 142 +++++++++++++++++ benches/shared.rs | 327 ++++++++++++++++++++++++++++++++++++++ benches/throughput.rs | 264 ------------------------------ src/cli.rs | 6 +- src/cli/agent.rs | 2 +- src/cli/manage.rs | 2 +- src/cli/proxy.rs | 20 +-- src/cli/proxy/sessions.rs | 21 +-- src/cli/relay.rs | 2 +- src/lib.rs | 20 +++ src/net/maxmind_db.rs | 2 +- src/net/xds.rs | 2 +- src/test.rs | 17 +- 14 files changed, 536 insertions(+), 304 deletions(-) create mode 100644 benches/read_write.rs create mode 100644 benches/shared.rs delete mode 100644 benches/throughput.rs diff --git a/Cargo.toml b/Cargo.toml index 851f638a16..9c111a5e59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ edition = "2021" exclude = ["docs", "build", "examples", "image"] [[bench]] -name = "throughput" +name = "read_write" harness = false test = true @@ -133,13 +133,16 @@ strum_macros = "0.25.2" sys-info = "0.9.1" [dev-dependencies] -regex = "1.9.6" -criterion = { version = "0.5.1", features = ["html_reports"] } +divan = "0.1.2" once_cell = "1.18.0" -tracing-test = "0.2.4" pretty_assertions = "1.4.0" -tempfile = "3.8.0" rand = "0.8.5" +regex = "1.9.6" +tracing-test = "0.2.4" +tempfile = "3.8.0" + +[target.'cfg(target_os = "linux")'.dev-dependencies] +libc = "0.2" [build-dependencies] tonic-build = { version = "0.10.2", default_features = false, features = [ diff --git a/benches/read_write.rs b/benches/read_write.rs new file mode 100644 index 0000000000..6b6e8245f8 --- /dev/null +++ b/benches/read_write.rs @@ -0,0 +1,142 @@ +mod shared; + +use divan::Bencher; +use shared::*; + +use std::thread::spawn; + +fn main() { + divan::main(); +} + +/// We use this to run each benchmark on the different packets, note the size +/// of the packet rather than than packet index is used to give better output +/// from divan +const SIZES: &[usize] = &[254, 508, 1500]; + +#[inline] +fn counter(psize: usize) -> impl divan::counter::Counter { + divan::counter::BytesCount::new(psize * NUMBER_OF_PACKETS) +} + +#[inline] +fn get_packet_from_size() -> &'static [u8] { + PACKETS + .iter() + .find(|p| p.len() == N) + .expect("failed to find appropriately sized packet") +} + +mod read { + use super::*; + + #[divan::bench(consts = SIZES)] + fn direct(b: Bencher) { + let (writer, reader) = socket_pair(None, None); + let (tx, rx) = channel(); + let packet = get_packet_from_size::(); + + let writer = Writer::new(writer, reader.local_addr().unwrap(), rx, packet); + + spawn(move || loop { + if !writer.write_all(NUMBER_OF_PACKETS) { + break; + } + }); + + b.counter(counter(N)).bench_local(|| { + read_to_end(&reader, &tx, NUMBER_OF_PACKETS, N); + }); + } + + #[divan::bench(consts = SIZES)] + fn quilkin(b: Bencher) { + let (writer, reader) = socket_pair(None, None); + let (tx, rx) = channel(); + let packet = get_packet_from_size::(); + + //quilkin::test::enable_log("quilkin=debug"); + + let _quilkin_loop = QuilkinLoop::spinup(READ_QUILKIN_PORT, reader.local_addr().unwrap()); + + let writer = Writer::new( + writer, + (Ipv4Addr::LOCALHOST, READ_QUILKIN_PORT).into(), + rx, + packet, + ); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + spawn(move || loop { + if !writer.write_all(NUMBER_OF_PACKETS) { + break; + } + }); + + b.counter(counter(N)).bench_local(|| { + read_to_end(&reader, &tx, NUMBER_OF_PACKETS, N); + }); + } +} + +mod write { + use super::*; + + #[divan::bench(consts = SIZES)] + fn direct(b: Bencher) { + let (writer, reader) = socket_pair(None, None); + let (tx, rx) = channel(); + let packet = get_packet_from_size::(); + + let writer = Writer::new(writer, reader.local_addr().unwrap(), rx, packet); + + let (loop_tx, loop_rx) = mpsc::sync_channel(1); + + spawn(move || { + while let Ok((num, size)) = loop_rx.recv() { + read_to_end(&reader, &tx, num, size); + } + }); + + b.counter(counter(N)).bench_local(|| { + // Signal the read loop to run + loop_tx.send((NUMBER_OF_PACKETS, N)).unwrap(); + + writer.write_all(NUMBER_OF_PACKETS); + }); + } + + #[divan::bench(consts = SIZES)] + fn quilkin(b: Bencher) { + let (writer, reader) = socket_pair(None, None); + let (tx, rx) = channel(); + let packet = get_packet_from_size::(); + + let (loop_tx, loop_rx) = mpsc::sync_channel(1); + + let _quilkin_loop = QuilkinLoop::spinup(WRITE_QUILKIN_PORT, reader.local_addr().unwrap()); + + let writer = Writer::new( + writer, + (Ipv4Addr::LOCALHOST, WRITE_QUILKIN_PORT).into(), + rx, + packet, + ); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + spawn(move || { + while let Ok((num, size)) = loop_rx.recv() { + read_to_end(&reader, &tx, num, size); + } + }); + + b.counter(counter(N)).bench_local(|| { + // Signal the read loop to run + loop_tx.send((NUMBER_OF_PACKETS, N)).unwrap(); + + writer.write_all(NUMBER_OF_PACKETS); + }); + } +} diff --git a/benches/shared.rs b/benches/shared.rs new file mode 100644 index 0000000000..47872748ca --- /dev/null +++ b/benches/shared.rs @@ -0,0 +1,327 @@ +pub use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, + sync::{atomic, mpsc, Arc}, +}; + +pub const READ_QUILKIN_PORT: u16 = 9001; +pub const WRITE_QUILKIN_PORT: u16 = 9002; + +pub const MESSAGE_SIZE: usize = 0xffff; +pub const NUMBER_OF_PACKETS: usize = 10_000; + +pub const PACKETS: &[&[u8]] = &[ + // Half IPv4 MTU. + &[0xffu8; 254], + // IPv4 MTU. + &[0xffu8; 508], + // Ethernet MTU. + &[0xffu8; 1500], +]; + +pub fn make_socket(addr: SocketAddr) -> UdpSocket { + let socket = UdpSocket::bind(addr).expect("failed to bind"); + // socket + // .set_read_timeout(Some(std::time::Duration::from_millis(1))) + // .expect("failed to set read timeout"); + socket + .set_nonblocking(true) + .expect("failed to set non-blocking"); + socket +} + +#[derive(Debug)] +pub enum ReadLoopMsg { + #[allow(dead_code)] + Blocked(PacketStats), + Acked(PacketStats), + Finished(PacketStats), +} + +#[derive(Debug)] +pub struct PacketStats { + /// Number of individual receives that were completed + pub num_packets: usize, + /// Total number of bytes received + pub size_packets: usize, +} + +#[inline] +pub fn channel() -> (mpsc::Sender, mpsc::Receiver) { + mpsc::channel() +} + +#[inline] +pub fn socket_pair(write: Option, read: Option) -> (UdpSocket, UdpSocket) { + let w = make_socket((Ipv4Addr::LOCALHOST, write.unwrap_or_default()).into()); + let r = make_socket((Ipv4Addr::LOCALHOST, read.unwrap_or_default()).into()); + + (w, r) +} + +/// Writes never block even if the kernel's ring buffer is full, so we occasionally +/// ack chunks so the writer isn't waiting until the reader is blocked due to +/// ring buffer exhaustion in case +const CHUNK_SIZE: usize = 32 * 1024; +const ENABLE_GSO: bool = false; + +const fn batch_size(packet_size: usize) -> usize { + const MAX_GSO_SEGMENTS: usize = 64; + + let max_packets = CHUNK_SIZE / packet_size; + if !ENABLE_GSO { + return max_packets; + } + + // No min in const :( + if max_packets < MAX_GSO_SEGMENTS { + max_packets + } else { + MAX_GSO_SEGMENTS + } +} + +/// Runs a loop, reading from the socket until all the expected number of bytes (based on packet count and size) +/// have been successfully received. +/// +/// If the recv would block, a message is sent to request more bytes be sent to the socket, +/// we do this because while recv will fail if the timeout is surpassed and there is no +/// data to read, send (at least on linux) will never block on loopback even if there +/// not enough room in the ring buffer to hold the specified bytes +pub fn read_to_end( + socket: &UdpSocket, + tx: &mpsc::Sender, + packet_count: usize, + packet_size: usize, +) { + let mut packet = [0; MESSAGE_SIZE]; + + let mut num_packets = 0; + let mut size_packets = 0; + + let expected = packet_count * packet_size; + + let batch_size = batch_size(packet_size); + let mut batch_end = batch_size; + + while size_packets < expected { + let length = match socket.recv_from(&mut packet) { + Ok(t) => t.0, + Err(ref err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { + continue; + } + Err(err) => panic!("failed waiting for packet: {err}"), + }; + + num_packets += 1; + size_packets += length; + + if num_packets >= batch_end { + if tx + .send(ReadLoopMsg::Acked(PacketStats { + num_packets, + size_packets, + })) + .is_err() + { + return; + } + + batch_end += batch_size; + } + } + + let _ = tx.send(ReadLoopMsg::Finished(PacketStats { + num_packets, + size_packets, + })); +} + +pub struct Writer { + #[cfg(target_os = "linux")] + socket: socket2::Socket, + #[cfg(not(target_os = "linux"))] + socket: UdpSocket, + destination: SocketAddr, + rx: mpsc::Receiver, + batch_size: usize, + packet: &'static [u8], + #[cfg(unix)] + slices: Vec>, +} + +impl Writer { + pub fn new( + socket: UdpSocket, + destination: SocketAddr, + rx: mpsc::Receiver, + packet: &'static [u8], + ) -> Self { + let batch_size = batch_size(packet.len()); + + #[cfg(target_os = "linux")] + let (socket, slices) = { + let socket = socket2::Socket::from(socket); + + (socket, vec![std::io::IoSlice::new(packet); batch_size]) + }; + + Self { + socket, + destination, + rx, + batch_size, + packet, + #[cfg(target_os = "linux")] + slices, + } + } + + pub fn write_all(&self, packet_count: usize) -> bool { + use std::{mem, ptr}; + + // The value of the auxiliary data to put in the control message. + let segment_size = self.packet.len() as u16; + + #[cfg(target_os = "linux")] + let (dst, buf, layout) = { + // The number of bytes needed for this control message. + let cmsg_size = unsafe { libc::CMSG_SPACE(mem::size_of_val(&segment_size) as _) }; + let layout = std::alloc::Layout::from_size_align( + cmsg_size as usize, + mem::align_of::(), + ) + .unwrap(); + let buf = unsafe { std::alloc::alloc(layout) }; + + (socket2::SockAddr::from(self.destination), buf, layout) + }; + + let send_batch = |received: usize| { + let to_send = (packet_count - received).min(self.batch_size); + + // GSO, see https://github.com/flub/socket-use/blob/main/src/bin/sendmsg_gso.rs + #[cfg(target_os = "linux")] + { + if !ENABLE_GSO { + for _ in 0..to_send { + self.socket.send_to(self.packet, &dst).unwrap(); + } + return; + } + + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + + // Set the single destination and the payloads of each datagram + msg.msg_name = dst.as_ptr() as *mut _; + msg.msg_namelen = dst.len(); + msg.msg_iov = self.slices.as_ptr() as *mut _; + msg.msg_iovlen = to_send; + + msg.msg_control = buf as *mut _; + msg.msg_controllen = layout.size(); + + let cmsg: &mut libc::cmsghdr = unsafe { + let cmsg = libc::CMSG_FIRSTHDR(&msg); + let cmsg_zeroed: libc::cmsghdr = mem::zeroed(); + ptr::copy_nonoverlapping(&cmsg_zeroed, cmsg, 1); + cmsg.as_mut().unwrap() + }; + cmsg.cmsg_level = libc::SOL_UDP; + cmsg.cmsg_type = libc::UDP_SEGMENT; + cmsg.cmsg_len = + unsafe { libc::CMSG_LEN(mem::size_of_val(&segment_size) as _) } as libc::size_t; + unsafe { ptr::write(libc::CMSG_DATA(cmsg) as *mut u16, segment_size) }; + + use std::os::fd::AsRawFd; + if unsafe { libc::sendmsg(self.socket.as_raw_fd(), &msg, 0) } == -1 { + panic!("failed to send batch: {}", std::io::Error::last_os_error()); + } + } + + #[cfg(not(target_os = "linux"))] + { + for _ in 0..to_send { + self.socket.send_to(self.packet, self.destination).unwrap(); + } + } + }; + + // Queue 2 batches at the beginning, giving the reader enough work to do + // after the initial batch has been read + send_batch(0); + send_batch(self.batch_size); + + let mut finished = false; + while let Ok(msg) = self.rx.recv() { + match msg { + ReadLoopMsg::Blocked(ps) => { + panic!("reader blocked {ps:?}"); + } + ReadLoopMsg::Acked(ps) => { + send_batch(ps.num_packets); + } + ReadLoopMsg::Finished(ps) => { + assert_eq!(ps.size_packets, self.packet.len() * packet_count); + finished = true; + break; + } + } + } + + // Don't leak the buf + unsafe { std::alloc::dealloc(buf, layout) }; + finished + } +} + +pub struct QuilkinLoop { + shutdown: Option, + thread: Option>, +} + +impl QuilkinLoop { + /// Run and instance of quilkin that sends and receives data from the given address. + pub fn spinup(port: u16, endpoint: SocketAddr) -> Self { + let (shutdown_tx, shutdown_rx) = + quilkin::make_shutdown_channel(quilkin::ShutdownKind::Benching); + + let thread = std::thread::spawn(move || { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let config = Arc::new(quilkin::Config::default()); + config.clusters.modify(|clusters| { + clusters + .insert_default([quilkin::net::endpoint::Endpoint::new(endpoint.into())].into()) + }); + + let proxy = quilkin::cli::Proxy { + port, + qcmp_port: runtime + .block_on(quilkin::test::available_addr( + &quilkin::test::AddressType::Random, + )) + .port(), + ..<_>::default() + }; + + runtime.block_on(async move { + let admin = quilkin::cli::Admin::Proxy(<_>::default()); + proxy.run(config, admin, shutdown_rx).await.unwrap(); + }); + }); + + Self { + shutdown: Some(shutdown_tx), + thread: Some(thread), + } + } +} + +impl Drop for QuilkinLoop { + fn drop(&mut self) { + let Some(stx) = self.shutdown.take() else { + return; + }; + stx.send(quilkin::ShutdownKind::Benching).unwrap(); + self.thread.take().unwrap().join().unwrap(); + } +} diff --git a/benches/throughput.rs b/benches/throughput.rs deleted file mode 100644 index 335338273e..0000000000 --- a/benches/throughput.rs +++ /dev/null @@ -1,264 +0,0 @@ -use std::net::{Ipv4Addr, SocketAddr, UdpSocket}; -use std::sync::{atomic, mpsc, Arc}; -use std::thread::sleep; -use std::time; - -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use once_cell::sync::Lazy; -use quilkin::test::AddressType; - -const MESSAGE_SIZE: usize = 0xffff; -const DEFAULT_MESSAGE: [u8; 0xffff] = [0xff; 0xffff]; -const BENCH_LOOP_ADDR: &str = "127.0.0.1:8002"; -const FEEDBACK_LOOP_ADDR: &str = "127.0.0.1:8001"; -const QUILKIN_ADDR: &str = "127.0.0.1:8000"; -const NUMBER_OF_PACKETS: usize = 10_000; - -const PACKETS: &[&[u8]] = &[ - // Half IPv4 MTU. - &[0xffu8; 254], - // IPv4 MTU. - &[0xffu8; 508], - // Ethernet MTU. - &[0xffu8; 1500], -]; - -/// Run and instance of quilkin that sends and received data -/// from the given address. -fn run_quilkin(port: u16, endpoint: SocketAddr) { - std::thread::spawn(move || { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let config = Arc::new(quilkin::Config::default()); - config.clusters.modify(|clusters| { - clusters.insert_default([quilkin::net::endpoint::Endpoint::new(endpoint.into())].into()) - }); - - let proxy = quilkin::cli::Proxy { - port, - qcmp_port: runtime - .block_on(quilkin::test::available_addr(&AddressType::Random)) - .port(), - ..<_>::default() - }; - - runtime.block_on(async move { - let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel::<()>(()); - let admin = quilkin::cli::Admin::Proxy(<_>::default()); - proxy.run(config, admin, shutdown_rx).await.unwrap(); - }); - }); -} - -static THROUGHPUT_SERVER_INIT: Lazy<()> = Lazy::new(|| { - run_quilkin(8000, FEEDBACK_LOOP_ADDR.parse().unwrap()); -}); - -static FEEDBACK_LOOP: Lazy<()> = Lazy::new(|| { - std::thread::spawn(|| { - let socket = UdpSocket::bind(FEEDBACK_LOOP_ADDR).unwrap(); - socket - .set_read_timeout(Some(std::time::Duration::from_millis(500))) - .unwrap(); - - loop { - let mut packet = [0; MESSAGE_SIZE]; - let (_, addr) = socket.recv_from(&mut packet).unwrap(); - let length = packet.iter().position(|&x| x == 0).unwrap_or(packet.len()); - let packet = &packet[..length]; - assert_eq!(packet, &DEFAULT_MESSAGE[..length]); - socket.send_to(packet, addr).unwrap(); - } - }); -}); - -fn throughput_benchmark(c: &mut Criterion) { - Lazy::force(&FEEDBACK_LOOP); - Lazy::force(&THROUGHPUT_SERVER_INIT); - // Sleep to give the servers some time to warm-up. - std::thread::sleep(std::time::Duration::from_millis(500)); - let socket = UdpSocket::bind(BENCH_LOOP_ADDR).unwrap(); - socket - .set_read_timeout(Some(std::time::Duration::from_millis(500))) - .unwrap(); - let mut packet = [0; MESSAGE_SIZE]; - - let mut group = c.benchmark_group("throughput"); - for message in PACKETS { - group.sample_size(NUMBER_OF_PACKETS); - group.sampling_mode(criterion::SamplingMode::Flat); - group.throughput(criterion::Throughput::Bytes(message.len() as u64)); - group.bench_with_input( - BenchmarkId::new("direct", format!("{} bytes", message.len())), - &message, - |b, message| { - b.iter(|| { - socket.send_to(message, FEEDBACK_LOOP_ADDR).unwrap(); - socket.recv_from(&mut packet).unwrap(); - }) - }, - ); - group.bench_with_input( - BenchmarkId::new("quilkin", format!("{} bytes", message.len())), - &message, - |b, message| { - b.iter(|| { - socket.send_to(message, QUILKIN_ADDR).unwrap(); - socket.recv_from(&mut packet).unwrap(); - }) - }, - ); - } - group.finish(); -} - -const WRITE_LOOP_ADDR: &str = "127.0.0.1:8003"; -const READ_LOOP_ADDR: &str = "127.0.0.1:8004"; - -const READ_QUILKIN_PORT: u16 = 9001; -static READ_SERVER_INIT: Lazy<()> = Lazy::new(|| { - run_quilkin(READ_QUILKIN_PORT, READ_LOOP_ADDR.parse().unwrap()); -}); - -const WRITE_QUILKIN_PORT: u16 = 9002; -static WRITE_SERVER_INIT: Lazy<()> = Lazy::new(|| { - run_quilkin(WRITE_QUILKIN_PORT, WRITE_LOOP_ADDR.parse().unwrap()); -}); - -/// Binds a socket to `addr`, and waits for an initial packet to be sent to it to establish -/// a connection. After which any `Vec` sent to the returned channel will result in that -/// data being send via that connection - thereby skipping the proxy `read` operation. -fn write_feedback(addr: SocketAddr) -> mpsc::Sender> { - let (write_tx, write_rx) = mpsc::channel::>(); - std::thread::spawn(move || { - let socket = UdpSocket::bind(addr).unwrap(); - socket - .set_read_timeout(Some(std::time::Duration::from_millis(500))) - .unwrap(); - let mut packet = [0; MESSAGE_SIZE]; - let (_, source) = socket.recv_from(&mut packet).unwrap(); - while let Ok(packet) = write_rx.recv() { - socket.send_to(packet.as_slice(), source).unwrap(); - } - }); - write_tx -} - -fn readwrite_benchmark(c: &mut Criterion) { - Lazy::force(&READ_SERVER_INIT); - - // start a feedback server for read operations, that sends a response through a channel, - // thereby skipping a proxy connection on the return. - let (read_tx, read_rx) = mpsc::channel::>(); - std::thread::spawn(move || { - let socket = UdpSocket::bind(READ_LOOP_ADDR).unwrap(); - socket - .set_read_timeout(Some(std::time::Duration::from_millis(500))) - .unwrap(); - let mut packet = [0; MESSAGE_SIZE]; - loop { - let (length, _) = socket.recv_from(&mut packet).unwrap(); - let packet = &packet[..length]; - assert_eq!(packet, &DEFAULT_MESSAGE[..length]); - - if read_tx.send(packet.to_vec()).is_err() { - return; - } - } - }); - - // start a feedback server for a direct write benchmark. - let direct_write_addr = (Ipv4Addr::LOCALHOST, 9004).into(); - let direct_write_tx = write_feedback(direct_write_addr); - - // start a feedback server for a quilkin write benchmark. - let quilkin_write_addr = (Ipv4Addr::LOCALHOST, WRITE_QUILKIN_PORT); - let quilkin_write_tx = write_feedback(WRITE_LOOP_ADDR.parse().unwrap()); - Lazy::force(&WRITE_SERVER_INIT); - - // Sleep to give the servers some time to warm-up. - std::thread::sleep(std::time::Duration::from_millis(150)); - - let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); - socket - .set_read_timeout(Some(std::time::Duration::from_millis(500))) - .unwrap(); - - // prime the direct write connection - socket.send_to(PACKETS[0], direct_write_addr).unwrap(); - - // we need to send packets at least once a minute, otherwise the endpoint session expires. - // So setting up a ping packet for the write test. - // TODO(markmandel): If we ever make session timeout configurable, we can remove this. - let ping_socket = socket.try_clone().unwrap(); - let stop = Arc::new(atomic::AtomicBool::default()); - let ping_stop = stop.clone(); - std::thread::spawn(move || { - while !ping_stop.load(atomic::Ordering::Relaxed) { - ping_socket.send_to(PACKETS[0], quilkin_write_addr).unwrap(); - sleep(time::Duration::from_secs(30)); - } - }); - - let mut group = c.benchmark_group("readwrite"); - - for message in PACKETS { - group.sample_size(NUMBER_OF_PACKETS); - group.sampling_mode(criterion::SamplingMode::Flat); - group.throughput(criterion::Throughput::Bytes(message.len() as u64)); - - // direct read - group.bench_with_input( - BenchmarkId::new("direct-read", format!("{} bytes", message.len())), - &message, - |b, message| { - b.iter(|| { - socket.send_to(message, READ_LOOP_ADDR).unwrap(); - read_rx.recv().unwrap(); - }) - }, - ); - // quilkin read - let addr = (Ipv4Addr::LOCALHOST, READ_QUILKIN_PORT); - group.bench_with_input( - BenchmarkId::new("quilkin-read", format!("{} bytes", message.len())), - &message, - |b, message| { - b.iter(|| { - socket.send_to(message, addr).unwrap(); - read_rx.recv().unwrap(); - }) - }, - ); - - // direct write - let mut packet = [0; MESSAGE_SIZE]; - group.bench_with_input( - BenchmarkId::new("direct-write", format!("{} bytes", message.len())), - &message, - |b, message| { - b.iter(|| { - direct_write_tx.send(message.to_vec()).unwrap(); - socket.recv(&mut packet).unwrap(); - }) - }, - ); - - // quilkin write - let mut packet = [0; MESSAGE_SIZE]; - group.bench_with_input( - BenchmarkId::new("quilkin-write", format!("{} bytes", message.len())), - &message, - |b, message| { - b.iter(|| { - quilkin_write_tx.send(message.to_vec()).unwrap(); - socket.recv(&mut packet).unwrap(); - }) - }, - ); - } - - stop.store(true, atomic::Ordering::Relaxed); -} - -criterion_group!(benches, readwrite_benchmark, throughput_benchmark); -criterion_main!(benches); diff --git a/src/cli.rs b/src/cli.rs index 0af904e2ad..0b789c3398 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -23,7 +23,7 @@ use std::{ use clap::builder::TypedValueParser; use clap::crate_version; -use tokio::{signal, sync::watch}; +use tokio::signal; use crate::Config; use strum_macros::{Display, EnumString}; @@ -181,7 +181,7 @@ impl Cli { mode.server(config.clone(), self.admin_address); } - let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); + let (shutdown_tx, shutdown_rx) = crate::make_shutdown_channel(Default::default()); #[cfg(target_os = "linux")] let mut sig_term_fut = signal::unix::signal(signal::unix::SignalKind::terminate())?; @@ -200,7 +200,7 @@ impl Cli { tracing::info!(%signal, "shutting down from signal"); // Don't unwrap in order to ensure that we execute // any subsequent shutdown tasks. - shutdown_tx.send(()).ok(); + shutdown_tx.send(crate::ShutdownKind::Normal).ok(); }); match self.command { diff --git a/src/cli/agent.rs b/src/cli/agent.rs index e714790f48..9bfb85e436 100644 --- a/src/cli/agent.rs +++ b/src/cli/agent.rs @@ -76,7 +76,7 @@ impl Agent { &self, config: Arc, mode: Admin, - mut shutdown_rx: tokio::sync::watch::Receiver<()>, + mut shutdown_rx: crate::ShutdownRx, ) -> crate::Result<()> { let locality = (self.region.is_some() || self.zone.is_some() || self.sub_zone.is_some()) .then(|| crate::net::endpoint::Locality { diff --git a/src/cli/manage.rs b/src/cli/manage.rs index e17fa9b110..2740091e39 100644 --- a/src/cli/manage.rs +++ b/src/cli/manage.rs @@ -57,7 +57,7 @@ impl Manage { &self, config: std::sync::Arc, mode: Admin, - mut shutdown_rx: tokio::sync::watch::Receiver<()>, + mut shutdown_rx: crate::ShutdownRx, ) -> crate::Result<()> { let locality = (self.region.is_some() || self.zone.is_some() || self.sub_zone.is_some()) .then(|| crate::net::endpoint::Locality { diff --git a/src/cli/proxy.rs b/src/cli/proxy.rs index 66c8397bcd..11ad5e88c7 100644 --- a/src/cli/proxy.rs +++ b/src/cli/proxy.rs @@ -37,7 +37,7 @@ use crate::filters::FilterFactory; use crate::{ filters::{Filter, ReadContext}, net::{xds::ResourceType, DualStackLocalSocket}, - Config, Result, + Config, Result, ShutdownRx, }; define_port!(7777); @@ -87,7 +87,7 @@ impl Proxy { &self, config: std::sync::Arc, mode: Admin, - mut shutdown_rx: tokio::sync::watch::Receiver<()>, + mut shutdown_rx: ShutdownRx, ) -> crate::Result<()> { let _mmdb_task = self.mmdb.clone().map(|source| { tokio::spawn(async move { @@ -167,12 +167,14 @@ impl Proxy { .await .map_err(|error| eyre::eyre!(error))?; - tracing::info!(sessions=%sessions.sessions().len(), "waiting for active sessions to expire"); - while sessions.sessions().is_not_empty() { - tokio::time::sleep(Duration::from_secs(1)).await; - tracing::debug!(sessions=%sessions.sessions().len(), "sessions still active"); + if *shutdown_rx.borrow() == crate::ShutdownKind::Normal { + tracing::info!(sessions=%sessions.sessions().len(), "waiting for active sessions to expire"); + while sessions.sessions().is_not_empty() { + tokio::time::sleep(Duration::from_secs(1)).await; + tracing::debug!(sessions=%sessions.sessions().len(), "sessions still active"); + } + tracing::info!("all sessions expired"); } - tracing::info!("all sessions expired"); Ok(()) } @@ -605,7 +607,7 @@ mod tests { ) .unwrap(), ), - tokio::sync::watch::channel(()).1, + crate::make_shutdown_channel(crate::ShutdownKind::Testing).1, ), } .spawn(); @@ -655,7 +657,7 @@ mod tests { let sessions = SessionPool::new( config.clone(), shared_socket.clone(), - tokio::sync::watch::channel(()).1, + crate::make_shutdown_channel(crate::ShutdownKind::Testing).1, ); proxy diff --git a/src/cli/proxy/sessions.rs b/src/cli/proxy/sessions.rs index 3c84b89926..1095a7b09e 100644 --- a/src/cli/proxy/sessions.rs +++ b/src/cli/proxy/sessions.rs @@ -21,14 +21,11 @@ use std::{ time::Duration, }; -use tokio::{ - sync::{watch, RwLock}, - time::Instant, -}; +use tokio::{sync::RwLock, time::Instant}; use crate::{ config::Config, filters::Filter, net::maxmind_db::IpNetEntry, net::DualStackLocalSocket, - Loggable, + Loggable, ShutdownRx, }; pub(crate) mod metrics; @@ -48,7 +45,7 @@ pub struct SessionPool { storage: Arc>, session_map: SessionMap, downstream_socket: Arc, - shutdown_rx: watch::Receiver<()>, + shutdown_rx: ShutdownRx, config: Arc, } @@ -68,7 +65,7 @@ impl SessionPool { pub fn new( config: Arc, downstream_socket: Arc, - shutdown_rx: watch::Receiver<()>, + shutdown_rx: ShutdownRx, ) -> Arc { const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); @@ -499,11 +496,15 @@ impl Loggable for Error { #[cfg(test)] mod tests { use super::*; - use crate::test::{available_addr, AddressType, TestHelper}; + use crate::{ + make_shutdown_channel, + test::{available_addr, AddressType, TestHelper}, + ShutdownKind, ShutdownTx, + }; use std::sync::Arc; - async fn new_pool(config: impl Into>) -> (Arc, watch::Sender<()>) { - let (tx, rx) = watch::channel(()); + async fn new_pool(config: impl Into>) -> (Arc, ShutdownTx) { + let (tx, rx) = make_shutdown_channel(ShutdownKind::Testing); ( SessionPool::new( Arc::new(config.into().unwrap_or_default()), diff --git a/src/cli/relay.rs b/src/cli/relay.rs index aa3189367c..08b1e1e96a 100644 --- a/src/cli/relay.rs +++ b/src/cli/relay.rs @@ -61,7 +61,7 @@ impl Relay { &self, config: Arc, mode: crate::cli::Admin, - mut shutdown_rx: tokio::sync::watch::Receiver<()>, + mut shutdown_rx: crate::ShutdownRx, ) -> crate::Result<()> { let xds_server = crate::net::xds::server::spawn(self.xds_port, config.clone()); let mds_server = tokio::spawn(crate::net::xds::server::control_plane_discovery_server( diff --git a/src/lib.rs b/src/lib.rs index c2d7ea5815..b9b0d1ba46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,26 @@ pub use quilkin_macros::include_proto; pub(crate) use self::net::maxmind_db::MaxmindDb; +#[derive(Copy, Clone, PartialEq, Default, Debug)] +pub enum ShutdownKind { + /// Normal shutdown kind, the receiver should perform proper shutdown procedures + #[default] + Normal, + /// In a testing environment, some or all shutdown behavior may be skippable + Testing, + /// In a benching environment, some or all shutdown behavior may be skippable + Benching, +} + +/// Receiver for a shutdown event. +pub type ShutdownRx = tokio::sync::watch::Receiver; +pub type ShutdownTx = tokio::sync::watch::Sender; + +#[inline] +pub fn make_shutdown_channel(init: ShutdownKind) -> (ShutdownTx, ShutdownRx) { + tokio::sync::watch::channel(init) +} + /// A type which can be logged, usually error types. pub(crate) trait Loggable { /// Output a log. diff --git a/src/net/maxmind_db.rs b/src/net/maxmind_db.rs index 22f4d41fea..c3250b7912 100644 --- a/src/net/maxmind_db.rs +++ b/src/net/maxmind_db.rs @@ -62,7 +62,7 @@ impl MaxmindDb { let mmdb = match crate::MaxmindDb::instance().clone() { Some(mmdb) => mmdb, None => { - tracing::debug!("skipping mmdb telemetry, no maxmind database available"); + tracing::trace!("skipping mmdb telemetry, no maxmind database available"); return None; } }; diff --git a/src/net/xds.rs b/src/net/xds.rs index f82dbdf3eb..793b3612fa 100644 --- a/src/net/xds.rs +++ b/src/net/xds.rs @@ -190,7 +190,7 @@ mod tests { // Test that the client can handle the manager dropping out. let handle = tokio::spawn(server::spawn(xds_port, xds_config.clone())); - let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(()); + let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(crate::ShutdownKind::Testing); tokio::spawn(server::spawn(xds_port, xds_config.clone())); let client_proxy = crate::cli::Proxy { port: client_addr.port(), diff --git a/src/test.rs b/src/test.rs index a15eb66311..e9bfa5da84 100644 --- a/src/test.rs +++ b/src/test.rs @@ -18,7 +18,7 @@ use std::net::Ipv4Addr; /// Common utilities for testing use std::{net::SocketAddr, str::from_utf8, sync::Arc, sync::Once}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot}; use tracing_subscriber::EnvFilter; use crate::{ @@ -27,6 +27,7 @@ use crate::{ net::endpoint::metadata::Value, net::endpoint::{Endpoint, EndpointAddress}, net::DualStackLocalSocket, + ShutdownKind, ShutdownRx, ShutdownTx, }; static LOG_ONCE: Once = Once::new(); @@ -118,8 +119,8 @@ impl StaticFilter for TestFilter { #[derive(Default)] pub struct TestHelper { /// Channel to subscribe to, and trigger the shutdown of created resources. - shutdown_ch: Option<(watch::Sender<()>, watch::Receiver<()>)>, - server_shutdown_tx: Vec>>, + shutdown_ch: Option<(ShutdownTx, ShutdownRx)>, + server_shutdown_tx: Vec>, } /// Returned from [creating a socket](TestHelper::open_socket_and_recv_single_packet) @@ -134,7 +135,7 @@ impl Drop for TestHelper { fn drop(&mut self) { for shutdown_tx in self.server_shutdown_tx.iter_mut().flat_map(|tx| tx.take()) { shutdown_tx - .send(()) + .send(ShutdownKind::Testing) .map_err(|error| { tracing::warn!( %error, @@ -145,7 +146,7 @@ impl Drop for TestHelper { } if let Some((shutdown_tx, _)) = self.shutdown_ch.take() { - shutdown_tx.send(()).unwrap(); + shutdown_tx.send(ShutdownKind::Testing).unwrap(); } } } @@ -273,7 +274,7 @@ impl TestHelper { server: crate::cli::Proxy, with_admin: Option>, ) { - let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); + let (shutdown_tx, shutdown_rx) = crate::make_shutdown_channel(crate::ShutdownKind::Testing); self.server_shutdown_tx.push(Some(shutdown_tx)); let mode = crate::cli::Admin::Proxy(<_>::default()); @@ -287,12 +288,12 @@ impl TestHelper { } /// Returns a receiver subscribed to the helper's shutdown event. - async fn get_shutdown_subscriber(&mut self) -> watch::Receiver<()> { + async fn get_shutdown_subscriber(&mut self) -> ShutdownRx { // If this is the first call, then we set up the channel first. match self.shutdown_ch { Some((_, ref rx)) => rx.clone(), None => { - let ch = watch::channel(()); + let ch = crate::make_shutdown_channel(crate::ShutdownKind::Testing); let recv = ch.1.clone(); self.shutdown_ch = Some(ch); recv From eae4f7a574ce344264fe932d055bfa1a9ed5f2d3 Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Mon, 13 Nov 2023 17:18:59 +0100 Subject: [PATCH 2/7] Fix test --- tests/qcmp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/qcmp.rs b/tests/qcmp.rs index 596a0e7ab8..59caf90ac5 100644 --- a/tests/qcmp.rs +++ b/tests/qcmp.rs @@ -49,7 +49,7 @@ async fn agent_ping() { ..<_>::default() }; let server_config = std::sync::Arc::new(quilkin::Config::default()); - let (_tx, rx) = tokio::sync::watch::channel(()); + let (_tx, rx) = quilkin::make_shutdown_channel(quilkin::ShutdownKind::Testing); let admin = quilkin::cli::Admin::Agent(<_>::default()); tokio::spawn(async move { agent From a85a80ecdea14df9f73cf1b89c189537f0eb24c1 Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Mon, 13 Nov 2023 17:19:27 +0100 Subject: [PATCH 3/7] Add debug info in release Helps debugging benchmarks --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9c111a5e59..320664ff0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,9 +141,6 @@ regex = "1.9.6" tracing-test = "0.2.4" tempfile = "3.8.0" -[target.'cfg(target_os = "linux")'.dev-dependencies] -libc = "0.2" - [build-dependencies] tonic-build = { version = "0.10.2", default_features = false, features = [ "transport", @@ -156,3 +153,6 @@ protobuf-src = { version = "1.1.0", optional = true } default = ["vendor-protoc"] instrument = [] vendor-protoc = ["dep:protobuf-src"] + +[profile.release] +debug = true From 9c3ff6932357c6e0931169e3286fb1f6ce6a67a1 Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Mon, 13 Nov 2023 17:19:52 +0100 Subject: [PATCH 4/7] Move spammy message to trace --- src/cli/proxy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli/proxy.rs b/src/cli/proxy.rs index 11ad5e88c7..12e7bbefff 100644 --- a/src/cli/proxy.rs +++ b/src/cli/proxy.rs @@ -304,7 +304,7 @@ impl DownstreamReceiveWorkerConfig { let mut buf = vec![0; 1 << 16]; let mut last_received_at = None; loop { - tracing::debug!( + tracing::trace!( id = worker_id, port = ?socket.local_ipv6_addr().map(|addr| addr.port()), "Awaiting packet" From a12e139dca27c2e2e7fff09a1d5d5fc9c4dffef9 Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Mon, 13 Nov 2023 17:20:02 +0100 Subject: [PATCH 5/7] Cleanup --- benches/read_write.rs | 107 +++++-------- benches/shared.rs | 345 +++++++++++++++++++++++++----------------- 2 files changed, 249 insertions(+), 203 deletions(-) diff --git a/benches/read_write.rs b/benches/read_write.rs index 6b6e8245f8..b6c6c62536 100644 --- a/benches/read_write.rs +++ b/benches/read_write.rs @@ -3,99 +3,73 @@ mod shared; use divan::Bencher; use shared::*; -use std::thread::spawn; - fn main() { divan::main(); } -/// We use this to run each benchmark on the different packets, note the size -/// of the packet rather than than packet index is used to give better output -/// from divan -const SIZES: &[usize] = &[254, 508, 1500]; - #[inline] fn counter(psize: usize) -> impl divan::counter::Counter { - divan::counter::BytesCount::new(psize * NUMBER_OF_PACKETS) -} - -#[inline] -fn get_packet_from_size() -> &'static [u8] { - PACKETS - .iter() - .find(|p| p.len() == N) - .expect("failed to find appropriately sized packet") + divan::counter::BytesCount::new(psize * NUMBER_OF_PACKETS as usize) } +#[divan::bench_group(sample_count = 10)] mod read { use super::*; - #[divan::bench(consts = SIZES)] + #[divan::bench(consts = PACKET_SIZES)] fn direct(b: Bencher) { let (writer, reader) = socket_pair(None, None); let (tx, rx) = channel(); - let packet = get_packet_from_size::(); - - let writer = Writer::new(writer, reader.local_addr().unwrap(), rx, packet); + let writer = Writer::::new(writer, reader.local_addr().unwrap(), rx); - spawn(move || loop { + spawn(format!("direct_writer_{N}"), move || loop { if !writer.write_all(NUMBER_OF_PACKETS) { break; } }); b.counter(counter(N)).bench_local(|| { - read_to_end(&reader, &tx, NUMBER_OF_PACKETS, N); + read_to_end::(&reader, &tx, NUMBER_OF_PACKETS); }); } - #[divan::bench(consts = SIZES)] + #[divan::bench(consts = PACKET_SIZES)] fn quilkin(b: Bencher) { let (writer, reader) = socket_pair(None, None); let (tx, rx) = channel(); - let packet = get_packet_from_size::(); - //quilkin::test::enable_log("quilkin=debug"); - - let _quilkin_loop = QuilkinLoop::spinup(READ_QUILKIN_PORT, reader.local_addr().unwrap()); - - let writer = Writer::new( - writer, - (Ipv4Addr::LOCALHOST, READ_QUILKIN_PORT).into(), - rx, - packet, - ); + let quilkin_loop = QuilkinLoop::spinup(READ_QUILKIN_PORT, reader.local_addr().unwrap()); + let writer = Writer::::new(writer, (Ipv4Addr::LOCALHOST, READ_QUILKIN_PORT).into(), rx); + let _quilkin_loop = writer.wait_ready(quilkin_loop, &reader); - std::thread::sleep(std::time::Duration::from_millis(100)); - - spawn(move || loop { + spawn(format!("quilkin_writer_{N}"), move || loop { if !writer.write_all(NUMBER_OF_PACKETS) { break; } }); b.counter(counter(N)).bench_local(|| { - read_to_end(&reader, &tx, NUMBER_OF_PACKETS, N); + read_to_end::(&reader, &tx, NUMBER_OF_PACKETS); }); } } +#[divan::bench_group(sample_count = 10)] mod write { use super::*; - #[divan::bench(consts = SIZES)] + #[divan::bench(consts = PACKET_SIZES)] fn direct(b: Bencher) { let (writer, reader) = socket_pair(None, None); let (tx, rx) = channel(); - let packet = get_packet_from_size::(); - let writer = Writer::new(writer, reader.local_addr().unwrap(), rx, packet); + let writer = Writer::::new(writer, reader.local_addr().unwrap(), rx); let (loop_tx, loop_rx) = mpsc::sync_channel(1); - spawn(move || { - while let Ok((num, size)) = loop_rx.recv() { - read_to_end(&reader, &tx, num, size); + spawn(format!("direct_reader_{N}"), move || { + while let Ok((num, _size)) = loop_rx.recv() { + read_to_end::(&reader, &tx, num); } }); @@ -107,36 +81,39 @@ mod write { }); } - #[divan::bench(consts = SIZES)] + #[divan::bench(consts = PACKET_SIZES)] fn quilkin(b: Bencher) { let (writer, reader) = socket_pair(None, None); let (tx, rx) = channel(); - let packet = get_packet_from_size::(); - let (loop_tx, loop_rx) = mpsc::sync_channel(1); + //quilkin::test::enable_log("quilkin=debug"); - let _quilkin_loop = QuilkinLoop::spinup(WRITE_QUILKIN_PORT, reader.local_addr().unwrap()); + let quilkin_loop = QuilkinLoop::spinup(WRITE_QUILKIN_PORT, reader.local_addr().unwrap()); + let writer = Writer::::new(writer, (Ipv4Addr::LOCALHOST, WRITE_QUILKIN_PORT).into(), rx); + let _quilkin_loop = writer.wait_ready(quilkin_loop, &reader); - let writer = Writer::new( - writer, - (Ipv4Addr::LOCALHOST, WRITE_QUILKIN_PORT).into(), - rx, - packet, - ); + let thread = { + let (loop_tx, loop_rx) = mpsc::sync_channel(1); - std::thread::sleep(std::time::Duration::from_millis(100)); + let thread = spawn(format!("quilkin_reader_{}", N), move || { + while let Ok((num, _size)) = loop_rx.recv() { + read_to_end::(&reader, &tx, num); + } + }); - spawn(move || { - while let Ok((num, size)) = loop_rx.recv() { - read_to_end(&reader, &tx, num, size); - } - }); + let mut wtf = 0; - b.counter(counter(N)).bench_local(|| { - // Signal the read loop to run - loop_tx.send((NUMBER_OF_PACKETS, N)).unwrap(); + b.counter(counter(N)).bench_local(|| { + // Signal the read loop to run + loop_tx.send((NUMBER_OF_PACKETS, N)).unwrap(); + wtf += 1; - writer.write_all(NUMBER_OF_PACKETS); - }); + writer.write_all(NUMBER_OF_PACKETS); + }); + + thread + }; + + thread.join().unwrap(); } } diff --git a/benches/shared.rs b/benches/shared.rs index 47872748ca..841fc0e0ba 100644 --- a/benches/shared.rs +++ b/benches/shared.rs @@ -6,16 +6,12 @@ pub use std::{ pub const READ_QUILKIN_PORT: u16 = 9001; pub const WRITE_QUILKIN_PORT: u16 = 9002; -pub const MESSAGE_SIZE: usize = 0xffff; -pub const NUMBER_OF_PACKETS: usize = 10_000; - -pub const PACKETS: &[&[u8]] = &[ - // Half IPv4 MTU. - &[0xffu8; 254], - // IPv4 MTU. - &[0xffu8; 508], - // Ethernet MTU. - &[0xffu8; 1500], +pub const NUMBER_OF_PACKETS: u16 = 10_000; + +pub const PACKET_SIZES: &[usize] = &[ + 254, // Half IPv4 MTU. + 508, // IPv4 MTU. + 1500, // Ethernet MTU. ]; pub fn make_socket(addr: SocketAddr) -> UdpSocket { @@ -29,6 +25,17 @@ pub fn make_socket(addr: SocketAddr) -> UdpSocket { socket } +#[inline] +pub fn spawn(name: impl Into, func: F) -> std::thread::JoinHandle<()> +where + F: FnOnce() + Send + 'static, +{ + std::thread::Builder::new() + .name(name.into()) + .spawn(func) + .unwrap() +} + #[derive(Debug)] pub enum ReadLoopMsg { #[allow(dead_code)] @@ -40,7 +47,7 @@ pub enum ReadLoopMsg { #[derive(Debug)] pub struct PacketStats { /// Number of individual receives that were completed - pub num_packets: usize, + pub num_packets: u16, /// Total number of bytes received pub size_packets: usize, } @@ -61,23 +68,10 @@ pub fn socket_pair(write: Option, read: Option) -> (UdpSocket, UdpSock /// Writes never block even if the kernel's ring buffer is full, so we occasionally /// ack chunks so the writer isn't waiting until the reader is blocked due to /// ring buffer exhaustion in case -const CHUNK_SIZE: usize = 32 * 1024; -const ENABLE_GSO: bool = false; - -const fn batch_size(packet_size: usize) -> usize { - const MAX_GSO_SEGMENTS: usize = 64; +const CHUNK_SIZE: usize = 8 * 1024; - let max_packets = CHUNK_SIZE / packet_size; - if !ENABLE_GSO { - return max_packets; - } - - // No min in const :( - if max_packets < MAX_GSO_SEGMENTS { - max_packets - } else { - MAX_GSO_SEGMENTS - } +const fn batch_size(packet_size: usize) -> u16 { + (CHUNK_SIZE / packet_size) as u16 } /// Runs a loop, reading from the socket until all the expected number of bytes (based on packet count and size) @@ -87,21 +81,74 @@ const fn batch_size(packet_size: usize) -> usize { /// we do this because while recv will fail if the timeout is surpassed and there is no /// data to read, send (at least on linux) will never block on loopback even if there /// not enough room in the ring buffer to hold the specified bytes -pub fn read_to_end( +pub fn read_to_end( socket: &UdpSocket, tx: &mpsc::Sender, - packet_count: usize, - packet_size: usize, + packet_count: u16, ) { - let mut packet = [0; MESSAGE_SIZE]; + use std::fmt; + + let mut packet = [0; N]; let mut num_packets = 0; let mut size_packets = 0; - let expected = packet_count * packet_size; + let expected = packet_count as usize * N; - let batch_size = batch_size(packet_size); - let mut batch_end = batch_size; + let batch_size = batch_size(N); + + struct Batch { + received: usize, + bits: Vec, + range: std::ops::Range, + } + + impl fmt::Debug for Batch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "{:?}", self.range)?; + + let side = (self.range.len() as f32).sqrt().ceil() as usize; + + for ch in self.bits.chunks(side) { + f.write_str("\n")?; + for v in ch { + if *v { + f.write_str("x")?; + } else { + f.write_str(".")?; + } + } + } + + Ok(()) + } + } + + let mut batch_i = 0u16; + let mut batch_range = || -> std::ops::Range { + let start = batch_size * batch_i; + + if start > packet_count { + return 0..0; + } + + batch_i += 1; + start..(start + batch_size).min(packet_count) + }; + + // We can have a max of 2 batches in flight at a time + let mut batches = [ + Batch { + received: 0, + bits: vec![false; batch_size as usize], + range: batch_range(), + }, + Batch { + received: 0, + bits: vec![false; batch_size as usize], + range: batch_range(), + }, + ]; while size_packets < expected { let length = match socket.recv_from(&mut packet) { @@ -112,144 +159,145 @@ pub fn read_to_end( Err(err) => panic!("failed waiting for packet: {err}"), }; - num_packets += 1; - size_packets += length; + assert_eq!(length, N); + + { + let seq = (packet[1] as u16) << 8 | packet[0] as u16; - if num_packets >= batch_end { - if tx - .send(ReadLoopMsg::Acked(PacketStats { - num_packets, - size_packets, - })) - .is_err() - { - return; + if seq > num_packets { + dbg!(&batches[0]); + dbg!(&batches[1]); } - batch_end += batch_size; + let batch = batches.iter_mut().find(|b| b.range.contains(&seq)).unwrap(); + + batch.received += 1; + if batch.received == batch.range.len() { + batch.received = 0; + batch.range = batch_range(); + + if tx + .send(ReadLoopMsg::Acked(PacketStats { + num_packets, + size_packets, + })) + .is_err() + { + return; + } + } } + + num_packets += 1; + size_packets += length; } - let _ = tx.send(ReadLoopMsg::Finished(PacketStats { - num_packets, - size_packets, - })); + match socket.recv_from(&mut packet) { + Ok(t) => panic!("writer sent more data than was intended: {t:?}"), + Err(ref err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { + let _ = tx.send(ReadLoopMsg::Finished(PacketStats { + num_packets, + size_packets, + })); + } + Err(err) => panic!("failed waiting for packet: {err}"), + } } -pub struct Writer { - #[cfg(target_os = "linux")] - socket: socket2::Socket, - #[cfg(not(target_os = "linux"))] +pub struct Writer { socket: UdpSocket, destination: SocketAddr, rx: mpsc::Receiver, - batch_size: usize, - packet: &'static [u8], - #[cfg(unix)] - slices: Vec>, } -impl Writer { +impl Writer { pub fn new( socket: UdpSocket, destination: SocketAddr, rx: mpsc::Receiver, - packet: &'static [u8], ) -> Self { - let batch_size = batch_size(packet.len()); - - #[cfg(target_os = "linux")] - let (socket, slices) = { - let socket = socket2::Socket::from(socket); - - (socket, vec![std::io::IoSlice::new(packet); batch_size]) - }; - Self { socket, destination, rx, - batch_size, - packet, - #[cfg(target_os = "linux")] - slices, } } - pub fn write_all(&self, packet_count: usize) -> bool { - use std::{mem, ptr}; + /// Waits until a write is received by the specified socket + pub fn wait_ready(&self, quilkin: QuilkinLoop, reader: &UdpSocket) -> QuilkinLoop { + const MAX_WAIT: std::time::Duration = std::time::Duration::from_secs(10); - // The value of the auxiliary data to put in the control message. - let segment_size = self.packet.len() as u16; + let start = std::time::Instant::now(); - #[cfg(target_os = "linux")] - let (dst, buf, layout) = { - // The number of bytes needed for this control message. - let cmsg_size = unsafe { libc::CMSG_SPACE(mem::size_of_val(&segment_size) as _) }; - let layout = std::alloc::Layout::from_size_align( - cmsg_size as usize, - mem::align_of::(), - ) - .unwrap(); - let buf = unsafe { std::alloc::alloc(layout) }; + let send_packet = [0xaa; 1]; + let mut recv_packet = [0x00; 1]; - (socket2::SockAddr::from(self.destination), buf, layout) - }; - - let send_batch = |received: usize| { - let to_send = (packet_count - received).min(self.batch_size); + // Temporarily make the socket blocking + reader.set_nonblocking(false).unwrap(); + reader + .set_read_timeout(Some(std::time::Duration::from_millis(10))) + .unwrap(); - // GSO, see https://github.com/flub/socket-use/blob/main/src/bin/sendmsg_gso.rs - #[cfg(target_os = "linux")] - { - if !ENABLE_GSO { - for _ in 0..to_send { - self.socket.send_to(self.packet, &dst).unwrap(); + while start.elapsed() < MAX_WAIT { + self.socket.send_to(&send_packet, self.destination).unwrap(); + + match reader.recv_from(&mut recv_packet) { + Ok(_) => { + assert_eq!(send_packet, recv_packet); + + // Drain until block just in case + loop { + match reader.recv_from(&mut recv_packet) { + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + reader.set_nonblocking(true).unwrap(); + reader.set_read_timeout(None).unwrap(); + return quilkin; + } + Err(err) => { + panic!("failed to drain read socket: {err:?}"); + } + } } - return; } - - let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; - - // Set the single destination and the payloads of each datagram - msg.msg_name = dst.as_ptr() as *mut _; - msg.msg_namelen = dst.len(); - msg.msg_iov = self.slices.as_ptr() as *mut _; - msg.msg_iovlen = to_send; - - msg.msg_control = buf as *mut _; - msg.msg_controllen = layout.size(); - - let cmsg: &mut libc::cmsghdr = unsafe { - let cmsg = libc::CMSG_FIRSTHDR(&msg); - let cmsg_zeroed: libc::cmsghdr = mem::zeroed(); - ptr::copy_nonoverlapping(&cmsg_zeroed, cmsg, 1); - cmsg.as_mut().unwrap() - }; - cmsg.cmsg_level = libc::SOL_UDP; - cmsg.cmsg_type = libc::UDP_SEGMENT; - cmsg.cmsg_len = - unsafe { libc::CMSG_LEN(mem::size_of_val(&segment_size) as _) } as libc::size_t; - unsafe { ptr::write(libc::CMSG_DATA(cmsg) as *mut u16, segment_size) }; - - use std::os::fd::AsRawFd; - if unsafe { libc::sendmsg(self.socket.as_raw_fd(), &msg, 0) } == -1 { - panic!("failed to send batch: {}", std::io::Error::last_os_error()); + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {} + Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => { + println!("debugger might have attached"); + } + Err(err) => { + panic!("failed to wait on read socket: {err:?}"); } } + } - #[cfg(not(target_os = "linux"))] - { - for _ in 0..to_send { - self.socket.send_to(self.packet, self.destination).unwrap(); - } + panic!("waited for {MAX_WAIT:?} for quilkin"); + } + + pub fn write_all(&self, packet_count: u16) -> bool { + let batch_size = batch_size(N); + + let mut packet_buf = [0xffu8; N]; + + let mut send_batch = |sent: u16| -> u16 { + let to_send = (packet_count - sent).min(batch_size); + + for seq in sent..sent + to_send { + let b = seq.to_ne_bytes(); + packet_buf[0] = b[0]; + packet_buf[1] = b[1]; + + self.socket.send_to(&packet_buf, self.destination).unwrap(); } + + to_send }; + let mut sent_packets = 0; + // Queue 2 batches at the beginning, giving the reader enough work to do // after the initial batch has been read - send_batch(0); - send_batch(self.batch_size); + sent_packets += send_batch(sent_packets); + sent_packets += send_batch(sent_packets); let mut finished = false; while let Ok(msg) = self.rx.recv() { @@ -258,34 +306,53 @@ impl Writer { panic!("reader blocked {ps:?}"); } ReadLoopMsg::Acked(ps) => { - send_batch(ps.num_packets); + if sent_packets < packet_count { + assert!(sent_packets > ps.num_packets); + sent_packets += send_batch(sent_packets); + } } ReadLoopMsg::Finished(ps) => { - assert_eq!(ps.size_packets, self.packet.len() * packet_count); + assert_eq!(sent_packets, ps.num_packets); + assert_eq!(ps.size_packets, N * packet_count as usize); finished = true; break; } } } - // Don't leak the buf - unsafe { std::alloc::dealloc(buf, layout) }; finished } } +#[allow(dead_code)] pub struct QuilkinLoop { shutdown: Option, thread: Option>, + port: u16, + endpoint: SocketAddr, } impl QuilkinLoop { /// Run and instance of quilkin that sends and receives data from the given address. pub fn spinup(port: u16, endpoint: SocketAddr) -> Self { + Self::spinup_inner(port, endpoint) + } + + #[allow(dead_code)] + fn reinit(self) -> Self { + let port = self.port; + let endpoint = self.endpoint; + + drop(self); + + Self::spinup_inner(port, endpoint) + } + + fn spinup_inner(port: u16, endpoint: SocketAddr) -> Self { let (shutdown_tx, shutdown_rx) = quilkin::make_shutdown_channel(quilkin::ShutdownKind::Benching); - let thread = std::thread::spawn(move || { + let thread = spawn("quilkin", move || { let runtime = tokio::runtime::Runtime::new().unwrap(); let config = Arc::new(quilkin::Config::default()); config.clusters.modify(|clusters| { @@ -312,6 +379,8 @@ impl QuilkinLoop { Self { shutdown: Some(shutdown_tx), thread: Some(thread), + port, + endpoint, } } } From cd15c0638f7891f83a6003ed8d6b9f8f154580c7 Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Mon, 13 Nov 2023 17:23:05 +0100 Subject: [PATCH 6/7] Remove debug code --- benches/read_write.rs | 3 --- benches/shared.rs | 31 ------------------------------- 2 files changed, 34 deletions(-) diff --git a/benches/read_write.rs b/benches/read_write.rs index b6c6c62536..a731bb42ab 100644 --- a/benches/read_write.rs +++ b/benches/read_write.rs @@ -101,12 +101,9 @@ mod write { } }); - let mut wtf = 0; - b.counter(counter(N)).bench_local(|| { // Signal the read loop to run loop_tx.send((NUMBER_OF_PACKETS, N)).unwrap(); - wtf += 1; writer.write_all(NUMBER_OF_PACKETS); }); diff --git a/benches/shared.rs b/benches/shared.rs index 841fc0e0ba..cc28baf657 100644 --- a/benches/shared.rs +++ b/benches/shared.rs @@ -86,8 +86,6 @@ pub fn read_to_end( tx: &mpsc::Sender, packet_count: u16, ) { - use std::fmt; - let mut packet = [0; N]; let mut num_packets = 0; @@ -99,31 +97,9 @@ pub fn read_to_end( struct Batch { received: usize, - bits: Vec, range: std::ops::Range, } - impl fmt::Debug for Batch { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{:?}", self.range)?; - - let side = (self.range.len() as f32).sqrt().ceil() as usize; - - for ch in self.bits.chunks(side) { - f.write_str("\n")?; - for v in ch { - if *v { - f.write_str("x")?; - } else { - f.write_str(".")?; - } - } - } - - Ok(()) - } - } - let mut batch_i = 0u16; let mut batch_range = || -> std::ops::Range { let start = batch_size * batch_i; @@ -140,12 +116,10 @@ pub fn read_to_end( let mut batches = [ Batch { received: 0, - bits: vec![false; batch_size as usize], range: batch_range(), }, Batch { received: 0, - bits: vec![false; batch_size as usize], range: batch_range(), }, ]; @@ -164,11 +138,6 @@ pub fn read_to_end( { let seq = (packet[1] as u16) << 8 | packet[0] as u16; - if seq > num_packets { - dbg!(&batches[0]); - dbg!(&batches[1]); - } - let batch = batches.iter_mut().find(|b| b.range.contains(&seq)).unwrap(); batch.received += 1; From 5d71f5d37fdff2cc5fac61603352793e56f2270d Mon Sep 17 00:00:00 2001 From: Jake Shadle Date: Mon, 13 Nov 2023 17:49:45 +0100 Subject: [PATCH 7/7] Fix example --- Cargo.toml | 2 +- examples/quilkin-filter-example/src/main.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 320664ff0d..91f88558b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ exclude = ["docs", "build", "examples", "image"] [[bench]] name = "read_write" harness = false -test = true +test = false [dependencies] # Local diff --git a/examples/quilkin-filter-example/src/main.rs b/examples/quilkin-filter-example/src/main.rs index b05e40e9f8..1f5c284fd5 100644 --- a/examples/quilkin-filter-example/src/main.rs +++ b/examples/quilkin-filter-example/src/main.rs @@ -93,7 +93,7 @@ impl StaticFilter for Greet { async fn main() -> quilkin::Result<()> { quilkin::filters::FilterRegistry::register(vec![Greet::factory()].into_iter()); - let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(()); + let (_shutdown_tx, shutdown_rx) = quilkin::make_shutdown_channel(quilkin::ShutdownKind::Normal); let proxy = quilkin::Proxy::default(); let config = quilkin::Config::default(); config.filters.store(std::sync::Arc::new(