diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 86b324a620..00092ffb24 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -48,6 +48,7 @@ use janus_core::{ http::HttpErrorResponse, time::{Clock, DurationExt, IntervalExt, TimeExt}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + Runtime, }; use janus_messages::{ query_type::{FixedSize, TimeInterval}, @@ -229,14 +230,16 @@ impl Default for Config { } impl Aggregator { - async fn new( + async fn new( datastore: Arc>, clock: C, + runtime: R, meter: &Meter, cfg: Config, ) -> Result { let report_writer = Arc::new(ReportWriteBatcher::new( Arc::clone(&datastore), + runtime, cfg.task_counter_shard_count, cfg.max_upload_batch_size, cfg.max_upload_batch_write_delay, @@ -3290,9 +3293,13 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key_with_id, HpkeApplicationInfo, HpkeKeypair, Label, }, - test_util::install_test_trace_subscriber, + test_util::{ + install_test_trace_subscriber, + runtime::{TestRuntime, TestRuntimeManager}, + }, time::{Clock, MockClock, TimeExt}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + Runtime, }; use janus_messages::{ query_type::TimeInterval, Duration, Extension, HpkeCiphertext, HpkeConfig, HpkeConfigId, @@ -3305,7 +3312,6 @@ mod tests { }; use rand::random; use std::{collections::HashSet, iter, sync::Arc, time::Duration as StdDuration}; - use tokio::time::sleep; pub(super) fn create_report_custom( task: &AggregatorTask, @@ -3367,6 +3373,23 @@ mod tests { Arc>, EphemeralDatastore, ) { + setup_upload_test_with_runtime(TestRuntime::default(), cfg).await + } + + async fn setup_upload_test_with_runtime( + runtime: R, + cfg: Config, + ) -> ( + Prio3Count, + Aggregator, + MockClock, + Task, + Arc>, + EphemeralDatastore, + ) + where + R: Runtime + Send + Sync + 'static, + { let clock = MockClock::default(); let vdaf = Prio3Count::new_count(2).unwrap(); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count).build(); @@ -3378,9 +3401,15 @@ mod tests { datastore.put_aggregator_task(&leader_task).await.unwrap(); - let aggregator = Aggregator::new(Arc::clone(&datastore), clock.clone(), &noop_meter(), cfg) - .await - .unwrap(); + let aggregator = Aggregator::new( + Arc::clone(&datastore), + clock.clone(), + runtime, + &noop_meter(), + cfg, + ) + .await + .unwrap(); ( vdaf, @@ -3527,9 +3556,13 @@ mod tests { async fn upload_wrong_hpke_config_id() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let leader_task = task.leader_view().unwrap(); let report = create_report(&leader_task, clock.now()); @@ -3565,8 +3598,10 @@ mod tests { }) }); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -3625,9 +3660,13 @@ mod tests { #[tokio::test] async fn upload_report_in_the_future_past_clock_skew() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let report = create_report( &task.leader_view().unwrap(), clock @@ -3649,8 +3688,10 @@ mod tests { assert_matches!(rejection.reason(), ReportRejectionReason::TooEarly); }); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -3669,9 +3710,13 @@ mod tests { async fn upload_report_for_collected_batch() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let report = create_report(&task.leader_view().unwrap(), clock.now()); // Insert a collection job for the batch interval including our report. @@ -3721,8 +3766,10 @@ mod tests { } ); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -3817,9 +3864,13 @@ mod tests { #[tokio::test] async fn upload_report_task_expired() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, _, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) .with_task_expiration(Some(clock.now())) @@ -3847,8 +3898,10 @@ mod tests { } ); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -3866,9 +3919,13 @@ mod tests { #[tokio::test] async fn upload_report_report_expired() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, _, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) .with_report_expiry_age(Some(Duration::from_seconds(60))) @@ -3897,8 +3954,10 @@ mod tests { } ); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -3916,9 +3975,13 @@ mod tests { #[tokio::test] async fn upload_report_faulty_encryption() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let task = task.leader_view().unwrap(); @@ -3947,8 +4010,10 @@ mod tests { } ); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -3966,9 +4031,13 @@ mod tests { #[tokio::test] async fn upload_report_public_share_decode_failure() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let task = task.leader_view().unwrap(); @@ -3996,8 +4065,10 @@ mod tests { } ); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { @@ -4015,9 +4086,13 @@ mod tests { #[tokio::test] async fn upload_report_leader_input_share_decode_failure() { install_test_trace_subscriber(); - let config = default_aggregator_config(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(config).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let task = task.leader_view().unwrap(); @@ -4059,8 +4134,10 @@ mod tests { } ); - // Wait out the batch write delay so the report status can be flushed to the DB. - sleep(config.max_upload_batch_write_delay * 2).await; + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; let got_counters = datastore .run_unnamed_tx(|tx| { diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 56ddb42fe8..125b7c7849 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -21,7 +21,9 @@ use janus_aggregator_core::{ }; use janus_core::{ auth_tokens::{AuthenticationToken, DAP_AUTH_HEADER}, - test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf, VdafTranscript}, + test_util::{ + dummy_vdaf, install_test_trace_subscriber, run_vdaf, runtime::TestRuntime, VdafTranscript, + }, time::{Clock, MockClock, TimeExt as _}, vdaf::VdafInstance, }; @@ -259,6 +261,7 @@ async fn setup_aggregate_init_test_without_sending_request< let handler = aggregator_handler( Arc::clone(&datastore), clock.clone(), + TestRuntime::default(), &noop_meter(), Config::default(), ) diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 932d1c5b42..72408e8aa3 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -399,7 +399,7 @@ mod tests { test_util::noop_meter, }; use janus_core::{ - test_util::install_test_trace_subscriber, + test_util::{install_test_trace_subscriber, runtime::TestRuntime}, time::{IntervalExt, MockClock}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, }; @@ -530,6 +530,7 @@ mod tests { let handler = aggregator_handler( Arc::clone(&datastore), clock, + TestRuntime::default(), &meter, default_aggregator_config(), ) diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index a406e396c8..ef6c2bc006 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -28,6 +28,7 @@ use janus_core::{ test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, + runtime::TestRuntime, }, time::{Clock, IntervalExt, MockClock}, vdaf::VdafInstance, @@ -143,6 +144,7 @@ pub(crate) async fn setup_collection_job_test_case( let handler = aggregator_handler( Arc::clone(&datastore), clock.clone(), + TestRuntime::default(), &noop_meter(), Config { batch_aggregation_shard_count: 32, diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 7f1d628759..786b86126a 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -8,6 +8,7 @@ use janus_core::{ http::extract_bearer_token, taskprov::TASKPROV_HEADER, time::Clock, + Runtime, }; use janus_messages::{ codec::Decode, problem_type::DapProblemType, query_type::TimeInterval, taskprov::TaskConfig, @@ -243,13 +244,18 @@ pub(crate) static COLLECTION_JOB_ROUTE: &str = "tasks/:task_id/collection_jobs/: pub(crate) static AGGREGATE_SHARES_ROUTE: &str = "tasks/:task_id/aggregate_shares"; /// Constructs a Trillium handler for the aggregator. -pub async fn aggregator_handler( +pub async fn aggregator_handler( datastore: Arc>, clock: C, + runtime: R, meter: &Meter, cfg: Config, -) -> Result { - let aggregator = Arc::new(Aggregator::new(datastore, clock, meter, cfg).await?); +) -> Result +where + C: Clock, + R: Runtime + Send + Sync + 'static, +{ + let aggregator = Arc::new(Aggregator::new(datastore, clock, runtime, meter, cfg).await?); aggregator_handler_with_aggregator(aggregator, meter).await } @@ -735,6 +741,7 @@ mod tests { test_util::{ dummy_vdaf::{self, AggregationParam, OutputShare}, install_test_trace_subscriber, run_vdaf, + runtime::TestRuntime, }, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, @@ -785,6 +792,7 @@ mod tests { let handler = aggregator_handler( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), default_aggregator_config(), ) @@ -880,6 +888,7 @@ mod tests { crate::aggregator::Aggregator::new( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), Config::default(), ) @@ -1028,6 +1037,7 @@ mod tests { crate::aggregator::Aggregator::new( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), cfg, ) @@ -2213,6 +2223,7 @@ mod tests { let handler = aggregator_handler( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), default_aggregator_config(), ) diff --git a/aggregator/src/aggregator/report_writer.rs b/aggregator/src/aggregator/report_writer.rs index c248be23db..02c1962a06 100644 --- a/aggregator/src/aggregator/report_writer.rs +++ b/aggregator/src/aggregator/report_writer.rs @@ -6,7 +6,7 @@ use janus_aggregator_core::datastore::{ models::{LeaderStoredReport, TaskUploadIncrementor}, Datastore, Transaction, }; -use janus_core::time::Clock; +use janus_core::{time::Clock, Runtime}; use janus_messages::TaskId; use prio::vdaf; use rand::{thread_rng, Rng}; @@ -27,22 +27,26 @@ type ResultSender = oneshot::Sender>>; type ReportWriteBatcherSender = mpsc::Sender<(ReportResult, Option)>; type ReportWriteBatcherReceiver = mpsc::Receiver<(ReportResult, Option)>; -pub struct ReportWriteBatcher { +pub struct ReportWriteBatcher { report_tx: ReportWriteBatcherSender, } impl ReportWriteBatcher { - pub fn new( + pub fn new( ds: Arc>, + runtime: R, counter_shard_count: u64, max_batch_size: usize, max_batch_write_delay: Duration, ) -> Self { let (report_tx, report_rx) = mpsc::channel(1); - tokio::spawn(async move { + let runtime = Arc::new(runtime); + let runtime_clone = Arc::clone(&runtime); + runtime.spawn(async move { Self::run_upload_batcher( ds, + runtime_clone, report_rx, counter_shard_count, max_batch_size, @@ -86,9 +90,13 @@ impl ReportWriteBatcher { result_rx.await.unwrap() } - #[tracing::instrument(name = "ReportWriteBatcher::run_upload_batcher", skip(ds, report_rx))] - async fn run_upload_batcher( + #[tracing::instrument( + name = "ReportWriteBatcher::run_upload_batcher", + skip(ds, runtime, report_rx) + )] + async fn run_upload_batcher( ds: Arc>, + runtime: Arc, mut report_rx: ReportWriteBatcherReceiver, counter_shard_count: u64, max_batch_size: usize, @@ -131,7 +139,7 @@ impl ReportWriteBatcher { let ds = Arc::clone(&ds); let report_results = replace(&mut report_results, Vec::with_capacity(max_batch_size)); - tokio::spawn(async move { + runtime.spawn(async move { Self::write_batch(ds, counter_shard_count, report_results).await; }); } diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index f3a8f5b466..ac7a2250f7 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -34,7 +34,7 @@ use janus_core::{ }, report_id::ReportIdChecksumExt, taskprov::TASKPROV_HEADER, - test_util::{install_test_trace_subscriber, VdafTranscript}, + test_util::{install_test_trace_subscriber, runtime::TestRuntime, VdafTranscript}, time::{Clock, DurationExt, MockClock, TimeExt}, vdaf::VERIFY_KEY_LENGTH, }; @@ -118,6 +118,7 @@ impl TaskprovTestCase { let handler = aggregator_handler( Arc::clone(&datastore), clock.clone(), + TestRuntime::default(), &noop_meter(), Config { taskprov_config: TaskprovConfig { enabled: true }, diff --git a/aggregator/src/binaries/aggregator.rs b/aggregator/src/binaries/aggregator.rs index dc93066124..affe335a63 100644 --- a/aggregator/src/binaries/aggregator.rs +++ b/aggregator/src/binaries/aggregator.rs @@ -9,7 +9,7 @@ use clap::Parser; use derivative::Derivative; use janus_aggregator_api::{self, aggregator_api_handler}; use janus_aggregator_core::datastore::Datastore; -use janus_core::{auth_tokens::AuthenticationToken, time::RealClock}; +use janus_core::{auth_tokens::AuthenticationToken, time::RealClock, TokioRuntime}; use opentelemetry::metrics::Meter; use serde::{de, Deserialize, Deserializer, Serialize}; use std::{ @@ -63,6 +63,7 @@ async fn run_aggregator( aggregator_handler( Arc::clone(&datastore), clock, + TokioRuntime, &meter, config.aggregator_config(), ) diff --git a/aggregator/src/metrics/tests.rs b/aggregator/src/metrics/tests.rs index d65cbae6c1..6513c8ca4c 100644 --- a/aggregator/src/metrics/tests.rs +++ b/aggregator/src/metrics/tests.rs @@ -4,7 +4,7 @@ use http::StatusCode; use janus_aggregator_core::datastore::test_util::ephemeral_datastore; use janus_core::{ retries::{retry_http_request, test_http_request_exponential_backoff}, - test_util::install_test_trace_subscriber, + test_util::{install_test_trace_subscriber, runtime::TestRuntime}, time::MockClock, }; use opentelemetry::metrics::MeterProvider as _; @@ -86,6 +86,7 @@ async fn http_metrics() { let handler = aggregator_handler( datastore.clone(), clock.clone(), + TestRuntime::default(), &meter, default_aggregator_config(), ) diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 62cc309477..005a60ede2 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -14,6 +14,7 @@ use janus_aggregator_core::{ use janus_core::{ auth_tokens::{AuthenticationToken, AuthenticationTokenHash}, time::RealClock, + Runtime, TokioRuntime, }; use janus_interop_binaries::{ status::{ERROR, SUCCESS}, @@ -129,8 +130,9 @@ async fn handle_add_task( .context("error adding task to database") } -async fn make_handler( +async fn make_handler( datastore: Arc>, + runtime: R, meter: &Meter, dap_serving_prefix: String, ) -> anyhow::Result { @@ -138,6 +140,7 @@ async fn make_handler( let dap_handler = aggregator_handler( Arc::clone(&datastore), RealClock::default(), + runtime, meter, aggregator::Config { max_upload_batch_size: 100, @@ -247,6 +250,7 @@ async fn main() -> anyhow::Result<()> { // endpoints. let handler = make_handler( Arc::clone(&datastore), + TokioRuntime, &ctx.meter, ctx.config.dap_serving_prefix, )