Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test flakiness #388

Merged
merged 3 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo-minimal.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1450,9 +1450,9 @@ dependencies = [

[[package]]
name = "ohttp-relay"
version = "0.0.8"
version = "0.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632"
checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2"
dependencies = [
"futures",
"http",
Expand Down
4 changes: 2 additions & 2 deletions Cargo-recent.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1450,9 +1450,9 @@ dependencies = [

[[package]]
name = "ohttp-relay"
version = "0.0.8"
version = "0.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632"
checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2"
dependencies = [
"futures",
"http",
Expand Down
2 changes: 1 addition & 1 deletion payjoin-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ url = { version = "2.3.1", features = ["serde"] }
[dev-dependencies]
bitcoind = { version = "0.36.0", features = ["0_21_2"] }
http = "1"
ohttp-relay = "0.0.8"
ohttp-relay = { version = "0.0.9", features = ["_test-util"] }
once_cell = "1"
payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] }
testcontainers = "0.15.0"
Expand Down
47 changes: 30 additions & 17 deletions payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ mod e2e {
payjoin_sent.unwrap().unwrap_or(Some(false)).unwrap(),
"Payjoin send was not detected"
);

fn find_free_port() -> u16 {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}
}

#[cfg(feature = "v2")]
Expand All @@ -170,6 +175,7 @@ mod e2e {
use url::Url;

type Error = Box<dyn std::error::Error + 'static>;
type BoxSendSyncError = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;

static INIT_TRACING: OnceCell<()> = OnceCell::new();
Expand All @@ -178,18 +184,26 @@ mod e2e {

init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
let (port, directory_handle) =
init_directory(db_host, (cert.clone(), key)).await.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();

let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
let (ohttp_relay_port, ohttp_relay_handle) =
ohttp_relay::listen_tcp_on_free_port(gateway_origin)
.await
.expect("Failed to init ohttp relay");
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();

let temp_dir = env::temp_dir();
let receiver_db_path = temp_dir.join("receiver_db");
let sender_db_path = temp_dir.join("sender_db");
let result: Result<()> = tokio::select! {
res = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => Err(format!("Ohttp relay is long running: {:?}", res).into()),
res = init_directory(directory_port, (cert.clone(), key)) => Err(format!("Directory server is long running: {:?}", res).into()),
res = ohttp_relay_handle => Err(format!("Ohttp relay is long running: {:?}", res).into()),
res = directory_handle => Err(format!("Directory server is long running: {:?}", res).into()),
res = send_receive_cli_async(ohttp_relay, directory, cert, receiver_db_path.clone(), sender_db_path.clone()) => res.map_err(|e| format!("send_receive failed: {:?}", e).into()),
};

Expand Down Expand Up @@ -476,13 +490,17 @@ mod e2e {
Err("Timeout waiting for service to be ready".into())
}

async fn init_directory(port: u16, local_cert_key: (Vec<u8>, Vec<u8>)) -> Result<()> {
let docker: Cli = Cli::default();
async fn init_directory(
db_host: String,
local_cert_key: (Vec<u8>, Vec<u8>),
) -> std::result::Result<
(u16, tokio::task::JoinHandle<std::result::Result<(), BoxSendSyncError>>),
BoxSendSyncError,
> {
println!("Database running on {}", db_host);
let timeout = Duration::from_secs(2);
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key)
.await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down Expand Up @@ -521,11 +539,6 @@ mod e2e {
}
}

fn find_free_port() -> u16 {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}

async fn cleanup_temp_file(path: &std::path::Path) {
if let Err(e) = fs::remove_dir_all(path).await {
eprintln!("Failed to remove {:?}: {}", path, e);
Expand Down
101 changes: 66 additions & 35 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,67 @@ const ID_LENGTH: usize = 13;
mod db;
use crate::db::DbPool;

#[cfg(feature = "_danger-local-https")]
type BoxError = Box<dyn std::error::Error + Send + Sync>;

#[cfg(feature = "_danger-local-https")]
pub async fn listen_tcp_with_tls_on_free_port(
db_host: String,
timeout: Duration,
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", listener.local_addr()?);
let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?;
Ok((port, handle))
}

// Helper function to avoid code duplication
#[cfg(feature = "_danger-local-https")]
async fn listen_tcp_with_tls_on_listener(
listener: tokio::net::TcpListener,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let tls_acceptor = init_tls_acceptor(tls_config)?;
// Spawn the connection handling loop in a separate task
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});
Ok(handle)
}

// Modify existing listen_tcp_with_tls to use the new helper
pub async fn listen_tcp(
port: u16,
db_host: String,
Expand Down Expand Up @@ -73,41 +134,11 @@ pub async fn listen_tcp_with_tls(
port: u16,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<(), Box<dyn std::error::Error>> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let tls_acceptor = init_tls_acceptor(tls_config)?;
let listener = TcpListener::bind(bind_addr).await?;
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}

Ok(())
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await
}

#[cfg(feature = "_danger-local-https")]
Expand Down
2 changes: 1 addition & 1 deletion payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ serde_json = "1.0.108"
bitcoind = { version = "0.36.0", features = ["0_21_2"] }
http = "1"
payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] }
ohttp-relay = "0.0.8"
ohttp-relay = { version = "0.0.9", features = ["_test-util"] }
nothingmuch marked this conversation as resolved.
Show resolved Hide resolved
once_cell = "1"
rcgen = { version = "0.11" }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
Expand Down
Loading
Loading