Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Late initialise upstream socket to prevent session map lock #781

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ exceptions = [
# Each entry is the crate and version constraint, and its specific allow
# list
{ name ="webpki-roots", version = "0.25.0", allow = ["MPL-2.0"] },
{ name ="webpki-roots", version = "0.23.0", allow = ["MPL-2.0"] }
{ name ="webpki-roots", version = "0.23.0", allow = ["MPL-2.0"] },
{ name ="option-ext", version = "0.2.0", allow = ["MPL-2.0"] }
]

[[licenses.clarify]]
Expand Down
27 changes: 10 additions & 17 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ use tokio::net::UdpSocket;
use crate::{
endpoint::{Endpoint, EndpointAddress},
filters::{Filter, ReadContext},
ttl_map::TryResult,
Config,
};

pub use sessions::{Session, SessionArgs, SessionKey, SessionMap};
pub use sessions::{Session, SessionKey, SessionMap};

/// Packet received from local port
#[derive(Debug)]
Expand Down Expand Up @@ -204,25 +203,21 @@ impl DownstreamReceiveWorkerConfig {
dest: endpoint.address.clone(),
};

let send_future = match sessions.try_get(&session_key) {
TryResult::Present(entry) => entry.send(packet),
TryResult::Absent => {
let session_args = SessionArgs {
config: config.clone(),
source: session_key.source.clone(),
downstream_socket: downstream_socket.clone(),
dest: endpoint.clone(),
let send_future = match sessions.get(&session_key) {
Some(entry) => entry.send(packet),
None => {
let session = Session::new(
config.clone(),
session_key.source.clone(),
downstream_socket.clone(),
endpoint.clone(),
asn_info,
};
)?;

let session = session_args.into_session().await?;
let future = session.send(packet);
sessions.insert(session_key, session);
future
}
TryResult::Locked => {
return Err(PipelineError::SessionMapLocked);
}
};

send_future.await
Expand All @@ -233,8 +228,6 @@ impl DownstreamReceiveWorkerConfig {
pub enum PipelineError {
#[error("No upstream endpoints available")]
NoUpstreamEndpoints,
#[error("session map was locked")]
SessionMapLocked,
#[error("filter {0}")]
Filter(#[from] crate::filters::FilterError),
#[error("qcmp: {0}")]
Expand Down
105 changes: 61 additions & 44 deletions src/proxy/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ pub(crate) mod metrics;

use std::sync::Arc;

use tokio::{net::UdpSocket, select, sync::watch, time::Instant};
use tokio::{
net::UdpSocket,
select,
sync::{watch, OnceCell},
time::Instant,
};

use crate::{
endpoint::{Endpoint, EndpointAddress},
Expand All @@ -35,7 +40,7 @@ pub struct Session {
/// created_at is time at which the session was created
created_at: Instant,
/// socket that sends and receives from and to the endpoint address
upstream_socket: Arc<UdpSocket>,
upstream_socket: Arc<OnceCell<Arc<UdpSocket>>>,
/// dest is where to send data to
dest: Endpoint,
/// address of original sender
Expand Down Expand Up @@ -68,63 +73,77 @@ struct ReceivedPacketContext<'a> {
dest: EndpointAddress,
}

pub struct SessionArgs {
pub config: Arc<crate::Config>,
pub source: EndpointAddress,
pub downstream_socket: Arc<UdpSocket>,
pub dest: Endpoint,
pub asn_info: Option<IpNetEntry>,
}

impl SessionArgs {
/// Creates a new Session, and starts the process of receiving udp sockets
/// from its ephemeral port from endpoint(s)
pub async fn into_session(self) -> Result<Session, super::PipelineError> {
Session::new(self).await
}
}

impl Session {
/// internal constructor for a Session from SessionArgs
#[tracing::instrument(skip_all)]
async fn new(args: SessionArgs) -> Result<Self, super::PipelineError> {
let addr = (std::net::Ipv4Addr::UNSPECIFIED, 0);
let upstream_socket = Arc::new(UdpSocket::bind(addr).await?);
upstream_socket
.connect(args.dest.address.to_socket_addr().await?)
.await?;
pub fn new(
config: Arc<crate::Config>,
source: EndpointAddress,
downstream_socket: Arc<UdpSocket>,
dest: Endpoint,
asn_info: Option<IpNetEntry>,
) -> Result<Self, super::PipelineError> {
let (shutdown_tx, shutdown_rx) = watch::channel::<()>(());

let s = Session {
config: args.config.clone(),
upstream_socket,
source: args.source.clone(),
dest: args.dest,
config: config.clone(),
upstream_socket: Arc::new(OnceCell::new()),
source: source.clone(),
dest,
created_at: Instant::now(),
shutdown_tx,
asn_info: args.asn_info,
asn_info,
};

tracing::debug!(source = %s.source, dest = ?s.dest, "Session created");

self::metrics::total_sessions().inc();
s.active_session_metric().inc();
s.run(args.downstream_socket, shutdown_rx);
s.run(downstream_socket, shutdown_rx);
Ok(s)
}

fn upstream_socket(
&self,
) -> impl std::future::Future<Output = Result<Arc<UdpSocket>, super::PipelineError>> {
let upstream_socket = self.upstream_socket.clone();
let address = self.dest.address.clone();

async move {
upstream_socket
.get_or_try_init(|| async {
let upstream_socket =
UdpSocket::bind((std::net::Ipv4Addr::UNSPECIFIED, 0)).await?;
upstream_socket
.connect(address.to_socket_addr().await?)
.await?;
Ok(Arc::new(upstream_socket))
})
.await
.cloned()
}
}

/// run starts processing receiving upstream udp packets
/// and sending them back downstream
fn run(&self, downstream_socket: Arc<UdpSocket>, mut shutdown_rx: watch::Receiver<()>) {
let source = self.source.clone();
let config = self.config.clone();
let endpoint = self.dest.clone();
let upstream_socket = self.upstream_socket.clone();
let upstream_socket = self.upstream_socket();
let asn_info = self.asn_info.clone();

tokio::spawn(async move {
let mut buf: Vec<u8> = vec![0; 65535];
let mut last_received_at = None;
let upstream_socket = match upstream_socket.await {
Ok(socket) => socket,
Err(error) => {
tracing::error!(%error, "upstream socket failed to initialise");
return;
}
};

loop {
tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet");
let asn_info = asn_info.as_ref();
Expand Down Expand Up @@ -226,8 +245,8 @@ impl Session {
contents = %crate::utils::base64_encode(buf),
"sending packet upstream");

let socket = self.upstream_socket.clone();
async move { socket.send(buf).await.map_err(From::from) }
let socket = self.upstream_socket();
async move { socket.await?.send(buf).await.map_err(From::from) }
}
}

Expand Down Expand Up @@ -277,7 +296,7 @@ mod tests {

use crate::{
endpoint::{Endpoint, EndpointAddress},
proxy::sessions::{ReceivedPacketContext, SessionArgs},
proxy::sessions::ReceivedPacketContext,
test_utils::{create_socket, new_test_config, TestHelper},
};

Expand All @@ -289,17 +308,15 @@ mod tests {
let socket = Arc::new(create_socket().await);
let msg = "hello";

let sess = Session::new(SessionArgs {
config: <_>::default(),
source: addr.clone(),
downstream_socket: socket.clone(),
dest: endpoint,
asn_info: None,
})
.await
.unwrap();
let sess =
Session::new(<_>::default(), addr.clone(), socket.clone(), endpoint, None).unwrap();

sess.send(msg.as_bytes()).await.unwrap();
sess.upstream_socket()
.await
.unwrap()
.send(msg.as_bytes())
.await
.unwrap();

let mut buf = vec![0; 1024];
let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf))
Expand Down