Skip to content

Commit

Permalink
Fixed size: make max_batch_size optional
Browse files Browse the repository at this point in the history
`draft-ietf-ppm-dap-09` makes the `max_batch_size` parameter for
fixed-size query tasks optional. If it's absent, we aim to create
batches of exactly `min_batch_size` reports.

Part of #2389
  • Loading branch information
tgeoghegan committed Jan 26, 2024
1 parent 5ccce50 commit 69d0273
Show file tree
Hide file tree
Showing 15 changed files with 294 additions and 70 deletions.
195 changes: 188 additions & 7 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
self: Arc<Self>,
task: Arc<AggregatorTask>,
vdaf: Arc<A>,
task_max_batch_size: u64,
task_max_batch_size: Option<u64>,
task_batch_time_window_size: Option<janus_messages::Duration>,
) -> anyhow::Result<bool>
where
Expand All @@ -646,7 +646,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
{
let (task_min_batch_size, task_max_batch_size) = (
usize::try_from(task.min_batch_size())?,
usize::try_from(task_max_batch_size)?,
task_max_batch_size.map(usize::try_from).transpose()?,
);
Ok(self
.datastore
Expand Down Expand Up @@ -1430,7 +1430,7 @@ mod tests {
let task = Arc::new(
TaskBuilder::new(
TaskQueryType::FixedSize {
max_batch_size: MAX_BATCH_SIZE as u64,
max_batch_size: Some(MAX_BATCH_SIZE as u64),
batch_time_window_size: None,
},
VdafInstance::Prio3Count,
Expand Down Expand Up @@ -1637,7 +1637,7 @@ mod tests {
let task = Arc::new(
TaskBuilder::new(
TaskQueryType::FixedSize {
max_batch_size: MAX_BATCH_SIZE as u64,
max_batch_size: Some(MAX_BATCH_SIZE as u64),
batch_time_window_size: None,
},
VdafInstance::Prio3Count,
Expand Down Expand Up @@ -1794,7 +1794,7 @@ mod tests {
let task = Arc::new(
TaskBuilder::new(
TaskQueryType::FixedSize {
max_batch_size: MAX_BATCH_SIZE as u64,
max_batch_size: Some(MAX_BATCH_SIZE as u64),
batch_time_window_size: None,
},
VdafInstance::Prio3Count,
Expand Down Expand Up @@ -2056,7 +2056,7 @@ mod tests {
let task = Arc::new(
TaskBuilder::new(
TaskQueryType::FixedSize {
max_batch_size: MAX_BATCH_SIZE as u64,
max_batch_size: Some(MAX_BATCH_SIZE as u64),
batch_time_window_size: None,
},
VdafInstance::Prio3Count,
Expand Down Expand Up @@ -2325,7 +2325,7 @@ mod tests {
let task = Arc::new(
TaskBuilder::new(
TaskQueryType::FixedSize {
max_batch_size: MAX_BATCH_SIZE as u64,
max_batch_size: Some(MAX_BATCH_SIZE as u64),
batch_time_window_size: Some(batch_time_window_size),
},
VdafInstance::Prio3Count,
Expand Down Expand Up @@ -2589,6 +2589,187 @@ mod tests {
);
}

#[tokio::test]
async fn create_aggregation_jobs_for_fixed_size_task_no_max_batch_size() {
// Setup.
install_test_trace_subscriber();
let clock: MockClock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let ds = ephemeral_datastore.datastore(clock.clone()).await;

const MIN_AGGREGATION_JOB_SIZE: usize = 50;
const MAX_AGGREGATION_JOB_SIZE: usize = 60;
const MIN_BATCH_SIZE: usize = 200;

let task = Arc::new(
TaskBuilder::new(
TaskQueryType::FixedSize {
max_batch_size: None,
batch_time_window_size: None,
},
VdafInstance::Prio3Count,
)
.with_min_batch_size(MIN_BATCH_SIZE as u64)
.build()
.leader_view()
.unwrap(),
);

// Create MIN_BATCH_SIZE + MIN_BATCH_SIZE + MIN_AGGREGATION_JOB_SIZE reports. We expect
// aggregation jobs to be created containing all these reports, but only two batches.
let report_time = clock.now();
let vdaf = Arc::new(Prio3::new_count(2).unwrap());
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let reports: Arc<Vec<_>> = Arc::new(
iter::repeat_with(|| {
let report_metadata = ReportMetadata::new(random(), report_time);
let transcript = run_vdaf(
vdaf.as_ref(),
task.vdaf_verify_key().unwrap().as_bytes(),
&(),
report_metadata.id(),
&false,
);
LeaderStoredReport::generate(
*task.id(),
report_metadata,
helper_hpke_keypair.config(),
Vec::new(),
&transcript,
)
})
.take(MIN_BATCH_SIZE + MIN_BATCH_SIZE + MIN_AGGREGATION_JOB_SIZE)
.collect(),
);

let report_ids: HashSet<ReportId> = reports
.iter()
.map(|report| *report.metadata().id())
.collect();

ds.run_unnamed_tx(|tx| {
let task = Arc::clone(&task);
let vdaf = Arc::clone(&vdaf);
let reports = Arc::clone(&reports);

Box::pin(async move {
tx.put_aggregator_task(&task).await.unwrap();
for report in reports.iter() {
tx.put_client_report(vdaf.as_ref(), report).await.unwrap();
}
Ok(())
})
})
.await
.unwrap();

// Run.
let job_creator = Arc::new(AggregationJobCreator::new(
ds,
noop_meter(),
Duration::from_secs(3600),
Duration::from_secs(1),
MIN_AGGREGATION_JOB_SIZE,
MAX_AGGREGATION_JOB_SIZE,
));
Arc::clone(&job_creator)
.create_aggregation_jobs_for_task(Arc::clone(&task))
.await
.unwrap();

// Verify.
let want_ra_states: Arc<HashMap<_, _>> = Arc::new(
reports
.iter()
.map(|report| {
(
*report.metadata().id(),
report
.as_start_leader_report_aggregation(random(), 0)
.state()
.clone(),
)
})
.collect(),
);
let (outstanding_batches, (agg_jobs, _)) =
job_creator
.datastore
.run_unnamed_tx(|tx| {
let task = Arc::clone(&task);
let vdaf = Arc::clone(&vdaf);
let want_ra_states = Arc::clone(&want_ra_states);

Box::pin(async move {
Ok((
tx.get_outstanding_batches(task.id(), &None).await.unwrap(),
read_and_verify_aggregate_info_for_task::<
VERIFY_KEY_LENGTH,
FixedSize,
_,
_,
>(
tx, vdaf.as_ref(), task.id(), want_ra_states.as_ref()
)
.await,
))
})
})
.await
.unwrap();

// Verify outstanding batches.
let mut total_max_size = 0;
println!("{outstanding_batches:?}");

for outstanding_batch in &outstanding_batches {
assert_eq!(outstanding_batch.size().start(), &0);
assert!(
outstanding_batch.size().end() == &MIN_BATCH_SIZE
|| outstanding_batch.size().end() == &MIN_AGGREGATION_JOB_SIZE
);
total_max_size += *outstanding_batch.size().end();
}
assert_eq!(
total_max_size,
2 * MIN_BATCH_SIZE + MIN_AGGREGATION_JOB_SIZE
);
let batch_ids: HashSet<_> = outstanding_batches
.iter()
.map(|outstanding_batch| *outstanding_batch.id())
.collect();

// Verify aggregation jobs.
let mut seen_report_ids = HashSet::new();
let mut batches_with_small_agg_jobs = HashSet::new();
for (agg_job, report_ids) in agg_jobs {
// Aggregation jobs are created in step 0.
assert_eq!(agg_job.step(), AggregationJobStep::from(0));

// Every batch corresponds to one of the outstanding batches.
assert!(batch_ids.contains(agg_job.batch_id()));

// At most one aggregation job per batch will be smaller than the normal minimum
// aggregation job size.
if report_ids.len() < MIN_AGGREGATION_JOB_SIZE {
assert!(!batches_with_small_agg_jobs.contains(agg_job.batch_id()));
batches_with_small_agg_jobs.insert(*agg_job.batch_id());
}

// The aggregation job is at most MAX_AGGREGATION_JOB_SIZE in size.
assert!(report_ids.len() <= MAX_AGGREGATION_JOB_SIZE);

// Report IDs are not repeated across or inside aggregation jobs.
for report_id in report_ids {
assert!(!seen_report_ids.contains(&report_id));
seen_report_ids.insert(report_id);
}
}

// Every client report was added to some aggregation job.
assert_eq!(report_ids, seen_report_ids);
}

/// Test helper function that reads all aggregation jobs & batches for a given task ID,
/// returning the aggregation jobs, the report IDs included in the aggregation job, and the
/// batches. Report IDs are returned in the order they are included in the aggregation job, and
Expand Down
6 changes: 3 additions & 3 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2246,7 +2246,7 @@ mod tests {

let task = TaskBuilder::new(
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
VdafInstance::Prio3Count,
Expand Down Expand Up @@ -2501,7 +2501,7 @@ mod tests {

let task = TaskBuilder::new(
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
VdafInstance::Poplar1 { bits: 1 },
Expand Down Expand Up @@ -3171,7 +3171,7 @@ mod tests {

let task = TaskBuilder::new(
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
VdafInstance::Poplar1 { bits: 1 },
Expand Down
23 changes: 16 additions & 7 deletions aggregator/src/aggregator/batch_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct Properties {
max_aggregation_job_size: usize,
task_id: TaskId,
task_min_batch_size: usize,
task_max_batch_size: usize,
effective_task_max_batch_size: usize,
task_batch_time_window_size: Option<Duration>,
}

Expand All @@ -59,7 +59,7 @@ where
max_aggregation_job_size: usize,
task_id: TaskId,
task_min_batch_size: usize,
task_max_batch_size: usize,
task_max_batch_size: Option<usize>,
task_batch_time_window_size: Option<Duration>,
aggregation_job_writer: &'a mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
) -> Self {
Expand All @@ -69,7 +69,13 @@ where
max_aggregation_job_size,
task_id,
task_min_batch_size,
task_max_batch_size,
// If the task has no explicit max_batch_size set, then our goal is to create
// batches of exactly min_batch_size reports, so we use that value as the effective
// maximum batch size, but we may create batches which exceed this size. See
// process_batches, below.
//
// https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-09#section-4.1.2-6
effective_task_max_batch_size: task_max_batch_size.unwrap_or(task_min_batch_size),
task_batch_time_window_size,
},
aggregation_job_writer,
Expand Down Expand Up @@ -153,7 +159,8 @@ where
return Ok(());
}
// Discard any outstanding batches that do not currently have room for more reports.
if largest_outstanding_batch.max_size() >= properties.task_max_batch_size {
if largest_outstanding_batch.max_size() >= properties.effective_task_max_batch_size
{
PeekMut::pop(largest_outstanding_batch);
continue;
}
Expand All @@ -164,7 +171,8 @@ where
bucket.unaggregated_reports.len(),
properties.max_aggregation_job_size,
),
properties.task_max_batch_size - largest_outstanding_batch.max_size(),
properties.effective_task_max_batch_size
- largest_outstanding_batch.max_size(),
);
if (desired_aggregation_job_size >= properties.min_aggregation_job_size)
|| (largest_outstanding_batch.max_size() < properties.task_min_batch_size
Expand Down Expand Up @@ -205,7 +213,8 @@ where
// any more.
let desired_aggregation_job_size = min(
properties.max_aggregation_job_size,
properties.task_max_batch_size - largest_outstanding_batch.max_size(),
properties.effective_task_max_batch_size
- largest_outstanding_batch.max_size(),
);
if bucket.unaggregated_reports.len() >= desired_aggregation_job_size {
Self::create_aggregation_job(
Expand Down Expand Up @@ -236,7 +245,7 @@ where
bucket.unaggregated_reports.len(),
properties.max_aggregation_job_size,
),
properties.task_max_batch_size,
properties.effective_task_max_batch_size,
);
if desired_aggregation_job_size >= new_batch_threshold {
let batch_id = random();
Expand Down
8 changes: 4 additions & 4 deletions aggregator/src/aggregator/collection_job_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async fn setup_fixed_size_current_batch_collection_job_test_case(
let test_case = setup_collection_job_test_case(
Role::Leader,
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
)
Expand Down Expand Up @@ -792,7 +792,7 @@ async fn collection_job_put_idempotence_fixed_size_by_batch_id() {
let test_case = setup_collection_job_test_case(
Role::Leader,
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
)
Expand Down Expand Up @@ -845,7 +845,7 @@ async fn collection_job_put_idempotence_fixed_size_by_batch_id_mutate_batch_id()
let test_case = setup_collection_job_test_case(
Role::Leader,
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
)
Expand Down Expand Up @@ -914,7 +914,7 @@ async fn collection_job_put_idempotence_fixed_size_by_batch_id_mutate_aggregatio
let test_case = setup_collection_job_test_case(
Role::Leader,
QueryType::FixedSize {
max_batch_size: 10,
max_batch_size: Some(10),
batch_time_window_size: None,
},
)
Expand Down
Loading

0 comments on commit 69d0273

Please sign in to comment.