Skip to content

Commit

Permalink
Add the certififcate chain to TlsSession
Browse files Browse the repository at this point in the history
This allows users to query the certificate chain when in the TlsSession
event callbacks (for dc or for on_tls_exporter_ready). The API added
here allocates a Vec<Vec<u8>> which is plausibly quite wasteful, so the
API is for now doc(hidden) to avoid stabilizing it.
  • Loading branch information
Mark-Simulacrum committed Oct 14, 2024
1 parent 7752afb commit 284500e
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 0 deletions.
18 changes: 18 additions & 0 deletions quic/s2n-quic-core/src/crypto/tls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
pub use bytes::{Bytes, BytesMut};
use core::fmt::Debug;
Expand Down Expand Up @@ -35,6 +37,19 @@ impl TlsExportError {
}
}

#[derive(Debug)]
#[non_exhaustive]
pub enum ChainError {
#[non_exhaustive]
Failure,
}

impl ChainError {
pub fn failure() -> Self {
ChainError::Failure
}
}

pub trait TlsSession: Send {
/// See <https://datatracker.ietf.org/doc/html/rfc5705> and <https://www.rfc-editor.org/rfc/rfc8446>.
fn tls_exporter(
Expand All @@ -45,6 +60,9 @@ pub trait TlsSession: Send {
) -> Result<(), TlsExportError>;

fn cipher_suite(&self) -> CipherSuite;

#[cfg(feature = "alloc")]
fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, ChainError>;
}

//= https://www.rfc-editor.org/rfc/rfc9000#section-4
Expand Down
4 changes: 4 additions & 0 deletions quic/s2n-quic-core/src/crypto/tls/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ impl TlsSession for Session {
fn cipher_suite(&self) -> CipherSuite {
CipherSuite::TLS_AES_128_GCM_SHA256
}

fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, tls::ChainError> {
Err(tls::ChainError::failure())
}
}

#[derive(Debug)]
Expand Down
9 changes: 9 additions & 0 deletions quic/s2n-quic-core/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0

use crate::{connection, endpoint};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use core::{ops::RangeInclusive, time::Duration};

mod generated;
Expand Down Expand Up @@ -149,6 +151,13 @@ impl<'a> TlsSession<'a> {
self.session.tls_exporter(label, context, output)
}

// Currently intended only for unstable usage
#[doc(hidden)]
#[cfg(feature = "alloc")]
pub fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, crate::crypto::tls::ChainError> {
self.session.peer_cert_chain_der()
}

pub fn cipher_suite(&self) -> crate::event::api::CipherSuite {
self.session.cipher_suite().into_event()
}
Expand Down
11 changes: 11 additions & 0 deletions quic/s2n-quic-rustls/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ impl tls::TlsSession for Session {
CipherSuite::Unknown
}
}

fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, tls::ChainError> {
let err = tls::ChainError::failure();
Ok(self
.connection
.peer_certificates()
.ok_or(err)?
.iter()
.map(|v| v.to_vec())
.collect())
}
}

impl fmt::Debug for Session {
Expand Down
10 changes: 10 additions & 0 deletions quic/s2n-quic-tls/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ impl tls::TlsSession for Session {
fn cipher_suite(&self) -> CipherSuite {
self.state.cipher_suite()
}

fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, tls::ChainError> {
self.connection
.peer_cert_chain()
.map_err(|_| tls::ChainError::failure())?
.iter()
.map(|v| Ok(v?.der()?.to_vec()))
.collect::<Result<Vec<Vec<u8>>, s2n_tls::error::Error>>()
.map_err(|_| tls::ChainError::failure())
}
}

impl tls::Session for Session {
Expand Down
2 changes: 2 additions & 0 deletions quic/s2n-quic/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ mod skip_packets;
// build options than s2n-tls. We should build the rustls provider with
// mTLS enabled and remove the `cfg(target_os("windows"))`.
#[cfg(not(target_os = "windows"))]
mod chain;
#[cfg(not(target_os = "windows"))]
mod client_handshake_confirm;
#[cfg(not(target_os = "windows"))]
mod dc;
Expand Down
128 changes: 128 additions & 0 deletions quic/s2n-quic/src/tests/chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

//! This module shows an example of an event provider that accesses certificate chains
//! from QUIC connections on both client and server.

use super::*;
use crate::provider::event::events::{self, ConnectionInfo, ConnectionMeta, Subscriber};

struct Chain;

#[derive(Default)]
struct ChainContext {
chain: Option<Vec<Vec<u8>>>,
sender: Option<tokio::sync::mpsc::Sender<Vec<Vec<u8>>>>,
}

impl Subscriber for Chain {
type ConnectionContext = ChainContext;

#[inline]
fn create_connection_context(
&mut self,
_: &ConnectionMeta,
_info: &ConnectionInfo,
) -> Self::ConnectionContext {
ChainContext::default()
}

fn on_tls_exporter_ready(
&mut self,
context: &mut Self::ConnectionContext,
_meta: &ConnectionMeta,
event: &events::TlsExporterReady,
) {
if let Some(sender) = context.sender.take() {
sender
.blocking_send(event.session.peer_cert_chain_der().unwrap())
.unwrap();
} else {
context.chain = Some(event.session.peer_cert_chain_der().unwrap());
}
}
}

fn start_server(
mut server: Server,
server_chain: tokio::sync::mpsc::Sender<Vec<Vec<u8>>>,
) -> crate::provider::io::testing::Result<SocketAddr> {
let server_addr = server.local_addr()?;

// accept connections and echo back
spawn(async move {
while let Some(mut connection) = server.accept().await {
let chain = connection
.query_event_context_mut(|ctx: &mut ChainContext| {
if let Some(chain) = ctx.chain.take() {
Some(chain)
} else {
ctx.sender = Some(server_chain.clone());
None
}
})
.unwrap();
if let Some(chain) = chain {
server_chain.send(chain).await.unwrap();
}
}
});

Ok(server_addr)
}

fn tls_test<C>(f: fn(crate::Connection, Vec<Vec<u8>>) -> C)
where
C: 'static + core::future::Future<Output = ()> + Send,
{
let model = Model::default();
model.set_delay(Duration::from_millis(50));

test(model, |handle| {
let server = Server::builder()
.with_io(handle.builder().build()?)?
.with_tls(build_server_mtls_provider(certificates::MTLS_CA_CERT)?)?
.with_event((Chain, tracing_events()))?
.start()?;
let (send, server_chain) = tokio::sync::mpsc::channel(1);
let server_chain = Arc::new(tokio::sync::Mutex::new(server_chain));

let addr = start_server(server, send)?;

let client = Client::builder()
.with_io(handle.builder().build().unwrap())?
.with_tls(build_client_mtls_provider(certificates::MTLS_CA_CERT)?)?
.with_event((Chain, tracing_events()))?
.start()?;

// show it working for several connections
for _ in 0..10 {
let client = client.clone();
let server_chain = server_chain.clone();
primary::spawn(async move {
let connect = Connect::new(addr).with_server_name("localhost");
let conn = client.connect(connect).await.unwrap();
delay(Duration::from_millis(100)).await;
let server_chain = server_chain.lock().await.recv().await.unwrap();
f(conn, server_chain).await;
});
}

Ok(addr)
})
.unwrap();
}

#[test]
fn happy_case() {
tls_test(|mut conn, server_chain| async move {
let client_chain = conn
.query_event_context_mut(|ctx: &mut ChainContext| ctx.chain.take().unwrap())
.unwrap();
// these are DER-encoded and we lack nice conversion functions, so just assert some simple
// properties.
assert!(server_chain.len() > 1);
assert!(client_chain.len() > 1);
assert_ne!(server_chain, client_chain);
});
}

0 comments on commit 284500e

Please sign in to comment.