diff --git a/Cargo.lock b/Cargo.lock index 052ee9821c..824ccb2525 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,9 +347,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" +checksum = "504bdec147f2cc13c8b57ed9401fd8a147cc66b67ad5cb241394244f2c947549" [[package]] name = "cfg-if" @@ -1197,7 +1197,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.5.7", + "socket2", "tokio", "tower-service", "tracing", @@ -1311,9 +1311,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1322,7 +1322,7 @@ dependencies = [ "http-body 1.0.1", "hyper 1.4.1", "pin-project-lite", - "socket2 0.5.7", + "socket2", "tokio", "tower", "tower-service", @@ -1426,7 +1426,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2 0.5.7", + "socket2", "widestring", "windows-sys 0.48.0", "winreg", @@ -2285,7 +2285,7 @@ dependencies = [ "quilkin", "rand", "serde_json", - "socket2 0.5.7", + "socket2", "tempfile", "tokio", "tracing", @@ -2325,11 +2325,13 @@ dependencies = [ "hyper 1.4.1", "hyper-rustls", "hyper-util", + "io-uring", "ipnetwork", "k8s-openapi", "kube", "kube-core", "lasso", + "libc", "libflate", "lz4_flex", "maxminddb", @@ -2353,8 +2355,9 @@ dependencies = [ "serde_regex", "serde_stacker", "serde_yaml", + "slab", "snap", - "socket2 0.5.7", + "socket2", "stable-eyre", "strum", "strum_macros", @@ -2364,7 +2367,6 @@ dependencies = [ "time", "tokio", "tokio-stream", - "tokio-uring", "tonic", "tower", "tracing", @@ -2614,9 +2616,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" @@ -2752,9 +2754,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.205" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "e33aedb1a7135da52b7c21791455563facbbcc43d0f0f66165b42c21b3dfb150" dependencies = [ "serde_derive", ] @@ -2771,9 +2773,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.205" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "692d6f5ac90220161d6774db30c662202721e64aed9058d2c394f451261420c1" dependencies = [ "proc-macro2", "quote", @@ -2922,16 +2924,6 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "socket2" version = "0.5.7" @@ -3072,15 +3064,15 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.11.0" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3170,7 +3162,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.7", + "socket2", "tokio-macros", "tracing", "windows-sys 0.52.0", @@ -3230,21 +3222,6 @@ dependencies = [ "tokio-util", ] -[[package]] -name = "tokio-uring" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "748482e3e13584a34664a710168ad5068e8cb1d968aa4ffa887e83ca6dd27967" -dependencies = [ - "bytes", - "futures-util", - "io-uring", - "libc", - "slab", - "socket2 0.4.10", - "tokio", -] - [[package]] name = "tokio-util" version = "0.7.11" diff --git a/Cargo.toml b/Cargo.toml index b9c1bef1c5..63dc6b44b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,9 @@ categories = ["game-development", "network-programming"] edition.workspace = true exclude = ["docs", "build", "examples", "image"] +[lints] +workspace = true + [[bench]] name = "read_write" harness = false @@ -144,8 +147,10 @@ version = "0.1" features = ["client", "client-legacy"] [target.'cfg(target_os = "linux")'.dependencies] +io-uring = { version = "0.6", default-features = false } +libc = "0.2" +slab = "0.4" sys-info = "0.9.1" -tokio-uring = { version = "0.5", features = ["bytes"] } pprof = { version = "0.13.0", features = ["prost", "prost-codec"] } [dev-dependencies] @@ -225,3 +230,6 @@ fixedstr = { version = "0.5", features = ["flex-str"] } parking_lot = "0.12.1" schemars = { version = "0.8.15", features = ["bytes", "url"] } url = { version = "2.4.1", features = ["serde"] } + +[workspace.lints.clippy] +undocumented_unsafe_blocks = "deny" \ No newline at end of file diff --git a/crates/agones/Cargo.toml b/crates/agones/Cargo.toml index 41bd18e37c..16187d468d 100644 --- a/crates/agones/Cargo.toml +++ b/crates/agones/Cargo.toml @@ -23,6 +23,9 @@ license = "Apache-2.0" description = "End to end integration tests to be run against a Kubernetes cluster with Agones installed" readme = "README.md" +[lints] +workspace = true + [dependencies] base64.workspace = true futures.workspace = true diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index 3218fdd146..7549b58950 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -30,7 +30,8 @@ edition = "2018" [lib] proc-macro = true -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lints] +workspace = true [dependencies] proc-macro2 = "1.0.58" diff --git a/crates/quilkin-proto/Cargo.toml b/crates/quilkin-proto/Cargo.toml index a1adcd196c..13e5757ab8 100644 --- a/crates/quilkin-proto/Cargo.toml +++ b/crates/quilkin-proto/Cargo.toml @@ -20,7 +20,8 @@ version = "0.1.0" edition.workspace = true license.workspace = true -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lints] +workspace = true [dependencies] prost.workspace = true diff --git a/crates/test/Cargo.toml b/crates/test/Cargo.toml index f8f5ba01f5..7a90f45ec4 100644 --- a/crates/test/Cargo.toml +++ b/crates/test/Cargo.toml @@ -20,6 +20,9 @@ version = "0.1.0" edition = "2021" publish = false +[lints] +workspace = true + [dependencies] async-channel.workspace = true once_cell.workspace = true diff --git a/crates/test/tests/proxy.rs b/crates/test/tests/proxy.rs index ee8fb0ce51..cb0b8a63b4 100644 --- a/crates/test/tests/proxy.rs +++ b/crates/test/tests/proxy.rs @@ -124,17 +124,17 @@ trace_test!(uring_receiver, { config, tx, BUFFER_POOL.clone(), - shutdown_rx, + shutdown_rx.clone(), ), } - .spawn() + .spawn(shutdown_rx) .await .expect("failed to spawn task"); // Drop the socket, otherwise it can drop(ws); - sb.timeout(500, ready.notified()).await; + let _ = sb.timeout(500, ready).await; let msg = "hello-downstream"; tracing::debug!("sending packet"); @@ -166,7 +166,7 @@ trace_test!( config.clone(), tx, BUFFER_POOL.clone(), - shutdown_rx, + shutdown_rx.clone(), ); const WORKER_COUNT: usize = 3; @@ -179,12 +179,13 @@ trace_test!( &sessions, rx, BUFFER_POOL.clone(), + shutdown_rx, ) .await .unwrap(); for wn in workers { - sb.timeout(200, wn.notified()).await; + let _ = sb.timeout(200, wn).await; } let socket = std::sync::Arc::new(sb.client()); diff --git a/src/codec/qcmp.rs b/src/codec/qcmp.rs index c1f29b9099..e91f52bdce 100644 --- a/src/codec/qcmp.rs +++ b/src/codec/qcmp.rs @@ -38,35 +38,19 @@ const PING: u8 = 0; const PONG: u8 = 1; pub struct QcmpPacket { - buf: Vec, + buf: [u8; MAX_QCMP_PACKET_LEN], len: usize, } impl Default for QcmpPacket { fn default() -> Self { Self { - buf: vec![0; MAX_QCMP_PACKET_LEN], + buf: [0; MAX_QCMP_PACKET_LEN], len: 0, } } } -#[cfg(target_os = "linux")] -unsafe impl tokio_uring::buf::IoBuf for QcmpPacket { - fn stable_ptr(&self) -> *const u8 { - self.buf.as_ptr() - } - - fn bytes_init(&self) -> usize { - self.len - } - - fn bytes_total(&self) -> usize { - self.buf.len() - } -} - -#[cfg(not(target_os = "linux"))] impl std::ops::Deref for QcmpPacket { type Target = [u8]; @@ -208,7 +192,8 @@ impl Measurement for QcmpMeasurement { } } -pub fn spawn(socket: socket2::Socket, mut shutdown_rx: crate::ShutdownRx) { +#[cfg(not(target_os = "linux"))] +pub fn spawn(socket: socket2::Socket, mut shutdown_rx: crate::ShutdownRx) -> crate::Result<()> { let port = crate::net::socket_port(&socket); uring_spawn!(uring_span!(tracing::debug_span!("qcmp")), async move { @@ -266,6 +251,191 @@ pub fn spawn(socket: socket2::Socket, mut shutdown_rx: crate::ShutdownRx) { }; } }); + + Ok(()) +} + +#[cfg(target_os = "linux")] +pub fn spawn(socket: socket2::Socket, mut shutdown_rx: crate::ShutdownRx) -> crate::Result<()> { + use crate::components::proxy::io_uring_shared::EventFd; + use eyre::Context as _; + + let port = crate::net::socket_port(&socket); + + // Create an eventfd so we can signal to the qcmp loop when we want to exit + let mut shutdown_event = EventFd::new()?; + let shutdown = shutdown_event.writer(); + + // Spawn a task on the main loop whose sole purpose is to signal the eventfd + tokio::task::spawn(async move { + let _ = shutdown_rx.changed().await; + shutdown.write(1); + }); + + let _thread_span = uring_span!(tracing::debug_span!("qcmp").or_current()); + let dispatcher = tracing::dispatcher::get_default(|d| d.clone()); + + std::thread::Builder::new() + .name("qcmp".into()) + .spawn(move || -> eyre::Result<()> { + let _guard = tracing::dispatcher::set_default(&dispatcher); + + let mut ring = io_uring::IoUring::new(3).context("unable to create io uring")?; + let (submitter, mut sq, mut cq) = ring.split(); + + const RECV: u64 = 0; + const SEND: u64 = 1; + const SHUTDOWN: u64 = 2; + + // Queue the read from the shutdown eventfd used to signal when the loop + // should exit + let entry = shutdown_event.io_uring_entry().user_data(SHUTDOWN); + // SAFETY: the memory being written to is located on the stack inside the shutdown event, and is alive + // at least as long as the uring loop + unsafe { + sq.push(&entry).context("unable to insert io-uring entry")?; + } + + // Our loop is simple and only ever processes one ping/pong pair at a time + // so we just reuse the same buffer for both receives and sends + let mut buf = QcmpPacket::default(); + // SAFETY: msghdr is POD + let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() }; + // SAFETY: msghdr is POD + let addr = unsafe { + socket2::SockAddr::new( + std::mem::zeroed(), + std::mem::size_of::() as _, + ) + }; + + let mut iov = libc::iovec { + iov_base: buf.buf.as_mut_ptr() as *mut _, + iov_len: 0, + }; + + msghdr.msg_iov = std::ptr::addr_of_mut!(iov); + msghdr.msg_iovlen = 1; + msghdr.msg_name = addr.as_ptr() as *mut libc::sockaddr_storage as *mut _; + msghdr.msg_namelen = addr.len(); + + let msghdr_mut = std::ptr::addr_of_mut!(msghdr); + + let socket = DualStackLocalSocket::new(port) + .context("failed to create already bound qcmp socket")?; + let socket_fd = socket.raw_fd(); + + let enqueue_recv = + |sq: &mut io_uring::SubmissionQueue, iov: &mut libc::iovec| -> eyre::Result<()> { + iov.iov_len = MAX_QCMP_PACKET_LEN; + let entry = io_uring::opcode::RecvMsg::new(socket_fd, msghdr_mut) + .build() + .user_data(RECV); + // SAFETY: the memory being written to is located on the stack and outlives the uring loop + unsafe { sq.push(&entry) }.context("unable to insert io-uring entry")?; + Ok(()) + }; + + enqueue_recv(&mut sq, &mut iov)?; + + sq.sync(); + + loop { + match submitter.submit_and_wait(1) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => {} + Err(err) => { + return Err(err).context("failed to submit io-uring operations"); + } + } + cq.sync(); + + let mut has_pending_send = false; + for cqe in &mut cq { + let ret = cqe.result(); + + match cqe.user_data() { + RECV => { + if ret < 0 { + let error = std::io::Error::from_raw_os_error(-ret).to_string(); + tracing::error!(%error, "failed to send QCMP response"); + continue; + } + + buf.len = ret as _; + let received_at = UtcTimestamp::now(); + let command = match Protocol::parse(&buf) { + Ok(Some(command)) => command, + Ok(None) => { + tracing::debug!("rejected non-QCMP packet"); + continue; + } + Err(error) => { + tracing::debug!(%error, "rejected malformed packet"); + continue; + } + }; + + let Protocol::Ping { + client_timestamp, + nonce, + } = command + else { + tracing::warn!("rejected unsupported QCMP packet"); + continue; + }; + + Protocol::ping_reply(nonce, client_timestamp, received_at) + .encode(&mut buf); + + tracing::debug!("sending QCMP ping reply"); + + // Update the iovec with the actual length of the pong + iov.iov_len = buf.len; + + // Note we don't have to do anything else with the msghdr + // as the recv has already filled in the socket address + // of the sender, which is also our destination + + { + let entry = io_uring::opcode::SendMsg::new( + socket_fd, + std::ptr::addr_of!(msghdr), + ) + .build() + .user_data(SEND); + // SAFETY: the memory being read from is located on the stack and outlives the uring loop + if unsafe { sq.push(&entry) }.is_err() { + tracing::error!("failed to enqueue QCMP pong response"); + continue; + } + } + + has_pending_send = true; + } + SEND => { + if ret < 0 { + let error = std::io::Error::from_raw_os_error(-ret).to_string(); + tracing::error!(%error, "failed to send QCMP response"); + } + } + SHUTDOWN => { + tracing::info!("QCMP thread was signaled to shutdown"); + return Ok(()); + } + ud => unreachable!("io-uring user data {ud} is invalid"), + } + } + + if !has_pending_send { + enqueue_recv(&mut sq, &mut iov)?; + } + + sq.sync(); + } + })?; + + Ok(()) } /// The set of possible QCMP commands. @@ -680,7 +850,7 @@ mod tests { let addr = socket.local_addr().unwrap().as_socket().unwrap(); let (_tx, rx) = crate::make_shutdown_channel(Default::default()); - spawn(socket, rx); + spawn(socket, rx).unwrap(); let delay = Duration::from_millis(50); let node = QcmpMeasurement::with_artificial_delay(delay).unwrap(); diff --git a/src/components/agent.rs b/src/components/agent.rs index c1b86d5021..eb0b86b50a 100644 --- a/src/components/agent.rs +++ b/src/components/agent.rs @@ -86,7 +86,7 @@ impl Agent { None }; - crate::codec::qcmp::spawn(self.qcmp_socket, shutdown_rx.clone()); + crate::codec::qcmp::spawn(self.qcmp_socket, shutdown_rx.clone())?; shutdown_rx.changed().await.map_err(From::from) } } diff --git a/src/components/proxy.rs b/src/components/proxy.rs index 2b9d669726..66299b12e9 100644 --- a/src/components/proxy.rs +++ b/src/components/proxy.rs @@ -2,8 +2,10 @@ mod error; pub mod packet_router; mod sessions; +#[cfg(target_os = "linux")] +pub(crate) mod io_uring_shared; + use super::RunArgs; -use crate::pool::PoolBuffer; pub use error::{ErrorMap, PipelineError}; pub use sessions::SessionPool; use std::{ @@ -14,6 +16,17 @@ use std::{ }, }; +pub struct SendPacket { + pub destination: SocketAddr, + pub data: crate::pool::FrozenPoolBuffer, + pub asn_info: Option, +} + +pub struct RecvPacket { + pub source: SocketAddr, + pub data: crate::pool::PoolBuffer, +} + #[derive(Clone, Debug)] pub struct Ready { pub idle_request_interval: std::time::Duration, @@ -182,11 +195,7 @@ impl Proxy { let id = config.id.load(); let num_workers = self.num_workers.get(); - let (upstream_sender, upstream_receiver) = async_channel::bounded::<( - PoolBuffer, - Option, - SocketAddr, - )>(250); + let (upstream_sender, upstream_receiver) = async_channel::bounded(250); let buffer_pool = Arc::new(crate::pool::BufferPool::new(num_workers, 64 * 1024)); let sessions = SessionPool::new( config.clone(), @@ -262,10 +271,11 @@ impl Proxy { &sessions, upstream_receiver, buffer_pool, + shutdown_rx.clone(), ) .await?; - crate::codec::qcmp::spawn(self.qcmp, shutdown_rx.clone()); + crate::codec::qcmp::spawn(self.qcmp, shutdown_rx.clone())?; crate::net::phoenix::spawn( self.phoenix, config.clone(), @@ -274,7 +284,7 @@ impl Proxy { )?; for notification in worker_notifications { - notification.notified().await; + let _ = notification.await; } tracing::info!("Quilkin is ready"); diff --git a/src/components/proxy/io_uring_shared.rs b/src/components/proxy/io_uring_shared.rs new file mode 100644 index 0000000000..3133a8e338 --- /dev/null +++ b/src/components/proxy/io_uring_shared.rs @@ -0,0 +1,770 @@ +//! We have two cases in the proxy where io-uring is used that are _almost_ identical +//! so this just has a shared implementation of utilities +//! +//! Note there is also the QCMP loop, but that one is simpler and is different +//! enough that it doesn't make sense to share the same code + +use crate::{ + components::proxy::{self, PipelineError}, + metrics, + net::maxmind_db::MetricsIpNetEntry, + pool::{FrozenPoolBuffer, PoolBuffer}, + time::UtcTimestamp, +}; +use io_uring::{squeue::Entry, types::Fd}; +use socket2::SockAddr; +use std::{ + os::fd::{AsRawFd, FromRawFd}, + sync::Arc, +}; + +/// A simple wrapper around [eventfd](https://man7.org/linux/man-pages/man2/eventfd.2.html) +/// +/// We use eventfd to signal to io uring loops from async tasks, it is essentially +/// the equivalent of a signalling 64 bit cross-process atomic +pub(crate) struct EventFd { + fd: std::os::fd::OwnedFd, + val: u64, +} + +#[derive(Clone)] +pub(crate) struct EventFdWriter { + fd: i32, +} + +impl EventFdWriter { + #[inline] + pub(crate) fn write(&self, val: u64) { + // SAFETY: we have a valid descriptor, and most of the errors that apply + // to the general write call that eventfd_write wraps are not applicable + // + // Note that while the docs state eventfd_write is glibc, it is implemented + // on musl as well, but really is just a write with 8 bytes + unsafe { + libc::eventfd_write(self.fd, val); + } + } +} + +impl EventFd { + #[inline] + pub(crate) fn new() -> std::io::Result { + // SAFETY: We have no invariants to uphold, but we do need to check the + // return value + let fd = unsafe { libc::eventfd(0, 0) }; + + // This can fail for various reasons mostly around resource limits, if + // this is hit there is either something really wrong (OOM, too many file + // descriptors), or resource limits were externally placed that were too strict + if fd == -1 { + return Err(std::io::Error::last_os_error()); + } + + Ok(Self { + // SAFETY: we've validated the file descriptor + fd: unsafe { std::os::fd::OwnedFd::from_raw_fd(fd) }, + val: 0, + }) + } + + #[inline] + pub(crate) fn writer(&self) -> EventFdWriter { + EventFdWriter { + fd: self.fd.as_raw_fd(), + } + } + + /// Constructs an io-uring entry to read (ie wait) on this eventfd + #[inline] + pub(crate) fn io_uring_entry(&mut self) -> Entry { + io_uring::opcode::Read::new( + Fd(self.fd.as_raw_fd()), + &mut self.val as *mut u64 as *mut u8, + 8, + ) + .build() + } +} + +struct RecvPacket { + /// The buffer filled with data during recv_from + buffer: PoolBuffer, + /// The IP of the sender + source: std::net::SocketAddr, +} + +struct SendPacket { + /// The destination address of the packet + destination: SockAddr, + /// The packet data being sent + buffer: FrozenPoolBuffer, + /// The asn info for the sender, used for metrics + asn_info: Option, +} + +/// A simple double buffer for queing packets that need to be sent, each enqueue +/// notifies an eventfd that sends are available +#[derive(Clone)] +struct PendingSends { + packets: Arc>>, + notify: EventFdWriter, +} + +impl PendingSends { + pub fn new(notify: EventFdWriter) -> Self { + Self { + packets: Default::default(), + notify, + } + } + + #[inline] + pub fn push(&self, packet: SendPacket) { + self.packets.lock().push(packet); + self.notify.write(1); + } + + #[inline] + pub fn swap(&self, swap: Vec) -> Vec { + std::mem::replace(&mut self.packets.lock(), swap) + } +} + +enum LoopPacketInner { + Recv(RecvPacket), + Send(SendPacket), +} + +/// A packet that is currently on the io-uring loop, either being received or sent +/// +/// The struct is expected to be pinned at a location in memory in a slab, as we +/// give pointers to the internal data in the struct, which also contains +/// referential pointers that need to stay pinned until the I/O is complete +#[repr(C)] +struct LoopPacket { + msghdr: libc::msghdr, + addr: libc::sockaddr_storage, + packet: Option, + io_vec: libc::iovec, +} + +impl LoopPacket { + #[inline] + fn new() -> Self { + Self { + // SAFETY: msghdr is POD + msghdr: unsafe { std::mem::zeroed() }, + packet: None, + io_vec: libc::iovec { + iov_base: std::ptr::null_mut(), + iov_len: 0, + }, + // SAFETY: sockaddr_storage is POD + addr: unsafe { std::mem::zeroed() }, + } + } + + #[inline] + fn set_packet(&mut self, mut packet: LoopPacketInner) { + match &mut packet { + LoopPacketInner::Recv(recv) => { + // For receives, the length of the buffer is the total capacity + self.io_vec.iov_base = recv.buffer.as_mut_ptr().cast(); + self.io_vec.iov_len = recv.buffer.capacity(); + } + LoopPacketInner::Send(send) => { + // For sends, the length of the buffer is the actual number of initialized bytes, + // and note that iov_base is a *mut even though for sends the buffer is not actually + // mutated + self.io_vec.iov_base = send.buffer.as_ptr() as *mut u8 as *mut _; + self.io_vec.iov_len = send.buffer.len(); + + // SAFETY: both pointers are valid at this point, with the same size + unsafe { + std::ptr::copy_nonoverlapping( + send.destination.as_ptr().cast(), + &mut self.addr, + 1, + ); + } + } + } + + // Increment the refcount of the buffer to ensure it stays alive for the + // duration of the I/O + self.packet = Some(packet); + + self.msghdr.msg_iov = std::ptr::addr_of_mut!(self.io_vec); + self.msghdr.msg_iovlen = 1; + self.msghdr.msg_name = std::ptr::addr_of_mut!(self.addr).cast(); + self.msghdr.msg_namelen = std::mem::size_of::() as _; + } + + #[inline] + fn finalize_recv(mut self, ret: usize) -> RecvPacket { + let LoopPacketInner::Recv(mut recv) = self.packet.take().unwrap() else { + unreachable!("finalized a send packet") + }; + + // SAFETY: we're initialising it with correctly sized data + let mut source = unsafe { + SockAddr::new( + self.addr, + std::mem::size_of::() as _, + ) + } + .as_socket() + .unwrap(); + source.set_ip(source.ip().to_canonical()); + + recv.source = source; + recv.buffer.set_len(ret); + recv + } + + #[inline] + fn finalize_send(mut self) -> SendPacket { + let LoopPacketInner::Send(send) = self.packet.take().unwrap() else { + unreachable!("finalized a recv packet") + }; + + send + } +} + +pub enum PacketProcessorCtx { + Router { + config: Arc, + sessions: Arc, + error_sender: super::error::ErrorSender, + /// Receiver for upstream packets being sent to this downstream + upstream_receiver: crate::components::proxy::sessions::DownstreamReceiver, + worker_id: usize, + }, + SessionPool { + pool: Arc, + downstream_receiver: tokio::sync::mpsc::Receiver, + port: u16, + }, +} + +/// Spawns worker tasks +/// +/// One task processes received packets, notifying the io-uring loop when a +/// packet finishes processing, the other receives packets to send and notifies +/// the io-uring loop when there are 1 or more packets available to be sent +fn spawn_workers( + rt: &tokio::runtime::Runtime, + ctx: PacketProcessorCtx, + pending_sends: PendingSends, + packet_processed_event: EventFdWriter, + mut shutdown_rx: crate::ShutdownRx, + shutdown_event: EventFdWriter, +) -> tokio::sync::mpsc::Sender { + let (tx, mut rx) = tokio::sync::mpsc::channel::(1); + + // Spawn a task that just monitors the shutdown receiver to notify the io-uring loop to exit + rt.spawn(async move { + // The result is uninteresting, either a shutdown has been signalled, or all senders have been dropped + // which equates to the same thing + let _ = shutdown_rx.changed().await; + shutdown_event.write(1); + }); + + match ctx { + PacketProcessorCtx::Router { + config, + sessions, + error_sender, + worker_id, + upstream_receiver, + } => { + rt.spawn(async move { + let mut last_received_at = None; + + let mut error_acc = super::error::ErrorAccumulator::new(error_sender); + + while let Some(packet) = rx.recv().await { + let received_at = UtcTimestamp::now(); + if let Some(last_received_at) = last_received_at { + metrics::packet_jitter(metrics::READ, &metrics::EMPTY) + .set((received_at - last_received_at).nanos()); + } + last_received_at = Some(received_at); + + let ds_packet = proxy::packet_router::DownstreamPacket { + contents: packet.buffer, + source: packet.source, + }; + + crate::components::proxy::packet_router::DownstreamReceiveWorkerConfig::process_task( + ds_packet, + worker_id, + &config, + &sessions, + &mut error_acc, + ) + .await; + + packet_processed_event.write(1); + } + }); + + rt.spawn(async move { + while let Ok(packet) = upstream_receiver.recv().await { + let packet = SendPacket { + destination: packet.destination.into(), + buffer: packet.data, + asn_info: packet.asn_info, + }; + pending_sends.push(packet); + } + }); + } + PacketProcessorCtx::SessionPool { + pool, + port, + mut downstream_receiver, + } => { + rt.spawn(async move { + let mut last_received_at = None; + + while let Some(packet) = rx.recv().await { + pool.process_received_upstream_packet( + packet.buffer, + packet.source, + port, + &mut last_received_at, + ) + .await; + + packet_processed_event.write(1); + } + }); + + rt.spawn(async move { + while let Some(packet) = downstream_receiver.recv().await { + let packet = SendPacket { + destination: packet.destination.into(), + buffer: packet.data, + asn_info: packet.asn_info, + }; + pending_sends.push(packet); + } + }); + } + } + + tx +} + +#[inline] +fn empty_net_addr() -> std::net::SocketAddr { + std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0) +} + +enum Token { + /// Packet received + Recv { key: usize }, + /// Packet sent + Send { key: usize }, + /// Recv packet processed + RecvPacketProcessed, + /// One or more packets are ready to be sent + PendingsSends, + /// Loop shutdown requested + Shutdown, +} + +struct LoopCtx<'uring> { + sq: io_uring::squeue::SubmissionQueue<'uring, Entry>, + backlog: std::collections::VecDeque, + socket_fd: Fd, + tokens: slab::Slab, + /// Packets currently being received or sent in the io-uring loop + loop_packets: slab::Slab, +} + +impl<'uring> LoopCtx<'uring> { + #[inline] + fn sync(&mut self) { + self.sq.sync(); + } + + /// Enqueues a recv_from on the socket + #[inline] + fn enqueue_recv(&mut self, buffer: crate::pool::PoolBuffer) { + let packet = LoopPacketInner::Recv(RecvPacket { + buffer, + source: empty_net_addr(), + }); + + let (key, msghdr) = { + let entry = self.loop_packets.vacant_entry(); + let key = entry.key(); + let pp = entry.insert(LoopPacket::new()); + pp.set_packet(packet); + (key, std::ptr::addr_of_mut!(pp.msghdr)) + }; + + let token = self.tokens.insert(Token::Recv { key }); + self.push( + io_uring::opcode::RecvMsg::new(self.socket_fd, msghdr) + .build() + .user_data(token as _), + ); + } + + /// Enqueues a send_to on the socket + #[inline] + fn enqueue_send(&mut self, packet: SendPacket) { + // We rely on sends using state with stable addresses, but realistically we should + // never be at capacity + if self.loop_packets.capacity() - self.loop_packets.len() == 0 { + metrics::errors_total( + metrics::WRITE, + "io-uring packet send slab is at capacity", + &packet.asn_info.as_ref().into(), + ); + return; + } + + let (key, msghdr) = { + let entry = self.loop_packets.vacant_entry(); + let key = entry.key(); + let pp = entry.insert(LoopPacket::new()); + pp.set_packet(LoopPacketInner::Send(packet)); + (key, std::ptr::addr_of!(pp.msghdr)) + }; + + let token = self.tokens.insert(Token::Send { key }); + self.push( + io_uring::opcode::SendMsg::new(self.socket_fd, msghdr) + .build() + .user_data(token as _), + ); + } + + #[inline] + fn pop_packet(&mut self, key: usize) -> LoopPacket { + self.loop_packets.remove(key) + } + + /// For now we have a backlog, but this would basically mean that we are receiving + /// more upstream packets than we can send downstream, which should? never happen + #[inline] + fn process_backlog(&mut self, submitter: &io_uring::Submitter<'uring>) -> std::io::Result<()> { + loop { + if self.sq.is_full() { + match submitter.submit() { + Ok(_) => (), + Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, + Err(err) => return Err(err), + } + } + self.sq.sync(); + + match self.backlog.pop_front() { + // SAFETY: Same as Self::push, all memory pointed to in our ops are pinned at + // stable locations in memory + Some(sqe) => unsafe { + let _ = self.sq.push(&sqe); + }, + None => break, + } + } + + Ok(()) + } + + #[inline] + fn push_with_token(&mut self, entry: Entry, token: Token) { + let token = self.tokens.insert(token); + self.push(entry.user_data(token as _)); + } + + #[inline] + fn push(&mut self, entry: Entry) { + // SAFETY: we keep all memory/file descriptors alive and in a stable locations + // for the duration of the I/O requests + unsafe { + if self.sq.push(&entry).is_err() { + self.backlog.push_back(entry); + } + } + } + + #[inline] + fn remove(&mut self, token: usize) -> Token { + self.tokens.remove(token) + } +} + +pub struct IoUringLoop { + runtime: tokio::runtime::Runtime, + socket: crate::net::DualStackLocalSocket, + concurrent_sends: usize, +} + +impl IoUringLoop { + pub fn new( + concurrent_sends: u16, + socket: crate::net::DualStackLocalSocket, + ) -> Result { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .max_blocking_threads(1) + .worker_threads(3) + .build()?; + + Ok(Self { + runtime, + concurrent_sends: concurrent_sends as _, + socket, + }) + } + + pub fn spawn( + self, + thread_name: String, + ctx: PacketProcessorCtx, + buffer_pool: Arc, + shutdown: crate::ShutdownRx, + ) -> Result, PipelineError> { + let dispatcher = tracing::dispatcher::get_default(|d| d.clone()); + let (tx, rx) = tokio::sync::oneshot::channel(); + + let rt = self.runtime; + let socket = self.socket; + let concurrent_sends = self.concurrent_sends; + + let mut ring = io_uring::IoUring::new((concurrent_sends + 3) as _)?; + + // Used to notify the uring loop when 1 or more packets have been queued + // up to be sent to a remote address + let mut pending_sends_event = EventFd::new()?; + // Used to notify the uring when a received packet has finished + // processing and we can perform another recv, as we (currently) only + // ever process a single packet at a time + let mut process_event = EventFd::new()?; + // Used to notify the uring loop to shutdown + let mut shutdown_event = EventFd::new()?; + + std::thread::Builder::new() + .name(thread_name) + .spawn(move || { + let _guard = tracing::dispatcher::set_default(&dispatcher); + + let tokens = slab::Slab::with_capacity(concurrent_sends + 1 + 1 + 1); + let loop_packets = slab::Slab::with_capacity(concurrent_sends + 1); + + // Create an eventfd to notify the uring thread (this one) of + // pending sends + let pending_sends = PendingSends::new(pending_sends_event.writer()); + // Just double buffer the pending writes for simplicity + let mut double_pending_sends = Vec::new(); + + // When sending packets, this is the direction used when updating metrics + let send_dir = if matches!(ctx, PacketProcessorCtx::Router { .. }) { + metrics::WRITE + } else { + metrics::READ + }; + + // Spawn the worker tasks that process in an async context unlike + // our io-uring loop below + let process_packet_tx = spawn_workers( + &rt, + ctx, + pending_sends.clone(), + process_event.writer(), + shutdown, + shutdown_event.writer(), + ); + + let (submitter, sq, mut cq) = ring.split(); + + let mut loop_ctx = LoopCtx { + sq, + socket_fd: socket.raw_fd(), + backlog: Default::default(), + loop_packets, + tokens, + }; + + loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); + loop_ctx + .push_with_token(pending_sends_event.io_uring_entry(), Token::PendingsSends); + loop_ctx.push_with_token(shutdown_event.io_uring_entry(), Token::Shutdown); + + // Sync always needs to be called when entries have been pushed + // onto the submission queue for the loop to actually function (ie, similar to await on futures) + loop_ctx.sync(); + + // Notify that we have set everything up + let _ = tx.send(()); + + // The core io uring loop + 'io: loop { + match submitter.submit_and_wait(1) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => {} + Err(error) => { + tracing::error!(%error, "io-uring submit_and_wait failed"); + return; + } + } + cq.sync(); + + if let Err(error) = loop_ctx.process_backlog(&submitter) { + tracing::error!(%error, "failed to process io-uring backlog"); + return; + } + + // Now actually process all of the completed io requests + for cqe in &mut cq { + let ret = cqe.result(); + let token_index = cqe.user_data() as usize; + + let token = loop_ctx.remove(token_index); + match token { + Token::Recv { key } => { + // Pop the packet regardless of whether we failed or not so that + // we don't consume a buffer slot forever + let packet = loop_ctx.pop_packet(key); + + if ret < 0 { + let error = std::io::Error::from_raw_os_error(-ret); + tracing::error!(%error, "error receiving packet"); + loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); + continue; + } + + let packet = packet.finalize_recv(ret as usize); + if process_packet_tx.blocking_send(packet).is_err() { + unreachable!("packet process thread has a pending packet"); + } + + // Queue the wait for the processing of the packet to finish + loop_ctx.push_with_token( + process_event.io_uring_entry(), + Token::RecvPacketProcessed, + ); + } + Token::RecvPacketProcessed => { + loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); + } + Token::PendingsSends => { + double_pending_sends = pending_sends.swap(double_pending_sends); + loop_ctx.push_with_token( + pending_sends_event.io_uring_entry(), + Token::PendingsSends, + ); + + for pending in + double_pending_sends.drain(0..double_pending_sends.len()) + { + loop_ctx.enqueue_send(pending); + } + } + Token::Send { key } => { + let packet = loop_ctx.pop_packet(key).finalize_send(); + let asn_info = packet.asn_info.as_ref().into(); + + if ret < 0 { + let source = + std::io::Error::from_raw_os_error(-ret).to_string(); + metrics::errors_total(send_dir, &source, &asn_info).inc(); + metrics::packets_dropped_total(send_dir, &source, &asn_info) + .inc(); + } else if ret as usize != packet.buffer.len() { + metrics::packets_total(send_dir, &asn_info).inc(); + metrics::errors_total( + send_dir, + "sent bytes != packet length", + &asn_info, + ) + .inc(); + } else { + metrics::packets_total(send_dir, &asn_info).inc(); + metrics::bytes_total(send_dir, &asn_info).inc_by(ret as u64); + } + } + Token::Shutdown => { + tracing::info!("io-uring loop shutdown requested"); + break 'io; + } + } + } + + loop_ctx.sync(); + } + })?; + + Ok(rx) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// This is just a sanity check that eventfd, which we use to notify the io-uring + /// loop of events from async tasks, functions as we need to, namely that + /// an event posted before the I/O request is submitted to the I/O loop still + /// triggers the completion of the I/O request + #[test] + #[cfg(target_os = "linux")] + #[allow(clippy::undocumented_unsafe_blocks)] + fn eventfd_works_as_expected() { + let mut event = EventFd::new().unwrap(); + let event_writer = event.writer(); + + // Write even before we create the loop + event_writer.write(1); + + let mut ring = io_uring::IoUring::new(2).unwrap(); + let (submitter, mut sq, mut cq) = ring.split(); + + unsafe { + sq.push(&event.io_uring_entry().user_data(1)).unwrap(); + } + + sq.sync(); + + loop { + match submitter.submit_and_wait(1) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => {} + Err(error) => { + panic!("oh no {error}"); + } + } + cq.sync(); + + for cqe in &mut cq { + assert_eq!(cqe.result(), 8); + + match cqe.user_data() { + // This was written before the loop started, but now write to the event + // before queuing up the next read + 1 => { + assert_eq!(event.val, 1); + event_writer.write(9999); + + unsafe { + sq.push(&event.io_uring_entry().user_data(2)).unwrap(); + } + } + 2 => { + assert_eq!(event.val, 9999); + return; + } + _ => unreachable!(), + } + } + + sq.sync(); + } + } +} diff --git a/src/components/proxy/packet_router.rs b/src/components/proxy/packet_router.rs index 0fcf118a73..b083e2194e 100644 --- a/src/components/proxy/packet_router.rs +++ b/src/components/proxy/packet_router.rs @@ -6,18 +6,21 @@ use crate::{ filters::{Filter as _, ReadContext}, metrics, pool::PoolBuffer, - time::UtcTimestamp, Config, }; use std::{net::SocketAddr, sync::Arc}; use tokio::sync::mpsc; +#[cfg(target_os = "linux")] +mod io_uring; +#[cfg(not(target_os = "linux"))] +mod reference; + /// Packet received from local port -#[derive(Debug)] -struct DownstreamPacket { - contents: PoolBuffer, - received_at: UtcTimestamp, - source: SocketAddr, +pub(crate) struct DownstreamPacket { + pub(crate) contents: PoolBuffer, + //received_at: UtcTimestamp, + pub(crate) source: SocketAddr, } /// Represents the required arguments to run a worker task that @@ -35,140 +38,9 @@ pub struct DownstreamReceiveWorkerConfig { } impl DownstreamReceiveWorkerConfig { - pub async fn spawn(self) -> eyre::Result> { - let Self { - worker_id, - upstream_receiver, - port, - config, - sessions, - error_sender, - buffer_pool, - } = self; - - let notify = Arc::new(tokio::sync::Notify::new()); - let is_ready = notify.clone(); - - let thread_span = - uring_span!(tracing::debug_span!("receiver", id = worker_id).or_current()); - - let worker = uring_spawn!(thread_span, async move { - let mut last_received_at = None; - let socket = crate::net::DualStackLocalSocket::new(port) - .unwrap() - .make_refcnt(); - - tracing::trace!(port, "bound worker"); - let send_socket = socket.clone(); - - let inner_task = async move { - is_ready.notify_one(); - - loop { - tokio::select! { - result = upstream_receiver.recv() => { - match result { - Err(error) => { - tracing::trace!(%error, "error receiving packet"); - metrics::errors_total( - metrics::WRITE, - &error.to_string(), - &metrics::EMPTY, - ) - .inc(); - } - Ok((data, asn_info, send_addr)) => { - let (result, _) = send_socket.send_to(data, send_addr).await; - let asn_info = asn_info.as_ref().into(); - match result { - Ok(size) => { - metrics::packets_total(metrics::WRITE, &asn_info) - .inc(); - metrics::bytes_total(metrics::WRITE, &asn_info) - .inc_by(size as u64); - } - Err(error) => { - let source = error.to_string(); - metrics::errors_total( - metrics::WRITE, - &source, - &asn_info, - ) - .inc(); - metrics::packets_dropped_total( - metrics::WRITE, - &source, - &asn_info, - ) - .inc(); - } - } - } - } - } - } - } - }; - - cfg_if::cfg_if! { - if #[cfg(debug_assertions)] { - uring_inner_spawn!(inner_task.instrument(tracing::debug_span!("upstream").or_current())); - } else { - uring_inner_spawn!(inner_task); - } - } - - let mut error_acc = super::error::ErrorAccumulator::new(error_sender); - - loop { - // Initialize a buffer for the UDP packet. We use the maximum size of a UDP - // packet, which is the maximum value of 16 a bit integer. - let buffer = buffer_pool.clone().alloc(); - - let (result, contents) = socket.recv_from(buffer).await; - - match result { - Ok((_size, mut source)) => { - source.set_ip(source.ip().to_canonical()); - let packet = DownstreamPacket { - received_at: UtcTimestamp::now(), - contents, - source, - }; - - if let Some(last_received_at) = last_received_at { - metrics::packet_jitter(metrics::READ, &metrics::EMPTY) - .set((packet.received_at - last_received_at).nanos()); - } - last_received_at = Some(packet.received_at); - - Self::process_task( - packet, - source, - worker_id, - &config, - &sessions, - &mut error_acc, - ) - .await; - } - Err(error) => { - tracing::error!(%error, "error receiving packet"); - return; - } - } - } - }); - - use eyre::WrapErr as _; - worker.await.context("failed to spawn receiver task")??; - Ok(notify) - } - #[inline] - async fn process_task( + pub(crate) async fn process_task( packet: DownstreamPacket, - source: std::net::SocketAddr, worker_id: usize, config: &Arc, sessions: &Arc, @@ -177,7 +49,7 @@ impl DownstreamReceiveWorkerConfig { tracing::trace!( id = worker_id, size = packet.contents.len(), - source = %source, + source = %packet.source, "received packet from downstream" ); @@ -257,7 +129,8 @@ pub async fn spawn_receivers( sessions: &Arc, upstream_receiver: DownstreamReceiver, buffer_pool: Arc, -) -> crate::Result>> { + shutdown: crate::ShutdownRx, +) -> crate::Result>> { let (error_sender, mut error_receiver) = mpsc::channel(128); let port = crate::net::socket_port(&socket); @@ -274,7 +147,7 @@ pub async fn spawn_receivers( buffer_pool: buffer_pool.clone(), }; - worker_notifications.push(worker.spawn().await?); + worker_notifications.push(worker.spawn(shutdown.clone()).await?); } drop(error_sender); diff --git a/src/components/proxy/packet_router/io_uring.rs b/src/components/proxy/packet_router/io_uring.rs new file mode 100644 index 0000000000..d2b3916011 --- /dev/null +++ b/src/components/proxy/packet_router/io_uring.rs @@ -0,0 +1,39 @@ +use eyre::Context as _; + +impl super::DownstreamReceiveWorkerConfig { + pub async fn spawn( + self, + shutdown: crate::ShutdownRx, + ) -> eyre::Result> { + use crate::components::proxy::io_uring_shared; + + let Self { + worker_id, + upstream_receiver, + port, + config, + sessions, + error_sender, + buffer_pool, + } = self; + + let socket = + crate::net::DualStackLocalSocket::new(port).context("failed to bind socket")?; + + let io_loop = io_uring_shared::IoUringLoop::new(2000, socket)?; + io_loop + .spawn( + format!("packet-router-{worker_id}"), + io_uring_shared::PacketProcessorCtx::Router { + config, + sessions, + error_sender, + upstream_receiver, + worker_id, + }, + buffer_pool, + shutdown, + ) + .context("failed to spawn io-uring loop") + } +} diff --git a/src/components/proxy/packet_router/reference.rs b/src/components/proxy/packet_router/reference.rs new file mode 100644 index 0000000000..8a2e5fc2b7 --- /dev/null +++ b/src/components/proxy/packet_router/reference.rs @@ -0,0 +1,133 @@ +//! The reference implementation is used for non-Linux targets + +impl super::DownstreamReceiveWorkerConfig { + pub async fn spawn( + self, + _shutdown: crate::ShutdownRx, + ) -> eyre::Result> { + let Self { + worker_id, + upstream_receiver, + port, + config, + sessions, + error_sender, + buffer_pool, + } = self; + + let (tx, rx) = tokio::sync::oneshot::channel(); + + let thread_span = + uring_span!(tracing::debug_span!("receiver", id = worker_id).or_current()); + + let worker = uring_spawn!(thread_span, async move { + let mut last_received_at = None; + let socket = crate::net::DualStackLocalSocket::new(port) + .unwrap() + .make_refcnt(); + + tracing::trace!(port, "bound worker"); + let send_socket = socket.clone(); + + let inner_task = async move { + let _ = tx.send(()); + + loop { + tokio::select! { + result = upstream_receiver.recv() => { + match result { + Err(error) => { + tracing::trace!(%error, "error receiving packet"); + crate::metrics::errors_total( + crate::metrics::WRITE, + &error.to_string(), + &crate::metrics::EMPTY, + ) + .inc(); + } + Ok(crate::components::proxy::SendPacket { + destination, + asn_info, + data, + }) => { + let (result, _) = send_socket.send_to(data, destination).await; + let asn_info = asn_info.as_ref().into(); + match result { + Ok(size) => { + crate::metrics::packets_total(crate::metrics::WRITE, &asn_info) + .inc(); + crate::metrics::bytes_total(crate::metrics::WRITE, &asn_info) + .inc_by(size as u64); + } + Err(error) => { + let source = error.to_string(); + crate::metrics::errors_total( + crate::metrics::WRITE, + &source, + &asn_info, + ) + .inc(); + crate::metrics::packets_dropped_total( + crate::metrics::WRITE, + &source, + &asn_info, + ) + .inc(); + } + } + } + } + } + } + } + }; + + cfg_if::cfg_if! { + if #[cfg(debug_assertions)] { + uring_inner_spawn!(inner_task.instrument(tracing::debug_span!("upstream").or_current())); + } else { + uring_inner_spawn!(inner_task); + } + } + + let mut error_acc = + crate::components::proxy::error::ErrorAccumulator::new(error_sender); + + loop { + // Initialize a buffer for the UDP packet. We use the maximum size of a UDP + // packet, which is the maximum value of 16 a bit integer. + let buffer = buffer_pool.clone().alloc(); + + let (result, contents) = socket.recv_from(buffer).await; + let received_at = crate::time::UtcTimestamp::now(); + + match result { + Ok((_size, mut source)) => { + source.set_ip(source.ip().to_canonical()); + let packet = super::DownstreamPacket { contents, source }; + + if let Some(last_received_at) = last_received_at { + crate::metrics::packet_jitter( + crate::metrics::READ, + &crate::metrics::EMPTY, + ) + .set((received_at - last_received_at).nanos()); + } + last_received_at = Some(received_at); + + Self::process_task(packet, worker_id, &config, &sessions, &mut error_acc) + .await; + } + Err(error) => { + tracing::error!(%error, "error receiving packet"); + return; + } + } + } + }); + + use eyre::WrapErr as _; + worker.await.context("failed to spawn receiver task")?; + Ok(rx) + } +} diff --git a/src/components/proxy/sessions.rs b/src/components/proxy/sessions.rs index a693e6280f..e7c43742bd 100644 --- a/src/components/proxy/sessions.rs +++ b/src/components/proxy/sessions.rs @@ -28,13 +28,11 @@ use tokio::{ }; use crate::{ + components::proxy::{PipelineError, SendPacket}, config::Config, filters::Filter, metrics, - net::{ - maxmind_db::{IpNetEntry, MetricsIpNetEntry}, - DualStackLocalSocket, - }, + net::maxmind_db::{IpNetEntry, MetricsIpNetEntry}, pool::{BufferPool, FrozenPoolBuffer, PoolBuffer}, time::UtcTimestamp, Loggable, ShutdownRx, @@ -43,11 +41,16 @@ use crate::{ pub(crate) mod inner_metrics; pub type SessionMap = crate::collections::ttl::TtlMap; -type ChannelData = (PoolBuffer, Option, SocketAddr); -type UpstreamChannelData = (FrozenPoolBuffer, Option, SocketAddr); -type UpstreamSender = mpsc::Sender; -type DownstreamSender = async_channel::Sender; -pub type DownstreamReceiver = async_channel::Receiver; + +#[cfg(target_os = "linux")] +mod io_uring; +#[cfg(not(target_os = "linux"))] +mod reference; + +type UpstreamSender = mpsc::Sender; + +type DownstreamSender = async_channel::Sender; +pub type DownstreamReceiver = async_channel::Receiver; #[derive(PartialEq, Eq, Hash)] pub enum SessionError { @@ -142,102 +145,24 @@ impl SessionPool { .as_socket() .ok_or(SessionError::SocketAddressUnavailable)? .port(); - let (tx, mut downstream_receiver) = mpsc::channel::(15); - - let pool = self.clone(); - - let initialised = uring_spawn!( - uring_span!(tracing::debug_span!("session pool")), - async move { - let mut last_received_at = None; - let mut shutdown_rx = pool.shutdown_rx.clone(); - let (tx, mut rx) = tokio::sync::oneshot::channel(); - - cfg_if::cfg_if! { - if #[cfg(target_os = "linux")] { - let socket = std::rc::Rc::new(DualStackLocalSocket::from_raw(raw_socket)); - } else { - let socket = std::sync::Arc::new(DualStackLocalSocket::from_raw(raw_socket)); - } - }; - let socket2 = socket.clone(); - - uring_inner_spawn!(async move { - loop { - match downstream_receiver.recv().await { - None => { - metrics::errors_total( - metrics::WRITE, - "downstream channel closed", - &metrics::EMPTY, - ) - .inc(); - break; - } - Some((data, asn_info, send_addr)) => { - tracing::trace!(%send_addr, length = data.len(), "sending packet upstream"); - let (result, _) = socket2.send_to(data, send_addr).await; - let asn_info = asn_info.as_ref().into(); - match result { - Ok(size) => { - metrics::packets_total(metrics::READ, &asn_info).inc(); - metrics::bytes_total(metrics::READ, &asn_info) - .inc_by(size as u64); - } - Err(error) => { - tracing::trace!(%error, "sending packet upstream failed"); - let source = error.to_string(); - metrics::errors_total(metrics::READ, &source, &asn_info) - .inc(); - metrics::packets_dropped_total( - metrics::READ, - &source, - &asn_info, - ) - .inc(); - } - } - } - } - } - - let _ = tx.send(()); - }); - - loop { - let buf = pool.buffer_pool.clone().alloc(); - tokio::select! { - received = socket.recv_from(buf) => { - let (result, buf) = received; - match result { - Err(error) => { - tracing::trace!(%error, "error receiving packet"); - metrics::errors_total(metrics::WRITE, &error.to_string(), &metrics::EMPTY).inc(); - }, - Ok((_size, recv_addr)) => pool.process_received_upstream_packet(buf, recv_addr, port, &mut last_received_at).await, - } - } - _ = shutdown_rx.changed() => { - tracing::debug!("Closing upstream socket loop"); - return; - } - _ = &mut rx => { - tracing::debug!("Closing upstream socket loop, downstream closed"); - return; - } - } - } - } - ); + let (downstream_sender, downstream_receiver) = mpsc::channel::(15); - initialised.await.unwrap()?; + let initialised = self + .clone() + .spawn_session(raw_socket, port, downstream_receiver)?; + initialised + .await + .map_err(|_err| PipelineError::ChannelClosed)?; - self.ports_to_sockets.write().await.insert(port, tx.clone()); - self.create_session_from_existing_socket(key, tx, port) + self.ports_to_sockets + .write() + .await + .insert(port, downstream_sender.clone()); + self.create_session_from_existing_socket(key, downstream_sender, port) .await } - async fn process_received_upstream_packet( + pub(crate) async fn process_received_upstream_packet( self: &Arc, packet: PoolBuffer, mut recv_addr: SocketAddr, @@ -425,17 +350,19 @@ impl SessionPool { return Err((asn_info, err.into())); } - let packet = context.contents; + let packet = context.contents.freeze(); tracing::trace!(%source, %dest, length = packet.len(), "sending packet downstream"); downstream_sender - .try_send((packet, asn_info, dest)) + .try_send(SendPacket { + data: packet, + destination: dest, + asn_info, + }) .map_err(|error| match error { - async_channel::TrySendError::Closed((_, asn_info, _)) => { - (asn_info, Error::ChannelClosed) - } - async_channel::TrySendError::Full((_, asn_info, _)) => { - (asn_info, Error::ChannelFull) + async_channel::TrySendError::Closed(packet) => { + (packet.asn_info, Error::ChannelClosed) } + async_channel::TrySendError::Full(packet) => (packet.asn_info, Error::ChannelFull), })?; Ok(()) } @@ -456,7 +383,11 @@ impl SessionPool { let (asn_info, sender) = self.get(key).await?; sender - .try_send((packet, asn_info, key.dest)) + .try_send(crate::components::proxy::SendPacket { + data: packet, + asn_info, + destination: key.dest, + }) .map_err(|error| match error { TrySendError::Closed(_) => super::PipelineError::ChannelClosed, TrySendError::Full(_) => super::PipelineError::ChannelFull, @@ -805,11 +736,11 @@ mod tests { pool.send(key, alloc_buffer(msg).freeze()).await.unwrap(); - let (data, _, _) = tokio::time::timeout(std::time::Duration::from_secs(1), receiver.recv()) + let packet = tokio::time::timeout(std::time::Duration::from_secs(1), receiver.recv()) .await .unwrap() .unwrap(); - assert_eq!(msg, &*data); + assert_eq!(msg, &*packet.data); } } diff --git a/src/components/proxy/sessions/inner_metrics.rs b/src/components/proxy/sessions/inner_metrics.rs index 6ee30c0e2a..bc2902ffa0 100644 --- a/src/components/proxy/sessions/inner_metrics.rs +++ b/src/components/proxy/sessions/inner_metrics.rs @@ -41,6 +41,7 @@ pub(crate) fn active_sessions(asn: Option<&crate::net::maxmind_db::IpNetEntry>) let len = crate::metrics::itoa(asnfo.id, &mut asn); ACTIVE_SESSIONS.with_label_values(&[ + // SAFETY: itoa only writes ASCII unsafe { std::str::from_utf8_unchecked(&asn[..len as _]) }, &asnfo.as_name, &asnfo.as_cc, diff --git a/src/components/proxy/sessions/io_uring.rs b/src/components/proxy/sessions/io_uring.rs new file mode 100644 index 0000000000..ea19e0bf95 --- /dev/null +++ b/src/components/proxy/sessions/io_uring.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +static SESSION_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); + +impl super::SessionPool { + pub(super) fn spawn_session( + self: Arc, + raw_socket: socket2::Socket, + port: u16, + downstream_receiver: tokio::sync::mpsc::Receiver, + ) -> Result, crate::components::proxy::PipelineError> { + use crate::components::proxy::io_uring_shared; + + let pool = self; + let id = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let _thread_span = uring_span!(tracing::debug_span!("session", id).or_current()); + + let io_loop = io_uring_shared::IoUringLoop::new( + 2000, + crate::net::DualStackLocalSocket::from_raw(raw_socket), + )?; + let buffer_pool = pool.buffer_pool.clone(); + let shutdown = pool.shutdown_rx.clone(); + + io_loop.spawn( + format!("session-{id}"), + io_uring_shared::PacketProcessorCtx::SessionPool { + pool, + downstream_receiver, + port, + }, + buffer_pool, + shutdown, + ) + } +} diff --git a/src/components/proxy/sessions/reference.rs b/src/components/proxy/sessions/reference.rs new file mode 100644 index 0000000000..204f7b160f --- /dev/null +++ b/src/components/proxy/sessions/reference.rs @@ -0,0 +1,106 @@ +impl super::SessionPool { + pub(super) fn spawn_session( + self: std::sync::Arc, + raw_socket: socket2::Socket, + port: u16, + mut downstream_receiver: tokio::sync::mpsc::Receiver, + ) -> Result, crate::components::proxy::PipelineError> { + let pool = self; + + let rx = uring_spawn!( + uring_span!(tracing::debug_span!("session pool")), + async move { + let mut last_received_at = None; + let mut shutdown_rx = pool.shutdown_rx.clone(); + + let socket = + std::sync::Arc::new(crate::net::DualStackLocalSocket::from_raw(raw_socket)); + let socket2 = socket.clone(); + let (tx, mut rx) = tokio::sync::oneshot::channel(); + + uring_inner_spawn!(async move { + loop { + match downstream_receiver.recv().await { + None => { + crate::metrics::errors_total( + crate::metrics::WRITE, + "downstream channel closed", + &crate::metrics::EMPTY, + ) + .inc(); + break; + } + Some(crate::components::proxy::SendPacket { + destination, + data, + asn_info, + }) => { + tracing::trace!(%destination, length = data.len(), "sending packet upstream"); + let (result, _) = socket2.send_to(data, destination).await; + let asn_info = asn_info.as_ref().into(); + match result { + Ok(size) => { + crate::metrics::packets_total( + crate::metrics::READ, + &asn_info, + ) + .inc(); + crate::metrics::bytes_total( + crate::metrics::READ, + &asn_info, + ) + .inc_by(size as u64); + } + Err(error) => { + tracing::trace!(%error, "sending packet upstream failed"); + let source = error.to_string(); + crate::metrics::errors_total( + crate::metrics::READ, + &source, + &asn_info, + ) + .inc(); + crate::metrics::packets_dropped_total( + crate::metrics::READ, + &source, + &asn_info, + ) + .inc(); + } + } + } + } + } + + let _ = tx.send(()); + }); + + loop { + let buf = pool.buffer_pool.clone().alloc(); + tokio::select! { + received = socket.recv_from(buf) => { + let (result, buf) = received; + match result { + Err(error) => { + tracing::trace!(%error, "error receiving packet"); + crate::metrics::errors_total(crate::metrics::WRITE, &error.to_string(), &crate::metrics::EMPTY).inc(); + }, + Ok((_size, recv_addr)) => pool.process_received_upstream_packet(buf, recv_addr, port, &mut last_received_at).await, + } + } + _ = shutdown_rx.changed() => { + tracing::debug!("Closing upstream socket loop"); + return; + } + _ = &mut rx => { + tracing::debug!("Closing upstream socket loop, downstream closed"); + return; + } + } + } + } + ); + + Ok(rx) + } +} diff --git a/src/metrics.rs b/src/metrics.rs index e07a0369c8..66edead875 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -86,6 +86,7 @@ pub struct AsnInfo<'a> { impl<'a> AsnInfo<'a> { #[inline] fn asn_str(&self) -> &str { + // SAFETY: we only write ASCII in itoa unsafe { std::str::from_utf8_unchecked(&self.asn[..self.asn_len as _]) } } } @@ -287,6 +288,7 @@ mod test { let mut asn = [0u8; 10]; let len = super::itoa(num, &mut asn); + // SAFETY: itoa only writes ASCII let asn_str = unsafe { std::str::from_utf8_unchecked(&asn[..len as _]) }; assert_eq!(asn_str, exp); diff --git a/src/net.rs b/src/net.rs index 157f9f2ba0..2ef7d09e3e 100644 --- a/src/net.rs +++ b/src/net.rs @@ -15,61 +15,33 @@ */ /// On linux spawns a io-uring runtime + thread, everywhere else spawns a regular tokio task. +#[cfg(not(target_os = "linux"))] macro_rules! uring_spawn { ($span:expr, $future:expr) => {{ - let (tx, rx) = tokio::sync::oneshot::channel::>(); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); use tracing::Instrument as _; - cfg_if::cfg_if! { - if #[cfg(target_os = "linux")] { - let dispatcher = tracing::dispatcher::get_default(|d| d.clone()); - std::thread::Builder::new().name("io-uring".into()).spawn(move || { - let _guard = tracing::dispatcher::set_default(&dispatcher); - - match tokio_uring::Runtime::new(&tokio_uring::builder().entries(2048)) { - Ok(runtime) => { - let _ = tx.send(Ok(())); - - if let Some(span) = $span { - runtime.block_on($future.instrument(span)); - } else { - runtime.block_on($future); - } - } - Err(error) => { - let _ = tx.send(Err(error)); - } - }; - }).expect("failed to spawn io-uring thread"); - } else { - use tracing::instrument::WithSubscriber as _; - - let fut = async move { - let _ = tx.send(Ok(())); - $future.await - }; - - if let Some(span) = $span { - tokio::spawn(fut.instrument(span).with_current_subscriber()); - } else { - tokio::spawn(fut.with_current_subscriber()); - } - } + use tracing::instrument::WithSubscriber as _; + + let fut = async move { + let _ = tx.send(()); + $future.await + }; + + if let Some(span) = $span { + tokio::spawn(fut.instrument(span).with_current_subscriber()); + } else { + tokio::spawn(fut.with_current_subscriber()); } rx }}; } /// On linux spawns a io-uring task, everywhere else spawns a regular tokio task. +#[cfg(not(target_os = "linux"))] macro_rules! uring_inner_spawn { ($future:expr) => { - cfg_if::cfg_if! { - if #[cfg(target_os = "linux")] { - tokio_uring::spawn($future); - } else { - tokio::spawn($future); - } - } + tokio::spawn($future); }; } @@ -82,7 +54,7 @@ macro_rules! uring_span { if #[cfg(debug_assertions)] { Some($span) } else { - None + Option::::None } } }}; @@ -105,7 +77,7 @@ use socket2::{Protocol, Socket, Type}; cfg_if::cfg_if! { if #[cfg(target_os = "linux")] { - use tokio_uring::net::UdpSocket; + use std::net::UdpSocket; } else { use tokio::net::UdpSocket; } @@ -121,7 +93,6 @@ fn socket_with_reuse_and_address(addr: SocketAddr) -> std::io::Result if #[cfg(target_os = "linux")] { raw_socket_with_reuse_and_address(addr) .map(From::from) - .map(UdpSocket::from_std) } else { epoll_socket_with_reuse_and_address(addr) } @@ -196,7 +167,7 @@ impl DualStackLocalSocket { let local_addr = socket.local_addr().unwrap(); cfg_if::cfg_if! { if #[cfg(target_os = "linux")] { - let socket = UdpSocket::from_std(socket); + let socket = socket; } else { // This is only for macOS and Windows (non-production platforms), // and should never happen anyway, so unwrap here is fine. @@ -234,15 +205,7 @@ impl DualStackLocalSocket { } cfg_if::cfg_if! { - if #[cfg(target_os = "linux")] { - pub async fn recv_from(&self, buf: B) -> (io::Result<(usize, SocketAddr)>, B) { - self.socket.recv_from(buf).await - } - - pub async fn send_to(&self, buf: B, target: SocketAddr) -> (io::Result, B) { - self.socket.send_to(buf, target).await - } - } else { + if #[cfg(not(target_os = "linux"))] { pub async fn recv_from>(&self, mut buf: B) -> (io::Result<(usize, SocketAddr)>, B) { let result = self.socket.recv_from(&mut buf).await; (result, buf) @@ -252,6 +215,12 @@ impl DualStackLocalSocket { let result = self.socket.send_to(&buf, target).await; (result, buf) } + } else { + #[inline] + pub fn raw_fd(&self) -> io_uring::types::Fd { + use std::os::fd::AsRawFd; + io_uring::types::Fd(self.socket.as_raw_fd()) + } } } diff --git a/src/net/phoenix.rs b/src/net/phoenix.rs index 4a069c8395..89cf1eda88 100644 --- a/src/net/phoenix.rs +++ b/src/net/phoenix.rs @@ -829,7 +829,7 @@ mod tests { let (_tx, rx) = crate::make_shutdown_channel(Default::default()); let socket = raw_socket_with_reuse(qcmp_port).unwrap(); - crate::codec::qcmp::spawn(socket, rx.clone()); + crate::codec::qcmp::spawn(socket, rx.clone()).unwrap(); tokio::time::sleep(Duration::from_millis(150)).await; let measurement = diff --git a/src/pool.rs b/src/pool.rs index 9cea257eca..d37bd995e4 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -93,7 +93,7 @@ impl fmt::Debug for BufferPool { } pub struct PoolBuffer { - inner: BytesMut, + pub(crate) inner: BytesMut, owner: Arc, prefix: Option, suffix: Option, @@ -106,6 +106,11 @@ impl PoolBuffer { self.inner.len() } + #[inline] + pub fn capacity(&self) -> usize { + self.inner.capacity() + } + #[inline] pub fn is_empty(&self) -> bool { self.inner.is_empty() @@ -187,6 +192,15 @@ impl PoolBuffer { inner: Arc::new(self), } } + + /// Sets the length (number of initialized bytes) for the buffer + #[inline] + #[cfg(target_os = "linux")] + pub(crate) fn set_len(&mut self, len: usize) { + // SAFETY: len is the length as returned from the kernel on a successful + // recv_from call + unsafe { self.inner.set_len(len) } + } } impl fmt::Debug for PoolBuffer { @@ -222,37 +236,6 @@ impl std::ops::DerefMut for PoolBuffer { } } -#[cfg(target_os = "linux")] -unsafe impl tokio_uring::buf::IoBufMut for PoolBuffer { - #[inline] - fn stable_mut_ptr(&mut self) -> *mut u8 { - self.inner.stable_mut_ptr() - } - - #[inline] - unsafe fn set_init(&mut self, pos: usize) { - self.inner.set_init(pos) - } -} - -#[cfg(target_os = "linux")] -unsafe impl tokio_uring::buf::IoBuf for PoolBuffer { - #[inline] - fn stable_ptr(&self) -> *const u8 { - self.inner.stable_ptr() - } - - #[inline] - fn bytes_init(&self) -> usize { - self.inner.bytes_init() - } - - #[inline] - fn bytes_total(&self) -> usize { - self.inner.bytes_total() - } -} - impl Drop for PoolBuffer { #[inline] fn drop(&mut self) { @@ -288,24 +271,6 @@ impl FrozenPoolBuffer { } } -#[cfg(target_os = "linux")] -unsafe impl tokio_uring::buf::IoBuf for FrozenPoolBuffer { - #[inline] - fn stable_ptr(&self) -> *const u8 { - self.inner.stable_ptr() - } - - #[inline] - fn bytes_init(&self) -> usize { - self.inner.bytes_init() - } - - #[inline] - fn bytes_total(&self) -> usize { - self.inner.bytes_total() - } -} - impl std::ops::Deref for FrozenPoolBuffer { type Target = [u8];