diff --git a/payjoin-cli/tests/e2e.rs b/payjoin-cli/tests/e2e.rs index 439dd90b..57dc2be5 100644 --- a/payjoin-cli/tests/e2e.rs +++ b/payjoin-cli/tests/e2e.rs @@ -170,6 +170,7 @@ mod e2e { use url::Url; type Error = Box; + type BoxSendSyncError = Box; type Result = std::result::Result; static INIT_TRACING: OnceCell<()> = OnceCell::new(); @@ -180,8 +181,13 @@ mod e2e { let (cert, key) = local_cert_key(); let ohttp_relay_port = find_free_port(); let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + let (port, directory_handle) = + init_directory(db_host, (cert.clone(), key)).await.expect("Failed to init directory"); + let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); + let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); let temp_dir = env::temp_dir(); @@ -189,7 +195,7 @@ mod e2e { let sender_db_path = temp_dir.join("sender_db"); let result: Result<()> = tokio::select! { res = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => Err(format!("Ohttp relay is long running: {:?}", res).into()), - res = init_directory(directory_port, (cert.clone(), key)) => Err(format!("Directory server is long running: {:?}", res).into()), + res = directory_handle => Err(format!("Directory server is long running: {:?}", res).into()), res = send_receive_cli_async(ohttp_relay, directory, cert, receiver_db_path.clone(), sender_db_path.clone()) => res.map_err(|e| format!("send_receive failed: {:?}", e).into()), }; @@ -476,13 +482,17 @@ mod e2e { Err("Timeout waiting for service to be ready".into()) } - async fn init_directory(port: u16, local_cert_key: (Vec, Vec)) -> Result<()> { - let docker: Cli = Cli::default(); + async fn init_directory( + db_host: String, + local_cert_key: (Vec, Vec), + ) -> std::result::Result< + (u16, tokio::task::JoinHandle>), + BoxSendSyncError, + > { + println!("Database running on {}", db_host); let timeout = Duration::from_secs(2); - let db = docker.run(Redis); - let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); - println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await + payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key) + .await } // generates or gets a DER encoded localhost cert and key. diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 6d633e22..eb2a3fd9 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -36,6 +36,65 @@ const ID_LENGTH: usize = 13; mod db; use crate::db::DbPool; +type BoxError = Box; + +#[cfg(feature = "_danger-local-https")] +pub async fn listen_tcp_with_tls_on_free_port( + db_host: String, + timeout: Duration, + cert_key: (Vec, Vec), +) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { + let listener = tokio::net::TcpListener::bind("[::]:0").await?; + let port = listener.local_addr()?.port(); + println!("Directory server binding to port {}", listener.local_addr()?); + let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?; + Ok((port, handle)) +} + +// Helper function to avoid code duplication +async fn listen_tcp_with_tls_on_listener( + listener: tokio::net::TcpListener, + db_host: String, + timeout: Duration, + tls_config: (Vec, Vec), +) -> Result>, BoxError> { + let pool = DbPool::new(timeout, db_host).await?; + let ohttp = Arc::new(Mutex::new(init_ohttp()?)); + let tls_acceptor = init_tls_acceptor(tls_config)?; + // Spawn the connection handling loop in a separate task + let handle = tokio::spawn(async move { + while let Ok((stream, _)) = listener.accept().await { + let pool = pool.clone(); + let ohttp = ohttp.clone(); + let tls_acceptor = tls_acceptor.clone(); + tokio::spawn(async move { + let tls_stream = match tls_acceptor.accept(stream).await { + Ok(tls_stream) => tls_stream, + Err(e) => { + error!("TLS accept error: {}", e); + return; + } + }; + if let Err(err) = http1::Builder::new() + .serve_connection( + TokioIo::new(tls_stream), + service_fn(move |req| { + serve_payjoin_directory(req, pool.clone(), ohttp.clone()) + }), + ) + .with_upgrades() + .await + { + error!("Error serving connection: {:?}", err); + } + }); + } + Ok(()) + }); + Ok(handle) +} + +// Modify existing listen_tcp_with_tls to use the new helper pub async fn listen_tcp( port: u16, db_host: String, @@ -73,41 +132,11 @@ pub async fn listen_tcp_with_tls( port: u16, db_host: String, timeout: Duration, - tls_config: (Vec, Vec), -) -> Result<(), Box> { - let pool = DbPool::new(timeout, db_host).await?; - let ohttp = Arc::new(Mutex::new(init_ohttp()?)); - let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port); - let tls_acceptor = init_tls_acceptor(tls_config)?; - let listener = TcpListener::bind(bind_addr).await?; - while let Ok((stream, _)) = listener.accept().await { - let pool = pool.clone(); - let ohttp = ohttp.clone(); - let tls_acceptor = tls_acceptor.clone(); - tokio::spawn(async move { - let tls_stream = match tls_acceptor.accept(stream).await { - Ok(tls_stream) => tls_stream, - Err(e) => { - error!("TLS accept error: {}", e); - return; - } - }; - if let Err(err) = http1::Builder::new() - .serve_connection( - TokioIo::new(tls_stream), - service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone()) - }), - ) - .with_upgrades() - .await - { - error!("Error serving connection: {:?}", err); - } - }); - } - - Ok(()) + cert_key: (Vec, Vec), +) -> Result>, BoxError> { + let addr = format!("0.0.0.0:{}", port); + let listener = tokio::net::TcpListener::bind(&addr).await?; + listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await } #[cfg(feature = "_danger-local-https")] diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 69587f05..ee7c3043 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -187,6 +187,8 @@ mod integration { use super::*; + type BoxSendSyncError = Box; + static TESTS_TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(20)); static WAIT_SERVICE_INTERVAL: Lazy = Lazy::new(|| Duration::from_secs(3)); @@ -197,10 +199,17 @@ mod integration { .expect("Invalid OhttpKeys"); let (cert, key) = local_cert_key(); - let port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); + tokio::select!( - err = init_directory(port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => { assert_eq!( res.unwrap().headers().get("content-type").unwrap(), @@ -234,12 +243,18 @@ mod integration { let ohttp_relay_port = find_free_port(); let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); tokio::select!( err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) ); @@ -303,12 +318,18 @@ mod integration { let ohttp_relay_port = find_free_port(); let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); tokio::select!( err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) ); @@ -432,12 +453,18 @@ mod integration { let ohttp_relay_port = find_free_port(); let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); tokio::select!( err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) ); @@ -644,12 +671,17 @@ mod integration { let ohttp_relay_port = find_free_port(); let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); + let docker: Cli = Cli::default(); + let db = docker.run(Redis); + let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); + let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key)) + .await + .expect("Failed to init directory"); let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); tokio::select!( err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), + err = directory_handle => panic!("Directory server exited early: {:?}", err), res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()), ); @@ -771,15 +803,14 @@ mod integration { } async fn init_directory( - port: u16, + db_host: String, local_cert_key: (Vec, Vec), - ) -> Result<(), BoxError> { - let docker: Cli = Cli::default(); + ) -> Result<(u16, tokio::task::JoinHandle>), BoxSendSyncError> + { + println!("Database running on {}", db_host); let timeout = Duration::from_secs(2); - let db = docker.run(Redis); - let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); - println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await + payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key) + .await } // generates or gets a DER encoded localhost cert and key. @@ -920,7 +951,6 @@ mod integration { while start.elapsed() < *TESTS_TIMEOUT { let request_result = agent.get(health_url.as_str()).send().await.map_err(|_| "Bad request")?; - match request_result.status() { StatusCode::OK => return Ok(()), StatusCode::NOT_FOUND => return Err("Endpoint not found"),