Skip to content

Commit

Permalink
Scrub reports during aggregation job creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
branlwyd committed Jan 25, 2024
1 parent 1fac6e7 commit 7e5dbf7
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 37 deletions.
54 changes: 48 additions & 6 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use fixed::{
types::extra::{U15, U31},
FixedI16, FixedI32,
};
use futures::future::try_join_all;
use janus_aggregator_core::{
datastore::models::{AggregationJob, AggregationJobState},
datastore::{self, Datastore},
Expand Down Expand Up @@ -33,8 +34,15 @@ use prio::{
},
};
use rand::{random, thread_rng, Rng};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::time::{self, sleep_until, Instant, MissedTickBehavior};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{
time::{self, sleep_until, Instant, MissedTickBehavior},
try_join,
};
use tracing::{debug, error, info};
use trillium_tokio::{CloneCounterObserver, Stopper};

Expand Down Expand Up @@ -536,6 +544,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {

// Generate aggregation jobs & report aggregations based on the reports we read.
let mut aggregation_job_writer = AggregationJobWriter::new(Arc::clone(&task));
let mut report_ids_to_scrub = HashSet::new();
for agg_job_reports in reports.chunks(this.max_aggregation_job_size) {
if agg_job_reports.len() < this.min_aggregation_job_size {
if !agg_job_reports.is_empty() {
Expand Down Expand Up @@ -593,12 +602,22 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
))
})
.collect::<Result<_, datastore::Error>>()?;
report_ids_to_scrub
.extend(agg_job_reports.iter().map(|report| *report.metadata().id()));

aggregation_job_writer.put(aggregation_job, report_aggregations)?;
}

// Write the aggregation jobs & report aggregations we created.
aggregation_job_writer.write(tx, vdaf).await?;
try_join!(
aggregation_job_writer.write(tx, vdaf),
try_join_all(
report_ids_to_scrub
.iter()
.map(|report_id| tx.scrub_client_report(task.id(), report_id))
)
)?;

Ok(!aggregation_job_writer.is_empty())
})
})
Expand Down Expand Up @@ -2593,7 +2612,7 @@ mod tests {
for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>,
A::PublicShare: PartialEq,
{
try_join!(
let (agg_jobs_and_report_ids, batches) = try_join!(
try_join_all(
tx.get_aggregation_jobs_for_task(task_id)
.await
Expand All @@ -2610,7 +2629,7 @@ mod tests {
.await
.map(|report_aggs| {
// Verify that each report aggregation has the expected state.
let report_ids = report_aggs
let report_ids: Vec<_> = report_aggs
.into_iter()
.map(|ra| {
let want_ra_state =
Expand All @@ -2632,6 +2651,29 @@ mod tests {
),
tx.get_batches_for_task(task_id),
)
.unwrap()
.unwrap();

// Verify that all reports we saw a report aggregation for are scrubbed.
let all_seen_report_ids: HashSet<_> = agg_jobs_and_report_ids
.iter()
.map(|(_, report_ids)| report_ids.iter())
.flatten()
.collect();
for report_id in &all_seen_report_ids {
tx.verify_client_report_scrubbed(task_id, report_id).await;
}

// Verify that all reports we did not see a report aggregation for are not scrubbed. (We do
// so by reading the report, since reading a report will fail if the report is scrubbed.)
for report_id in want_ra_states.keys() {
if all_seen_report_ids.contains(report_id) {
continue;
}
tx.get_client_report(vdaf, task_id, report_id)
.await
.unwrap();
}

(agg_jobs_and_report_ids, batches)
}
}
31 changes: 23 additions & 8 deletions aggregator/src/aggregator/batch_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ use janus_aggregator_core::datastore::{
};
use janus_core::time::{Clock, DurationExt, TimeExt};
use janus_messages::{
query_type::FixedSize, AggregationJobStep, BatchId, Duration, Interval, TaskId, Time,
query_type::FixedSize, AggregationJobStep, BatchId, Duration, Interval, ReportId, TaskId, Time,
};
use prio::{codec::Encode, vdaf::Aggregator};
use rand::random;
use std::{
cmp::{max, min, Ordering},
collections::{binary_heap::PeekMut, hash_map, BinaryHeap, HashMap, VecDeque},
collections::{binary_heap::PeekMut, hash_map, BinaryHeap, HashMap, HashSet, VecDeque},
ops::RangeInclusive,
sync::Arc,
};
Expand All @@ -32,8 +32,9 @@ where
{
properties: Properties,
aggregation_job_writer: &'a mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
map: HashMap<Option<Time>, Bucket<SEED_SIZE, A>>,
buckets: HashMap<Option<Time>, Bucket<SEED_SIZE, A>>,
new_batches: Vec<(BatchId, Option<Time>)>,
report_ids_to_scrub: HashSet<ReportId>,
}

/// Common properties used by [`BatchCreator`]. This is broken out into a separate structure to make
Expand Down Expand Up @@ -72,8 +73,9 @@ where
task_batch_time_window_size,
},
aggregation_job_writer,
map: HashMap::new(),
buckets: HashMap::new(),
new_batches: Vec::new(),
report_ids_to_scrub: HashSet::new(),
}
}

Expand All @@ -95,15 +97,15 @@ where
.to_batch_interval_start(&batch_time_window_size)
})
.transpose()?;
let mut map_entry = self.map.entry(time_bucket_start_opt);
let mut map_entry = self.buckets.entry(time_bucket_start_opt);
let bucket = match &mut map_entry {
hash_map::Entry::Occupied(occupied) => occupied.get_mut(),
hash_map::Entry::Vacant(_) => {
// Lazily find existing unfilled batches.
let outstanding_batches = tx
.get_outstanding_batches(&self.properties.task_id, &time_bucket_start_opt)
.await?;
self.map
self.buckets
.entry(time_bucket_start_opt)
.or_insert_with(|| Bucket::new(outstanding_batches))
}
Expand All @@ -115,6 +117,7 @@ where
Self::process_batches(
&self.properties,
self.aggregation_job_writer,
&mut self.report_ids_to_scrub,
&mut self.new_batches,
&time_bucket_start_opt,
bucket,
Expand All @@ -136,6 +139,7 @@ where
fn process_batches(
properties: &Properties,
aggregation_job_writer: &mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
report_ids_to_scrub: &mut HashSet<ReportId>,
new_batches: &mut Vec<(BatchId, Option<Time>)>,
time_bucket_start: &Option<Time>,
bucket: &mut Bucket<SEED_SIZE, A>,
Expand Down Expand Up @@ -182,6 +186,7 @@ where
desired_aggregation_job_size,
&mut bucket.unaggregated_reports,
aggregation_job_writer,
report_ids_to_scrub,
)?;
largest_outstanding_batch.add_reports(desired_aggregation_job_size);
} else {
Expand Down Expand Up @@ -209,6 +214,7 @@ where
desired_aggregation_job_size,
&mut bucket.unaggregated_reports,
aggregation_job_writer,
report_ids_to_scrub,
)?;
largest_outstanding_batch.add_reports(desired_aggregation_job_size);
} else {
Expand Down Expand Up @@ -249,6 +255,7 @@ where
desired_aggregation_job_size,
&mut bucket.unaggregated_reports,
aggregation_job_writer,
report_ids_to_scrub,
)?;

// Loop to the top of this method to create more aggregation jobs in this newly
Expand All @@ -268,6 +275,7 @@ where
aggregation_job_size: usize,
unaggregated_reports: &mut VecDeque<LeaderStoredReport<SEED_SIZE, A>>,
aggregation_job_writer: &mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
report_ids_to_scrub: &mut HashSet<ReportId>,
) -> Result<(), Error> {
let aggregation_job_id = random();
debug!(
Expand All @@ -280,7 +288,7 @@ where
let mut min_client_timestamp = None;
let mut max_client_timestamp = None;

let report_aggregations = (0u64..)
let report_aggregations: Vec<_> = (0u64..)
.zip(unaggregated_reports.drain(..aggregation_job_size))
.map(|(ord, report)| {
let client_timestamp = *report.metadata().time();
Expand All @@ -294,6 +302,7 @@ where
report.as_start_leader_report_aggregation(aggregation_job_id, ord)
})
.collect();
report_ids_to_scrub.extend(report_aggregations.iter().map(|ra| *ra.report_id()));

let min_client_timestamp = min_client_timestamp.unwrap(); // unwrap safety: aggregation_job_size > 0
let max_client_timestamp = max_client_timestamp.unwrap(); // unwrap safety: aggregation_job_size > 0
Expand Down Expand Up @@ -329,10 +338,11 @@ where
// be smaller than max_aggregation_job_size. We will only create jobs smaller than
// min_aggregation_job_size if the remaining headroom in a batch requires it, otherwise
// remaining reports will be added to unaggregated_report_ids, to be marked as unaggregated.
for (time_bucket_start, mut bucket) in self.map.into_iter() {
for (time_bucket_start, mut bucket) in self.buckets.into_iter() {
Self::process_batches(
&self.properties,
self.aggregation_job_writer,
&mut self.report_ids_to_scrub,
&mut self.new_batches,
&time_bucket_start,
&mut bucket,
Expand All @@ -348,6 +358,11 @@ where

try_join!(
self.aggregation_job_writer.write(tx, vdaf),
try_join_all(
self.report_ids_to_scrub
.iter()
.map(|report_id| tx.scrub_client_report(&self.properties.task_id, report_id))
),
try_join_all(
self.new_batches
.iter()
Expand Down
28 changes: 28 additions & 0 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,34 @@ impl<C: Clock> Transaction<'_, C> {
)
}

#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub async fn verify_client_report_scrubbed(&self, task_id: &TaskId, report_id: &ReportId) {
let row = self
.query_one(
"SELECT
client_reports.extensions,
client_reports.public_share,
client_reports.leader_input_share,
client_reports.helper_encrypted_input_share
FROM client_reports
JOIN tasks ON tasks.id = client_reports.task_id
WHERE tasks.task_id = $1
AND client_reports.report_id = $2",
&[task_id.as_ref(), report_id.as_ref()],
)
.await
.unwrap();

assert_eq!(row.get::<_, Option<Vec<u8>>>("extensions"), None);
assert_eq!(row.get::<_, Option<Vec<u8>>>("public_share"), None);
assert_eq!(row.get::<_, Option<Vec<u8>>>("leader_input_share"), None);
assert_eq!(
row.get::<_, Option<Vec<u8>>>("helper_encrypted_input_share"),
None
);
}

/// put_report_share stores a report share, given its associated task ID.
///
/// This method is intended for use by aggregators acting in the Helper role; notably, it does
Expand Down
24 changes: 1 addition & 23 deletions aggregator_core/src/datastore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,29 +660,7 @@ async fn roundtrip_report(ephemeral_datastore: EphemeralDatastore) {
Box::pin(async move {
tx.scrub_client_report(&task_id, &report_id).await.unwrap();

let row = tx
.query_one(
"SELECT
client_reports.extensions,
client_reports.public_share,
client_reports.leader_input_share,
client_reports.helper_encrypted_input_share
FROM client_reports
JOIN tasks ON tasks.id = client_reports.task_id
WHERE tasks.task_id = $1
AND client_reports.report_id = $2",
&[&task_id.as_ref(), &report_id.as_ref()],
)
.await
.unwrap();

assert_eq!(row.get::<_, Option<Vec<u8>>>("extensions"), None);
assert_eq!(row.get::<_, Option<Vec<u8>>>("public_share"), None);
assert_eq!(row.get::<_, Option<Vec<u8>>>("leader_input_share"), None);
assert_eq!(
row.get::<_, Option<Vec<u8>>>("helper_encrypted_input_share"),
None
);
tx.verify_client_report_scrubbed(&task_id, &report_id).await;

assert_matches!(
tx.get_client_report::<0, dummy_vdaf::Vdaf>(
Expand Down

0 comments on commit 7e5dbf7

Please sign in to comment.