Skip to content

Commit

Permalink
Late initialise upstream socket to prevent session map lock (#781)
Browse files Browse the repository at this point in the history
* Late initialise upstream socket to prevent session map lock

* Add exception for `option-ext`

* allow deprecated chrono function
  • Loading branch information
XAMPPRocky authored Sep 6, 2023
1 parent 8497e59 commit f32387f
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 62 deletions.
3 changes: 2 additions & 1 deletion deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ exceptions = [
# Each entry is the crate and version constraint, and its specific allow
# list
{ name ="webpki-roots", version = "0.25.0", allow = ["MPL-2.0"] },
{ name ="webpki-roots", version = "0.23.0", allow = ["MPL-2.0"] }
{ name ="webpki-roots", version = "0.23.0", allow = ["MPL-2.0"] },
{ name ="option-ext", version = "0.2.0", allow = ["MPL-2.0"] }
]

[[licenses.clarify]]
Expand Down
27 changes: 10 additions & 17 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ use tokio::net::UdpSocket;
use crate::{
endpoint::{Endpoint, EndpointAddress},
filters::{Filter, ReadContext},
ttl_map::TryResult,
Config,
};

pub use sessions::{Session, SessionArgs, SessionKey, SessionMap};
pub use sessions::{Session, SessionKey, SessionMap};

/// Packet received from local port
#[derive(Debug)]
Expand Down Expand Up @@ -204,25 +203,21 @@ impl DownstreamReceiveWorkerConfig {
dest: endpoint.address.clone(),
};

let send_future = match sessions.try_get(&session_key) {
TryResult::Present(entry) => entry.send(packet),
TryResult::Absent => {
let session_args = SessionArgs {
config: config.clone(),
source: session_key.source.clone(),
downstream_socket: downstream_socket.clone(),
dest: endpoint.clone(),
let send_future = match sessions.get(&session_key) {
Some(entry) => entry.send(packet),
None => {
let session = Session::new(
config.clone(),
session_key.source.clone(),
downstream_socket.clone(),
endpoint.clone(),
asn_info,
};
)?;

let session = session_args.into_session().await?;
let future = session.send(packet);
sessions.insert(session_key, session);
future
}
TryResult::Locked => {
return Err(PipelineError::SessionMapLocked);
}
};

send_future.await
Expand All @@ -233,8 +228,6 @@ impl DownstreamReceiveWorkerConfig {
pub enum PipelineError {
#[error("No upstream endpoints available")]
NoUpstreamEndpoints,
#[error("session map was locked")]
SessionMapLocked,
#[error("filter {0}")]
Filter(#[from] crate::filters::FilterError),
#[error("qcmp: {0}")]
Expand Down
105 changes: 61 additions & 44 deletions src/proxy/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ pub(crate) mod metrics;

use std::sync::Arc;

use tokio::{net::UdpSocket, select, sync::watch, time::Instant};
use tokio::{
net::UdpSocket,
select,
sync::{watch, OnceCell},
time::Instant,
};

use crate::{
endpoint::{Endpoint, EndpointAddress},
Expand All @@ -35,7 +40,7 @@ pub struct Session {
/// created_at is time at which the session was created
created_at: Instant,
/// socket that sends and receives from and to the endpoint address
upstream_socket: Arc<UdpSocket>,
upstream_socket: Arc<OnceCell<Arc<UdpSocket>>>,
/// dest is where to send data to
dest: Endpoint,
/// address of original sender
Expand Down Expand Up @@ -68,63 +73,77 @@ struct ReceivedPacketContext<'a> {
dest: EndpointAddress,
}

pub struct SessionArgs {
pub config: Arc<crate::Config>,
pub source: EndpointAddress,
pub downstream_socket: Arc<UdpSocket>,
pub dest: Endpoint,
pub asn_info: Option<IpNetEntry>,
}

impl SessionArgs {
/// Creates a new Session, and starts the process of receiving udp sockets
/// from its ephemeral port from endpoint(s)
pub async fn into_session(self) -> Result<Session, super::PipelineError> {
Session::new(self).await
}
}

impl Session {
/// internal constructor for a Session from SessionArgs
#[tracing::instrument(skip_all)]
async fn new(args: SessionArgs) -> Result<Self, super::PipelineError> {
let addr = (std::net::Ipv4Addr::UNSPECIFIED, 0);
let upstream_socket = Arc::new(UdpSocket::bind(addr).await?);
upstream_socket
.connect(args.dest.address.to_socket_addr().await?)
.await?;
pub fn new(
config: Arc<crate::Config>,
source: EndpointAddress,
downstream_socket: Arc<UdpSocket>,
dest: Endpoint,
asn_info: Option<IpNetEntry>,
) -> Result<Self, super::PipelineError> {
let (shutdown_tx, shutdown_rx) = watch::channel::<()>(());

let s = Session {
config: args.config.clone(),
upstream_socket,
source: args.source.clone(),
dest: args.dest,
config: config.clone(),
upstream_socket: Arc::new(OnceCell::new()),
source: source.clone(),
dest,
created_at: Instant::now(),
shutdown_tx,
asn_info: args.asn_info,
asn_info,
};

tracing::debug!(source = %s.source, dest = ?s.dest, "Session created");

self::metrics::total_sessions().inc();
s.active_session_metric().inc();
s.run(args.downstream_socket, shutdown_rx);
s.run(downstream_socket, shutdown_rx);
Ok(s)
}

fn upstream_socket(
&self,
) -> impl std::future::Future<Output = Result<Arc<UdpSocket>, super::PipelineError>> {
let upstream_socket = self.upstream_socket.clone();
let address = self.dest.address.clone();

async move {
upstream_socket
.get_or_try_init(|| async {
let upstream_socket =
UdpSocket::bind((std::net::Ipv4Addr::UNSPECIFIED, 0)).await?;
upstream_socket
.connect(address.to_socket_addr().await?)
.await?;
Ok(Arc::new(upstream_socket))
})
.await
.cloned()
}
}

/// run starts processing receiving upstream udp packets
/// and sending them back downstream
fn run(&self, downstream_socket: Arc<UdpSocket>, mut shutdown_rx: watch::Receiver<()>) {
let source = self.source.clone();
let config = self.config.clone();
let endpoint = self.dest.clone();
let upstream_socket = self.upstream_socket.clone();
let upstream_socket = self.upstream_socket();
let asn_info = self.asn_info.clone();

tokio::spawn(async move {
let mut buf: Vec<u8> = vec![0; 65535];
let mut last_received_at = None;
let upstream_socket = match upstream_socket.await {
Ok(socket) => socket,
Err(error) => {
tracing::error!(%error, "upstream socket failed to initialise");
return;
}
};

loop {
tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet");
let asn_info = asn_info.as_ref();
Expand Down Expand Up @@ -226,8 +245,8 @@ impl Session {
contents = %crate::utils::base64_encode(buf),
"sending packet upstream");

let socket = self.upstream_socket.clone();
async move { socket.send(buf).await.map_err(From::from) }
let socket = self.upstream_socket();
async move { socket.await?.send(buf).await.map_err(From::from) }
}
}

Expand Down Expand Up @@ -277,7 +296,7 @@ mod tests {

use crate::{
endpoint::{Endpoint, EndpointAddress},
proxy::sessions::{ReceivedPacketContext, SessionArgs},
proxy::sessions::ReceivedPacketContext,
test_utils::{create_socket, new_test_config, TestHelper},
};

Expand All @@ -289,17 +308,15 @@ mod tests {
let socket = Arc::new(create_socket().await);
let msg = "hello";

let sess = Session::new(SessionArgs {
config: <_>::default(),
source: addr.clone(),
downstream_socket: socket.clone(),
dest: endpoint,
asn_info: None,
})
.await
.unwrap();
let sess =
Session::new(<_>::default(), addr.clone(), socket.clone(), endpoint, None).unwrap();

sess.send(msg.as_bytes()).await.unwrap();
sess.upstream_socket()
.await
.unwrap()
.send(msg.as_bytes())
.await
.unwrap();

let mut buf = vec![0; 1024];
let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf))
Expand Down

0 comments on commit f32387f

Please sign in to comment.