Skip to content

Commit

Permalink
Upgrade testcontainers to 0.16 (#3072)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Apr 30, 2024
1 parent a899f94 commit 61bc054
Show file tree
Hide file tree
Showing 15 changed files with 406 additions and 310 deletions.
307 changes: 256 additions & 51 deletions Cargo.lock

Large diffs are not rendered by default.

24 changes: 6 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
[workspace]
members = [
"aggregator",
"aggregator_api",
"aggregator_core",
"client",
"collector",
"core",
"integration_tests",
"interop_binaries",
"messages",
"tools",
"xtask",
]
members = ["aggregator", "aggregator_api", "aggregator_core", "client", "collector", "core", "integration_tests", "interop_binaries", "messages", "tools", "xtask"]
resolver = "2"

[workspace.package]
Expand Down Expand Up @@ -45,7 +33,7 @@ janus_core = { version = "0.6", path = "core" }
janus_integration_tests = { version = "0.6", path = "integration_tests" }
janus_interop_binaries = { version = "0.6", path = "interop_binaries" }
janus_messages = { version = "0.6", path = "messages" }
k8s-openapi = { version = "0.20.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube`
k8s-openapi = { version = "0.20.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube`
kube = { version = "0.87.2", default-features = false, features = ["client", "rustls-tls"] }
opentelemetry = { version = "0.22", features = ["metrics"] }
opentelemetry_sdk = { version = "0.22", features = ["metrics"] }
Expand All @@ -57,7 +45,7 @@ serde_yaml = "0.9.34"
rand = "0.8"
reqwest = { version = "0.12.4", default-features = false, features = ["rustls-tls"] }
rstest = "0.18.2"
testcontainers = "0.15.0"
testcontainers = "0.16.5"
thiserror = "1.0"
tokio = { version = "1.37", features = ["full", "tracing"] }
trillium = "0.2.19"
Expand All @@ -82,6 +70,6 @@ debug = 0
# relatively fast compilation. It is intended for use in size-constrained testing scenarios, e.g.
# building a binary artifact that ends up embedded in another binary.
inherits = "dev"
opt-level = "z" # Optimize for size.
debug = false # Do not generate debug info.
strip = true # Strip symbols from binary.
opt-level = "z" # Optimize for size.
debug = false # Do not generate debug info.
strip = true # Strip symbols from binary.
17 changes: 8 additions & 9 deletions aggregator/src/binary_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ mod tests {
use janus_aggregator_core::datastore::test_util::ephemeral_datastore;
use janus_core::test_util::{
install_test_trace_subscriber,
testcontainers::{container_client, Postgres, Volume},
testcontainers::{Postgres, Volume},
};
use opentelemetry::metrics::MeterProvider as _;
use opentelemetry_sdk::{
Expand All @@ -523,7 +523,7 @@ mod tests {
testing::metrics::InMemoryMetricsExporter,
};
use std::{collections::HashMap, fs};
use testcontainers::RunnableImage;
use testcontainers::{core::Mount, runners::AsyncRunner, RunnableImage};
use tokio::task::spawn_blocking;
use tracing_subscriber::{reload, EnvFilter};
use trillium::Status;
Expand Down Expand Up @@ -604,7 +604,6 @@ mod tests {
async fn postgres_tls_connection() {
install_test_trace_subscriber();

let client = container_client();
// We need to be careful about providing the certificate and private key to the Postgres
// container. The key must have '-rw-------' permissions, and both must be readable by the
// postgres user, which has UID 70 inside the container at time of writing. Merely mounting
Expand All @@ -629,16 +628,16 @@ mod tests {
.to_string(),
]),
))
.with_volume((
.with_mount(Mount::bind_mount(
fs::canonicalize("tests/tls_files")
.unwrap()
.into_os_string()
.into_string()
.unwrap(),
"/etc/ssl/postgresql_host",
))
.with_volume((volume.name(), "/etc/ssl/postgresql"));
let setup_container = client.run(setup_image);
.with_mount(Mount::volume_mount(volume.name(), "/etc/ssl/postgresql"));
let setup_container = setup_image.start().await;
drop(setup_container);

let image = RunnableImage::from((
Expand All @@ -652,10 +651,10 @@ mod tests {
"ssl_key_file=/etc/ssl/postgresql/127.0.0.1-key.pem".to_string(),
]),
))
.with_volume((volume.name(), "/etc/ssl/postgresql"));
let db_container = client.run(image);
.with_mount(Mount::volume_mount(volume.name(), "/etc/ssl/postgresql"));
let db_container = image.start().await;
const POSTGRES_DEFAULT_PORT: u16 = 5432;
let port = db_container.get_host_port_ipv4(POSTGRES_DEFAULT_PORT);
let port = db_container.get_host_port_ipv4(POSTGRES_DEFAULT_PORT).await;

let db_config = DbConfig {
url: format!("postgres://[email protected]:{port}/postgres?sslmode=require")
Expand Down
48 changes: 10 additions & 38 deletions aggregator_core/src/datastore/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,19 @@ use sqlx::{
use std::{
path::PathBuf,
str::FromStr,
sync::{Arc, Barrier, Weak},
thread::{self, JoinHandle},
sync::{Arc, Weak},
time::Duration,
};
use testcontainers::RunnableImage;
use tokio::sync::{oneshot, Mutex};
use testcontainers::{runners::AsyncRunner, ContainerAsync, RunnableImage};
use tokio::sync::Mutex;
use tokio_postgres::{connect, Config, NoTls};
use tracing::trace;

use super::SUPPORTED_SCHEMA_VERSIONS;

struct EphemeralDatabase {
_db_container: ContainerAsync<Postgres>,
port_number: u16,
shutdown_barrier: Arc<Barrier>,
join_handle: Option<JoinHandle<()>>,
}

impl EphemeralDatabase {
Expand All @@ -51,30 +49,15 @@ impl EphemeralDatabase {
}

async fn start() -> Self {
let (port_tx, port_rx) = oneshot::channel();
let shutdown_barrier = Arc::new(Barrier::new(2));
let join_handle = thread::spawn({
let shutdown_barrier = Arc::clone(&shutdown_barrier);
move || {
// Start an instance of Postgres running in a container.
let container_client = testcontainers::clients::Cli::default();
let db_container = container_client.run(RunnableImage::from(Postgres::default()));
const POSTGRES_DEFAULT_PORT: u16 = 5432;
let port_number = db_container.get_host_port_ipv4(POSTGRES_DEFAULT_PORT);
trace!("Postgres container is up with port {port_number}");
port_tx.send(port_number).unwrap();

// Wait for the barrier as a shutdown signal.
shutdown_barrier.wait();
trace!("Shutting down Postgres container with port {port_number}");
}
});
let port_number = port_rx.await.unwrap();
// Start an instance of Postgres running in a container.
let db_container = RunnableImage::from(Postgres::default()).start().await;
const POSTGRES_DEFAULT_PORT: u16 = 5432;
let port_number = db_container.get_host_port_ipv4(POSTGRES_DEFAULT_PORT).await;
trace!("Postgres container is up with port {port_number}");

Self {
_db_container: db_container,
port_number,
shutdown_barrier,
join_handle: Some(join_handle),
}
}

Expand All @@ -86,17 +69,6 @@ impl EphemeralDatabase {
}
}

impl Drop for EphemeralDatabase {
fn drop(&mut self) {
// Wait on the shutdown barrier, which will cause the container-management thread to
// begin shutdown. Then wait for the container-management thread itself to terminate.
// This guarantees container shutdown finishes before dropping the EphemeralDatabase
// completes.
self.shutdown_barrier.wait();
self.join_handle.take().unwrap().join().unwrap();
}
}

/// EphemeralDatastore represents an ephemeral datastore instance. It has methods allowing
/// creation of Datastores, as well as the ability to retrieve the underlying connection pool.
///
Expand Down
20 changes: 2 additions & 18 deletions core/src/test_util/testcontainers.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
//! Testing functionality that interacts with the testcontainers library.
use std::{
collections::HashMap,
process::Command,
sync::{Arc, Mutex, Weak},
};
use testcontainers::{clients::Cli, core::WaitFor, Image};

/// Returns a container client, possibly shared with other callers to this function.
pub fn container_client() -> Arc<Cli> {
static CONTAINER_CLIENT_MU: Mutex<Weak<Cli>> = Mutex::new(Weak::new());

let mut container_client = CONTAINER_CLIENT_MU.lock().unwrap();
container_client.upgrade().unwrap_or_else(|| {
let client = Arc::new(Cli::default());
*container_client = Arc::downgrade(&client);
client
})
}
use std::{collections::HashMap, process::Command};
use testcontainers::{core::WaitFor, Image};

/// A [`testcontainers::Image`] that provides a Postgres server.
#[derive(Debug)]
Expand Down
1 change: 1 addition & 0 deletions deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ allow = [
"Unicode-DFS-2016",
"OpenSSL",
"Unlicense",
"CC0-1.0",
]

[[licenses.clarify]]
Expand Down
39 changes: 18 additions & 21 deletions integration_tests/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use prio::{
use rand::random;
use serde_json::{json, Value};
use std::env;
use testcontainers::{clients::Cli, core::WaitFor, Image, RunnableImage};
use testcontainers::{core::WaitFor, runners::AsyncRunner, Image, RunnableImage};
use url::Url;

/// Extension trait to encode measurements for VDAFs as JSON objects, according to
Expand Down Expand Up @@ -148,7 +148,6 @@ pub enum ClientBackend<'a> {
/// Uploads reports by starting a containerized client implementation, and sending it requests
/// using draft-dcook-ppm-dap-interop-test-design.
Container {
container_client: &'a Cli,
container_image: InteropClient,
network: &'a str,
},
Expand All @@ -161,7 +160,7 @@ impl<'a> ClientBackend<'a> {
task_parameters: &TaskParameters,
(leader_port, helper_port): (u16, u16),
vdaf: V,
) -> anyhow::Result<ClientImplementation<'a, V>>
) -> anyhow::Result<ClientImplementation<V>>
where
V: vdaf::Client<16> + InteropClientEncoding,
{
Expand All @@ -174,26 +173,25 @@ impl<'a> ClientBackend<'a> {
.await
.map_err(Into::into),
ClientBackend::Container {
container_client,
container_image,
network,
} => Ok(ClientImplementation::new_container(
test_name,
container_client,
container_image.clone(),
network,
task_parameters,
vdaf,
)),
)
.await),
}
}
}

pub struct ContainerClientImplementation<'d, V>
pub struct ContainerClientImplementation<V>
where
V: vdaf::Client<16>,
{
_container: ContainerLogsDropGuard<'d, InteropClient>,
_container: ContainerLogsDropGuard<InteropClient>,
leader: Url,
helper: Url,
task_id: TaskId,
Expand All @@ -206,23 +204,23 @@ where

/// A DAP client implementation, specialized to work with a particular VDAF. See also
/// [`ClientBackend`].
pub enum ClientImplementation<'d, V>
pub enum ClientImplementation<V>
where
V: vdaf::Client<16>,
{
InProcess { client: Client<V> },
Container(Box<ContainerClientImplementation<'d, V>>),
Container(Box<ContainerClientImplementation<V>>),
}

impl<'d, V> ClientImplementation<'d, V>
impl<V> ClientImplementation<V>
where
V: vdaf::Client<16> + InteropClientEncoding,
{
pub async fn new_in_process(
task_parameters: &TaskParameters,
(leader_port, helper_port): (u16, u16),
vdaf: V,
) -> Result<ClientImplementation<'static, V>, janus_client::Error> {
) -> Result<ClientImplementation<V>, janus_client::Error> {
let (leader_aggregator_endpoint, helper_aggregator_endpoint) = task_parameters
.endpoint_fragments
.endpoints_for_host_client(leader_port, helper_port);
Expand All @@ -237,9 +235,8 @@ where
Ok(ClientImplementation::InProcess { client })
}

pub fn new_container(
pub async fn new_container(
test_name: &str,
container_client: &'d Cli,
container_image: InteropClient,
network: &str,
task_parameters: &TaskParameters,
Expand All @@ -249,14 +246,14 @@ where
let client_container_name = format!("client-{random_part}");
let container = ContainerLogsDropGuard::new_janus(
test_name,
container_client.run(
RunnableImage::from(container_image)
.with_network(network)
.with_env_var(get_rust_log_level())
.with_container_name(client_container_name),
),
RunnableImage::from(container_image)
.with_network(network)
.with_env_var(get_rust_log_level())
.with_container_name(client_container_name)
.start()
.await,
);
let host_port = container.get_host_port_ipv4(8080);
let host_port = container.get_host_port_ipv4(8080).await;
let http_client = reqwest::Client::new();
let (leader_aggregator_endpoint, helper_aggregator_endpoint) = task_parameters
.endpoint_fragments
Expand Down
Loading

0 comments on commit 61bc054

Please sign in to comment.