diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index e1c9c57b8..4f2e912c6 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -4,6 +4,7 @@ use fixed::{ types::extra::{U15, U31}, FixedI16, FixedI32, }; +use futures::future::try_join_all; use janus_aggregator_core::{ datastore::models::{AggregationJob, AggregationJobState}, datastore::{self, Datastore}, @@ -33,8 +34,15 @@ use prio::{ }, }; use rand::{random, thread_rng, Rng}; -use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::time::{self, sleep_until, Instant, MissedTickBehavior}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, +}; +use tokio::{ + time::{self, sleep_until, Instant, MissedTickBehavior}, + try_join, +}; use tracing::{debug, error, info}; use trillium_tokio::{CloneCounterObserver, Stopper}; @@ -536,6 +544,7 @@ impl AggregationJobCreator { // Generate aggregation jobs & report aggregations based on the reports we read. let mut aggregation_job_writer = AggregationJobWriter::new(Arc::clone(&task)); + let mut report_ids_to_scrub = HashSet::new(); for agg_job_reports in reports.chunks(this.max_aggregation_job_size) { if agg_job_reports.len() < this.min_aggregation_job_size { if !agg_job_reports.is_empty() { @@ -593,12 +602,22 @@ impl AggregationJobCreator { )) }) .collect::>()?; + report_ids_to_scrub + .extend(agg_job_reports.iter().map(|report| *report.metadata().id())); aggregation_job_writer.put(aggregation_job, report_aggregations)?; } // Write the aggregation jobs & report aggregations we created. - aggregation_job_writer.write(tx, vdaf).await?; + try_join!( + aggregation_job_writer.write(tx, vdaf), + try_join_all( + report_ids_to_scrub + .iter() + .map(|report_id| tx.scrub_client_report(task.id(), report_id)) + ) + )?; + Ok(!aggregation_job_writer.is_empty()) }) }) @@ -2593,7 +2612,7 @@ mod tests { for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, A::PublicShare: PartialEq, { - try_join!( + let (agg_jobs_and_report_ids, batches) = try_join!( try_join_all( tx.get_aggregation_jobs_for_task(task_id) .await @@ -2610,7 +2629,7 @@ mod tests { .await .map(|report_aggs| { // Verify that each report aggregation has the expected state. - let report_ids = report_aggs + let report_ids: Vec<_> = report_aggs .into_iter() .map(|ra| { let want_ra_state = @@ -2632,6 +2651,29 @@ mod tests { ), tx.get_batches_for_task(task_id), ) - .unwrap() + .unwrap(); + + // Verify that all reports we saw a report aggregation for are scrubbed. + let all_seen_report_ids: HashSet<_> = agg_jobs_and_report_ids + .iter() + .map(|(_, report_ids)| report_ids.iter()) + .flatten() + .collect(); + for report_id in &all_seen_report_ids { + tx.verify_client_report_scrubbed(task_id, report_id).await; + } + + // Verify that all reports we did not see a report aggregation for are not scrubbed. (We do + // so by reading the report, since reading a report will fail if the report is scrubbed.) + for report_id in want_ra_states.keys() { + if all_seen_report_ids.contains(report_id) { + continue; + } + tx.get_client_report(vdaf, task_id, report_id) + .await + .unwrap(); + } + + (agg_jobs_and_report_ids, batches) } } diff --git a/aggregator/src/aggregator/batch_creator.rs b/aggregator/src/aggregator/batch_creator.rs index b5952079d..f23178a72 100644 --- a/aggregator/src/aggregator/batch_creator.rs +++ b/aggregator/src/aggregator/batch_creator.rs @@ -7,13 +7,13 @@ use janus_aggregator_core::datastore::{ }; use janus_core::time::{Clock, DurationExt, TimeExt}; use janus_messages::{ - query_type::FixedSize, AggregationJobStep, BatchId, Duration, Interval, TaskId, Time, + query_type::FixedSize, AggregationJobStep, BatchId, Duration, Interval, ReportId, TaskId, Time, }; use prio::{codec::Encode, vdaf::Aggregator}; use rand::random; use std::{ cmp::{max, min, Ordering}, - collections::{binary_heap::PeekMut, hash_map, BinaryHeap, HashMap, VecDeque}, + collections::{binary_heap::PeekMut, hash_map, BinaryHeap, HashMap, HashSet, VecDeque}, ops::RangeInclusive, sync::Arc, }; @@ -32,8 +32,9 @@ where { properties: Properties, aggregation_job_writer: &'a mut AggregationJobWriter, - map: HashMap, Bucket>, + buckets: HashMap, Bucket>, new_batches: Vec<(BatchId, Option