Skip to content

Commit

Permalink
Task rewrite: use AggregatorTask in query_type
Browse files Browse the repository at this point in the history
Adopts `janus_aggregator_core::task::AggregatorTask` in the
`janus_aggregator_core::query_type` and `janus_aggregator::query_type`
modules, making use of the routines for converting between
`janus_aggregator_core::Task` and
`janus_aggregator_core::AggregatorTask` to make this minimally
intrusive.

Part of #1524
  • Loading branch information
tgeoghegan committed Sep 29, 2023
1 parent 0bf9101 commit 6e60456
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 70 deletions.
32 changes: 18 additions & 14 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use janus_aggregator_core::{
Datastore, Error as DatastoreError, Transaction,
},
query_type::AccumulableQueryType,
task::{self, Task, VerifyKey},
task::{self, AggregatorTask, Task, VerifyKey},
taskprov::{self, PeerAggregator},
};
#[cfg(feature = "test-util")]
Expand Down Expand Up @@ -1612,7 +1612,7 @@ impl VdafOps {
for (ord, prepare_init) in req.prepare_inits().iter().enumerate() {
// Compute intervals for each batch identifier included in this aggregation job.
let batch_identifier = Q::to_batch_identifier(
&task,
&task.view_for_role()?,
req.batch_selector().batch_identifier(),
prepare_init.report_share().metadata().time(),
)?;
Expand Down Expand Up @@ -2270,8 +2270,9 @@ impl VdafOps {
}
}

let aggregator_task = task.view_for_role()?;
let collection_identifier =
Q::collection_identifier_for_query(tx, &task, req.query())
Q::collection_identifier_for_query(tx, &aggregator_task, req.query())
.await?
.ok_or_else(|| {
datastore::Error::User(
Expand All @@ -2285,7 +2286,8 @@ impl VdafOps {

// Check that the batch interval is valid for the task
// https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.5.6.1.1
if !Q::validate_collection_identifier(&task, &collection_identifier) {
if !Q::validate_collection_identifier(&aggregator_task, &collection_identifier)
{
return Err(datastore::Error::User(
Error::BatchInvalid(*task.id(), format!("{collection_identifier}"))
.into(),
Expand All @@ -2297,14 +2299,14 @@ impl VdafOps {
Q::validate_query_count::<SEED_SIZE, C, A>(
tx,
&vdaf,
&task,
&aggregator_task,
&collection_identifier,
&aggregation_param,
),
Q::count_client_reports(tx, &task, &collection_identifier),
Q::count_client_reports(tx, &aggregator_task, &collection_identifier),
try_join_all(
Q::batch_identifiers_for_collection_identifier(
&task,
&aggregator_task,
&collection_identifier
)
.map(|batch_identifier| {
Expand All @@ -2324,7 +2326,7 @@ impl VdafOps {
),
try_join_all(
Q::batch_identifiers_for_collection_identifier(
&task,
&aggregator_task,
&collection_identifier
)
.map(|batch_identifier| {
Expand Down Expand Up @@ -2544,10 +2546,11 @@ impl VdafOps {
)
})?;

let aggregator_task = task.view_for_role()?;
let (batches, _) = try_join!(
Q::get_batches_for_collection_identifier(
tx,
&task,
&aggregator_task,
collection_job.batch_identifier(),
collection_job.aggregation_parameter()
),
Expand Down Expand Up @@ -2802,7 +2805,7 @@ impl VdafOps {

// §4.4.4.3: check that the batch interval meets the requirements from §4.6
if !Q::validate_collection_identifier(
&task,
&task.view_for_role()?,
aggregate_share_req.batch_selector().batch_identifier(),
) {
return Err(Error::BatchInvalid(
Expand Down Expand Up @@ -2838,6 +2841,7 @@ impl VdafOps {
Arc::clone(&aggregate_share_req),
);
Box::pin(async move {
let aggregator_task = task.view_for_role()?;
// Check if we have already serviced an aggregate share request with these
// parameters and serve the cached results if so.
let aggregation_param = A::AggregationParam::get_decoded(
Expand Down Expand Up @@ -2870,15 +2874,15 @@ impl VdafOps {
let (batch_aggregations, _) = try_join!(
Q::get_batch_aggregations_for_collection_identifier(
tx,
&task,
&aggregator_task,
vdaf.as_ref(),
aggregate_share_req.batch_selector().batch_identifier(),
&aggregation_param
),
Q::validate_query_count::<SEED_SIZE, C, A>(
tx,
vdaf.as_ref(),
&task,
&aggregator_task,
aggregate_share_req.batch_selector().batch_identifier(),
&aggregation_param,
)
Expand All @@ -2888,7 +2892,7 @@ impl VdafOps {
// currently-nonexistent batch aggregation, we write (empty) batch
// aggregations for any that have not already been written to storage.
let empty_batch_aggregations = empty_batch_aggregations(
&task,
&aggregator_task,
batch_aggregation_shard_count,
aggregate_share_req.batch_selector().batch_identifier(),
&aggregation_param,
Expand Down Expand Up @@ -2982,7 +2986,7 @@ fn empty_batch_aggregations<
Q: CollectableQueryType,
A: vdaf::Aggregator<SEED_SIZE, 16> + Send + Sync + 'static,
>(
task: &Task,
task: &AggregatorTask,
batch_aggregation_shard_count: u64,
batch_identifier: &Q::BatchIdentifier,
aggregation_param: &A::AggregationParam,
Expand Down
2 changes: 1 addition & 1 deletion aggregator/src/aggregator/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<const SEED_SIZE: usize, Q: AccumulableQueryType, A: vdaf::Aggregator<SEED_S
output_share: &A::OutputShare,
) -> Result<(), datastore::Error> {
let batch_identifier =
Q::to_batch_identifier(&self.task, partial_batch_identifier, client_timestamp)?;
Q::to_batch_identifier(&self.task.view_for_role()?, partial_batch_identifier, client_timestamp)?;
let client_timestamp_interval =
Interval::from_time(client_timestamp).map_err(|e| datastore::Error::User(e.into()))?;
let batch_aggregation_fn = || {
Expand Down
20 changes: 15 additions & 5 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,12 @@ mod tests {
Role::Leader,
)
.build();
let batch_identifier =
TimeInterval::to_batch_identifier(&leader_task, &(), &report_time).unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(
&leader_task.view_for_role().unwrap(),
&(),
&report_time,
)
.unwrap();
let leader_report = LeaderStoredReport::new_dummy(*leader_task.id(), report_time);

let helper_task = TaskBuilder::new(
Expand Down Expand Up @@ -872,7 +876,9 @@ mod tests {
// Create 2 max-size batches, a min-size batch, one extra report (which will be added to the
// min-size batch).
let report_time = clock.now();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &report_time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &report_time)
.unwrap();
let reports: Vec<_> =
iter::repeat_with(|| LeaderStoredReport::new_dummy(*task.id(), report_time))
.take(2 * MAX_AGGREGATION_JOB_SIZE + MIN_AGGREGATION_JOB_SIZE + 1)
Expand Down Expand Up @@ -980,7 +986,9 @@ mod tests {
.build(),
);
let report_time = clock.now();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &report_time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &report_time)
.unwrap();
let first_report = LeaderStoredReport::new_dummy(*task.id(), report_time);
let second_report = LeaderStoredReport::new_dummy(*task.id(), report_time);

Expand Down Expand Up @@ -1113,7 +1121,9 @@ mod tests {

// Create a min-size batch.
let report_time = clock.now();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &report_time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &report_time)
.unwrap();
let reports: Vec<_> =
iter::repeat_with(|| LeaderStoredReport::new_dummy(*task.id(), report_time))
.take(2 * MAX_AGGREGATION_JOB_SIZE + MIN_AGGREGATION_JOB_SIZE + 1)
Expand Down
22 changes: 14 additions & 8 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ mod tests {
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &time).unwrap();
let report_metadata = ReportMetadata::new(random(), time);
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();
let measurement = IdpfInput::from_bools(&[true]);
Expand Down Expand Up @@ -1286,7 +1287,8 @@ mod tests {
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &time).unwrap();
let report_metadata = ReportMetadata::new(random(), time);
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();

Expand Down Expand Up @@ -1646,7 +1648,8 @@ mod tests {
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &time).unwrap();
let report_metadata = ReportMetadata::new(random(), time);
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();
let measurement = IdpfInput::from_bools(&[true]);
Expand Down Expand Up @@ -2404,7 +2407,8 @@ mod tests {
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let active_batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let active_batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &time).unwrap();
let other_batch_identifier = Interval::new(
active_batch_identifier
.start()
Expand Down Expand Up @@ -2727,7 +2731,7 @@ mod tests {
_,
>(
tx,
&task,
&task.view_for_role().unwrap(),
&vdaf,
&Interval::new(
report_metadata
Expand Down Expand Up @@ -3096,7 +3100,7 @@ mod tests {
VERIFY_KEY_LENGTH,
Poplar1<XofShake128, 16>,
_,
>(tx, &task, &vdaf, &batch_id, &aggregation_param)
>(tx, &task.view_for_role().unwrap(), &vdaf, &batch_id, &aggregation_param)
.await?;
let batch = tx
.get_batch(task.id(), &batch_id, &aggregation_param)
Expand Down Expand Up @@ -3162,7 +3166,8 @@ mod tests {
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &time).unwrap();
let report_metadata = ReportMetadata::new(random(), time);
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();

Expand Down Expand Up @@ -3370,7 +3375,8 @@ mod tests {
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let batch_identifier =
TimeInterval::to_batch_identifier(&task.view_for_role().unwrap(), &(), &time).unwrap();
let report_metadata = ReportMetadata::new(random(), time);
let transcript = run_vdaf(&vdaf, verify_key.as_bytes(), &(), report_metadata.id(), &0);
let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
Expand Down
10 changes: 7 additions & 3 deletions aggregator/src/aggregator/aggregation_job_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<const SEED_SIZE: usize, Q: CollectableQueryType, A: vdaf::Aggregator<SEED_S
.iter()
.map(|ra| {
Q::to_batch_identifier(
&self.task,
&self.task.view_for_role()?,
info.aggregation_job.partial_batch_identifier(),
ra.time(),
)
Expand Down Expand Up @@ -468,13 +468,14 @@ impl<const SEED_SIZE: usize, Q: CollectableQueryType, A: vdaf::Aggregator<SEED_S

// Find all batches which are relevant to a collection job that just had a batch move into
// CLOSED state.
let aggregator_task = self.task.view_for_role()?;
let relevant_batches: Arc<HashMap<_, _>> = Arc::new({
let batches = Arc::new(Mutex::new(batches));
let relevant_batch_identifiers: HashSet<_> = affected_collection_jobs
.values()
.flat_map(|collection_job| {
Q::batch_identifiers_for_collection_identifier(
&self.task,
&aggregator_task,
collection_job.batch_identifier(),
)
})
Expand Down Expand Up @@ -519,7 +520,10 @@ impl<const SEED_SIZE: usize, Q: CollectableQueryType, A: vdaf::Aggregator<SEED_S
async move {
let mut is_collectable = true;
for batch_identifier in Q::batch_identifiers_for_collection_identifier(
&self.task,
&self
.task
.view_for_role()
.map_err(|e| Error::User(e.into()))?,
collection_job.batch_identifier(),
) {
let batch = match relevant_batches.get(&batch_identifier) {
Expand Down
4 changes: 2 additions & 2 deletions aggregator/src/aggregator/collection_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl CollectionJobDriver {
let batch_aggregations: Vec<_> =
Q::get_batch_aggregations_for_collection_identifier(
tx,
&task,
&task.leader_view()?,
vdaf.as_ref(),
collection_job.batch_identifier(),
collection_job.aggregation_parameter(),
Expand All @@ -177,7 +177,7 @@ impl CollectionJobDriver {
// transactionally to avoid the possibility of overwriting other transactions'
// updates to batch aggregations.
let empty_batch_aggregations = empty_batch_aggregations(
&task,
&task.leader_view()?,
batch_aggregation_shard_count,
collection_job.batch_identifier(),
collection_job.aggregation_parameter(),
Expand Down
12 changes: 6 additions & 6 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,7 @@ mod tests {
.unwrap();
let second_batch_want_batch_aggregations =
empty_batch_aggregations::<VERIFY_KEY_LENGTH, TimeInterval, Poplar1<XofShake128, 16>>(
&task,
&task.view_for_role().unwrap(),
BATCH_AGGREGATION_SHARD_COUNT,
&second_batch_identifier,
&aggregation_param,
Expand Down Expand Up @@ -2969,7 +2969,7 @@ mod tests {
_,
>(
tx,
&task,
&task.view_for_role().unwrap(),
&vdaf,
&Interval::new(
report_metadata_0
Expand Down Expand Up @@ -3051,7 +3051,7 @@ mod tests {
_,
>(
tx,
&task,
&task.view_for_role().unwrap(),
&vdaf,
&Interval::new(
report_metadata_2
Expand Down Expand Up @@ -3272,7 +3272,7 @@ mod tests {
_,
>(
tx,
&task,
&task.view_for_role().unwrap(),
&vdaf,
&Interval::new(
report_metadata_0
Expand Down Expand Up @@ -3359,7 +3359,7 @@ mod tests {
_,
>(
tx,
&task,
&task.view_for_role().unwrap(),
&vdaf,
&Interval::new(
report_metadata_2
Expand Down Expand Up @@ -4338,7 +4338,7 @@ mod tests {
let test_case = setup_collection_job_test_case(Role::Leader, QueryType::TimeInterval).await;

let batch_interval = TimeInterval::to_batch_identifier(
&test_case.task,
&test_case.task.view_for_role().unwrap(),
&(),
&Time::from_seconds_since_epoch(0),
)
Expand Down
Loading

0 comments on commit 6e60456

Please sign in to comment.