diff --git a/examples/tonic.rs b/examples/tonic.rs index 5da3b68..c538f36 100644 --- a/examples/tonic.rs +++ b/examples/tonic.rs @@ -1,11 +1,6 @@ use std::net::SocketAddr; use std::sync::Arc; -use bytes::Bytes; -use http::Request; -use http_body_util::{Full}; -use hyper::body::Incoming; -use hyper::Response; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; @@ -13,9 +8,7 @@ use rustls::ServerConfig; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; use tonic::server::NamedService; -use tonic::service::AxumRouter; use tonic::transport::Server; -use tower::ServiceExt; use tracing::{info, Level}; use postel::{load_certs, load_private_key, serve_http_with_shutdown}; @@ -30,9 +23,7 @@ impl NamedService for GreeterService { #[tokio::main] async fn main() -> Result<(), Box> { // Initialize logging - tracing_subscriber::fmt() - .with_max_level(Level::INFO) - .init(); + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); // Configure server address let addr = SocketAddr::from(([127, 0, 0, 1], 8443)); @@ -72,13 +63,10 @@ async fn main() -> Result<(), Box> { let server = tokio::spawn(async move { info!("Server starting up..."); - let svc = >>::map_response::<_, Response>>(Server::builder() + let svc = Server::builder() .add_service(health_service) .into_service() - .into_axum_router(), |res| { - // TODO the issue here is that streams are not sync - res.map(|_| Full::new(Bytes::from("placeholder"))) - }); + .into_axum_router(); let hyper_svc = TowerToHyperService::new(svc); @@ -92,8 +80,8 @@ async fn main() -> Result<(), Box> { info!("Shutdown signal received"); }), ) - .await - .expect("Server failed unexpectedly"); + .await + .expect("Server failed unexpectedly"); }); // Keep the main thread running until Ctrl+C @@ -105,4 +93,4 @@ async fn main() -> Result<(), Box> { server.await?; Ok(()) -} \ No newline at end of file +} diff --git a/src/error.rs b/src/error.rs index 2b73e0a..5cf0eb5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -58,7 +58,7 @@ pub(crate) fn handle_accept_error(e: impl Into) -> ControlFlow; +type Source = crate::Error; /// Represents errors that originate from the server. /// This struct provides a public API for error handling. diff --git a/src/http.rs b/src/http.rs index 28213d7..1aa6cf8 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,26 +1,25 @@ -use crate::io::Transport; -use std::future::pending; -use std::{future::Future, pin::pin, sync::Arc}; -use tokio_rustls::TlsAcceptor; +use std::future::{pending, Future}; +use std::pin::pin; +use std::sync::Arc; +use std::time::Duration; use bytes::Bytes; +use futures::stream::{FuturesUnordered, StreamExt}; use http::{Request, Response}; use http_body::Body; use hyper::body::Incoming; +use hyper::rt::{Read, Write}; use hyper::service::Service; -use hyper_util::rt::TokioTimer; -use hyper_util::{ - rt::TokioIo, - server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, -}; +use hyper_util::rt::{TokioIo, TokioTimer}; +use hyper_util::server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::sleep; -use tokio::time::Duration; +use tokio_rustls::TlsAcceptor; use tokio_stream::Stream; -use tokio_stream::StreamExt as _; use tracing::{debug, trace}; use crate::fuse::Fuse; +use crate::io::Transport; /// Sleeps for a specified duration or waits indefinitely. /// @@ -38,26 +37,51 @@ async fn sleep_or_pending(wait_for: Option) { }; } -/// Serves HTTP an HTTP connection on the transport from a hyper service backend. +/// Handles TLS connection acceptance with proper error handling +async fn accept_tls_connection( + io: IO, + tls_acceptor: Arc, +) -> Result, crate::Error> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + // Perform TLS handshake in a blocking task to avoid impacting the runtime + // Because this is one of the most computationally heavy things the sever does. + // In the case of ECDSA and very fast handshakes, this has more downside + // than upside, but in the case of RSA and slow handshakes, this is a good idea. + // It amortizes out to about 2 µs of overhead per connection. + // and moves this computationally heavy task off the main thread pool. + match tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(tls_acceptor.accept(io)) + }) + .await + { + Ok(Ok(stream)) => Ok(stream), + // This connection was malformed and the server was unable to handle it + Ok(Err(e)) => Err(e.into()), + Err(e) => Err(e.into()), + } +} + +/// Serves an HTTP connection, managing its lifecycle and handling requests. /// -/// This method handles an HTTP connection on a given transport `IO`, processing requests through -/// the provided service and managing the connection lifecycle. +/// This function takes a connection and processes HTTP requests using the provided service, +/// handling connection shutdown and cleanup appropriately. /// /// # Type Parameters /// -/// * `B`: The body type for the HTTP response. -/// * `IO`: The I/O type for the HTTP connection. -/// * `S`: The service type that processes HTTP requests. -/// * `E`: The executor type for the HTTP server connection. +/// * `B`: The body type for HTTP responses +/// * `IO`: The I/O type for the connection +/// * `S`: The service type that processes HTTP requests +/// * `E`: The executor type for the server connection /// /// # Arguments /// -/// * `hyper_io`: The I/O object representing the inbound hyper IO stream. -/// * `hyper_service`: The hyper `Service` implementation used to process HTTP requests. -/// * `builder`: A `Builder` used to create and serve the HTTP connection. -/// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. -/// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection -/// before initiating a graceful shutdown. +/// * `hyper_io`: The I/O object for the connection +/// * `hyper_service`: The service implementation for processing requests +/// * `builder`: Configuration builder for the connection +/// * `watcher`: Optional shutdown signal receiver +/// * `max_connection_age`: Optional maximum connection lifetime #[inline] pub async fn serve_http_connection( hyper_io: IO, @@ -66,34 +90,27 @@ pub async fn serve_http_connection( watcher: Option>, max_connection_age: Option, ) where - B: Body + Send + 'static, - B::Data: Send, - B::Error: Into> + Send + Sync, - IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: Body + 'static, + B::Error: Into, + IO: Read + Write + Unpin + Send + 'static, S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into> + Send, - E: HttpServerConnExec + Send + Sync + 'static, + S::Future: Send, + S::Error: Into, + E: HttpServerConnExec, { - // Set up a fused future for the watcher - let mut watcher = watcher.clone(); + // Set up shutdown signal monitoring + let mut watcher = watcher; let mut sig = pin!(Fuse { inner: watcher.as_mut().map(|w| w.changed()), }); - // Set up the sleep future for max connection age + // Configure connection lifetime monitoring let sleep = sleep_or_pending(max_connection_age); tokio::pin!(sleep); - // TODO(It's absolutely terrible that we have to clone the builder here) - // and configure it rather than passing it in. - // this is due to an API flaw in the hyper_util crate. - // this builder doesn't have a way to convert back to a builder - // once you start building. - let mut builder = builder.clone(); builder - // HTTP/1 settings + // HTTP/1 optimizations .http1() // Enable half-close for better connection handling .half_close(true) @@ -107,7 +124,7 @@ pub async fn serve_http_connection( .preserve_header_case(true) // Disable automatic title casing of headers to reduce processing overhead .title_case_headers(false) - // HTTP/2 settings + // HTTP/2 optimizations .http2() // Add the timer to the builder to avoid potential issues .timer(TokioTimer::new()) @@ -142,10 +159,8 @@ pub async fn serve_http_connection( // Here we wait for the http connection to terminate loop { tokio::select! { - // Handle the connection result - rv = &mut conn => { - if let Err(err) = rv { - // Log any errors that occur while serving the HTTP connection + result = &mut conn => { + if let Err(err) = result { debug!("failed serving HTTP connection: {:#}", err); } break; @@ -375,13 +390,6 @@ pub async fn serve_http_connection( /// Ok(()) /// } /// ``` -/// -/// # Notes -/// -/// - The server will continue to accept new connections until the `signal` future resolves. -/// - When using TLS, make sure to provide a properly configured `ServerConfig`. -/// - The function will return when all connections have been closed after the shutdown signal. -#[inline] pub async fn serve_http_with_shutdown( service: S, incoming: I, @@ -395,13 +403,13 @@ where IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, IE: Into + Send + 'static, S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into> + Send, - ResBody: Body + Send + Sync + 'static, - ResBody::Error: Into + Send + Sync, - E: HttpServerConnExec + Send + Sync + 'static, + S::Future: Send, + S::Error: Into, + ResBody: Body + Send + 'static, + ResBody::Error: Into + Send, + E: HttpServerConnExec + Send + 'static, { - // Create a channel for signaling graceful shutdown to listening connections + // Initialize shutdown signaling let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); let signal_tx = Arc::new(signal_tx); @@ -411,16 +419,15 @@ where // The signal future that will resolve when the server should shut down let mut sig = pin!(Fuse { inner: signal }); - // Prepare the incoming stream of TCP connections - // from the provided stream of IO objects, which is coming - // most likely from a TCP stream. + // Configure TLS if enabled + let tls_acceptor = tls_config.map(|config| Arc::new(TlsAcceptor::from(config))); + + // Prepare connection handling let incoming = crate::tcp::serve_tcp_incoming(incoming); // Pin the incoming stream to the stack let mut incoming = pin!(incoming); - - // Create TLS acceptor if TLS config is provided - let tls_acceptor = tls_config.map(TlsAcceptor::from); + let mut active_connections = FuturesUnordered::new(); // Enter the main server loop loop { @@ -433,97 +440,50 @@ where trace!("signal received, shutting down"); break; }, - // Wait for the next IO result from the incoming stream - io = incoming.next() => { - // If we got an IO result from the incoming stream - // This effectively demultiplexes the incoming stream of IO objects, - // which each represent a connection which may then be individually - // streamed/handled. - // - // So this is effectively a demultiplexer for the incoming stream of IO objects. - // - // Because of the way the stream handling is implemented, - // the responses are multiplexed back over the same stream to the client. - // However, that would not be intuitive just from looking it this code - // because the reverse multiplexing is "invisible" to the reader. - let io = match io { - // We check if it's a valid stream - Some(Ok(io)) => io, - // or if it's a non-fatal error - Some(Err(e)) => { + Some(io_result) = incoming.next() => { + let io = match io_result { + Ok(io) => io, + Err(e) => { trace!("error accepting connection: {:#}", e); - // if it's a non-fatal error, we continue processing IO objects continue; - }, - None => { - // If we got a fatal error, meaning we lost connection or something else - // we break out of the loop - break - }, + } }; trace!("TCP streaming connection accepted"); - // For each of these TCP streams, we are going to want to - // spawn a new task to handle the connection. + let connection_service = service.clone(); + let connection_builder = builder.clone(); + let connection_signal_rx = graceful.then_some(signal_rx.clone()); + + let transport = if let Some(tls_acceptor) = &tls_acceptor { + match accept_tls_connection(io, Arc::clone(tls_acceptor)).await { + Ok(tls_stream) => Transport::new_tls(tls_stream), + Err(e) => { + // This connection failed to handshake + debug!("TLS handshake failed: {:#}", e); + continue; + } + } + } else { + Transport::new_plain(io) + }; - // Clone necessary values for the spawned task - let service = service.clone(); - let builder = builder.clone(); - let tls_acceptor = tls_acceptor.clone(); - let signal_rx = signal_rx.clone(); + // Convert our abstracted tokio transport into a hyper transport + let hyper_io = TokioIo::new(transport); - // Spawn a new task to handle this connection - tokio::spawn(async move { - // Abstract the transport layer for hyper - - let transport = if let Some(tls_acceptor) = &tls_acceptor { - // If TLS is enabled, then we perform a TLS handshake - // Clone the TLS acceptor and IO for use in the blocking task - let tls_acceptor = tls_acceptor.clone(); - let io = io; - - match tokio::task::spawn_blocking(move || { - // Perform the TLS handshake in a blocking task - // Because this is one of the most computationally heavy things the sever does. - // In the case of ECDSA and very fast handshakes, this has more downside - // than upside, but in the case of RSA and slow handshakes, this is a good idea. - // It amortizes out to about 2 µs of overhead per connection. - // and moves this computationally heavy task off the main thread pool. - tokio::runtime::Handle::current().block_on(tls_acceptor.accept(io)) - }).await { - // Handle the result of the TLS handshake - Ok(Ok(tls_stream)) => Transport::new_tls(tls_stream), - Ok(Err(e)) => { - // This connection failed to handshake - debug!("TLS handshake failed: {:#}", e); - return; - }, - Err(e) => { - // This connection was malformed and the server was unable to handle it - debug!("TLS handshake task panicked: {:#}", e); - return; - } - - } - } - else { - // If TLS is not enabled, then we use a plain transport - Transport::new_plain(io) - }; - - // Convert our abstracted tokio transport into a hyper transport - let hyper_io = TokioIo::new(transport); - - // Serve the HTTP connections on this transport - serve_http_connection( - hyper_io, - service, - builder, - graceful.then_some(signal_rx), - None - ).await; - }); + // Create future for serving the connection + let conn_future = serve_http_connection( + hyper_io, + connection_service, + connection_builder, + connection_signal_rx, + None + ); + + active_connections.push(conn_future); + }, + Some(_) = active_connections.next(), if !active_connections.is_empty() => { + trace!("Connection completed, {} active", active_connections.len()); } } } @@ -539,8 +499,11 @@ where signal_tx.receiver_count() ); - // Wait for all connections to close - // TODO(Add a timeout here, optionally) + while !active_connections.is_empty() { + if let Some(_) = active_connections.next().await { + trace!("Connection closed during shutdown"); + } + } signal_tx.closed().await; } @@ -550,21 +513,20 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{load_certs, load_private_key}; use bytes::Bytes; - use http_body_util::{BodyExt, Empty, Full}; - use hyper::{body::Incoming, Request, Response, StatusCode}; + use http::StatusCode; + use http_body_util::{BodyExt, Full}; + use hyper::{Request, Response}; use hyper_util::rt::TokioExecutor; - use hyper_util::service::TowerToHyperService; use rustls::ServerConfig; use std::net::SocketAddr; - use std::sync::Arc; use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; use tokio_stream::wrappers::TcpListenerStream; - async fn echo(req: Request) -> Result>, hyper::Error> { + // Common test handler used by both HTTP and HTTPS tests + async fn test_handler(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { (&hyper::Method::GET, "/") => { Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) @@ -573,6 +535,14 @@ mod tests { let body = req.collect().await?.to_bytes(); Ok(Response::new(Full::new(body))) } + (&hyper::Method::GET, "/delay") => { + tokio::time::sleep(Duration::from_millis(100)).await; + Ok(Response::new(Full::new(Bytes::from("Delayed response")))) + } + (&hyper::Method::GET, "/large") => { + let large_data = vec![b'x'; 1024 * 1024]; // 1MB response + Ok(Response::new(Full::new(Bytes::from(large_data)))) + } _ => { let mut res = Response::new(Full::new(Bytes::from("Not Found"))); *res.status_mut() = StatusCode::NOT_FOUND; @@ -581,128 +551,99 @@ mod tests { } } - async fn setup_test_server(addr: SocketAddr) -> (TcpListenerStream, SocketAddr) { + // Helper for setting up test server + async fn setup_test_server( + // TODO this is not passed through in any meaningful way yet + _max_conn_age: Option + ) -> (SocketAddr, oneshot::Sender<()>) { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let listener = TcpListener::bind(addr).await.unwrap(); let server_addr = listener.local_addr().unwrap(); let incoming = TcpListenerStream::new(listener); - (incoming, server_addr) - } - async fn create_test_tls_config() -> Arc { - let certs = load_certs("examples/sample.pem").unwrap(); - let key = load_private_key("examples/sample.rsa").unwrap(); - let config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .unwrap(); - Arc::new(config) - } + let (shutdown_tx, shutdown_rx) = oneshot::channel(); - async fn send_request( - addr: SocketAddr, - req: Request>, - ) -> Result, Box> { - let stream = TcpStream::connect(addr).await?; - let io = TokioIo::new(stream); - - let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; - tokio::spawn(async move { - if let Err(err) = conn.await { - eprintln!("Connection failed: {:?}", err); - } - }); + let builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let service = hyper::service::service_fn(test_handler); - Ok(sender.send_request(req).await?) - } + tokio::spawn(serve_http_with_shutdown( + service, + incoming, + builder, + None, + Some(async { shutdown_rx.await.ok(); }), + )); - // HTTP Tests + (server_addr, shutdown_tx) + } - mod http_tests { + mod payload_tests { use super::*; #[tokio::test] - async fn test_http_basic_requests() { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; - - let (shutdown_tx, shutdown_rx) = oneshot::channel(); + async fn test_large_payload() { + let (addr, shutdown_tx) = setup_test_server(None).await; + let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); - - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - None, - Some(async { - shutdown_rx.await.ok(); - }), - )); + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); - // Test GET request + // Test large response let req = Request::builder() - .uri("/") - .body(Empty::::new()) + .uri("/large") + .body(Full::new(Bytes::new())) .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); + let res = sender.send_request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = res.collect().await.unwrap().to_bytes(); - assert_eq!(&body[..], b"Hello, World!"); + assert_eq!(body.len(), 1024 * 1024); - // Test POST request + // Test large request + let large_data = vec![b'x'; 1024 * 1024]; let req = Request::builder() .method(hyper::Method::POST) .uri("/echo") - .body(Empty::::new()) + .body(Full::new(Bytes::from(large_data.clone()))) .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); + let res = sender.send_request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); - - // Test 404 response - let req = Request::builder() - .uri("/not_found") - .body(Empty::::new()) - .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); - assert_eq!(res.status(), StatusCode::NOT_FOUND); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(body.len(), large_data.len()); shutdown_tx.send(()).unwrap(); - server.await.unwrap().unwrap(); } #[tokio::test] - async fn test_http_concurrent_requests() { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; + async fn test_concurrent_large_payloads() { + let (addr, shutdown_tx) = setup_test_server(None).await; + let mut handles = Vec::new(); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); + for _ in 0..3 { + let addr = addr; + let handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - None, - Some(async { - shutdown_rx.await.ok(); - }), - )); + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); - let mut handles = vec![]; - for _ in 0..10 { - let addr = server_addr; - let handle = tokio::spawn(async move { let req = Request::builder() - .uri("/") - .body(Empty::::new()) + .uri("/large") + .body(Full::new(Bytes::new())) .unwrap(); - let res = send_request(addr, req).await.unwrap(); + let res = sender.send_request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap(); + assert_eq!(body.to_bytes().len(), 1024 * 1024); }); handles.push(handle); } @@ -712,133 +653,126 @@ mod tests { } shutdown_tx.send(()).unwrap(); - server.await.unwrap().unwrap(); } + } + + mod shutdown_tests { + use super::*; #[tokio::test] - async fn test_http_graceful_shutdown() { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; + async fn test_graceful_shutdown_with_active_requests() { + let (addr, shutdown_tx) = setup_test_server(None).await; - let (shutdown_tx, shutdown_rx) = oneshot::channel(); + // Start a slow request + let slow_req = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - None, - Some(async { - shutdown_rx.await.ok(); - }), - )); + let req = Request::builder() + .uri("/delay") + .body(Full::new(Bytes::new())) + .unwrap(); + sender.send_request(req).await + }); - // Send a request before shutdown - let req = Request::builder() - .uri("/") - .body(Empty::::new()) - .unwrap(); - let res = send_request(server_addr, req) - .await - .expect("Failed to send initial request"); + // Wait a bit then initiate shutdown + tokio::time::sleep(Duration::from_millis(50)).await; + shutdown_tx.send(()).unwrap(); + + // The slow request should complete successfully + let res = slow_req.await.unwrap().unwrap(); assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Delayed response"); + } - // Initiate graceful shutdown - shutdown_tx.send(()).unwrap(); + #[tokio::test] + async fn test_shutdown_rejects_new_connections() { + let (addr, shutdown_tx) = setup_test_server(None).await; - // Wait for the server to shut down - let shutdown_timeout = Duration::from_millis(150); - let shutdown_result = tokio::time::timeout(shutdown_timeout, async { - loop { - tokio::time::sleep(Duration::from_millis(10)).await; - let req = Request::builder() - .uri("/") - .body(Empty::::new()) - .unwrap(); - match send_request(server_addr, req).await { - Ok(_) => continue, // Server still accepting connections - Err(e) if e.to_string().contains("Connection refused") => { - // Server has shut down as expected - return Ok(()); - } - Err(e) => return Err(e), // Unexpected error - } - } - }) - .await; + // Send shutdown signal + shutdown_tx.send(()).unwrap(); - match shutdown_result { - Ok(Ok(())) => println!("Server shut down successfully"), - Ok(Err(e)) => panic!("Unexpected error during shutdown: {}", e), - Err(_) => panic!("Timeout waiting for server to shut down"), - } + // Wait a bit for shutdown to process + tokio::time::sleep(Duration::from_millis(50)).await; - // Ensure the server task completes - server.await.unwrap().unwrap(); + // Attempt to connect should fail + let result = TcpStream::connect(addr).await; + assert!(result.is_err()); } } - // HTTPS Tests - mod https_tests { use super::*; + use crate::{load_certs, load_private_key}; use crate::test::helper::RUSTLS; use once_cell::sync::Lazy; - async fn create_https_client() -> ( + async fn setup_test_tls_config() -> Arc { + let certs = load_certs("examples/sample.pem").unwrap(); + let key = load_private_key("examples/sample.rsa").unwrap(); + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + Arc::new(config) + } + + async fn setup_test_client() -> ( tokio_rustls::TlsConnector, rustls::pki_types::ServerName<'static>, ) { - let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); let client_config = rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store) + .with_root_certificates(root_store) .with_no_client_auth(); - let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); - let domain = rustls::pki_types::ServerName::try_from("localhost") - .expect("Failed to create ServerName"); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap(); - (tls_connector, domain) + (connector, domain) } - #[tokio::test] - async fn test_https_connection() { - Lazy::force(&RUSTLS); - + async fn setup_tls_test_server() -> (SocketAddr, oneshot::Sender<()>, Arc) { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; - - let tls_config = create_test_tls_config().await; + let listener = TcpListener::bind(addr).await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + let incoming = TcpListenerStream::new(listener); let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let tls_config = setup_test_tls_config().await; + let builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let service = hyper::service::service_fn(test_handler); - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); - - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, + tokio::spawn(serve_http_with_shutdown( + service, incoming, - http_server_builder, - Some(tls_config), - Some(async { - shutdown_rx.await.ok(); - }), + builder, + Some(tls_config.clone()), + Some(async { shutdown_rx.await.ok(); }), )); - let (tls_connector, domain) = create_https_client().await; + (server_addr, shutdown_tx, tls_config) + } - let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); - let tls_stream = tls_connector.connect(domain, tcp_stream).await.unwrap(); + async fn connect_tls_client( + addr: SocketAddr, + connector: tokio_rustls::TlsConnector, + domain: rustls::pki_types::ServerName<'static>, + ) -> hyper::client::conn::http1::SendRequest> { + let tcp = TcpStream::connect(addr).await.unwrap(); + let tls_stream = connector.connect(domain, tcp).await.unwrap(); + let io = TokioIo::new(tls_stream); - let (mut sender, conn) = - hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) - .await - .unwrap(); + let (sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); tokio::spawn(async move { if let Err(err) = conn.await { @@ -846,132 +780,202 @@ mod tests { } }); - let req = Request::builder() - .uri("/") - .body(Empty::::new()) - .unwrap(); - - let res = sender.send_request(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); + sender + } - let body = res.collect().await.unwrap().to_bytes(); - assert_eq!(&body[..], b"Hello, World!"); + mod tls_connection_tests { + use super::*; - shutdown_tx.send(()).unwrap(); - server.await.unwrap().unwrap(); - } + #[tokio::test] + async fn test_tls_basic_request() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; + let (connector, domain) = setup_test_client().await; + let mut sender = connect_tls_client(addr, connector, domain).await; - #[tokio::test] - async fn test_https_invalid_client_cert() { - Lazy::force(&RUSTLS); + let req = Request::builder() + .uri("/") + .body(Full::new(Bytes::new())) + .unwrap(); - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); - let tls_config = create_test_tls_config().await; - let (shutdown_tx, shutdown_rx) = oneshot::channel(); + shutdown_tx.send(()).unwrap(); + } - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); + #[tokio::test] + async fn test_tls_multiple_requests_same_connection() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; + let (connector, domain) = setup_test_client().await; + let mut sender = connect_tls_client(addr, connector, domain).await; - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - Some(tls_config), - Some(async { - shutdown_rx.await.ok(); - }), - )); + // Send multiple requests on the same connection + for _ in 0..3 { + let req = Request::builder() + .uri("/") + .body(Full::new(Bytes::new())) + .unwrap(); - let client_config = rustls::ClientConfig::builder() - .with_root_certificates(rustls::RootCertStore::empty()) - .with_no_client_auth(); + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + } - let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + shutdown_tx.send(()).unwrap(); + } - let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); - let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap(); + #[tokio::test] + async fn test_tls_concurrent_connections() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; + let mut handles = Vec::new(); + + // Create multiple concurrent TLS connections + for _ in 0..5 { + let addr = addr; + let handle = tokio::spawn(async move { + let (connector, domain) = setup_test_client().await; + let mut sender = connect_tls_client(addr, connector, domain).await; + + let req = Request::builder() + .uri("/") + .body(Full::new(Bytes::new())) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + }); + handles.push(handle); + } - let result = tls_connector.connect(domain, tcp_stream).await; - assert!( - result.is_err(), - "Expected TLS connection to fail due to invalid client certificate" - ); + for handle in handles { + handle.await.unwrap(); + } - shutdown_tx.send(()).unwrap(); - server.await.unwrap().unwrap(); + shutdown_tx.send(()).unwrap(); + } } - #[tokio::test] - async fn test_https_graceful_shutdown() { - Lazy::force(&RUSTLS); - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; + mod tls_error_tests { + use super::*; - let tls_config = create_test_tls_config().await; - let (shutdown_tx, shutdown_rx) = oneshot::channel(); + #[tokio::test] + async fn test_invalid_client_cert() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); + // Create a client with empty root store (won't trust server cert) + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - Some(tls_config), - Some(async { - shutdown_rx.await.ok(); - }), - )); + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap(); + + let tcp = TcpStream::connect(addr).await.unwrap(); + let result = connector.connect(domain, tcp).await; + + // Should fail due to untrusted certificate + assert!(result.is_err()); + + shutdown_tx.send(()).unwrap(); + } - let (tls_connector, domain) = create_https_client().await; + #[tokio::test] + async fn test_wrong_hostname() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; + let (connector, _) = setup_test_client().await; + + // Try to connect with wrong hostname + let wrong_domain = rustls::pki_types::ServerName::try_from("wronghost").unwrap(); + let tcp = TcpStream::connect(addr).await.unwrap(); + let result = connector.connect(wrong_domain, tcp).await; + + // Should fail due to hostname mismatch + assert!(result.is_err()); + + shutdown_tx.send(()).unwrap(); + } + } - // Establish a connection - let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); - let tls_stream = tls_connector.connect(domain, tcp_stream).await.unwrap(); + mod tls_payload_tests { + use super::*; - let (mut sender, conn) = - hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) - .await + #[tokio::test] + async fn test_tls_large_payload() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; + let (connector, domain) = setup_test_client().await; + let mut sender = connect_tls_client(addr, connector, domain).await; + + // Test large response + let req = Request::builder() + .uri("/large") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(body.len(), 1024 * 1024); + + // Test large request + let large_data = vec![b'x'; 1024 * 1024]; + let req = Request::builder() + .method(hyper::Method::POST) + .uri("/echo") + .body(Full::new(Bytes::from(large_data.clone()))) .unwrap(); + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(body.len(), large_data.len()); - tokio::spawn(async move { - if let Err(err) = conn.await { - eprintln!("Connection failed: {:?}", err); - } - }); + shutdown_tx.send(()).unwrap(); + } + } - // Send a request - let req = Request::builder() - .uri("/") - .body(Empty::::new()) - .unwrap(); + mod tls_shutdown_tests { + use super::*; - let res = sender.send_request(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); + #[tokio::test] + async fn test_tls_graceful_shutdown() { + Lazy::force(&RUSTLS); + let (addr, shutdown_tx, _) = setup_tls_test_server().await; + let (connector, domain) = setup_test_client().await; + let mut sender = connect_tls_client(addr, connector, domain).await; - // Initiate graceful shutdown - shutdown_tx.send(()).unwrap(); + // Send a request before shutdown + let req = Request::builder() + .uri("/") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // Wait a bit to allow the server to start shutting down - tokio::time::sleep(Duration::from_millis(10)).await; + // Initiate shutdown + shutdown_tx.send(()).unwrap(); - // Try to send another request, it should fail - let req = Request::builder() - .uri("/") - .body(Empty::::new()) - .unwrap(); + // Wait a bit + tokio::time::sleep(Duration::from_millis(50)).await; - let result = sender.send_request(req).await; - assert!( - result.is_err(), - "Expected request to fail after graceful shutdown" - ); + // Try to send another request on the same connection + let req = Request::builder() + .uri("/") + .body(Full::new(Bytes::new())) + .unwrap(); + let result = sender.send_request(req).await; - server.await.unwrap().unwrap(); + // Should fail as connection is shutting down + assert!(result.is_err()); + } } } -} +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index ab2a5c1..e979f48 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ - pub use error::{Error as TransportError, Kind as TransportErrorKind}; pub use http::serve_http_connection; pub use http::serve_http_with_shutdown;