diff --git a/deny.toml b/deny.toml index c83f7de876..12a140ef21 100644 --- a/deny.toml +++ b/deny.toml @@ -31,7 +31,8 @@ exceptions = [ # Each entry is the crate and version constraint, and its specific allow # list { name ="webpki-roots", version = "0.25.0", allow = ["MPL-2.0"] }, - { name ="webpki-roots", version = "0.23.0", allow = ["MPL-2.0"] } + { name ="webpki-roots", version = "0.23.0", allow = ["MPL-2.0"] }, + { name ="option-ext", version = "0.2.0", allow = ["MPL-2.0"] } ] [[licenses.clarify]] diff --git a/src/proxy.rs b/src/proxy.rs index ca4f843a7b..fcc0574989 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -23,11 +23,10 @@ use tokio::net::UdpSocket; use crate::{ endpoint::{Endpoint, EndpointAddress}, filters::{Filter, ReadContext}, - ttl_map::TryResult, Config, }; -pub use sessions::{Session, SessionArgs, SessionKey, SessionMap}; +pub use sessions::{Session, SessionKey, SessionMap}; /// Packet received from local port #[derive(Debug)] @@ -204,25 +203,21 @@ impl DownstreamReceiveWorkerConfig { dest: endpoint.address.clone(), }; - let send_future = match sessions.try_get(&session_key) { - TryResult::Present(entry) => entry.send(packet), - TryResult::Absent => { - let session_args = SessionArgs { - config: config.clone(), - source: session_key.source.clone(), - downstream_socket: downstream_socket.clone(), - dest: endpoint.clone(), + let send_future = match sessions.get(&session_key) { + Some(entry) => entry.send(packet), + None => { + let session = Session::new( + config.clone(), + session_key.source.clone(), + downstream_socket.clone(), + endpoint.clone(), asn_info, - }; + )?; - let session = session_args.into_session().await?; let future = session.send(packet); sessions.insert(session_key, session); future } - TryResult::Locked => { - return Err(PipelineError::SessionMapLocked); - } }; send_future.await @@ -233,8 +228,6 @@ impl DownstreamReceiveWorkerConfig { pub enum PipelineError { #[error("No upstream endpoints available")] NoUpstreamEndpoints, - #[error("session map was locked")] - SessionMapLocked, #[error("filter {0}")] Filter(#[from] crate::filters::FilterError), #[error("qcmp: {0}")] diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index a7e142fcfd..59db2ce723 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -18,7 +18,12 @@ pub(crate) mod metrics; use std::sync::Arc; -use tokio::{net::UdpSocket, select, sync::watch, time::Instant}; +use tokio::{ + net::UdpSocket, + select, + sync::{watch, OnceCell}, + time::Instant, +}; use crate::{ endpoint::{Endpoint, EndpointAddress}, @@ -35,7 +40,7 @@ pub struct Session { /// created_at is time at which the session was created created_at: Instant, /// socket that sends and receives from and to the endpoint address - upstream_socket: Arc, + upstream_socket: Arc>>, /// dest is where to send data to dest: Endpoint, /// address of original sender @@ -68,63 +73,77 @@ struct ReceivedPacketContext<'a> { dest: EndpointAddress, } -pub struct SessionArgs { - pub config: Arc, - pub source: EndpointAddress, - pub downstream_socket: Arc, - pub dest: Endpoint, - pub asn_info: Option, -} - -impl SessionArgs { - /// Creates a new Session, and starts the process of receiving udp sockets - /// from its ephemeral port from endpoint(s) - pub async fn into_session(self) -> Result { - Session::new(self).await - } -} - impl Session { /// internal constructor for a Session from SessionArgs #[tracing::instrument(skip_all)] - async fn new(args: SessionArgs) -> Result { - let addr = (std::net::Ipv4Addr::UNSPECIFIED, 0); - let upstream_socket = Arc::new(UdpSocket::bind(addr).await?); - upstream_socket - .connect(args.dest.address.to_socket_addr().await?) - .await?; + pub fn new( + config: Arc, + source: EndpointAddress, + downstream_socket: Arc, + dest: Endpoint, + asn_info: Option, + ) -> Result { let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); let s = Session { - config: args.config.clone(), - upstream_socket, - source: args.source.clone(), - dest: args.dest, + config: config.clone(), + upstream_socket: Arc::new(OnceCell::new()), + source: source.clone(), + dest, created_at: Instant::now(), shutdown_tx, - asn_info: args.asn_info, + asn_info, }; tracing::debug!(source = %s.source, dest = ?s.dest, "Session created"); self::metrics::total_sessions().inc(); s.active_session_metric().inc(); - s.run(args.downstream_socket, shutdown_rx); + s.run(downstream_socket, shutdown_rx); Ok(s) } + fn upstream_socket( + &self, + ) -> impl std::future::Future, super::PipelineError>> { + let upstream_socket = self.upstream_socket.clone(); + let address = self.dest.address.clone(); + + async move { + upstream_socket + .get_or_try_init(|| async { + let upstream_socket = + UdpSocket::bind((std::net::Ipv4Addr::UNSPECIFIED, 0)).await?; + upstream_socket + .connect(address.to_socket_addr().await?) + .await?; + Ok(Arc::new(upstream_socket)) + }) + .await + .cloned() + } + } + /// run starts processing receiving upstream udp packets /// and sending them back downstream fn run(&self, downstream_socket: Arc, mut shutdown_rx: watch::Receiver<()>) { let source = self.source.clone(); let config = self.config.clone(); let endpoint = self.dest.clone(); - let upstream_socket = self.upstream_socket.clone(); + let upstream_socket = self.upstream_socket(); let asn_info = self.asn_info.clone(); tokio::spawn(async move { let mut buf: Vec = vec![0; 65535]; let mut last_received_at = None; + let upstream_socket = match upstream_socket.await { + Ok(socket) => socket, + Err(error) => { + tracing::error!(%error, "upstream socket failed to initialise"); + return; + } + }; + loop { tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet"); let asn_info = asn_info.as_ref(); @@ -226,8 +245,8 @@ impl Session { contents = %crate::utils::base64_encode(buf), "sending packet upstream"); - let socket = self.upstream_socket.clone(); - async move { socket.send(buf).await.map_err(From::from) } + let socket = self.upstream_socket(); + async move { socket.await?.send(buf).await.map_err(From::from) } } } @@ -277,7 +296,7 @@ mod tests { use crate::{ endpoint::{Endpoint, EndpointAddress}, - proxy::sessions::{ReceivedPacketContext, SessionArgs}, + proxy::sessions::ReceivedPacketContext, test_utils::{create_socket, new_test_config, TestHelper}, }; @@ -289,17 +308,15 @@ mod tests { let socket = Arc::new(create_socket().await); let msg = "hello"; - let sess = Session::new(SessionArgs { - config: <_>::default(), - source: addr.clone(), - downstream_socket: socket.clone(), - dest: endpoint, - asn_info: None, - }) - .await - .unwrap(); + let sess = + Session::new(<_>::default(), addr.clone(), socket.clone(), endpoint, None).unwrap(); - sess.send(msg.as_bytes()).await.unwrap(); + sess.upstream_socket() + .await + .unwrap() + .send(msg.as_bytes()) + .await + .unwrap(); let mut buf = vec![0; 1024]; let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf))