diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index aec21fc65..88938fc09 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -5,7 +5,7 @@ use crate::{ aggregator::{ accumulator::Accumulator, aggregate_share::compute_aggregate_share, - error::{handle_ping_pong_error, ReportRejectedReason}, + error::{handle_ping_pong_error, ReportRejection, ReportRejectionReason}, error::{BatchMismatch, OptOutReason}, query_type::{CollectableQueryType, UploadableQueryType}, report_writer::{ReportWriteBatcher, WritableReport}, @@ -203,6 +203,11 @@ pub struct Config { /// the cost of collection. pub batch_aggregation_shard_count: u64, + /// Defines the number of shards to break report counters into. Increasing this value will + /// reduce the amount of database contention during report uploads, while increasing the cost + /// of getting task metrics. + pub task_counter_shard_count: u64, + /// Defines how often to refresh the global HPKE configs cache. This affects how often an aggregator /// becomes aware of key state changes. pub global_hpke_configs_refresh_interval: StdDuration, @@ -216,6 +221,7 @@ impl Default for Config { max_upload_batch_size: 1, max_upload_batch_write_delay: StdDuration::ZERO, batch_aggregation_shard_count: 1, + task_counter_shard_count: 1, global_hpke_configs_refresh_interval: GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, taskprov_config: TaskprovConfig::default(), } @@ -231,6 +237,7 @@ impl Aggregator { ) -> Result { let report_writer = Arc::new(ReportWriteBatcher::new( Arc::clone(&datastore), + cfg.task_counter_shard_count, cfg.max_upload_batch_size, cfg.max_upload_batch_write_delay, )); @@ -1394,14 +1401,16 @@ impl VdafOps { C: Clock, Q: UploadableQueryType, { - // Shorthand function for generating an Error::ReportRejected with proper parameters. + // Shorthand function for generating an Error::ReportRejected with proper parameters and + // recording it in the report_writer. let reject_report = |reason| { - Arc::new(Error::ReportRejected( - *task.id(), - *report.metadata().id(), - *report.metadata().time(), - reason, - )) + let report_id = *report.metadata().id(); + let report_time = *report.metadata().time(); + async move { + let rejection = ReportRejection::new(*task.id(), report_id, report_time, reason); + report_writer.write_report(Err(rejection)).await?; + Ok::<_, Arc>(Arc::new(Error::ReportRejected(rejection))) + } }; let report_deadline = clock @@ -1412,18 +1421,14 @@ impl VdafOps { // Reject reports from too far in the future. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#section-4.4.2-21 if report.metadata().time().is_after(&report_deadline) { - return Err(Arc::new(Error::ReportTooEarly( - *task.id(), - *report.metadata().id(), - *report.metadata().time(), - ))); + return Err(reject_report(ReportRejectionReason::TooEarly).await?); } // Reject reports after a task has expired. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#section-4.4.2-20 if let Some(task_expiration) = task.task_expiration() { if report.metadata().time().is_after(task_expiration) { - return Err(reject_report(ReportRejectedReason::TaskExpired)); + return Err(reject_report(ReportRejectionReason::TaskExpired).await?); } } @@ -1435,7 +1440,7 @@ impl VdafOps { .add(report_expiry_age) .map_err(|err| Arc::new(Error::from(err)))?; if clock.now().is_after(&report_expiry_time) { - return Err(reject_report(ReportRejectedReason::TooOld)); + return Err(reject_report(ReportRejectionReason::Expired).await?); } } @@ -1455,9 +1460,7 @@ impl VdafOps { "public share decoding failed", ); upload_decode_failure_counter.add(1, &[]); - return Err(reject_report( - ReportRejectedReason::PublicShareDecodeFailure, - )); + return Err(reject_report(ReportRejectionReason::DecodeFailure).await?); } }; @@ -1486,10 +1489,10 @@ impl VdafOps { // Verify that the report's HPKE config ID is known. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#section-4.4.2-17 (None, None) => { - return Err(Arc::new(Error::OutdatedHpkeConfig( - *task.id(), + return Err(reject_report(ReportRejectionReason::OutdatedHpkeConfig( *report.leader_encrypted_input_share().config_id(), - ))); + )) + .await?); } (None, Some(global_hpke_keypair)) => try_hpke_open(&global_hpke_keypair), (Some(task_hpke_keypair), None) => try_hpke_open(task_hpke_keypair), @@ -1513,7 +1516,7 @@ impl VdafOps { "Report decryption failed", ); upload_decrypt_failure_counter.add(1, &[]); - return Err(reject_report(ReportRejectedReason::LeaderDecryptFailure)); + return Err(reject_report(ReportRejectionReason::DecryptFailure).await?); } }; @@ -1540,9 +1543,7 @@ impl VdafOps { "Leader input share decoding failed", ); upload_decode_failure_counter.add(1, &[]); - return Err(reject_report( - ReportRejectedReason::LeaderInputShareDecodeFailure, - )); + return Err(reject_report(ReportRejectionReason::DecodeFailure).await?); } }; @@ -1556,7 +1557,9 @@ impl VdafOps { ); report_writer - .write_report(WritableReport::::new(vdaf, report)) + .write_report(Ok(Box::new(WritableReport::::new( + vdaf, report, + )))) .await } } @@ -3248,14 +3251,14 @@ pub(crate) mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - error::ReportRejectedReason, test_util::default_aggregator_config, Aggregator, Config, + error::ReportRejectionReason, test_util::default_aggregator_config, Aggregator, Config, Error, }; use assert_matches::assert_matches; use futures::future::try_join_all; use janus_aggregator_core::{ datastore::{ - models::{CollectionJob, CollectionJobState}, + models::{CollectionJob, CollectionJobState, TaskUploadCounter}, test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, }, @@ -3417,17 +3420,29 @@ mod tests { .unwrap(); // Verify that the original report, rather than the modified report, is stored. - let got_report = ds + let (got_report, got_counter) = ds .run_unnamed_tx(|tx| { let vdaf = vdaf.clone(); let task_id = *task.id(); let report_id = *report.metadata().id(); - Box::pin(async move { tx.get_client_report(&vdaf, &task_id, &report_id).await }) + Box::pin(async move { + Ok(( + tx.get_client_report(&vdaf, &task_id, &report_id) + .await + .unwrap(), + tx.get_task_upload_counter(&task_id).await.unwrap(), + )) + }) }) .await - .unwrap() .unwrap(); - assert!(got_report.eq_report(&vdaf, leader_task.current_hpke_key(), &report)); + assert!(got_report + .unwrap() + .eq_report(&vdaf, leader_task.current_hpke_key(), &report)); + assert_eq!( + got_counter, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 0, 0, 1, 0, 0) + ) } #[tokio::test] @@ -3472,13 +3487,25 @@ mod tests { .collect(); assert_eq!(want_report_ids, got_report_ids); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 0, 0, 100, 0, 0), + ); } #[tokio::test] async fn upload_wrong_hpke_config_id() { install_test_trace_subscriber(); - let (_, aggregator, clock, task, _, _ephemeral_datastore) = + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = setup_upload_test(default_aggregator_config()).await; let leader_task = task.leader_view().unwrap(); let report = create_report(&leader_task, clock.now()); @@ -3502,10 +3529,30 @@ mod tests { report.helper_encrypted_input_share().clone(), ); - assert_matches!(aggregator.handle_upload(task.id(), &report.get_encoded()).await.unwrap_err().as_ref(), Error::OutdatedHpkeConfig(task_id, config_id) => { - assert_eq!(task.id(), task_id); - assert_eq!(config_id, &unused_hpke_config_id); + let result = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!(result.as_ref(), Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::OutdatedHpkeConfig(id) => { + assert_eq!(id, &unused_hpke_config_id); + }) }); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 0, 1, 0, 0, 0), + ); } #[tokio::test] @@ -3535,13 +3582,25 @@ mod tests { .unwrap(); assert_eq!(task.id(), got_report.task_id()); assert_eq!(report.metadata(), got_report.metadata()); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 0, 0, 1, 0, 0), + ); } #[tokio::test] async fn upload_report_in_the_future_past_clock_skew() { install_test_trace_subscriber(); - let (_, aggregator, clock, task, _, _ephemeral_datastore) = + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = setup_upload_test(default_aggregator_config()).await; let report = create_report( &task.leader_view().unwrap(), @@ -3557,12 +3616,24 @@ mod tests { .handle_upload(task.id(), &report.get_encoded()) .await .unwrap_err(); - - assert_matches!(upload_error.as_ref(), Error::ReportTooEarly(task_id, report_id, time) => { - assert_eq!(task.id(), task_id); - assert_eq!(report.metadata().id(), report_id); - assert_eq!(report.metadata().time(), time); + assert_matches!(upload_error.as_ref(), Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::TooEarly); }); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 0, 0, 0, 1, 0), + ); } #[tokio::test] @@ -3612,12 +3683,25 @@ mod tests { .unwrap_err(); assert_matches!( error.as_ref(), - Error::ReportRejected(err_task_id, err_report_id, err_time, ReportRejectedReason::IntervalAlreadyCollected) => { - assert_eq!(task.id(), err_task_id); - assert_eq!(report.metadata().id(), err_report_id); - assert_eq!(report.metadata().time(), err_time); + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::IntervalCollected); } ); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 1, 0, 0, 0, 0, 0, 0, 0), + ); } #[tokio::test] @@ -3697,6 +3781,246 @@ mod tests { } } + #[tokio::test] + async fn upload_report_task_expired() { + install_test_trace_subscriber(); + + let (_, aggregator, clock, _, datastore, _ephemeral_datastore) = + setup_upload_test(default_aggregator_config()).await; + + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_task_expiration(Some(clock.now())) + .build() + .leader_view() + .unwrap(); + datastore.put_aggregator_task(&task).await.unwrap(); + + // Advance the clock to expire the task. + clock.advance(&Duration::from_seconds(1)); + let report = create_report(&task, clock.now()); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::TaskExpired); + } + ); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 0, 0, 0, 0, 1), + ); + } + + #[tokio::test] + async fn upload_report_report_expired() { + install_test_trace_subscriber(); + + let (_, aggregator, clock, _, datastore, _ephemeral_datastore) = + setup_upload_test(default_aggregator_config()).await; + + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_report_expiry_age(Some(Duration::from_seconds(60))) + .build() + .leader_view() + .unwrap(); + datastore.put_aggregator_task(&task).await.unwrap(); + + let report = create_report(&task, clock.now()); + + // Advance the clock to expire the report. + clock.advance(&Duration::from_seconds(61)); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::Expired); + } + ); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 0, 1, 0, 0, 0, 0), + ); + } + + #[tokio::test] + async fn upload_report_faulty_encryption() { + install_test_trace_subscriber(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test(default_aggregator_config()).await; + + let task = task.leader_view().unwrap(); + + // Encrypt with the wrong key. + let report = create_report_custom( + &task, + clock.now(), + random(), + &generate_test_hpke_config_and_private_key_with_id( + (*task.current_hpke_key().config().id()).into(), + ), + ); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::DecryptFailure); + } + ); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 0, 1, 0, 0, 0, 0, 0), + ); + } + + #[tokio::test] + async fn upload_report_public_share_decode_failure() { + install_test_trace_subscriber(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test(default_aggregator_config()).await; + + let task = task.leader_view().unwrap(); + + let mut report = create_report(&task, clock.now()); + report = Report::new( + report.metadata().clone(), + // Some obviously wrong public share. + vec![0; 10], + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), + ); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::DecodeFailure); + } + ); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 1, 0, 0, 0, 0, 0, 0), + ); + } + + #[tokio::test] + async fn upload_report_leader_input_share_decode_failure() { + install_test_trace_subscriber(); + let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test(default_aggregator_config()).await; + + let task = task.leader_view().unwrap(); + + let mut report = create_report(&task, clock.now()); + report = Report::new( + report.metadata().clone(), + report.public_share().to_vec(), + hpke::seal( + task.current_hpke_key().config(), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader), + // Some obviously wrong payload. + &PlaintextInputShare::new(Vec::new(), vec![0; 100]).get_encoded(), + &InputShareAad::new( + *task.id(), + report.metadata().clone(), + report.public_share().to_vec(), + ) + .get_encoded(), + ) + .unwrap(), + report.helper_encrypted_input_share().clone(), + ); + + // Try to upload the report, verify that we get the expected error. + let error = aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap_err(); + assert_matches!( + error.as_ref(), + Error::ReportRejected(rejection) => { + assert_eq!(task.id(), rejection.task_id()); + assert_eq!(report.metadata().id(), rejection.report_id()); + assert_eq!(report.metadata().time(), rejection.time()); + assert_matches!(rejection.reason(), ReportRejectionReason::DecodeFailure); + } + ); + + let got_counters = datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { tx.get_task_upload_counter(&task_id).await }) + }) + .await + .unwrap(); + assert_eq!( + got_counters, + TaskUploadCounter::new(*task.id(), 0, 1, 0, 0, 0, 0, 0, 0), + ); + } + pub(crate) fn generate_helper_report_share>( task_id: TaskId, report_metadata: ReportMetadata, diff --git a/aggregator/src/aggregator/error.rs b/aggregator/src/aggregator/error.rs index 1feb4d494..855e78171 100644 --- a/aggregator/src/aggregator/error.rs +++ b/aggregator/src/aggregator/error.rs @@ -1,4 +1,7 @@ -use janus_aggregator_core::{datastore, task}; +use janus_aggregator_core::{ + datastore::{self, models::TaskUploadIncrementor}, + task, +}; use janus_core::http::HttpErrorResponse; use janus_messages::{ AggregationJobId, AggregationJobStep, CollectionJobId, HpkeConfigId, Interval, PrepareError, @@ -29,14 +32,9 @@ pub enum Error { /// Error handling a message. #[error("invalid message: {0}")] Message(#[from] janus_messages::Error), - /// Corresponds to `reportRejected` in DAP. A report was rejected for some reason that is not - /// specified in DAP. - #[error("task {0}: report {1} rejected: {2}")] - ReportRejected(TaskId, ReportId, Time, ReportRejectedReason), - /// Corresponds to `reportTooEarly` in DAP. A report was rejected because the timestamp is too - /// far in the future. - #[error("task {0}: report {1} too early: {2}")] - ReportTooEarly(TaskId, ReportId, Time), + /// Catch-all error for invalid reports. + #[error("{0}")] + ReportRejected(ReportRejection), /// Corresponds to `invalidMessage` in DAP. #[error("task {0:?}: invalid message: {1}")] InvalidMessage(Option, &'static str), @@ -69,9 +67,6 @@ pub enum Error { /// An attempt was made to act on a collection job that has been abandoned by the aggregator. #[error("abandoned collection job: {0}")] AbandonedCollectionJob(CollectionJobId), - /// Corresponds to `outdatedHpkeConfig` in DAP. - #[error("task {0}: outdated HPKE config: {1}")] - OutdatedHpkeConfig(TaskId, HpkeConfigId), /// Corresponds to `unauthorizedRequest` in DAP. #[error("task {0}: unauthorized request")] UnauthorizedRequest(TaskId), @@ -145,37 +140,109 @@ pub enum Error { DifferentialPrivacy(VdafError), } -#[derive(Debug)] -pub enum ReportRejectedReason { - IntervalAlreadyCollected, - LeaderDecryptFailure, - LeaderInputShareDecodeFailure, - PublicShareDecodeFailure, +/// Contains details that describe the report and why it was rejected. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ReportRejection { + task_id: TaskId, + report_id: ReportId, + time: Time, + reason: ReportRejectionReason, +} + +impl ReportRejection { + pub fn new( + task_id: TaskId, + report_id: ReportId, + time: Time, + reason: ReportRejectionReason, + ) -> Self { + Self { + task_id, + report_id, + time, + reason, + } + } + + pub fn task_id(&self) -> &TaskId { + &self.task_id + } + + pub fn report_id(&self) -> &ReportId { + &self.report_id + } + + pub fn time(&self) -> &Time { + &self.time + } + + pub fn reason(&self) -> &ReportRejectionReason { + &self.reason + } +} + +impl Display for ReportRejection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "task {}, report {}, time {}, rejected {}", + self.task_id, self.report_id, self.time, self.reason + ) + } +} + +/// Indicates why a report was rejected. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReportRejectionReason { + IntervalCollected, + DecryptFailure, + DecodeFailure, TaskExpired, - TooOld, + Expired, + TooEarly, + OutdatedHpkeConfig(HpkeConfigId), } -impl ReportRejectedReason { +impl ReportRejectionReason { pub fn detail(&self) -> &'static str { match self { - ReportRejectedReason::IntervalAlreadyCollected => { + ReportRejectionReason::IntervalCollected => { "Report falls into a time interval that has already been collected." } - ReportRejectedReason::LeaderDecryptFailure => { - "Leader's report share could not be decrypted." - } - ReportRejectedReason::LeaderInputShareDecodeFailure => { - "Leader's input share could not be decoded." + ReportRejectionReason::DecryptFailure => "Report share could not be decrypted.", + ReportRejectionReason::DecodeFailure => "Report could not be decoded.", + ReportRejectionReason::TaskExpired => "Task has expired.", + ReportRejectionReason::Expired => "Report timestamp is too old.", + ReportRejectionReason::TooEarly => "Report timestamp is too far in the future.", + ReportRejectionReason::OutdatedHpkeConfig(_) => { + "Report is using an outdated HPKE configuration." } - ReportRejectedReason::PublicShareDecodeFailure => { - "Report public share could not be decoded." + } + } +} + +impl From<&ReportRejectionReason> for TaskUploadIncrementor { + fn from(value: &ReportRejectionReason) -> Self { + match value { + ReportRejectionReason::IntervalCollected => TaskUploadIncrementor::IntervalCollected, + ReportRejectionReason::DecryptFailure => TaskUploadIncrementor::ReportDecryptFailure, + ReportRejectionReason::DecodeFailure => TaskUploadIncrementor::ReportDecodeFailure, + ReportRejectionReason::TaskExpired => TaskUploadIncrementor::TaskExpired, + ReportRejectionReason::Expired => TaskUploadIncrementor::ReportExpired, + ReportRejectionReason::TooEarly => TaskUploadIncrementor::ReportTooEarly, + ReportRejectionReason::OutdatedHpkeConfig(_) => { + TaskUploadIncrementor::ReportOutdatedKey } - ReportRejectedReason::TaskExpired => "Task has expired.", - ReportRejectedReason::TooOld => "Report timestamp is too old.", } } } +impl Display for ReportRejectionReason { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + /// Errors that cause the aggregator to opt-out of a taskprov task. #[derive(Debug, thiserror::Error)] pub enum OptOutReason { @@ -199,8 +266,11 @@ impl Error { Error::InvalidConfiguration(_) => "invalid_configuration", Error::MessageDecode(_) => "message_decode", Error::Message(_) => "message", - Error::ReportRejected(_, _, _, _) => "report_rejected", - Error::ReportTooEarly(_, _, _) => "report_too_early", + Error::ReportRejected(rejection) => match rejection.reason { + ReportRejectionReason::TooEarly => "report_too_early", + ReportRejectionReason::OutdatedHpkeConfig(_) => "outdated_hpke_config", + _ => "report_rejected", + }, Error::InvalidMessage(_, _) => "unrecognized_message", Error::StepMismatch { .. } => "step_mismatch", Error::UnrecognizedTask(_) => "unrecognized_task", @@ -209,7 +279,6 @@ impl Error { Error::DeletedCollectionJob(_) => "deleted_collection_job", Error::AbandonedCollectionJob(_) => "abandoned_collection_job", Error::UnrecognizedCollectionJob(_) => "unrecognized_collection_job", - Error::OutdatedHpkeConfig(_, _) => "outdated_hpke_config", Error::UnauthorizedRequest(_) => "unauthorized_request", Error::Datastore(_) => "datastore", Error::Vdaf(_) => "vdaf", diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index a6bfe22c0..ae5386890 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -1,4 +1,4 @@ -use super::{Aggregator, Config, Error}; +use super::{error::ReportRejectionReason, Aggregator, Config, Error}; use crate::aggregator::problem_details::{ProblemDetailsConnExt, ProblemDocument}; use async_trait::async_trait; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; @@ -46,11 +46,21 @@ impl Handler for Error { Error::InvalidConfiguration(_) => conn.with_status(Status::InternalServerError), Error::MessageDecode(_) => conn .with_problem_document(&ProblemDocument::new_dap(DapProblemType::InvalidMessage)), - Error::ReportRejected(task_id, _, _, reason) => conn.with_problem_document( - &ProblemDocument::new_dap(DapProblemType::ReportRejected) - .with_task_id(task_id) - .with_detail(reason.detail()), - ), + Error::ReportRejected(rejection) => match rejection.reason() { + ReportRejectionReason::OutdatedHpkeConfig(_) => conn.with_problem_document( + &ProblemDocument::new_dap(DapProblemType::OutdatedConfig) + .with_task_id(rejection.task_id()), + ), + ReportRejectionReason::TooEarly => conn.with_problem_document( + &ProblemDocument::new_dap(DapProblemType::ReportTooEarly) + .with_task_id(rejection.task_id()), + ), + _ => conn.with_problem_document( + &ProblemDocument::new_dap(DapProblemType::ReportRejected) + .with_task_id(rejection.task_id()) + .with_detail(rejection.reason().detail()), + ), + }, Error::InvalidMessage(task_id, _) => { let mut doc = ProblemDocument::new_dap(DapProblemType::InvalidMessage); if let Some(task_id) = task_id { @@ -86,12 +96,6 @@ impl Handler for Error { .with_collection_job_id(collection_job_id), ), Error::UnrecognizedCollectionJob(_) => conn.with_status(Status::NotFound), - Error::OutdatedHpkeConfig(task_id, _) => conn.with_problem_document( - &ProblemDocument::new_dap(DapProblemType::OutdatedConfig).with_task_id(task_id), - ), - Error::ReportTooEarly(task_id, _, _) => conn.with_problem_document( - &ProblemDocument::new_dap(DapProblemType::ReportTooEarly).with_task_id(task_id), - ), Error::UnauthorizedRequest(task_id) => conn.with_problem_document( &ProblemDocument::new_dap(DapProblemType::UnauthorizedRequest) .with_task_id(task_id), @@ -682,7 +686,7 @@ mod tests { }, collection_job_tests::setup_collection_job_test_case, empty_batch_aggregations, - error::ReportRejectedReason, + error::ReportRejectionReason, http_handlers::{ aggregator_handler, aggregator_handler_with_aggregator, test_util::{decode_response_body, take_problem_details}, @@ -1196,7 +1200,7 @@ mod tests { "reportRejected", "Report could not be processed.", task.id(), - Some(ReportRejectedReason::TooOld.detail()), + Some(ReportRejectionReason::Expired.detail()), ) .await; @@ -1286,7 +1290,7 @@ mod tests { "reportRejected", "Report could not be processed.", task_expire_soon.id(), - Some(ReportRejectedReason::TaskExpired.detail()), + Some(ReportRejectionReason::TaskExpired.detail()), ) .await; @@ -1314,7 +1318,7 @@ mod tests { "reportRejected", "Report could not be processed.", leader_task.id(), - Some(ReportRejectedReason::PublicShareDecodeFailure.detail()), + Some(ReportRejectionReason::DecodeFailure.detail()), ) .await; @@ -1339,7 +1343,7 @@ mod tests { "reportRejected", "Report could not be processed.", leader_task.id(), - Some(ReportRejectedReason::LeaderDecryptFailure.detail()), + Some(ReportRejectionReason::DecryptFailure.detail()), ) .await; @@ -1376,7 +1380,7 @@ mod tests { "reportRejected", "Report could not be processed.", leader_task.id(), - Some(ReportRejectedReason::LeaderInputShareDecodeFailure.detail()), + Some(ReportRejectionReason::DecodeFailure.detail()), ) .await; diff --git a/aggregator/src/aggregator/problem_details.rs b/aggregator/src/aggregator/problem_details.rs index f2cccd775..230a25f4d 100644 --- a/aggregator/src/aggregator/problem_details.rs +++ b/aggregator/src/aggregator/problem_details.rs @@ -109,7 +109,7 @@ impl ProblemDetailsConnExt for Conn { #[cfg(test)] mod tests { use crate::aggregator::{ - error::{BatchMismatch, ReportRejectedReason}, + error::{BatchMismatch, ReportRejection, ReportRejectionReason}, send_request_to_helper, Error, }; use assert_matches::assert_matches; @@ -119,7 +119,7 @@ mod tests { use janus_core::time::{Clock, RealClock}; use janus_messages::{ problem_type::{DapProblemType, DapProblemTypeParseError}, - Duration, HpkeConfigId, Interval, ReportIdChecksum, + Duration, Interval, ReportIdChecksum, }; use opentelemetry::metrics::Unit; use rand::random; @@ -179,15 +179,37 @@ mod tests { TestCase::new(Box::new(|| Error::InvalidConfiguration("test")), None), TestCase::new( Box::new(|| { - Error::ReportRejected( + Error::ReportRejected(ReportRejection::new( random(), random(), RealClock::default().now(), - ReportRejectedReason::TaskExpired - ) + ReportRejectionReason::TaskExpired + )) }), Some(DapProblemType::ReportRejected), ), + TestCase::new( + Box::new(|| { + Error::ReportRejected(ReportRejection::new( + random(), + random(), + RealClock::default().now(), + ReportRejectionReason::TooEarly + )) + }), + Some(DapProblemType::ReportTooEarly), + ), + TestCase::new( + Box::new(|| { + Error::ReportRejected(ReportRejection::new( + random(), + random(), + RealClock::default().now(), + ReportRejectionReason::OutdatedHpkeConfig(random()), + )) + }), + Some(DapProblemType::OutdatedConfig), + ), TestCase::new( Box::new(|| Error::InvalidMessage(Some(random()), "test")), Some(DapProblemType::InvalidMessage), @@ -204,16 +226,6 @@ mod tests { Box::new(|| Error::UnrecognizedAggregationJob(random(), random())), Some(DapProblemType::UnrecognizedAggregationJob), ), - TestCase::new( - Box::new(|| Error::OutdatedHpkeConfig(random(), HpkeConfigId::from(0))), - Some(DapProblemType::OutdatedConfig), - ), - TestCase::new( - Box::new(|| { - Error::ReportTooEarly(random(), random(), RealClock::default().now()) - }), - Some(DapProblemType::ReportTooEarly), - ), TestCase::new( Box::new(|| Error::UnauthorizedRequest(random())), Some(DapProblemType::UnauthorizedRequest), diff --git a/aggregator/src/aggregator/query_type.rs b/aggregator/src/aggregator/query_type.rs index 2177021f0..365c7e0af 100644 --- a/aggregator/src/aggregator/query_type.rs +++ b/aggregator/src/aggregator/query_type.rs @@ -1,4 +1,7 @@ -use super::{error::ReportRejectedReason, Error}; +use super::{ + error::{ReportRejection, ReportRejectionReason}, + Error, +}; use async_trait::async_trait; use janus_aggregator_core::{ datastore::{self, models::LeaderStoredReport, Transaction}, @@ -23,7 +26,7 @@ pub trait UploadableQueryType: QueryType { tx: &Transaction<'_, C>, vdaf: &A, report: &LeaderStoredReport, - ) -> Result<(), datastore::Error> + ) -> Result<(), Error> where A::InputShare: Send + Sync, A::PublicShare: Send + Sync; @@ -39,7 +42,7 @@ impl UploadableQueryType for TimeInterval { tx: &Transaction<'_, C>, vdaf: &A, report: &LeaderStoredReport, - ) -> Result<(), datastore::Error> + ) -> Result<(), Error> where A::InputShare: Send + Sync, A::PublicShare: Send + Sync, @@ -55,15 +58,12 @@ impl UploadableQueryType for TimeInterval { ) .await?; if !conflicting_collect_jobs.is_empty() { - return Err(datastore::Error::User( - Error::ReportRejected( - *report.task_id(), - *report.metadata().id(), - *report.metadata().time(), - ReportRejectedReason::IntervalAlreadyCollected, - ) - .into(), - )); + return Err(Error::ReportRejected(ReportRejection::new( + *report.task_id(), + *report.metadata().id(), + *report.metadata().time(), + ReportRejectionReason::IntervalCollected, + ))); } Ok(()) } @@ -79,7 +79,7 @@ impl UploadableQueryType for FixedSize { _: &Transaction<'_, C>, _: &A, _: &LeaderStoredReport, - ) -> Result<(), datastore::Error> { + ) -> Result<(), Error> { // Fixed-size tasks associate reports to batches at time of aggregation rather than at time // of upload, and there are no other relevant checks to apply here, so this method simply // returns Ok(()). diff --git a/aggregator/src/aggregator/report_writer.rs b/aggregator/src/aggregator/report_writer.rs index de1bae6c2..acc4f790b 100644 --- a/aggregator/src/aggregator/report_writer.rs +++ b/aggregator/src/aggregator/report_writer.rs @@ -1,9 +1,15 @@ use crate::aggregator::{query_type::UploadableQueryType, Error}; use async_trait::async_trait; use futures::future::join_all; -use janus_aggregator_core::datastore::{self, models::LeaderStoredReport, Datastore, Transaction}; +use janus_aggregator_core::datastore::{ + self, + models::{LeaderStoredReport, TaskUploadIncrementor}, + Datastore, Transaction, +}; use janus_core::time::Clock; +use janus_messages::TaskId; use prio::vdaf; +use rand::{thread_rng, Rng}; use std::{fmt::Debug, marker::PhantomData, mem::replace, sync::Arc, time::Duration}; use tokio::{ select, @@ -12,14 +18,14 @@ use tokio::{ }; use tracing::debug; -type ReportWriteBatcherSender = mpsc::Sender<( - Box>, - oneshot::Sender>>, -)>; -type ReportWriteBatcherReceiver = mpsc::Receiver<( - Box>, - oneshot::Sender>>, -)>; +use super::error::ReportRejection; + +type ReportResult = Result>, ReportRejection>; + +type ResultSender = oneshot::Sender>>; + +type ReportWriteBatcherSender = mpsc::Sender<(ReportResult, ResultSender)>; +type ReportWriteBatcherReceiver = mpsc::Receiver<(ReportResult, ResultSender)>; pub struct ReportWriteBatcher { report_tx: ReportWriteBatcherSender, @@ -28,29 +34,31 @@ pub struct ReportWriteBatcher { impl ReportWriteBatcher { pub fn new( ds: Arc>, + counter_shard_count: u64, max_batch_size: usize, max_batch_write_delay: Duration, ) -> Self { let (report_tx, report_rx) = mpsc::channel(1); tokio::spawn(async move { - Self::run_upload_batcher(ds, report_rx, max_batch_size, max_batch_write_delay).await + Self::run_upload_batcher( + ds, + report_rx, + counter_shard_count, + max_batch_size, + max_batch_write_delay, + ) + .await }); Self { report_tx } } - pub async fn write_report + 'static>( - &self, - report: R, - ) -> Result<(), Arc> { + pub async fn write_report(&self, report_result: ReportResult) -> Result<(), Arc> { // Send report to be written. // Unwrap safety: report_rx is not dropped until ReportWriteBatcher is dropped. let (rslt_tx, rslt_rx) = oneshot::channel(); - self.report_tx - .send((Box::new(report), rslt_tx)) - .await - .unwrap(); + self.report_tx.send((report_result, rslt_tx)).await.unwrap(); // Await the result of writing the report. // Unwrap safety: rslt_tx is always sent on before being dropped, and is never closed. @@ -61,13 +69,13 @@ impl ReportWriteBatcher { async fn run_upload_batcher( ds: Arc>, mut report_rx: ReportWriteBatcherReceiver, + counter_shard_count: u64, max_batch_size: usize, max_batch_write_delay: Duration, ) { let mut is_done = false; let mut batch_expiry = Instant::now(); - let mut report_writers = Vec::with_capacity(max_batch_size); - let mut result_txs = Vec::with_capacity(max_batch_size); + let mut report_results = Vec::with_capacity(max_batch_size); while !is_done { // Wait for an event of interest. let write_batch = select! { @@ -76,36 +84,34 @@ impl ReportWriteBatcher { item = report_rx.recv() => { match item { // We got an item. Add it to the current batch of reports to be written. - Some((report_writer, rslt_tx)) => { - if report_writers.is_empty() { + Some(report) => { + if report_results.is_empty() { batch_expiry = Instant::now() + max_batch_write_delay; } - report_writers.push(report_writer); - result_txs.push(rslt_tx); - report_writers.len() >= max_batch_size + report_results.push(report); + report_results.len() >= max_batch_size } // The channel is closed. Note this, and write any final reports that may be // batched before shutting down. None => { is_done = true; - !report_writers.is_empty() + !report_results.is_empty() }, } }, // ... or the current batch, if there is one, times out. - _ = sleep_until(batch_expiry), if !report_writers.is_empty() => true, + _ = sleep_until(batch_expiry), if !report_results.is_empty() => true, }; // If the event made us want to write the current batch to storage, do so. if write_batch { let ds = Arc::clone(&ds); - let result_writers = - replace(&mut report_writers, Vec::with_capacity(max_batch_size)); - let result_txs = replace(&mut result_txs, Vec::with_capacity(max_batch_size)); + let report_results = + replace(&mut report_results, Vec::with_capacity(max_batch_size)); tokio::spawn(async move { - Self::write_batch(ds, result_writers, result_txs).await; + Self::write_batch(ds, counter_shard_count, report_results).await; }); } } @@ -114,19 +120,43 @@ impl ReportWriteBatcher { #[tracing::instrument(skip_all)] async fn write_batch( ds: Arc>, - report_writers: Vec>>, - result_txs: Vec>>>, + counter_shard_count: u64, + mut report_results: Vec<(ReportResult, ResultSender)>, ) { - // Check preconditions. - assert_eq!(report_writers.len(), result_txs.len()); + let ord = thread_rng().gen_range(0..counter_shard_count); + + // Sort by task ID to prevent deadlocks with concurrently running transactions. Since we are + // using the same ord for all statements, we do not need to sort by ord. + report_results.sort_unstable_by_key(|writer| match &writer.0 { + Ok(report_writer) => *report_writer.task_id(), + Err(rejection) => *rejection.task_id(), + }); // Run all report writes concurrently. - let report_writers = Arc::new(report_writers); + let (report_results, result_senders): (Vec>, Vec) = + report_results.into_iter().unzip(); + let report_results = Arc::new(report_results); let rslts = ds .run_tx("upload", |tx| { - let report_writers = Arc::clone(&report_writers); + let report_results = Arc::clone(&report_results); Box::pin(async move { - Ok(join_all(report_writers.iter().map(|rw| rw.write_report(tx))).await) + Ok( + join_all(report_results.iter().map(|report_result| async move { + match report_result { + Ok(report_writer) => report_writer.write_report(tx, ord).await, + Err(rejection) => { + tx.increment_task_upload_counter( + rejection.task_id(), + ord, + &rejection.reason().into(), + ) + .await?; + Ok(()) + } + } + })) + .await, + ) }) }) .await; @@ -134,12 +164,9 @@ impl ReportWriteBatcher { match rslts { Ok(rslts) => { // Individual, per-request results. - assert_eq!(result_txs.len(), rslts.len()); // sanity check: should be guaranteed. - for (rslt_tx, rslt) in result_txs.into_iter().zip(rslts.into_iter()) { - if rslt_tx - .send(rslt.map_err(|err| Arc::new(Error::from(err)))) - .is_err() - { + assert_eq!(result_senders.len(), rslts.len()); // sanity check: should be guaranteed. + for (rslt_tx, rslt) in result_senders.into_iter().zip(rslts.into_iter()) { + if rslt_tx.send(rslt.map_err(Arc::new)).is_err() { debug!( "ReportWriter couldn't send result to requester (request cancelled?)" ); @@ -149,7 +176,7 @@ impl ReportWriteBatcher { Err(err) => { // Total-transaction failures are given to all waiting report uploaders. let err = Arc::new(Error::from(err)); - for rslt_tx in result_txs.into_iter() { + for rslt_tx in result_senders.into_iter() { if rslt_tx.send(Err(Arc::clone(&err))).is_err() { debug!( "ReportWriter couldn't send result to requester (request cancelled?)" @@ -163,7 +190,8 @@ impl ReportWriteBatcher { #[async_trait] pub trait ReportWriter: Debug + Send + Sync { - async fn write_report(&self, tx: &Transaction) -> Result<(), datastore::Error>; + fn task_id(&self) -> &TaskId; + async fn write_report(&self, tx: &Transaction, ord: u64) -> Result<(), Error>; } #[derive(Debug)] @@ -210,19 +238,45 @@ where C: Clock, Q: UploadableQueryType, { - async fn write_report(&self, tx: &Transaction) -> Result<(), datastore::Error> { - Q::validate_uploaded_report(tx, self.vdaf.as_ref(), &self.report).await?; - - // Store the report. - match tx - .put_client_report::(&self.vdaf, &self.report) - .await - { - // If the report already existed in the datastore, assume it is a duplicate and return - // OK. - Ok(()) | Err(datastore::Error::MutationTargetAlreadyExists) => Ok(()), + fn task_id(&self) -> &TaskId { + self.report.task_id() + } - err => err, + async fn write_report(&self, tx: &Transaction, ord: u64) -> Result<(), Error> { + // Some validation requires we query the database. Thus it's still possible to reject a + // report at this stage. + match Q::validate_uploaded_report(tx, self.vdaf.as_ref(), &self.report).await { + Ok(_) => { + let result = tx + .put_client_report::(&self.vdaf, &self.report) + .await; + match result { + Ok(_) => { + tx.increment_task_upload_counter( + self.report.task_id(), + ord, + &TaskUploadIncrementor::ReportSuccess, + ) + .await?; + Ok(()) + } + // Assume this was a duplicate report, return OK but don't increment the counter + // so we avoid double counting successful reports. + Err(datastore::Error::MutationTargetAlreadyExists) => Ok(()), + Err(error) => Err(error.into()), + } + } + Err(error) => { + if let Error::ReportRejected(rejection) = error { + tx.increment_task_upload_counter( + rejection.task_id(), + ord, + &rejection.reason().into(), + ) + .await?; + } + Err(error) + } } } } diff --git a/aggregator/src/binaries/aggregator.rs b/aggregator/src/binaries/aggregator.rs index a3fb382ba..535357df8 100644 --- a/aggregator/src/binaries/aggregator.rs +++ b/aggregator/src/binaries/aggregator.rs @@ -358,6 +358,12 @@ pub struct Config { /// the cost of collection. pub batch_aggregation_shard_count: u64, + /// Defines the number of shards to break report counters into. Increasing this value will + /// reduce the amount of database contention during report uploads, while increasing the cost + /// of getting task metrics. + #[serde(default = "default_task_counter_shard_count")] + pub task_counter_shard_count: u64, + /// Defines how often to refresh the global HPKE configs cache in milliseconds. This affects how /// often an aggregator becomes aware of key state changes. If unspecified, default is defined /// by [`GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL`]. You shouldn't normally have to @@ -366,6 +372,10 @@ pub struct Config { pub global_hpke_configs_refresh_interval: Option, } +fn default_task_counter_shard_count() -> u64 { + 1 +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct GarbageCollectorConfig { /// How frequently garbage collection is run, in seconds. @@ -417,6 +427,7 @@ impl Config { self.max_upload_batch_write_delay_ms, ), batch_aggregation_shard_count: self.batch_aggregation_shard_count, + task_counter_shard_count: self.task_counter_shard_count, taskprov_config: self.taskprov_config.clone(), global_hpke_configs_refresh_interval: match self.global_hpke_configs_refresh_interval { Some(duration) => Duration::from_millis(duration), @@ -500,6 +511,7 @@ mod tests { max_upload_batch_size: 100, max_upload_batch_write_delay_ms: 250, batch_aggregation_shard_count: 32, + task_counter_shard_count: 64, taskprov_config: TaskprovConfig::default(), global_hpke_configs_refresh_interval: None, }) diff --git a/aggregator/tests/integration/graceful_shutdown.rs b/aggregator/tests/integration/graceful_shutdown.rs index b774efebb..76652b2c9 100644 --- a/aggregator/tests/integration/graceful_shutdown.rs +++ b/aggregator/tests/integration/graceful_shutdown.rs @@ -265,6 +265,7 @@ async fn aggregator_shutdown() { max_upload_batch_size: 100, max_upload_batch_write_delay_ms: 250, batch_aggregation_shard_count: 32, + task_counter_shard_count: 64, global_hpke_configs_refresh_interval: None, }; diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 7fcbcfc1a..197d91740 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -5,7 +5,8 @@ use self::models::{ AggregatorRole, AuthenticationTokenType, Batch, BatchAggregation, CollectionJob, CollectionJobState, CollectionJobStateCode, GlobalHpkeKeypair, HpkeKeyState, LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, ReportAggregation, - ReportAggregationState, ReportAggregationStateCode, SqlInterval, + ReportAggregationState, ReportAggregationStateCode, SqlInterval, TaskUploadCounter, + TaskUploadIncrementor, }; use crate::{ query_type::{AccumulableQueryType, CollectableQueryType}, @@ -4731,6 +4732,111 @@ impl Transaction<'_, C> { .await?; check_single_row_mutation(self.execute(&stmt, &[&aggregator_url, &role]).await?) } + + /// Get the [`TaskUploadCounter`] for a task. This is aggregated across all shards. + #[tracing::instrument(skip(self), err)] + pub async fn get_task_upload_counter( + &self, + task_id: &TaskId, + ) -> Result { + let stmt = self + .prepare_cached( + "SELECT + COALESCE(SUM(interval_collected)::BIGINT, 0) AS interval_collected, + COALESCE(SUM(report_decode_failure)::BIGINT, 0) AS report_decode_failure, + COALESCE(SUM(report_decrypt_failure)::BIGINT, 0) AS report_decrypt_failure, + COALESCE(SUM(report_expired)::BIGINT, 0) AS report_expired, + COALESCE(SUM(report_outdated_key)::BIGINT, 0) AS report_outdated_key, + COALESCE(SUM(report_success)::BIGINT, 0) AS report_success, + COALESCE(SUM(report_too_early)::BIGINT, 0) AS report_too_early, + COALESCE(SUM(task_expired)::BIGINT, 0) AS task_expired + FROM task_upload_counters + WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1)", + ) + .await?; + + let row = self.query_one(&stmt, &[task_id.as_ref()]).await?; + Ok(TaskUploadCounter { + task_id: *task_id, + interval_collected: row.get_bigint_and_convert("interval_collected")?, + report_decode_failure: row.get_bigint_and_convert("report_decode_failure")?, + report_decrypt_failure: row.get_bigint_and_convert("report_decrypt_failure")?, + report_expired: row.get_bigint_and_convert("report_expired")?, + report_outdated_key: row.get_bigint_and_convert("report_outdated_key")?, + report_success: row.get_bigint_and_convert("report_success")?, + report_too_early: row.get_bigint_and_convert("report_too_early")?, + task_expired: row.get_bigint_and_convert("task_expired")?, + }) + } + + /// Add one to the counter associated with the given [`TaskId`]. The column to increment is given + /// by [`TaskUploadIncrementor`]. This is sharded, requiring an `ord` parameter to determine which + /// shard to add to. `ord` should be randomly generated by the caller. + #[tracing::instrument(skip(self), err)] + pub async fn increment_task_upload_counter( + &self, + task_id: &TaskId, + ord: u64, + incrementor: &TaskUploadIncrementor, + ) -> Result<(), Error> { + // Brute force each possible query. We cannot parameterize column names in prepared + // statements and we want to avoid the hazards of string interpolation into SQL. + let stmt = self + .prepare_cached(match incrementor { + TaskUploadIncrementor::IntervalCollected => { + "INSERT INTO task_upload_counters (task_id, ord, interval_collected) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET interval_collected = task_upload_counters.interval_collected + 1" + } + TaskUploadIncrementor::ReportDecodeFailure => { + "INSERT INTO task_upload_counters (task_id, ord, report_decode_failure) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET report_decode_failure = task_upload_counters.report_decode_failure + 1" + } + TaskUploadIncrementor::ReportDecryptFailure => { + "INSERT INTO task_upload_counters (task_id, ord, report_decrypt_failure) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET report_decrypt_failure = task_upload_counters.report_decrypt_failure + 1" + } + TaskUploadIncrementor::ReportExpired => { + "INSERT INTO task_upload_counters (task_id, ord, report_expired) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET report_expired = task_upload_counters.report_expired + 1" + } + TaskUploadIncrementor::ReportOutdatedKey => { + "INSERT INTO task_upload_counters (task_id, ord, report_outdated_key) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET report_outdated_key = task_upload_counters.report_outdated_key + 1" + } + TaskUploadIncrementor::ReportSuccess => { + "INSERT INTO task_upload_counters (task_id, ord, report_success) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET report_success = task_upload_counters.report_success + 1" + } + TaskUploadIncrementor::ReportTooEarly => { + "INSERT INTO task_upload_counters (task_id, ord, report_too_early) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET report_too_early = task_upload_counters.report_too_early + 1" + } + TaskUploadIncrementor::TaskExpired => { + "INSERT INTO task_upload_counters (task_id, ord, task_expired) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1) + ON CONFLICT (task_id, ord) DO UPDATE + SET task_expired = task_upload_counters.task_expired + 1" + } + }) + .await?; + let params: &[&(dyn ToSql + Sync)] = &[task_id.as_ref(), &i64::try_from(ord)?]; + + check_single_row_mutation(self.execute(&stmt, params).await?) + } } fn check_insert(row_count: u64) -> Result<(), Error> { diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index 995fa34ef..dd65c3806 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -1353,7 +1353,6 @@ pub enum CollectionJobStateCode { /// AggregateShareJob represents a row in the `aggregate_share_jobs` table, used by helpers to /// store the results of handling an AggregateShareReq from the leader. - #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct AggregateShareJob< @@ -1539,7 +1538,6 @@ pub enum BatchState { } /// Represents the state of a given batch (and aggregation parameter). - #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct Batch> { @@ -1847,3 +1845,70 @@ impl GlobalHpkeKeypair { &self.updated_at } } + +/// Per-task counts of uploaded reports and upload attempts. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct TaskUploadCounter { + pub(crate) task_id: TaskId, + + pub(crate) interval_collected: u64, + pub(crate) report_decode_failure: u64, + pub(crate) report_decrypt_failure: u64, + pub(crate) report_expired: u64, + pub(crate) report_outdated_key: u64, + pub(crate) report_success: u64, + pub(crate) report_too_early: u64, + pub(crate) task_expired: u64, +} + +impl TaskUploadCounter { + /// Create a new [`TaskUploadCounter`]. + /// + /// This is locked behind test-util since production code shouldn't need to manually create a + /// counter. + #[allow(clippy::too_many_arguments)] + #[cfg(feature = "test-util")] + pub fn new( + task_id: TaskId, + interval_collected: u64, + report_decode_failure: u64, + report_decrypt_failure: u64, + report_expired: u64, + report_outdated_key: u64, + report_success: u64, + report_too_early: u64, + task_expired: u64, + ) -> Self { + Self { + task_id, + interval_collected, + report_decode_failure, + report_decrypt_failure, + report_expired, + report_outdated_key, + report_success, + report_too_early, + task_expired, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TaskUploadIncrementor { + /// A report fell into a time interval that has already been collected. + IntervalCollected, + /// A report could not be decoded. + ReportDecodeFailure, + /// A report could not be decrypted. + ReportDecryptFailure, + /// A report contains a timestamp too far in the past. + ReportExpired, + /// A report is encrypted with an old or unknown HPKE key. + ReportOutdatedKey, + /// A report was successfully uploaded. + ReportSuccess, + /// A report contains a timestamp too far in the future. + ReportTooEarly, + /// A report was submitted to the task after the task's expiry. + TaskExpired, +} diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 24dfec808..860f02a1b 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -5,6 +5,7 @@ use crate::{ AggregationJobState, Batch, BatchAggregation, BatchAggregationState, BatchState, CollectionJob, CollectionJobState, GlobalHpkeKeypair, HpkeKeyState, LeaderStoredReport, Lease, OutstandingBatch, ReportAggregation, ReportAggregationState, SqlInterval, + TaskUploadCounter, TaskUploadIncrementor, }, schema_versions_template, test_util::{ephemeral_datastore_schema_version, generate_aead_key, EphemeralDatastore}, @@ -7438,3 +7439,79 @@ async fn reject_expired_reports_with_same_id(ephemeral_datastore: EphemeralDatas .await; assert_matches!(result, Err(Error::MutationTargetAlreadyExists)); } + +#[rstest_reuse::apply(schema_versions_template)] +#[tokio::test] +async fn roundtrip_task_upload_counter(ephemeral_datastore: EphemeralDatastore) { + install_test_trace_subscriber(); + let clock = MockClock::default(); + let datastore = ephemeral_datastore.datastore(clock.clone()).await; + + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); + + datastore.put_aggregator_task(&task).await.unwrap(); + + datastore + .run_unnamed_tx(|tx| { + let task_id = *task.id(); + Box::pin(async move { + let counter = tx.get_task_upload_counter(&task_id).await.unwrap(); + assert_eq!( + counter, + TaskUploadCounter { + task_id, + interval_collected: 0, + report_decode_failure: 0, + report_decrypt_failure: 0, + report_expired: 0, + report_success: 0, + report_too_early: 0, + report_outdated_key: 0, + task_expired: 0, + } + ); + + for case in [ + (TaskUploadIncrementor::IntervalCollected, 2), + (TaskUploadIncrementor::ReportDecodeFailure, 4), + (TaskUploadIncrementor::ReportDecryptFailure, 6), + (TaskUploadIncrementor::ReportExpired, 8), + (TaskUploadIncrementor::ReportOutdatedKey, 10), + (TaskUploadIncrementor::ReportSuccess, 100), + (TaskUploadIncrementor::ReportTooEarly, 25), + (TaskUploadIncrementor::TaskExpired, 12), + ] { + let ord = thread_rng().gen_range(0..32); + try_join_all( + (0..case.1) + .map(|_| tx.increment_task_upload_counter(&task_id, ord, &case.0)), + ) + .await + .unwrap(); + } + + let counter = tx.get_task_upload_counter(&task_id).await.unwrap(); + assert_eq!( + counter, + TaskUploadCounter { + task_id, + interval_collected: 2, + report_decode_failure: 4, + report_decrypt_failure: 6, + report_expired: 8, + report_outdated_key: 10, + report_success: 100, + report_too_early: 25, + task_expired: 12, + } + ); + + Ok(()) + }) + }) + .await + .unwrap(); +} diff --git a/db/00000000000001_initial_schema.down.sql b/db/00000000000001_initial_schema.down.sql index 6a1156840..d2d870567 100644 --- a/db/00000000000001_initial_schema.down.sql +++ b/db/00000000000001_initial_schema.down.sql @@ -25,6 +25,7 @@ DROP INDEX client_reports_task_and_timestamp_unaggregated_index CASCADE; DROP TABLE client_reports CASCADE; DROP TABLE task_hpke_keys CASCADE; DROP INDEX task_id_index CASCADE; +DROP TABLE task_upload_counters CASCADE; DROP TABLE tasks CASCADE; DROP TABLE taskprov_aggregator_auth_tokens; DROP TABLE taskprov_collector_auth_tokens; diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index 3b9f203c7..8ffa65dd7 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -133,6 +133,26 @@ CREATE TABLE tasks( ); CREATE INDEX task_id_index ON tasks(task_id); +-- Per task report upload counters. +CREATE TABLE task_upload_counters( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only + task_id BIGINT NOT NULL, + + interval_collected BIGINT NOT NULL DEFAULT 0, -- Reports submitted for an interval that was already collected. + report_decode_failure BIGINT NOT NULL DEFAULT 0, -- Reports which failed to decode. + report_decrypt_failure BIGINT NOT NULL DEFAULT 0, -- Reports which failed to decrypt. + report_expired BIGINT NOT NULL DEFAULT 0, -- Reports that were older than the task's report_expiry_age. + report_outdated_key BIGINT NOT NULL DEFAULT 0, -- Reports that were encrypted with an unknown or outdated HPKE key. + report_success BIGINT NOT NULL DEFAULT 0, -- Reports that were successfully uploaded. + report_too_early BIGINT NOT NULL DEFAULT 0, -- Reports whose timestamp is too far in the future. + task_expired BIGINT NOT NULL DEFAULT 0, -- Reports sent to the task while it is expired. + + ord BIGINT NOT NULL, -- Index of this task_upload_counters shard. + + CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE, + CONSTRAINT task_upload_counters_unique UNIQUE(task_id, ord) +); + -- The HPKE public keys (aka configs) and private keys used by a given task. CREATE TABLE task_hpke_keys( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only diff --git a/integration_tests/src/janus.rs b/integration_tests/src/janus.rs index 331a11c33..0d943376f 100644 --- a/integration_tests/src/janus.rs +++ b/integration_tests/src/janus.rs @@ -159,6 +159,7 @@ impl JanusInProcess { max_upload_batch_size: 100, max_upload_batch_write_delay_ms: 100, batch_aggregation_shard_count: 32, + task_counter_shard_count: 64, global_hpke_configs_refresh_interval: None, }; let aggregation_job_creator_options = AggregationJobCreatorOptions {