diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 25898d692..195b1c77e 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -5,7 +5,7 @@ use crate::{ aggregator::{ accumulator::Accumulator, aggregate_share::compute_aggregate_share, - error::{handle_ping_pong_error, ReportRejectedReason}, + error::{handle_ping_pong_error, ReportRejection, ReportRejectionReason}, error::{BatchMismatch, OptOutReason}, query_type::{CollectableQueryType, UploadableQueryType}, report_writer::{ReportWriteBatcher, WritableReport}, @@ -48,6 +48,7 @@ use janus_core::{ http::HttpErrorResponse, time::{Clock, DurationExt, IntervalExt, TimeExt}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + Runtime, }; use janus_messages::{ query_type::{FixedSize, TimeInterval}, @@ -187,7 +188,7 @@ pub struct Aggregator { } /// Config represents a configuration for an Aggregator. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Config { /// Defines the maximum size of a batch of uploaded reports which will be written in a single /// transaction. @@ -203,6 +204,11 @@ pub struct Config { /// the cost of collection. pub batch_aggregation_shard_count: u64, + /// Defines the number of shards to break report counters into. Increasing this value will + /// reduce the amount of database contention during report uploads, while increasing the cost + /// of getting task metrics. + pub task_counter_shard_count: u64, + /// Defines how often to refresh the global HPKE configs cache. This affects how often an aggregator /// becomes aware of key state changes. pub global_hpke_configs_refresh_interval: StdDuration, @@ -216,6 +222,7 @@ impl Default for Config { max_upload_batch_size: 1, max_upload_batch_write_delay: StdDuration::ZERO, batch_aggregation_shard_count: 1, + task_counter_shard_count: 32, global_hpke_configs_refresh_interval: GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, taskprov_config: TaskprovConfig::default(), } @@ -223,14 +230,17 @@ impl Default for Config { } impl Aggregator { - async fn new( + async fn new( datastore: Arc>, clock: C, + runtime: R, meter: &Meter, cfg: Config, ) -> Result { let report_writer = Arc::new(ReportWriteBatcher::new( Arc::clone(&datastore), + runtime, + cfg.task_counter_shard_count, cfg.max_upload_batch_size, cfg.max_upload_batch_write_delay, )); @@ -1400,14 +1410,16 @@ impl VdafOps { C: Clock, Q: UploadableQueryType, { - // Shorthand function for generating an Error::ReportRejected with proper parameters. + // Shorthand function for generating an Error::ReportRejected with proper parameters and + // recording it in the report_writer. let reject_report = |reason| { - Arc::new(Error::ReportRejected( - *task.id(), - *report.metadata().id(), - *report.metadata().time(), - reason, - )) + let report_id = *report.metadata().id(); + let report_time = *report.metadata().time(); + async move { + let rejection = ReportRejection::new(*task.id(), report_id, report_time, reason); + report_writer.write_rejection(rejection).await; + Ok::<_, Arc>(Arc::new(Error::ReportRejected(rejection))) + } }; let report_deadline = clock @@ -1418,18 +1430,14 @@ impl VdafOps { // Reject reports from too far in the future. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#section-4.4.2-21 if report.metadata().time().is_after(&report_deadline) { - return Err(Arc::new(Error::ReportTooEarly( - *task.id(), - *report.metadata().id(), - *report.metadata().time(), - ))); + return Err(reject_report(ReportRejectionReason::TooEarly).await?); } // Reject reports after a task has expired. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#section-4.4.2-20 if let Some(task_expiration) = task.task_expiration() { if report.metadata().time().is_after(task_expiration) { - return Err(reject_report(ReportRejectedReason::TaskExpired)); + return Err(reject_report(ReportRejectionReason::TaskExpired).await?); } } @@ -1441,7 +1449,7 @@ impl VdafOps { .add(report_expiry_age) .map_err(|err| Arc::new(Error::from(err)))?; if clock.now().is_after(&report_expiry_time) { - return Err(reject_report(ReportRejectedReason::TooOld)); + return Err(reject_report(ReportRejectionReason::Expired).await?); } } @@ -1461,9 +1469,7 @@ impl VdafOps { "public share decoding failed", ); upload_decode_failure_counter.add(1, &[]); - return Err(reject_report( - ReportRejectedReason::PublicShareDecodeFailure, - )); + return Err(reject_report(ReportRejectionReason::DecodeFailure).await?); } }; @@ -1492,10 +1498,10 @@ impl VdafOps { // Verify that the report's HPKE config ID is known. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#section-4.4.2-17 (None, None) => { - return Err(Arc::new(Error::OutdatedHpkeConfig( - *task.id(), + return Err(reject_report(ReportRejectionReason::OutdatedHpkeConfig( *report.leader_encrypted_input_share().config_id(), - ))); + )) + .await?); } (None, Some(global_hpke_keypair)) => try_hpke_open(&global_hpke_keypair), (Some(task_hpke_keypair), None) => try_hpke_open(task_hpke_keypair), @@ -1519,7 +1525,7 @@ impl VdafOps { "Report decryption failed", ); upload_decrypt_failure_counter.add(1, &[]); - return Err(reject_report(ReportRejectedReason::LeaderDecryptFailure)); + return Err(reject_report(ReportRejectionReason::DecryptFailure).await?); } }; @@ -1546,9 +1552,7 @@ impl VdafOps { "Leader input share decoding failed", ); upload_decode_failure_counter.add(1, &[]); - return Err(reject_report( - ReportRejectedReason::LeaderInputShareDecodeFailure, - )); + return Err(reject_report(ReportRejectionReason::DecodeFailure).await?); } }; @@ -1562,7 +1566,9 @@ impl VdafOps { ); report_writer - .write_report(WritableReport::::new(vdaf, report)) + .write_report(Box::new(WritableReport::::new( + vdaf, report, + ))) .await } } @@ -3273,14 +3279,14 @@ pub(crate) mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - error::ReportRejectedReason, test_util::default_aggregator_config, Aggregator, Config, + error::ReportRejectionReason, test_util::default_aggregator_config, Aggregator, Config, Error, }; use assert_matches::assert_matches; use futures::future::try_join_all; use janus_aggregator_core::{ datastore::{ - models::{CollectionJob, CollectionJobState}, + models::{CollectionJob, CollectionJobState, TaskUploadCounter}, test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, }, @@ -3295,9 +3301,13 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key_with_id, HpkeApplicationInfo, HpkeKeypair, Label, }, - test_util::install_test_trace_subscriber, + test_util::{ + install_test_trace_subscriber, + runtime::{TestRuntime, TestRuntimeManager}, + }, time::{Clock, MockClock, TimeExt}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + Runtime, }; use janus_messages::{ query_type::TimeInterval, Duration, Extension, HpkeCiphertext, HpkeConfig, HpkeConfigId, @@ -3367,6 +3377,23 @@ mod tests { Arc>, EphemeralDatastore, ) { + setup_upload_test_with_runtime(TestRuntime::default(), cfg).await + } + + async fn setup_upload_test_with_runtime( + runtime: R, + cfg: Config, + ) -> ( + Prio3Count, + Aggregator, + MockClock, + Task, + Arc>, + EphemeralDatastore, + ) + where + R: Runtime + Send + Sync + 'static, + { let clock = MockClock::default(); let vdaf = Prio3Count::new_count(2).unwrap(); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count).build(); @@ -3378,9 +3405,15 @@ mod tests { datastore.put_aggregator_task(&leader_task).await.unwrap(); - let aggregator = Aggregator::new(Arc::clone(&datastore), clock.clone(), &noop_meter(), cfg) - .await - .unwrap(); + let aggregator = Aggregator::new( + Arc::clone(&datastore), + clock.clone(), + runtime, + &noop_meter(), + cfg, + ) + .await + .unwrap(); ( vdaf, @@ -3442,17 +3475,29 @@ mod tests { .unwrap(); // Verify that the original report, rather than the modified report, is stored. - let got_report = ds + let (got_report, got_counter) = ds .run_unnamed_tx(|tx| { let vdaf = vdaf.clone(); let task_id = *task.id(); let report_id = *report.metadata().id(); - Box::pin(async move { tx.get_client_report(&vdaf, &task_id, &report_id).await }) + Box::pin(async move { + Ok(( + tx.get_client_report(&vdaf, &task_id, &report_id) + .await + .unwrap(), + tx.get_task_upload_counter(&task_id).await.unwrap(), + )) + }) }) .await - .unwrap() .unwrap(); - assert!(got_report.eq_report(&vdaf, leader_task.current_hpke_key(), &report)); + assert!(got_report + .unwrap() + .eq_report(&vdaf, leader_task.current_hpke_key(), &report)); + assert_eq!( + got_counter, + Some(TaskUploadCounter::new(0, 0, 0, 0, 0, 1, 0, 0)) + ) } #[tokio::test] @@ -3497,14 +3542,31 @@ mod tests { .collect(); assert_eq!(want_report_ids, got_report_ids); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 0, 0, 0, 100, 0, 0)) + ); } #[tokio::test] async fn upload_wrong_hpke_config_id() { install_test_trace_subscriber(); - let (_, aggregator, clock, task, _, _ephemeral_datastore) = - setup_upload_test(default_aggregator_config()).await; + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let leader_task = task.leader_view().unwrap(); let report = create_report(&leader_task, clock.now()); @@ -3527,10 +3589,35 @@ mod tests { report.helper_encrypted_input_share().clone(), ); - assert_matches!(aggregator.handle_upload(task.id(), &report.get_encoded()).await.unwrap_err().as_ref(), Error::OutdatedHpkeConfig(task_id, config_id) => { - assert_eq!(task.id(), task_id); - assert_eq!(config_id, &unused_hpke_config_id); + let result = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!(result.as_ref(), Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::OutdatedHpkeConfig(id) => { + assert_eq!(id, &unused_hpke_config_id); + }) }); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 0, 0, 1, 0, 0, 0)) + ) } #[tokio::test] @@ -3560,14 +3647,30 @@ mod tests { .unwrap(); assert_eq!(task.id(), got_report.task_id()); assert_eq!(report.metadata(), got_report.metadata()); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 0, 0, 0, 1, 0, 0)) + ) } #[tokio::test] async fn upload_report_in_the_future_past_clock_skew() { install_test_trace_subscriber(); - - let (_, aggregator, clock, task, _, _ephemeral_datastore) = - setup_upload_test(default_aggregator_config()).await; + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let report = create_report( &task.leader_view().unwrap(), clock @@ -3582,20 +3685,42 @@ mod tests { .handle_upload(task.id(), &report.get_encoded()) .await .unwrap_err(); - - assert_matches!(upload_error.as_ref(), Error::ReportTooEarly(task_id, report_id, time) => { - assert_eq!(task.id(), task_id); - assert_eq!(report.metadata().id(), report_id); - assert_eq!(report.metadata().time(), time); + assert_matches!(upload_error.as_ref(), Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::TooEarly); }); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 0, 0, 0, 0, 1, 0)) + ) } #[tokio::test] async fn upload_report_for_collected_batch() { install_test_trace_subscriber(); + let mut runtime_manager = TestRuntimeManager::new(); let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = - setup_upload_test(default_aggregator_config()).await; + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; let report = create_report(&task.leader_view().unwrap(), clock.now()); // Insert a collection job for the batch interval including our report. @@ -3637,12 +3762,30 @@ mod tests { .unwrap_err(); assert_matches!( error.as_ref(), - Error::ReportRejected(err_task_id, err_report_id, err_time, ReportRejectedReason::IntervalAlreadyCollected) => { - assert_eq!(task.id(), err_task_id); - assert_eq!(report.metadata().id(), err_report_id); - assert_eq!(report.metadata().time(), err_time); + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::IntervalCollected); } ); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(1, 0, 0, 0, 0, 0, 0, 0)) + ) } #[tokio::test] @@ -3722,6 +3865,294 @@ mod tests { } } + #[tokio::test] + async fn upload_report_task_expired() { + install_test_trace_subscriber(); + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, _, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; + + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_task_expiration(Some(clock.now())) + .build() + .leader_view() + .unwrap(); + datastore.put_aggregator_task(&task).await.unwrap(); + + // Advance the clock to expire the task. + clock.advance(&Duration::from_seconds(1)); + let report = create_report(&task, clock.now()); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::TaskExpired); + } + ); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 0, 0, 0, 0, 0, 1)) + ) + } + + #[tokio::test] + async fn upload_report_report_expired() { + install_test_trace_subscriber(); + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, _, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; + + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_report_expiry_age(Some(Duration::from_seconds(60))) + .build() + .leader_view() + .unwrap(); + datastore.put_aggregator_task(&task).await.unwrap(); + + let report = create_report(&task, clock.now()); + + // Advance the clock to expire the report. + clock.advance(&Duration::from_seconds(61)); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::Expired); + } + ); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 0, 1, 0, 0, 0, 0)) + ) + } + + #[tokio::test] + async fn upload_report_faulty_encryption() { + install_test_trace_subscriber(); + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; + + let task = task.leader_view().unwrap(); + + // Encrypt with the wrong key. + let report = create_report_custom( + &task, + clock.now(), + random(), + &generate_test_hpke_config_and_private_key_with_id( + (*task.current_hpke_key().config().id()).into(), + ), + ); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::DecryptFailure); + } + ); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 0, 1, 0, 0, 0, 0, 0)) + ) + } + + #[tokio::test] + async fn upload_report_public_share_decode_failure() { + install_test_trace_subscriber(); + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; + + let task = task.leader_view().unwrap(); + + let mut report = create_report(&task, clock.now()); + report = Report::new( + report.metadata().clone(), + // Some obviously wrong public share. + Vec::from([0; 10]), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), + ); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::DecodeFailure); + } + ); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 1, 0, 0, 0, 0, 0, 0)) + ) + } + + #[tokio::test] + async fn upload_report_leader_input_share_decode_failure() { + install_test_trace_subscriber(); + let mut runtime_manager = TestRuntimeManager::new(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test_with_runtime( + runtime_manager.with_label("aggregator"), + default_aggregator_config(), + ) + .await; + + let task = task.leader_view().unwrap(); + + let mut report = create_report(&task, clock.now()); + report = Report::new( + report.metadata().clone(), + report.public_share().to_vec(), + hpke::seal( + task.current_hpke_key().config(), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader), + // Some obviously wrong payload. + &PlaintextInputShare::new(Vec::new(), vec![0; 100]).get_encoded(), + &InputShareAad::new( + *task.id(), + report.metadata().clone(), + report.public_share().to_vec(), + ) + .get_encoded(), + ) + .unwrap(), + report.helper_encrypted_input_share().clone(), + ); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::DecodeFailure); + } + ); + + // Wait for the report writer to have completed one write task. + runtime_manager + .wait_for_completed_tasks("aggregator", 1) + .await; + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + Some(TaskUploadCounter::new(0, 1, 0, 0, 0, 0, 0, 0)) + ) + } + pub(crate) fn generate_helper_report_share>( task_id: TaskId, report_metadata: ReportMetadata, diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 69f89f79f..6723b68dd 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -21,7 +21,9 @@ use janus_aggregator_core::{ }; use janus_core::{ auth_tokens::{AuthenticationToken, DAP_AUTH_HEADER}, - test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf, VdafTranscript}, + test_util::{ + dummy_vdaf, install_test_trace_subscriber, run_vdaf, runtime::TestRuntime, VdafTranscript, + }, time::{Clock, MockClock, TimeExt as _}, vdaf::VdafInstance, }; @@ -259,6 +261,7 @@ async fn setup_aggregate_init_test_without_sending_request< let handler = aggregator_handler( Arc::clone(&datastore), clock.clone(), + TestRuntime::default(), &noop_meter(), Config::default(), ) diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index f9fdcd528..88bed44e3 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -399,7 +399,7 @@ mod tests { test_util::noop_meter, }; use janus_core::{ - test_util::install_test_trace_subscriber, + test_util::{install_test_trace_subscriber, runtime::TestRuntime}, time::{IntervalExt, MockClock}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, }; @@ -530,6 +530,7 @@ mod tests { let handler = aggregator_handler( Arc::clone(&datastore), clock, + TestRuntime::default(), &meter, default_aggregator_config(), ) diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index befc165a0..96ac667f3 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -28,6 +28,7 @@ use janus_core::{ test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, + runtime::TestRuntime, }, time::{Clock, IntervalExt, MockClock}, vdaf::VdafInstance, @@ -143,6 +144,7 @@ pub(crate) async fn setup_collection_job_test_case( let handler = aggregator_handler( Arc::clone(&datastore), clock.clone(), + TestRuntime::default(), &noop_meter(), Config { batch_aggregation_shard_count: 32, diff --git a/aggregator/src/aggregator/error.rs b/aggregator/src/aggregator/error.rs index 1feb4d494..855e78171 100644 --- a/aggregator/src/aggregator/error.rs +++ b/aggregator/src/aggregator/error.rs @@ -1,4 +1,7 @@ -use janus_aggregator_core::{datastore, task}; +use janus_aggregator_core::{ + datastore::{self, models::TaskUploadIncrementor}, + task, +}; use janus_core::http::HttpErrorResponse; use janus_messages::{ AggregationJobId, AggregationJobStep, CollectionJobId, HpkeConfigId, Interval, PrepareError, @@ -29,14 +32,9 @@ pub enum Error { /// Error handling a message. #[error("invalid message: {0}")] Message(#[from] janus_messages::Error), - /// Corresponds to `reportRejected` in DAP. A report was rejected for some reason that is not - /// specified in DAP. - #[error("task {0}: report {1} rejected: {2}")] - ReportRejected(TaskId, ReportId, Time, ReportRejectedReason), - /// Corresponds to `reportTooEarly` in DAP. A report was rejected because the timestamp is too - /// far in the future. - #[error("task {0}: report {1} too early: {2}")] - ReportTooEarly(TaskId, ReportId, Time), + /// Catch-all error for invalid reports. + #[error("{0}")] + ReportRejected(ReportRejection), /// Corresponds to `invalidMessage` in DAP. #[error("task {0:?}: invalid message: {1}")] InvalidMessage(Option, &'static str), @@ -69,9 +67,6 @@ pub enum Error { /// An attempt was made to act on a collection job that has been abandoned by the aggregator. #[error("abandoned collection job: {0}")] AbandonedCollectionJob(CollectionJobId), - /// Corresponds to `outdatedHpkeConfig` in DAP. - #[error("task {0}: outdated HPKE config: {1}")] - OutdatedHpkeConfig(TaskId, HpkeConfigId), /// Corresponds to `unauthorizedRequest` in DAP. #[error("task {0}: unauthorized request")] UnauthorizedRequest(TaskId), @@ -145,37 +140,109 @@ pub enum Error { DifferentialPrivacy(VdafError), } -#[derive(Debug)] -pub enum ReportRejectedReason { - IntervalAlreadyCollected, - LeaderDecryptFailure, - LeaderInputShareDecodeFailure, - PublicShareDecodeFailure, +/// Contains details that describe the report and why it was rejected. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ReportRejection { + task_id: TaskId, + report_id: ReportId, + time: Time, + reason: ReportRejectionReason, +} + +impl ReportRejection { + pub fn new( + task_id: TaskId, + report_id: ReportId, + time: Time, + reason: ReportRejectionReason, + ) -> Self { + Self { + task_id, + report_id, + time, + reason, + } + } + + pub fn task_id(&self) -> &TaskId { + &self.task_id + } + + pub fn report_id(&self) -> &ReportId { + &self.report_id + } + + pub fn time(&self) -> &Time { + &self.time + } + + pub fn reason(&self) -> &ReportRejectionReason { + &self.reason + } +} + +impl Display for ReportRejection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "task {}, report {}, time {}, rejected {}", + self.task_id, self.report_id, self.time, self.reason + ) + } +} + +/// Indicates why a report was rejected. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReportRejectionReason { + IntervalCollected, + DecryptFailure, + DecodeFailure, TaskExpired, - TooOld, + Expired, + TooEarly, + OutdatedHpkeConfig(HpkeConfigId), } -impl ReportRejectedReason { +impl ReportRejectionReason { pub fn detail(&self) -> &'static str { match self { - ReportRejectedReason::IntervalAlreadyCollected => { + ReportRejectionReason::IntervalCollected => { "Report falls into a time interval that has already been collected." } - ReportRejectedReason::LeaderDecryptFailure => { - "Leader's report share could not be decrypted." - } - ReportRejectedReason::LeaderInputShareDecodeFailure => { - "Leader's input share could not be decoded." + ReportRejectionReason::DecryptFailure => "Report share could not be decrypted.", + ReportRejectionReason::DecodeFailure => "Report could not be decoded.", + ReportRejectionReason::TaskExpired => "Task has expired.", + ReportRejectionReason::Expired => "Report timestamp is too old.", + ReportRejectionReason::TooEarly => "Report timestamp is too far in the future.", + ReportRejectionReason::OutdatedHpkeConfig(_) => { + "Report is using an outdated HPKE configuration." } - ReportRejectedReason::PublicShareDecodeFailure => { - "Report public share could not be decoded." + } + } +} + +impl From<&ReportRejectionReason> for TaskUploadIncrementor { + fn from(value: &ReportRejectionReason) -> Self { + match value { + ReportRejectionReason::IntervalCollected => TaskUploadIncrementor::IntervalCollected, + ReportRejectionReason::DecryptFailure => TaskUploadIncrementor::ReportDecryptFailure, + ReportRejectionReason::DecodeFailure => TaskUploadIncrementor::ReportDecodeFailure, + ReportRejectionReason::TaskExpired => TaskUploadIncrementor::TaskExpired, + ReportRejectionReason::Expired => TaskUploadIncrementor::ReportExpired, + ReportRejectionReason::TooEarly => TaskUploadIncrementor::ReportTooEarly, + ReportRejectionReason::OutdatedHpkeConfig(_) => { + TaskUploadIncrementor::ReportOutdatedKey } - ReportRejectedReason::TaskExpired => "Task has expired.", - ReportRejectedReason::TooOld => "Report timestamp is too old.", } } } +impl Display for ReportRejectionReason { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + /// Errors that cause the aggregator to opt-out of a taskprov task. #[derive(Debug, thiserror::Error)] pub enum OptOutReason { @@ -199,8 +266,11 @@ impl Error { Error::InvalidConfiguration(_) => "invalid_configuration", Error::MessageDecode(_) => "message_decode", Error::Message(_) => "message", - Error::ReportRejected(_, _, _, _) => "report_rejected", - Error::ReportTooEarly(_, _, _) => "report_too_early", + Error::ReportRejected(rejection) => match rejection.reason { + ReportRejectionReason::TooEarly => "report_too_early", + ReportRejectionReason::OutdatedHpkeConfig(_) => "outdated_hpke_config", + _ => "report_rejected", + }, Error::InvalidMessage(_, _) => "unrecognized_message", Error::StepMismatch { .. } => "step_mismatch", Error::UnrecognizedTask(_) => "unrecognized_task", @@ -209,7 +279,6 @@ impl Error { Error::DeletedCollectionJob(_) => "deleted_collection_job", Error::AbandonedCollectionJob(_) => "abandoned_collection_job", Error::UnrecognizedCollectionJob(_) => "unrecognized_collection_job", - Error::OutdatedHpkeConfig(_, _) => "outdated_hpke_config", Error::UnauthorizedRequest(_) => "unauthorized_request", Error::Datastore(_) => "datastore", Error::Vdaf(_) => "vdaf", diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 38ce77886..d0769b1d7 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -1,4 +1,4 @@ -use super::{Aggregator, Config, Error}; +use super::{error::ReportRejectionReason, Aggregator, Config, Error}; use crate::aggregator::problem_details::{ProblemDetailsConnExt, ProblemDocument}; use async_trait::async_trait; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; @@ -8,6 +8,7 @@ use janus_core::{ http::extract_bearer_token, taskprov::TASKPROV_HEADER, time::Clock, + Runtime, }; use janus_messages::{ codec::Decode, problem_type::DapProblemType, query_type::TimeInterval, taskprov::TaskConfig, @@ -45,11 +46,21 @@ impl Handler for Error { Error::InvalidConfiguration(_) => conn.with_status(Status::InternalServerError), Error::MessageDecode(_) => conn .with_problem_document(&ProblemDocument::new_dap(DapProblemType::InvalidMessage)), - Error::ReportRejected(task_id, _, _, reason) => conn.with_problem_document( - &ProblemDocument::new_dap(DapProblemType::ReportRejected) - .with_task_id(task_id) - .with_detail(reason.detail()), - ), + Error::ReportRejected(rejection) => match rejection.reason() { + ReportRejectionReason::OutdatedHpkeConfig(_) => conn.with_problem_document( + &ProblemDocument::new_dap(DapProblemType::OutdatedConfig) + .with_task_id(rejection.task_id()), + ), + ReportRejectionReason::TooEarly => conn.with_problem_document( + &ProblemDocument::new_dap(DapProblemType::ReportTooEarly) + .with_task_id(rejection.task_id()), + ), + _ => conn.with_problem_document( + &ProblemDocument::new_dap(DapProblemType::ReportRejected) + .with_task_id(rejection.task_id()) + .with_detail(rejection.reason().detail()), + ), + }, Error::InvalidMessage(task_id, _) => { let mut doc = ProblemDocument::new_dap(DapProblemType::InvalidMessage); if let Some(task_id) = task_id { @@ -85,12 +96,6 @@ impl Handler for Error { .with_collection_job_id(collection_job_id), ), Error::UnrecognizedCollectionJob(_) => conn.with_status(Status::NotFound), - Error::OutdatedHpkeConfig(task_id, _) => conn.with_problem_document( - &ProblemDocument::new_dap(DapProblemType::OutdatedConfig).with_task_id(task_id), - ), - Error::ReportTooEarly(task_id, _, _) => conn.with_problem_document( - &ProblemDocument::new_dap(DapProblemType::ReportTooEarly).with_task_id(task_id), - ), Error::UnauthorizedRequest(task_id) => conn.with_problem_document( &ProblemDocument::new_dap(DapProblemType::UnauthorizedRequest) .with_task_id(task_id), @@ -239,13 +244,18 @@ pub(crate) static COLLECTION_JOB_ROUTE: &str = "tasks/:task_id/collection_jobs/: pub(crate) static AGGREGATE_SHARES_ROUTE: &str = "tasks/:task_id/aggregate_shares"; /// Constructs a Trillium handler for the aggregator. -pub async fn aggregator_handler( +pub async fn aggregator_handler( datastore: Arc>, clock: C, + runtime: R, meter: &Meter, cfg: Config, -) -> Result { - let aggregator = Arc::new(Aggregator::new(datastore, clock, meter, cfg).await?); +) -> Result +where + C: Clock, + R: Runtime + Send + Sync + 'static, +{ + let aggregator = Arc::new(Aggregator::new(datastore, clock, runtime, meter, cfg).await?); aggregator_handler_with_aggregator(aggregator, meter).await } @@ -687,7 +697,7 @@ mod tests { }, collection_job_tests::setup_collection_job_test_case, empty_batch_aggregations, - error::ReportRejectedReason, + error::ReportRejectionReason, http_handlers::{ aggregator_handler, aggregator_handler_with_aggregator, test_util::{decode_response_body, take_problem_details}, @@ -731,6 +741,7 @@ mod tests { test_util::{ dummy_vdaf::{self, AggregationParam, OutputShare}, install_test_trace_subscriber, run_vdaf, + runtime::TestRuntime, }, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, @@ -781,6 +792,7 @@ mod tests { let handler = aggregator_handler( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), default_aggregator_config(), ) @@ -876,6 +888,7 @@ mod tests { crate::aggregator::Aggregator::new( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), Config::default(), ) @@ -1024,6 +1037,7 @@ mod tests { crate::aggregator::Aggregator::new( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), cfg, ) @@ -1201,7 +1215,7 @@ mod tests { "reportRejected", "Report could not be processed.", task.id(), - Some(ReportRejectedReason::TooOld.detail()), + Some(ReportRejectionReason::Expired.detail()), ) .await; @@ -1291,7 +1305,7 @@ mod tests { "reportRejected", "Report could not be processed.", task_expire_soon.id(), - Some(ReportRejectedReason::TaskExpired.detail()), + Some(ReportRejectionReason::TaskExpired.detail()), ) .await; @@ -1319,7 +1333,7 @@ mod tests { "reportRejected", "Report could not be processed.", leader_task.id(), - Some(ReportRejectedReason::PublicShareDecodeFailure.detail()), + Some(ReportRejectionReason::DecodeFailure.detail()), ) .await; @@ -1344,7 +1358,7 @@ mod tests { "reportRejected", "Report could not be processed.", leader_task.id(), - Some(ReportRejectedReason::LeaderDecryptFailure.detail()), + Some(ReportRejectionReason::DecryptFailure.detail()), ) .await; @@ -1381,7 +1395,7 @@ mod tests { "reportRejected", "Report could not be processed.", leader_task.id(), - Some(ReportRejectedReason::LeaderInputShareDecodeFailure.detail()), + Some(ReportRejectionReason::DecodeFailure.detail()), ) .await; @@ -2203,6 +2217,7 @@ mod tests { let handler = aggregator_handler( datastore.clone(), clock.clone(), + TestRuntime::default(), &noop_meter(), default_aggregator_config(), ) diff --git a/aggregator/src/aggregator/problem_details.rs b/aggregator/src/aggregator/problem_details.rs index f2cccd775..230a25f4d 100644 --- a/aggregator/src/aggregator/problem_details.rs +++ b/aggregator/src/aggregator/problem_details.rs @@ -109,7 +109,7 @@ impl ProblemDetailsConnExt for Conn { #[cfg(test)] mod tests { use crate::aggregator::{ - error::{BatchMismatch, ReportRejectedReason}, + error::{BatchMismatch, ReportRejection, ReportRejectionReason}, send_request_to_helper, Error, }; use assert_matches::assert_matches; @@ -119,7 +119,7 @@ mod tests { use janus_core::time::{Clock, RealClock}; use janus_messages::{ problem_type::{DapProblemType, DapProblemTypeParseError}, - Duration, HpkeConfigId, Interval, ReportIdChecksum, + Duration, Interval, ReportIdChecksum, }; use opentelemetry::metrics::Unit; use rand::random; @@ -179,15 +179,37 @@ mod tests { TestCase::new(Box::new(|| Error::InvalidConfiguration("test")), None), TestCase::new( Box::new(|| { - Error::ReportRejected( + Error::ReportRejected(ReportRejection::new( random(), random(), RealClock::default().now(), - ReportRejectedReason::TaskExpired - ) + ReportRejectionReason::TaskExpired + )) }), Some(DapProblemType::ReportRejected), ), + TestCase::new( + Box::new(|| { + Error::ReportRejected(ReportRejection::new( + random(), + random(), + RealClock::default().now(), + ReportRejectionReason::TooEarly + )) + }), + Some(DapProblemType::ReportTooEarly), + ), + TestCase::new( + Box::new(|| { + Error::ReportRejected(ReportRejection::new( + random(), + random(), + RealClock::default().now(), + ReportRejectionReason::OutdatedHpkeConfig(random()), + )) + }), + Some(DapProblemType::OutdatedConfig), + ), TestCase::new( Box::new(|| Error::InvalidMessage(Some(random()), "test")), Some(DapProblemType::InvalidMessage), @@ -204,16 +226,6 @@ mod tests { Box::new(|| Error::UnrecognizedAggregationJob(random(), random())), Some(DapProblemType::UnrecognizedAggregationJob), ), - TestCase::new( - Box::new(|| Error::OutdatedHpkeConfig(random(), HpkeConfigId::from(0))), - Some(DapProblemType::OutdatedConfig), - ), - TestCase::new( - Box::new(|| { - Error::ReportTooEarly(random(), random(), RealClock::default().now()) - }), - Some(DapProblemType::ReportTooEarly), - ), TestCase::new( Box::new(|| Error::UnauthorizedRequest(random())), Some(DapProblemType::UnauthorizedRequest), diff --git a/aggregator/src/aggregator/query_type.rs b/aggregator/src/aggregator/query_type.rs index 2177021f0..365c7e0af 100644 --- a/aggregator/src/aggregator/query_type.rs +++ b/aggregator/src/aggregator/query_type.rs @@ -1,4 +1,7 @@ -use super::{error::ReportRejectedReason, Error}; +use super::{ + error::{ReportRejection, ReportRejectionReason}, + Error, +}; use async_trait::async_trait; use janus_aggregator_core::{ datastore::{self, models::LeaderStoredReport, Transaction}, @@ -23,7 +26,7 @@ pub trait UploadableQueryType: QueryType { tx: &Transaction<'_, C>, vdaf: &A, report: &LeaderStoredReport, - ) -> Result<(), datastore::Error> + ) -> Result<(), Error> where A::InputShare: Send + Sync, A::PublicShare: Send + Sync; @@ -39,7 +42,7 @@ impl UploadableQueryType for TimeInterval { tx: &Transaction<'_, C>, vdaf: &A, report: &LeaderStoredReport, - ) -> Result<(), datastore::Error> + ) -> Result<(), Error> where A::InputShare: Send + Sync, A::PublicShare: Send + Sync, @@ -55,15 +58,12 @@ impl UploadableQueryType for TimeInterval { ) .await?; if !conflicting_collect_jobs.is_empty() { - return Err(datastore::Error::User( - Error::ReportRejected( - *report.task_id(), - *report.metadata().id(), - *report.metadata().time(), - ReportRejectedReason::IntervalAlreadyCollected, - ) - .into(), - )); + return Err(Error::ReportRejected(ReportRejection::new( + *report.task_id(), + *report.metadata().id(), + *report.metadata().time(), + ReportRejectionReason::IntervalCollected, + ))); } Ok(()) } @@ -79,7 +79,7 @@ impl UploadableQueryType for FixedSize { _: &Transaction<'_, C>, _: &A, _: &LeaderStoredReport, - ) -> Result<(), datastore::Error> { + ) -> Result<(), Error> { // Fixed-size tasks associate reports to batches at time of aggregation rather than at time // of upload, and there are no other relevant checks to apply here, so this method simply // returns Ok(()). diff --git a/aggregator/src/aggregator/report_writer.rs b/aggregator/src/aggregator/report_writer.rs index de1bae6c2..f79064950 100644 --- a/aggregator/src/aggregator/report_writer.rs +++ b/aggregator/src/aggregator/report_writer.rs @@ -1,9 +1,15 @@ use crate::aggregator::{query_type::UploadableQueryType, Error}; use async_trait::async_trait; use futures::future::join_all; -use janus_aggregator_core::datastore::{self, models::LeaderStoredReport, Datastore, Transaction}; -use janus_core::time::Clock; +use janus_aggregator_core::datastore::{ + self, + models::{LeaderStoredReport, TaskUploadIncrementor}, + Datastore, Transaction, +}; +use janus_core::{time::Clock, Runtime}; +use janus_messages::TaskId; use prio::vdaf; +use rand::{thread_rng, Rng}; use std::{fmt::Debug, marker::PhantomData, mem::replace, sync::Arc, time::Duration}; use tokio::{ select, @@ -12,62 +18,93 @@ use tokio::{ }; use tracing::debug; -type ReportWriteBatcherSender = mpsc::Sender<( - Box>, - oneshot::Sender>>, -)>; -type ReportWriteBatcherReceiver = mpsc::Receiver<( - Box>, - oneshot::Sender>>, -)>; +use super::error::ReportRejection; + +type ReportResult = Result>, ReportRejection>; + +type ResultSender = oneshot::Sender>>; + +type ReportWriteBatcherSender = mpsc::Sender<(ReportResult, Option)>; +type ReportWriteBatcherReceiver = mpsc::Receiver<(ReportResult, Option)>; -pub struct ReportWriteBatcher { +pub struct ReportWriteBatcher { report_tx: ReportWriteBatcherSender, } impl ReportWriteBatcher { - pub fn new( + pub fn new( ds: Arc>, + runtime: R, + counter_shard_count: u64, max_batch_size: usize, max_batch_write_delay: Duration, ) -> Self { let (report_tx, report_rx) = mpsc::channel(1); - tokio::spawn(async move { - Self::run_upload_batcher(ds, report_rx, max_batch_size, max_batch_write_delay).await + let runtime = Arc::new(runtime); + let runtime_clone = Arc::clone(&runtime); + runtime.spawn(async move { + Self::run_upload_batcher( + ds, + runtime_clone, + report_rx, + counter_shard_count, + max_batch_size, + max_batch_write_delay, + ) + .await }); Self { report_tx } } - pub async fn write_report + 'static>( + /// Save a report rejection to the database. + /// + /// This function does not wait for the result of the batch write, because we do not want + /// clients to retry bad reports, even due to server error. + pub async fn write_rejection(&self, report_rejection: ReportRejection) { + // Unwrap safety: report_rx is not dropped until ReportWriteBatcher is dropped. + self.report_tx + .send((Err(report_rejection), None)) + .await + .unwrap(); + } + + /// Save a report to the database. + /// + /// This function waits for and returns the result of the batch write. + pub async fn write_report( &self, - report: R, + report_writer: Box>, ) -> Result<(), Arc> { // Send report to be written. // Unwrap safety: report_rx is not dropped until ReportWriteBatcher is dropped. - let (rslt_tx, rslt_rx) = oneshot::channel(); + let (result_tx, result_rx) = oneshot::channel(); self.report_tx - .send((Box::new(report), rslt_tx)) + .send((Ok(report_writer), Some(result_tx))) .await .unwrap(); // Await the result of writing the report. // Unwrap safety: rslt_tx is always sent on before being dropped, and is never closed. - rslt_rx.await.unwrap() + result_rx.await.unwrap() } - #[tracing::instrument(skip(ds, report_rx))] - async fn run_upload_batcher( + #[tracing::instrument( + name = "ReportWriteBatcher::run_upload_batcher", + skip(ds, runtime, report_rx) + )] + async fn run_upload_batcher( ds: Arc>, + runtime: Arc, mut report_rx: ReportWriteBatcherReceiver, + counter_shard_count: u64, max_batch_size: usize, max_batch_write_delay: Duration, ) { let mut is_done = false; let mut batch_expiry = Instant::now(); - let mut report_writers = Vec::with_capacity(max_batch_size); - let mut result_txs = Vec::with_capacity(max_batch_size); + let mut report_results = Vec::with_capacity(max_batch_size); while !is_done { // Wait for an event of interest. let write_batch = select! { @@ -76,36 +113,34 @@ impl ReportWriteBatcher { item = report_rx.recv() => { match item { // We got an item. Add it to the current batch of reports to be written. - Some((report_writer, rslt_tx)) => { - if report_writers.is_empty() { + Some(report) => { + if report_results.is_empty() { batch_expiry = Instant::now() + max_batch_write_delay; } - report_writers.push(report_writer); - result_txs.push(rslt_tx); - report_writers.len() >= max_batch_size + report_results.push(report); + report_results.len() >= max_batch_size } // The channel is closed. Note this, and write any final reports that may be // batched before shutting down. None => { is_done = true; - !report_writers.is_empty() + !report_results.is_empty() }, } }, // ... or the current batch, if there is one, times out. - _ = sleep_until(batch_expiry), if !report_writers.is_empty() => true, + _ = sleep_until(batch_expiry), if !report_results.is_empty() => true, }; // If the event made us want to write the current batch to storage, do so. if write_batch { let ds = Arc::clone(&ds); - let result_writers = - replace(&mut report_writers, Vec::with_capacity(max_batch_size)); - let result_txs = replace(&mut result_txs, Vec::with_capacity(max_batch_size)); - tokio::spawn(async move { - Self::write_batch(ds, result_writers, result_txs).await; + let report_results = + replace(&mut report_results, Vec::with_capacity(max_batch_size)); + runtime.spawn(async move { + Self::write_batch(ds, counter_shard_count, report_results).await; }); } } @@ -114,48 +149,71 @@ impl ReportWriteBatcher { #[tracing::instrument(skip_all)] async fn write_batch( ds: Arc>, - report_writers: Vec>>, - result_txs: Vec>>>, + counter_shard_count: u64, + mut report_results: Vec<(ReportResult, Option)>, ) { - // Check preconditions. - assert_eq!(report_writers.len(), result_txs.len()); + let ord = thread_rng().gen_range(0..counter_shard_count); + + // Sort by task ID to prevent deadlocks with concurrently running transactions. Since we are + // using the same ord for all statements, we do not need to sort by ord. + report_results.sort_unstable_by_key(|writer| match &writer.0 { + Ok(report_writer) => *report_writer.task_id(), + Err(rejection) => *rejection.task_id(), + }); // Run all report writes concurrently. - let report_writers = Arc::new(report_writers); - let rslts = ds + let (report_results, result_senders): (Vec>, Vec>) = + report_results.into_iter().unzip(); + let report_results = Arc::new(report_results); + let results = ds .run_tx("upload", |tx| { - let report_writers = Arc::clone(&report_writers); + let report_results = Arc::clone(&report_results); Box::pin(async move { - Ok(join_all(report_writers.iter().map(|rw| rw.write_report(tx))).await) + Ok( + join_all(report_results.iter().map(|report_result| async move { + match report_result { + Ok(report_writer) => report_writer.write_report(tx, ord).await, + Err(rejection) => { + tx.increment_task_upload_counter( + rejection.task_id(), + ord, + &rejection.reason().into(), + ) + .await?; + Ok(()) + } + } + })) + .await, + ) }) }) .await; - match rslts { - Ok(rslts) => { + match results { + Ok(results) => { // Individual, per-request results. - assert_eq!(result_txs.len(), rslts.len()); // sanity check: should be guaranteed. - for (rslt_tx, rslt) in result_txs.into_iter().zip(rslts.into_iter()) { - if rslt_tx - .send(rslt.map_err(|err| Arc::new(Error::from(err)))) - .is_err() - { - debug!( - "ReportWriter couldn't send result to requester (request cancelled?)" - ); + assert_eq!(result_senders.len(), results.len()); // sanity check: should be guaranteed. + for (result_tx, result) in result_senders.into_iter().zip(results.into_iter()) { + if let Some(result_tx) = result_tx { + if result_tx.send(result.map_err(Arc::new)).is_err() { + debug!( + "ReportWriter couldn't send result to requester (request cancelled?)" + ); + } } } } Err(err) => { // Total-transaction failures are given to all waiting report uploaders. let err = Arc::new(Error::from(err)); - for rslt_tx in result_txs.into_iter() { - if rslt_tx.send(Err(Arc::clone(&err))).is_err() { + result_senders.into_iter().flatten().for_each(|result_tx| { + if result_tx.send(Err(Arc::clone(&err))).is_err() { debug!( "ReportWriter couldn't send result to requester (request cancelled?)" ); }; - } + }) } }; } @@ -163,7 +221,8 @@ impl ReportWriteBatcher { #[async_trait] pub trait ReportWriter: Debug + Send + Sync { - async fn write_report(&self, tx: &Transaction) -> Result<(), datastore::Error>; + fn task_id(&self) -> &TaskId; + async fn write_report(&self, tx: &Transaction, ord: u64) -> Result<(), Error>; } #[derive(Debug)] @@ -210,19 +269,45 @@ where C: Clock, Q: UploadableQueryType, { - async fn write_report(&self, tx: &Transaction) -> Result<(), datastore::Error> { - Q::validate_uploaded_report(tx, self.vdaf.as_ref(), &self.report).await?; - - // Store the report. - match tx - .put_client_report::(&self.vdaf, &self.report) - .await - { - // If the report already existed in the datastore, assume it is a duplicate and return - // OK. - Ok(()) | Err(datastore::Error::MutationTargetAlreadyExists) => Ok(()), + fn task_id(&self) -> &TaskId { + self.report.task_id() + } - err => err, + async fn write_report(&self, tx: &Transaction, ord: u64) -> Result<(), Error> { + // Some validation requires we query the database. Thus it's still possible to reject a + // report at this stage. + match Q::validate_uploaded_report(tx, self.vdaf.as_ref(), &self.report).await { + Ok(_) => { + let result = tx + .put_client_report::(&self.vdaf, &self.report) + .await; + match result { + Ok(_) => { + tx.increment_task_upload_counter( + self.report.task_id(), + ord, + &TaskUploadIncrementor::ReportSuccess, + ) + .await?; + Ok(()) + } + // Assume this was a duplicate report, return OK but don't increment the counter + // so we avoid double counting successful reports. + Err(datastore::Error::MutationTargetAlreadyExists) => Ok(()), + Err(error) => Err(error.into()), + } + } + Err(error) => { + if let Error::ReportRejected(rejection) = error { + tx.increment_task_upload_counter( + rejection.task_id(), + ord, + &rejection.reason().into(), + ) + .await?; + } + Err(error) + } } } } diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 0c62b5186..c6e5be09e 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -34,7 +34,7 @@ use janus_core::{ }, report_id::ReportIdChecksumExt, taskprov::TASKPROV_HEADER, - test_util::{install_test_trace_subscriber, VdafTranscript}, + test_util::{install_test_trace_subscriber, runtime::TestRuntime, VdafTranscript}, time::{Clock, DurationExt, MockClock, TimeExt}, vdaf::VERIFY_KEY_LENGTH, }; @@ -118,6 +118,7 @@ impl TaskprovTestCase { let handler = aggregator_handler( Arc::clone(&datastore), clock.clone(), + TestRuntime::default(), &noop_meter(), Config { taskprov_config: TaskprovConfig { enabled: true }, diff --git a/aggregator/src/binaries/aggregator.rs b/aggregator/src/binaries/aggregator.rs index a3fb382ba..3b5d6dfe5 100644 --- a/aggregator/src/binaries/aggregator.rs +++ b/aggregator/src/binaries/aggregator.rs @@ -9,7 +9,7 @@ use clap::Parser; use derivative::Derivative; use janus_aggregator_api::{self, aggregator_api_handler}; use janus_aggregator_core::datastore::Datastore; -use janus_core::{auth_tokens::AuthenticationToken, time::RealClock}; +use janus_core::{auth_tokens::AuthenticationToken, time::RealClock, TokioRuntime}; use opentelemetry::metrics::Meter; use serde::{de, Deserialize, Deserializer, Serialize}; use std::{ @@ -63,6 +63,7 @@ async fn run_aggregator( aggregator_handler( Arc::clone(&datastore), clock, + TokioRuntime, &meter, config.aggregator_config(), ) @@ -358,6 +359,12 @@ pub struct Config { /// the cost of collection. pub batch_aggregation_shard_count: u64, + /// Defines the number of shards to break report counters into. Increasing this value will + /// reduce the amount of database contention during report uploads, while increasing the cost + /// of getting task metrics. + #[serde(default = "default_task_counter_shard_count")] + pub task_counter_shard_count: u64, + /// Defines how often to refresh the global HPKE configs cache in milliseconds. This affects how /// often an aggregator becomes aware of key state changes. If unspecified, default is defined /// by [`GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL`]. You shouldn't normally have to @@ -366,6 +373,10 @@ pub struct Config { pub global_hpke_configs_refresh_interval: Option, } +fn default_task_counter_shard_count() -> u64 { + 32 +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct GarbageCollectorConfig { /// How frequently garbage collection is run, in seconds. @@ -417,7 +428,8 @@ impl Config { self.max_upload_batch_write_delay_ms, ), batch_aggregation_shard_count: self.batch_aggregation_shard_count, - taskprov_config: self.taskprov_config.clone(), + task_counter_shard_count: self.task_counter_shard_count, + taskprov_config: self.taskprov_config, global_hpke_configs_refresh_interval: match self.global_hpke_configs_refresh_interval { Some(duration) => Duration::from_millis(duration), None => GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, @@ -500,6 +512,7 @@ mod tests { max_upload_batch_size: 100, max_upload_batch_write_delay_ms: 250, batch_aggregation_shard_count: 32, + task_counter_shard_count: 64, taskprov_config: TaskprovConfig::default(), global_hpke_configs_refresh_interval: None, }) diff --git a/aggregator/src/config.rs b/aggregator/src/config.rs index b716b22cb..c650d917c 100644 --- a/aggregator/src/config.rs +++ b/aggregator/src/config.rs @@ -112,7 +112,7 @@ fn format_database_url(url: &Url, fmt: &mut std::fmt::Formatter) -> Result<(), s /// options are implementation-specific. /// /// [spec]: https://datatracker.ietf.org/doc/draft-wang-ppm-dap-taskprov/ -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct TaskprovConfig { /// Whether to enable the extension or not. Enabling this changes the behavior /// of the aggregator consistent with the taskprov [specification][spec]. diff --git a/aggregator/src/metrics/tests.rs b/aggregator/src/metrics/tests.rs index d65cbae6c..6513c8ca4 100644 --- a/aggregator/src/metrics/tests.rs +++ b/aggregator/src/metrics/tests.rs @@ -4,7 +4,7 @@ use http::StatusCode; use janus_aggregator_core::datastore::test_util::ephemeral_datastore; use janus_core::{ retries::{retry_http_request, test_http_request_exponential_backoff}, - test_util::install_test_trace_subscriber, + test_util::{install_test_trace_subscriber, runtime::TestRuntime}, time::MockClock, }; use opentelemetry::metrics::MeterProvider as _; @@ -86,6 +86,7 @@ async fn http_metrics() { let handler = aggregator_handler( datastore.clone(), clock.clone(), + TestRuntime::default(), &meter, default_aggregator_config(), ) diff --git a/aggregator/tests/integration/graceful_shutdown.rs b/aggregator/tests/integration/graceful_shutdown.rs index b774efebb..76652b2c9 100644 --- a/aggregator/tests/integration/graceful_shutdown.rs +++ b/aggregator/tests/integration/graceful_shutdown.rs @@ -265,6 +265,7 @@ async fn aggregator_shutdown() { max_upload_batch_size: 100, max_upload_batch_write_delay_ms: 250, batch_aggregation_shard_count: 32, + task_counter_shard_count: 64, global_hpke_configs_refresh_interval: None, }; diff --git a/aggregator_api/src/lib.rs b/aggregator_api/src/lib.rs index 8f72da4a7..db725fd24 100644 --- a/aggregator_api/src/lib.rs +++ b/aggregator_api/src/lib.rs @@ -93,6 +93,10 @@ pub fn aggregator_api_handler( .post("/tasks", instrumented(api(post_task::))) .get("/tasks/:task_id", instrumented(api(get_task::))) .delete("/tasks/:task_id", instrumented(api(delete_task::))) + .get( + "/tasks/:task_id/metrics/uploads", + instrumented(api(get_task_upload_metrics::)), + ) .get( "/tasks/:task_id/metrics", instrumented(api(get_task_metrics::)), diff --git a/aggregator_api/src/models.rs b/aggregator_api/src/models.rs index 735fb74b2..04fe5ca14 100644 --- a/aggregator_api/src/models.rs +++ b/aggregator_api/src/models.rs @@ -1,7 +1,7 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use derivative::Derivative; use janus_aggregator_core::{ - datastore::models::{GlobalHpkeKeypair, HpkeKeyState}, + datastore::models::{GlobalHpkeKeypair, HpkeKeyState, TaskUploadCounter}, task::{AggregatorTask, QueryType}, taskprov::{PeerAggregator, VerifyKeyInit}, }; @@ -173,6 +173,9 @@ pub(crate) struct GetTaskMetricsResp { pub(crate) report_aggregations: u64, } +#[derive(Serialize)] +pub(crate) struct GetTaskUploadMetricsResp(pub(crate) TaskUploadCounter); + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct GlobalHpkeConfigResp { pub(crate) config: HpkeConfig, diff --git a/aggregator_api/src/routes.rs b/aggregator_api/src/routes.rs index f6d4b8347..8b699a1a5 100644 --- a/aggregator_api/src/routes.rs +++ b/aggregator_api/src/routes.rs @@ -1,9 +1,9 @@ use crate::{ models::{ AggregatorApiConfig, AggregatorRole, DeleteTaskprovPeerAggregatorReq, GetTaskIdsResp, - GetTaskMetricsResp, GlobalHpkeConfigResp, PatchGlobalHpkeConfigReq, PostTaskReq, - PostTaskprovPeerAggregatorReq, PutGlobalHpkeConfigReq, SupportedVdaf, TaskResp, - TaskprovPeerAggregatorResp, + GetTaskMetricsResp, GetTaskUploadMetricsResp, GlobalHpkeConfigResp, + PatchGlobalHpkeConfigReq, PostTaskReq, PostTaskprovPeerAggregatorReq, + PutGlobalHpkeConfigReq, SupportedVdaf, TaskResp, TaskprovPeerAggregatorResp, }, Config, ConnExt, Error, }; @@ -276,6 +276,20 @@ pub(super) async fn get_task_metrics( })) } +pub(super) async fn get_task_upload_metrics( + conn: &mut Conn, + State(ds): State>>, +) -> Result, Error> { + let task_id = conn.task_id_param()?; + Ok(Json(GetTaskUploadMetricsResp( + ds.run_tx("get_task_upload_metrics", |tx| { + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await? + .ok_or(Error::NotFound)?, + ))) +} + pub(super) async fn get_global_hpke_configs( _: &mut Conn, State(ds): State>>, diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index 11e34b6e2..e2837bd97 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -1,9 +1,10 @@ use crate::{ aggregator_api_handler, models::{ - DeleteTaskprovPeerAggregatorReq, GetTaskIdsResp, GetTaskMetricsResp, GlobalHpkeConfigResp, - PatchGlobalHpkeConfigReq, PostTaskReq, PostTaskprovPeerAggregatorReq, - PutGlobalHpkeConfigReq, TaskResp, TaskprovPeerAggregatorResp, + DeleteTaskprovPeerAggregatorReq, GetTaskIdsResp, GetTaskMetricsResp, + GetTaskUploadMetricsResp, GlobalHpkeConfigResp, PatchGlobalHpkeConfigReq, PostTaskReq, + PostTaskprovPeerAggregatorReq, PutGlobalHpkeConfigReq, TaskResp, + TaskprovPeerAggregatorResp, }, Config, CONTENT_TYPE, }; @@ -14,7 +15,7 @@ use janus_aggregator_core::{ datastore::{ models::{ AggregationJob, AggregationJobState, HpkeKeyState, LeaderStoredReport, - ReportAggregation, ReportAggregationState, + ReportAggregation, ReportAggregationState, TaskUploadCounter, TaskUploadIncrementor, }, test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, @@ -865,6 +866,78 @@ async fn get_task_metrics() { ); } +#[tokio::test] +async fn get_task_upload_metrics() { + let (handler, _ephemeral_datastore, ds) = setup_api_test().await; + let task_id = ds + .run_unnamed_tx(|tx| { + Box::pin(async move { + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); + let task_id = *task.id(); + tx.put_aggregator_task(&task).await?; + + for case in [ + (TaskUploadIncrementor::ReportDecryptFailure, 2), + (TaskUploadIncrementor::ReportExpired, 4), + (TaskUploadIncrementor::ReportOutdatedKey, 6), + (TaskUploadIncrementor::ReportSuccess, 100), + (TaskUploadIncrementor::ReportTooEarly, 25), + (TaskUploadIncrementor::TaskExpired, 12), + ] { + let ord = thread_rng().gen_range(0..32); + try_join_all( + (0..case.1) + .map(|_| tx.increment_task_upload_counter(&task_id, ord, &case.0)), + ) + .await + .unwrap(); + } + + Ok(task_id) + }) + }) + .await + .unwrap(); + + // Verify: requesting metrics on a task returns the correct result. + assert_response!( + get(&format!("/tasks/{}/metrics/uploads", &task_id)) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) + .with_request_header("Accept", CONTENT_TYPE) + .run_async(&handler) + .await, + Status::Ok, + serde_json::to_string(&GetTaskUploadMetricsResp(TaskUploadCounter::new( + 0, 0, 2, 4, 6, 100, 25, 12 + ))) + .unwrap(), + ); + + // Verify: requesting metrics on a nonexistent task returns NotFound. + assert_response!( + get(&format!("/tasks/{}/metrics/uploads", &random::())) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) + .with_request_header("Accept", CONTENT_TYPE) + .run_async(&handler) + .await, + Status::NotFound, + "", + ); + + // Verify: unauthorized requests are denied appropriately. + assert_response!( + get(&format!("/tasks/{}/metrics/uploads", &task_id)) + .with_request_header("Accept", CONTENT_TYPE) + .run_async(&handler) + .await, + Status::Unauthorized, + "", + ); +} + #[tokio::test] async fn get_global_hpke_configs() { let (handler, _ephemeral_datastore, ds) = setup_api_test().await; @@ -1984,3 +2057,36 @@ fn get_task_metrics_resp_serialization() { ], ) } + +#[test] +fn get_task_upload_metrics_serialization() { + assert_ser_tokens( + &GetTaskUploadMetricsResp(TaskUploadCounter::new(0, 1, 2, 3, 4, 5, 6, 7)), + &[ + Token::NewtypeStruct { + name: "GetTaskUploadMetricsResp", + }, + Token::Struct { + name: "TaskUploadCounter", + len: 8, + }, + Token::Str("interval_collected"), + Token::U64(0), + Token::Str("report_decode_failure"), + Token::U64(1), + Token::Str("report_decrypt_failure"), + Token::U64(2), + Token::Str("report_expired"), + Token::U64(3), + Token::Str("report_outdated_key"), + Token::U64(4), + Token::Str("report_success"), + Token::U64(5), + Token::Str("report_too_early"), + Token::U64(6), + Token::Str("task_expired"), + Token::U64(7), + Token::StructEnd, + ], + ) +} diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index ea77e3213..e808763b9 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -5,7 +5,8 @@ use self::models::{ AggregatorRole, AuthenticationTokenType, Batch, BatchAggregation, CollectionJob, CollectionJobState, CollectionJobStateCode, GlobalHpkeKeypair, HpkeKeyState, LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, ReportAggregation, - ReportAggregationState, ReportAggregationStateCode, SqlInterval, + ReportAggregationState, ReportAggregationStateCode, SqlInterval, TaskUploadCounter, + TaskUploadIncrementor, }; use crate::{ query_type::{AccumulableQueryType, CollectableQueryType}, @@ -99,7 +100,7 @@ macro_rules! supported_schema_versions { // version is seen, [`Datastore::new`] fails. // // Note that the latest supported version must be first in the list. -supported_schema_versions!(3, 2, 1); +supported_schema_versions!(3, 2); /// Datastore represents a datastore for Janus, with support for transactional reads and writes. /// In practice, Datastore instances are currently backed by a PostgreSQL database. @@ -4758,6 +4759,75 @@ impl Transaction<'_, C> { .await?; check_single_row_mutation(self.execute(&stmt, &[&aggregator_url, &role]).await?) } + + /// Get the [`TaskUploadCounter`] for a task. This is aggregated across all shards. Returns + /// `None` if the task doesn't exist. + #[tracing::instrument(skip(self), err)] + pub async fn get_task_upload_counter( + &self, + task_id: &TaskId, + ) -> Result, Error> { + let stmt = self + .prepare_cached( + "SELECT + SUM(interval_collected)::BIGINT AS interval_collected, + SUM(report_decode_failure)::BIGINT AS report_decode_failure, + SUM(report_decrypt_failure)::BIGINT AS report_decrypt_failure, + SUM(report_expired)::BIGINT AS report_expired, + SUM(report_outdated_key)::BIGINT AS report_outdated_key, + SUM(report_success)::BIGINT AS report_success, + SUM(report_too_early)::BIGINT AS report_too_early, + SUM(task_expired)::BIGINT AS task_expired + FROM task_upload_counters + WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1)", + ) + .await?; + + let row = self.query_one(&stmt, &[task_id.as_ref()]).await?; + let interval_collected = row.get_nullable_bigint_and_convert("interval_collected")?; + Ok(match interval_collected { + Some(interval_collected) => Some(TaskUploadCounter { + interval_collected, + // The remaining columns should exist if the first one did, due to a DEFAULT 0 + // clause, so we don't need to treat these as nullable. + report_decode_failure: row.get_bigint_and_convert("report_decode_failure")?, + report_decrypt_failure: row.get_bigint_and_convert("report_decrypt_failure")?, + report_expired: row.get_bigint_and_convert("report_expired")?, + report_outdated_key: row.get_bigint_and_convert("report_outdated_key")?, + report_success: row.get_bigint_and_convert("report_success")?, + report_too_early: row.get_bigint_and_convert("report_too_early")?, + task_expired: row.get_bigint_and_convert("task_expired")?, + }), + None => None, + }) + } + + /// Add one to the counter associated with the given [`TaskId`]. The column to increment is given + /// by [`TaskUploadIncrementor`]. This is sharded, requiring an `ord` parameter to determine which + /// shard to add to. `ord` should be randomly generated by the caller. + #[tracing::instrument(skip(self), err)] + pub async fn increment_task_upload_counter( + &self, + task_id: &TaskId, + ord: u64, + incrementor: &TaskUploadIncrementor, + ) -> Result<(), Error> { + // SQL injection safety: The possible inputs of TaskUploadIncrementor and the resulting + // .column() are constrained to values that are known to be safe for interpolation. The + // calling function cannot supply arbitrary strings. + let column = incrementor.column(); + let stmt = format!( + "INSERT INTO task_upload_counters (task_id, ord, {column}) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET {column} = task_upload_counters.{column} + 1" + ); + let stmt = self.prepare_cached(&stmt).await?; + check_single_row_mutation( + self.execute(&stmt, &[task_id.as_ref(), &i64::try_from(ord)?]) + .await?, + ) + } } fn check_insert(row_count: u64) -> Result<(), Error> { diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index 995fa34ef..750151557 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -1353,7 +1353,6 @@ pub enum CollectionJobStateCode { /// AggregateShareJob represents a row in the `aggregate_share_jobs` table, used by helpers to /// store the results of handling an AggregateShareReq from the leader. - #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct AggregateShareJob< @@ -1539,7 +1538,6 @@ pub enum BatchState { } /// Represents the state of a given batch (and aggregation parameter). - #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct Batch> { @@ -1847,3 +1845,81 @@ impl GlobalHpkeKeypair { &self.updated_at } } + +/// Per-task counts of uploaded reports and upload attempts. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct TaskUploadCounter { + pub(crate) interval_collected: u64, + pub(crate) report_decode_failure: u64, + pub(crate) report_decrypt_failure: u64, + pub(crate) report_expired: u64, + pub(crate) report_outdated_key: u64, + pub(crate) report_success: u64, + pub(crate) report_too_early: u64, + pub(crate) task_expired: u64, +} + +impl TaskUploadCounter { + /// Create a new [`TaskUploadCounter`]. + /// + /// This is locked behind test-util since production code shouldn't need to manually create a + /// counter. + #[allow(clippy::too_many_arguments)] + #[cfg(feature = "test-util")] + pub fn new( + interval_collected: u64, + report_decode_failure: u64, + report_decrypt_failure: u64, + report_expired: u64, + report_outdated_key: u64, + report_success: u64, + report_too_early: u64, + task_expired: u64, + ) -> Self { + Self { + interval_collected, + report_decode_failure, + report_decrypt_failure, + report_expired, + report_outdated_key, + report_success, + report_too_early, + task_expired, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TaskUploadIncrementor { + /// A report fell into a time interval that has already been collected. + IntervalCollected, + /// A report could not be decoded. + ReportDecodeFailure, + /// A report could not be decrypted. + ReportDecryptFailure, + /// A report contains a timestamp too far in the past. + ReportExpired, + /// A report is encrypted with an old or unknown HPKE key. + ReportOutdatedKey, + /// A report was successfully uploaded. + ReportSuccess, + /// A report contains a timestamp too far in the future. + ReportTooEarly, + /// A report was submitted to the task after the task's expiry. + TaskExpired, +} + +impl TaskUploadIncrementor { + pub(crate) fn column(&self) -> &'static str { + match self { + TaskUploadIncrementor::IntervalCollected => "interval_collected", + TaskUploadIncrementor::ReportDecodeFailure => "report_decode_failure", + TaskUploadIncrementor::ReportDecryptFailure => "report_decrypt_failure", + TaskUploadIncrementor::ReportExpired => "report_expired", + TaskUploadIncrementor::ReportOutdatedKey => "report_outdated_key", + TaskUploadIncrementor::ReportSuccess => "report_success", + TaskUploadIncrementor::ReportTooEarly => "report_too_early", + TaskUploadIncrementor::TaskExpired => "task_expired", + } + } +} diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index c9e152c43..0a16a5fd0 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -5,6 +5,7 @@ use crate::{ AggregationJobState, Batch, BatchAggregation, BatchAggregationState, BatchState, CollectionJob, CollectionJobState, GlobalHpkeKeypair, HpkeKeyState, LeaderStoredReport, Lease, OutstandingBatch, ReportAggregation, ReportAggregationState, SqlInterval, + TaskUploadCounter, TaskUploadIncrementor, }, schema_versions_template, test_util::{ephemeral_datastore_schema_version, generate_aead_key, EphemeralDatastore}, @@ -7512,3 +7513,65 @@ async fn reject_expired_reports_with_same_id(ephemeral_datastore: EphemeralDatas .await; assert_matches!(result, Err(Error::MutationTargetAlreadyExists)); } + +#[rstest_reuse::apply(schema_versions_template)] +#[tokio::test] +async fn roundtrip_task_upload_counter(ephemeral_datastore: EphemeralDatastore) { + install_test_trace_subscriber(); + let clock = MockClock::default(); + let datastore = ephemeral_datastore.datastore(clock.clone()).await; + + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); + + datastore.put_aggregator_task(&task).await.unwrap(); + + datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { + let counter = tx.get_task_upload_counter(&task_id).await.unwrap(); + assert_eq!(counter, None); + + for case in [ + (TaskUploadIncrementor::IntervalCollected, 2), + (TaskUploadIncrementor::ReportDecodeFailure, 4), + (TaskUploadIncrementor::ReportDecryptFailure, 6), + (TaskUploadIncrementor::ReportExpired, 8), + (TaskUploadIncrementor::ReportOutdatedKey, 10), + (TaskUploadIncrementor::ReportSuccess, 100), + (TaskUploadIncrementor::ReportTooEarly, 25), + (TaskUploadIncrementor::TaskExpired, 12), + ] { + let ord = thread_rng().gen_range(0..32); + try_join_all( + (0..case.1) + .map(|_| tx.increment_task_upload_counter(&task_id, ord, &case.0)), + ) + .await + .unwrap(); + } + + let counter = tx.get_task_upload_counter(&task_id).await.unwrap(); + assert_eq!( + counter, + Some(TaskUploadCounter { + interval_collected: 2, + report_decode_failure: 4, + report_decrypt_failure: 6, + report_expired: 8, + report_outdated_key: 10, + report_success: 100, + report_too_early: 25, + task_expired: 12, + }) + ); + + Ok(()) + }) + }) + .await + .unwrap(); +} diff --git a/docs/samples/advanced_config/aggregator.yaml b/docs/samples/advanced_config/aggregator.yaml index 41e44fdc4..61ccd51fb 100644 --- a/docs/samples/advanced_config/aggregator.yaml +++ b/docs/samples/advanced_config/aggregator.yaml @@ -97,6 +97,11 @@ max_upload_batch_write_delay_ms: 250 # than the equivalent setting in the collection job driver. (required) batch_aggregation_shard_count: 32 +# Number of sharded database records per task counter. Increasing this value will reduce the amount +# of database contention during report uploads, while increasing the cost of getting task metrics. +# (optional, default: 32) +task_counter_shard_count: 32 + # Configuration for the taskprov extension. If enabled, this changes the behavior of the # aggregator as described in draft-wang-ppm-dap-taskprov. (optional) taskprov_config: diff --git a/integration_tests/src/janus.rs b/integration_tests/src/janus.rs index 331a11c33..0d943376f 100644 --- a/integration_tests/src/janus.rs +++ b/integration_tests/src/janus.rs @@ -159,6 +159,7 @@ impl JanusInProcess { max_upload_batch_size: 100, max_upload_batch_write_delay_ms: 100, batch_aggregation_shard_count: 32, + task_counter_shard_count: 64, global_hpke_configs_refresh_interval: None, }; let aggregation_job_creator_options = AggregationJobCreatorOptions { diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 62cc30947..005a60ede 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -14,6 +14,7 @@ use janus_aggregator_core::{ use janus_core::{ auth_tokens::{AuthenticationToken, AuthenticationTokenHash}, time::RealClock, + Runtime, TokioRuntime, }; use janus_interop_binaries::{ status::{ERROR, SUCCESS}, @@ -129,8 +130,9 @@ async fn handle_add_task( .context("error adding task to database") } -async fn make_handler( +async fn make_handler( datastore: Arc>, + runtime: R, meter: &Meter, dap_serving_prefix: String, ) -> anyhow::Result { @@ -138,6 +140,7 @@ async fn make_handler( let dap_handler = aggregator_handler( Arc::clone(&datastore), RealClock::default(), + runtime, meter, aggregator::Config { max_upload_batch_size: 100, @@ -247,6 +250,7 @@ async fn main() -> anyhow::Result<()> { // endpoints. let handler = make_handler( Arc::clone(&datastore), + TokioRuntime, &ctx.meter, ctx.config.dap_serving_prefix, )