Skip to content

Commit

Permalink
chore: add TLS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
0xAlcibiades committed Sep 11, 2024
1 parent 386e93a commit 7333d38
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tokio-rustls = "0.26.0"
tokio-stream = { version = "0.1.16", features = ["net"] }
tower = { version = "0.5.1", features = ["util"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

[dev-dependencies]
hyper = { version = "1.4.1", features = ["client"] }
Expand Down
321 changes: 321 additions & 0 deletions src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,324 @@ fn load_private_key(filename: &str) -> io::Result<PrivateKeyDer<'static>> {
rustls_pemfile::private_key(&mut reader)?
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found in file"))
}

Check warning on line 104 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L97-L104

Added lines #L97 - L104 were not covered by tests

#[cfg(test)]
mod tests {
use super::*;
use crate::tcp::serve_tcp_incoming;
use futures::StreamExt;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, ServerConfig};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tokio_stream::wrappers::TcpListenerStream;
use tracing::{debug, error, info, warn};

// Helper function to create a TLS acceptor for testing
async fn create_test_tls_acceptor() -> io::Result<TlsAcceptor> {
debug!("Creating test TLS acceptor");
let certs = load_certs("examples/sample.pem")?;
let key = load_private_key("examples/sample.rsa")?;

let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| {
error!("Failed to create ServerConfig: {}", e);
io::Error::new(io::ErrorKind::Other, e)

Check warning on line 133 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L132-L133

Added lines #L132 - L133 were not covered by tests
})?;

Ok(TlsAcceptor::from(Arc::new(config)))
}

#[tokio::test]
async fn test_tls_incoming_success() -> Result<(), Box<dyn std::error::Error>> {
let _guard = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.try_init();

info!("Starting test_tls_incoming_success");
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr).await?;
let server_addr = listener.local_addr()?;
debug!("Server listening on {}", server_addr);
let incoming = TcpListenerStream::new(listener);

let tls_acceptor = create_test_tls_acceptor().await?;

// Use serve_tcp_incoming to handle TCP connections
let tcp_incoming = serve_tcp_incoming(incoming);

// Spawn the server task
let server_task = tokio::spawn(async move {
debug!("Server task started");
let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor));
let result = tls_stream.next().await;
debug!("Server received connection: {:?}", result.is_some());
result
});

// Connect to the server with a TLS client
let mut root_store = rustls::RootCertStore::empty();
let certs = load_certs("examples/sample.pem")?;
root_store.add_parsable_certificates(certs);

let client_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();

let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));

debug!("Client connecting to {}", server_addr);
let tcp_stream = TcpStream::connect(server_addr).await?;
let domain = ServerName::try_from("localhost")?;
let _client_stream = connector.connect(domain, tcp_stream).await?;
debug!("Client connected successfully");

// Wait for the server to accept the connection
let result = server_task
.await?
.ok_or("Server task completed without result")?;
match result {
Ok(_) => info!("TLS connection established successfully"),
Err(ref e) => error!("TLS connection failed: {}", e),
}
assert!(result.is_ok());

Ok(())
}

#[tokio::test]
async fn test_tls_incoming_invalid_cert() -> Result<(), Box<dyn std::error::Error>> {
let _guard = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.try_init();

info!("Starting test_tls_incoming_invalid_cert");
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr).await?;
let server_addr = listener.local_addr()?;
debug!("Server listening on {}", server_addr);
let incoming = TcpListenerStream::new(listener);

// Create a TLS acceptor with an invalid certificate
let invalid_cert = vec![CertificateDer::from(vec![0; 32])]; // Invalid certificate
let key = load_private_key("examples/sample.rsa")?;

// Expect this to fail and log the error
let config_result = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(invalid_cert, key);

match config_result {
Ok(_) => warn!("ServerConfig creation unexpectedly succeeded with invalid cert"),
Err(e) => info!("ServerConfig creation failed as expected: {}", e),
}

// Use a valid certificate for the server to allow the test to continue
let valid_certs = load_certs("examples/sample.pem")?;
let valid_key = load_private_key("examples/sample.rsa")?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(valid_certs, valid_key)
.expect("ServerConfig creation should succeed with valid cert");

let tls_acceptor = TlsAcceptor::from(Arc::new(config));

// Use serve_tcp_incoming to handle TCP connections
let tcp_incoming = serve_tcp_incoming(incoming);

// Spawn the server task
let server_task = tokio::spawn(async move {
debug!("Server task started");
let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor));
let result = tls_stream.next().await;
debug!("Server received connection: {:?}", result.is_some());
result
});

// Connect to the server with a TLS client that doesn't trust the server's certificate
let connector = tokio_rustls::TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
));

debug!("Client connecting to {}", server_addr);
let tcp_stream = TcpStream::connect(server_addr).await?;
let domain = ServerName::try_from("localhost")?;

// This connection should fail due to certificate verification
let client_result = connector.connect(domain, tcp_stream).await;
match &client_result {
Ok(_) => warn!("Client connection succeeded unexpectedly"),
Err(e) => info!("Client connection failed as expected: {}", e),
}
assert!(client_result.is_err());

// The server should not encounter an error, but the connection should not be established
let server_result = server_task
.await?
.ok_or("Server task completed without result")?;
match &server_result {
Ok(_) => warn!("Server accepted connection unexpectedly"),
Err(e) => info!("Server did not establish connection as expected: {}", e),
}
assert!(server_result.is_err());

Ok(())
}

#[tokio::test]
async fn test_tls_incoming_client_hello_timeout() -> Result<(), Box<dyn std::error::Error>> {
let _guard = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.try_init();

info!("Starting test_tls_incoming_client_hello_timeout");
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr).await?;
let server_addr = listener.local_addr()?;
debug!("Server listening on {}", server_addr);
let incoming = TcpListenerStream::new(listener);

let tls_acceptor = create_test_tls_acceptor().await?;

// Use serve_tcp_incoming to handle TCP connections
let tcp_incoming = serve_tcp_incoming(incoming);

// Spawn the server task
let server_task = tokio::spawn(async move {
debug!("Server task started");
let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor));
let result =
tokio::time::timeout(std::time::Duration::from_secs(1), tls_stream.next()).await;
debug!("Server task completed with result: {:?}", result.is_err());
result
});

// Connect with a regular TCP client (no TLS handshake)
debug!("Client connecting with plain TCP to {}", server_addr);
let _tcp_stream = TcpStream::connect(server_addr).await?;
debug!("Client connected with plain TCP");

// The server task should timeout
let result = server_task.await?;
match result {
Ok(_) => warn!("Server did not timeout as expected"),
Err(ref e) => info!("Server timed out as expected: {}", e),
}
assert!(result.is_err()); // Timeout error

Ok(())
}

#[tokio::test]
async fn test_load_certs() -> io::Result<()> {
let _guard = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.try_init();

info!("Starting test_load_certs");
let certs = load_certs("examples/sample.pem")?;
debug!("Loaded {} certificates", certs.len());
assert!(!certs.is_empty(), "Certificate file should not be empty");
Ok(())
}

#[tokio::test]
async fn test_load_private_key() -> io::Result<()> {
let _guard = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.try_init();

info!("Starting test_load_private_key");
let key = load_private_key("examples/sample.rsa")?;
debug!("Loaded private key, length: {}", key.secret_der().len());
assert!(
!key.secret_der().is_empty(),
"Private key should not be empty"
);
Ok(())
}

// Simulating the tls_incoming function for testing purposes
// Replace this with your actual implementation
fn tls_incoming<IO>(
incoming: impl Stream<Item = Result<IO, Error>> + Send + 'static,
tls_acceptor: TlsAcceptor,
) -> impl Stream<Item = Result<tokio_rustls::server::TlsStream<IO>, Error>> + Send + 'static
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
Box::pin(incoming.then(move |result| {
let tls_acceptor = tls_acceptor.clone();
async move {
match result {
Ok(io) => {
debug!("Accepting TLS connection");
let accept_result = tls_acceptor.accept(io).await.map_err(Error::from);
match &accept_result {
Ok(_) => debug!("TLS connection accepted successfully"),
Err(e) => warn!("Failed to accept TLS connection: {}", e),
}
accept_result
}
Err(e) => {
warn!("Error in incoming connection: {}", e);
Err(e)

Check warning on line 374 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L372-L374

Added lines #L372 - L374 were not covered by tests
}
}
}
}))
}

// Helper function to load certificates
fn load_certs(filename: &str) -> io::Result<Vec<CertificateDer<'static>>> {
debug!("Loading certificates from {}", filename);
let certfile = std::fs::File::open(filename).map_err(|e| {
error!("Failed to open certificate file: {}", e);
io::Error::new(
io::ErrorKind::Other,
format!("failed to open {}: {}", filename, e),
)

Check warning on line 389 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L385-L389

Added lines #L385 - L389 were not covered by tests
})?;
let mut reader = io::BufReader::new(certfile);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
error!("Failed to parse certificates: {}", e);
io::Error::new(io::ErrorKind::Other, e)

Check warning on line 396 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L395-L396

Added lines #L395 - L396 were not covered by tests
})?;
debug!("Loaded {} certificates", certs.len());
Ok(certs)
}

// Helper function to load private key
fn load_private_key(filename: &str) -> io::Result<PrivateKeyDer<'static>> {
debug!("Loading private key from {}", filename);
let keyfile = std::fs::File::open(filename).map_err(|e| {
error!("Failed to open private key file: {}", e);
io::Error::new(
io::ErrorKind::Other,
format!("failed to open {}: {}", filename, e),
)

Check warning on line 410 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L406-L410

Added lines #L406 - L410 were not covered by tests
})?;
let mut reader = io::BufReader::new(keyfile);
let key = rustls_pemfile::private_key(&mut reader)
.map_err(|e| {
error!("Failed to parse private key: {}", e);
io::Error::new(io::ErrorKind::Other, e)

Check warning on line 416 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L415-L416

Added lines #L415 - L416 were not covered by tests
})?
.ok_or_else(|| {
error!("No private key found in file");
io::Error::new(io::ErrorKind::Other, "no private key found")

Check warning on line 420 in src/tls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls.rs#L419-L420

Added lines #L419 - L420 were not covered by tests
})?;
debug!("Loaded private key");
Ok(key)
}
}

0 comments on commit 7333d38

Please sign in to comment.