Skip to content

Commit

Permalink
Message & request changes for async aggregation/collection.
Browse files Browse the repository at this point in the history
Implementation of asynchronous aggregation is forthcoming. With this
change, we use the Pending/Finished statuses for aggregation/collection
as-specified, but both the Leader & Helper support only synchronous
aggregation.
  • Loading branch information
branlwyd committed Nov 27, 2024
1 parent a2c0398 commit 40b7763
Show file tree
Hide file tree
Showing 18 changed files with 844 additions and 566 deletions.
100 changes: 53 additions & 47 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ use janus_messages::{
taskprov::TaskConfig,
AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq,
AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobStep,
BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, ExtensionType, HpkeConfig,
HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare,
BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp, Duration, ExtensionType,
HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare,
PrepareResp, PrepareStepResult, Report, ReportError, ReportIdChecksum, ReportShare, Role,
TaskId,
};
Expand Down Expand Up @@ -529,14 +529,14 @@ impl<C: Clock> Aggregator<C> {
}

/// Handle a collection job creation request. Only supported by the leader. `req_bytes` is an
/// encoded [`CollectionReq`].
/// encoded [`CollectionJobReq`]. Returns an encoded [`CollectionJobResp`] on success.
async fn handle_create_collection_job(
&self,
task_id: &TaskId,
collection_job_id: &CollectionJobId,
req_bytes: &[u8],
auth_token: Option<AuthenticationToken>,
) -> Result<(), Error> {
) -> Result<Vec<u8>, Error> {
let task_aggregator = self
.task_aggregators
.get(task_id)
Expand Down Expand Up @@ -566,7 +566,7 @@ impl<C: Clock> Aggregator<C> {
task_id: &TaskId,
collection_job_id: &CollectionJobId,
auth_token: Option<AuthenticationToken>,
) -> Result<Option<Vec<u8>>, Error> {
) -> Result<Vec<u8>, Error> {
let task_aggregator = self
.task_aggregators
.get(task_id)
Expand Down Expand Up @@ -1034,7 +1034,7 @@ impl<C: Clock> TaskAggregator<C> {
datastore: &Datastore<C>,
collection_job_id: &CollectionJobId,
req_bytes: &[u8],
) -> Result<(), Error> {
) -> Result<Vec<u8>, Error> {
self.vdaf_ops
.handle_create_collection_job(
datastore,
Expand All @@ -1049,7 +1049,7 @@ impl<C: Clock> TaskAggregator<C> {
&self,
datastore: &Datastore<C>,
collection_job_id: &CollectionJobId,
) -> Result<Option<Vec<u8>>, Error> {
) -> Result<Vec<u8>, Error> {
self.vdaf_ops
.handle_get_collection_job(datastore, Arc::clone(&self.task), collection_job_id)
.await
Expand Down Expand Up @@ -1834,19 +1834,20 @@ impl VdafOps {
}

// This is a repeated request. Send the same response we computed last time.
return Ok(Some(AggregationJobResp::new(
tx.get_report_aggregations_for_aggregation_job(
vdaf,
&Role::Helper,
task_id,
aggregation_job_id,
)
.await?
.iter()
.filter_map(ReportAggregation::last_prep_resp)
.cloned()
.collect(),
)));
return Ok(Some(AggregationJobResp::Finished {
prepare_resps: tx
.get_report_aggregations_for_aggregation_job(
vdaf,
&Role::Helper,
task_id,
aggregation_job_id,
)
.await?
.iter()
.filter_map(ReportAggregation::last_prep_resp)
.cloned()
.collect(),
}));
}

Ok(None)
Expand Down Expand Up @@ -2358,11 +2359,11 @@ impl VdafOps {
let (mut prep_resps_by_agg_job, counters) =
aggregation_job_writer.write(tx, vdaf).await?;
Ok((
AggregationJobResp::new(
prep_resps_by_agg_job
AggregationJobResp::Finished {
prepare_resps: prep_resps_by_agg_job
.remove(aggregation_job.id())
.unwrap_or_default(),
),
},
counters,
))
})
Expand Down Expand Up @@ -2473,13 +2474,13 @@ impl VdafOps {
}
}
return Ok((
AggregationJobResp::new(
report_aggregations
AggregationJobResp::Finished {
prepare_resps: report_aggregations
.iter()
.filter_map(ReportAggregation::last_prep_resp)
.cloned()
.collect(),
),
},
TaskAggregationCounter::default(),
));
} else if aggregation_job.step().increment() != req.step() {
Expand Down Expand Up @@ -2574,7 +2575,7 @@ impl VdafOps {
task: Arc<AggregatorTask>,
collection_job_id: &CollectionJobId,
collection_req_bytes: &[u8],
) -> Result<(), Error> {
) -> Result<Vec<u8>, Error> {
match task.batch_mode() {
task::BatchMode::TimeInterval => {
vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => {
Expand Down Expand Up @@ -2612,19 +2613,19 @@ impl VdafOps {
vdaf: Arc<A>,
collection_job_id: &CollectionJobId,
req_bytes: &[u8],
) -> Result<(), Error>
) -> Result<Vec<u8>, Error>
where
A::AggregationParam: 'static + Send + Sync + PartialEq + Eq + Hash,
A::AggregateShare: Send + Sync,
{
let req =
Arc::new(CollectionReq::<B>::get_decoded(req_bytes).map_err(Error::MessageDecode)?);
Arc::new(CollectionJobReq::<B>::get_decoded(req_bytes).map_err(Error::MessageDecode)?);
let aggregation_param = Arc::new(
A::AggregationParam::get_decoded(req.aggregation_parameter())
.map_err(Error::MessageDecode)?,
);

Ok(datastore
datastore
.run_tx("collect", move |tx| {
let (task, vdaf, collection_job_id, req, aggregation_param) = (
Arc::clone(&task),
Expand Down Expand Up @@ -2715,7 +2716,11 @@ impl VdafOps {
Ok(())
})
})
.await?)
.await?;

CollectionJobResp::<B>::Processing
.get_encoded()
.map_err(Error::MessageEncode)
}

/// Handle GET requests to the leader's `tasks/{task-id}/collection_jobs/{collection-job-id}`
Expand All @@ -2727,7 +2732,7 @@ impl VdafOps {
datastore: &Datastore<C>,
task: Arc<AggregatorTask>,
collection_job_id: &CollectionJobId,
) -> Result<Option<Vec<u8>>, Error> {
) -> Result<Vec<u8>, Error> {
match task.batch_mode() {
task::BatchMode::TimeInterval => {
vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => {
Expand Down Expand Up @@ -2765,7 +2770,7 @@ impl VdafOps {
task: Arc<AggregatorTask>,
vdaf: Arc<A>,
collection_job_id: &CollectionJobId,
) -> Result<Option<Vec<u8>>, Error>
) -> Result<Vec<u8>, Error>
where
A::AggregationParam: Send + Sync,
A::AggregateShare: Send + Sync,
Expand All @@ -2790,7 +2795,7 @@ impl VdafOps {
match collection_job.state() {
CollectionJobState::Start => {
debug!(%collection_job_id, task_id = %task.id(), "collection job has not run yet");
Ok(None)
Ok(CollectionJobResp::<B>::Processing)
}

CollectionJobState::Finished {
Expand Down Expand Up @@ -2839,19 +2844,15 @@ impl VdafOps {
.map_err(Error::MessageEncode)?,
)?;

Ok(Some(
Collection::<B>::new(
PartialBatchSelector::new(
B::partial_batch_identifier(collection_job.batch_identifier()).clone(),
),
*report_count,
*client_timestamp_interval,
encrypted_leader_aggregate_share,
encrypted_helper_aggregate_share.clone(),
)
.get_encoded()
.map_err(Error::MessageEncode)?,
))
Ok(CollectionJobResp::<B>::Finished {
partial_batch_selector: PartialBatchSelector::new(
B::partial_batch_identifier(collection_job.batch_identifier()).clone(),
),
report_count: *report_count,
interval: *client_timestamp_interval,
leader_encrypted_agg_share: encrypted_leader_aggregate_share,
helper_encrypted_agg_share: encrypted_helper_aggregate_share.clone(),
})
}

CollectionJobState::Abandoned => Err(Error::AbandonedCollectionJob(
Expand All @@ -2863,6 +2864,11 @@ impl VdafOps {
Err(Error::DeletedCollectionJob(*task.id(), *collection_job_id))
}
}
.and_then(|collection_job_resp| {
collection_job_resp
.get_encoded()
.map_err(Error::MessageEncode)
})
}

#[tracing::instrument(skip(self, datastore, task), fields(task_id = ?task.id()), err(level = Level::DEBUG))]
Expand Down
44 changes: 27 additions & 17 deletions aggregator/src/aggregator/aggregate_init_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,23 @@ async fn setup_aggregate_init_test_for_vdaf<
&test_case.handler,
)
.await;
assert_eq!(response.status(), Some(Status::Ok));
assert_eq!(response.status(), Some(Status::Created));

let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await;
let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await;
let prepare_resps = assert_matches!(
&aggregation_job_resp,
AggregationJobResp::Finished { prepare_resps } => prepare_resps
);
assert_eq!(
aggregation_job_init_resp.prepare_resps().len(),
prepare_resps.len(),
test_case.aggregation_job_init_req.prepare_inits().len(),
);
assert_matches!(
aggregation_job_init_resp.prepare_resps()[0].result(),
prepare_resps[0].result(),
&PrepareStepResult::Continue { .. }
);

test_case.aggregation_job_init_resp = Some(aggregation_job_init_resp);
test_case.aggregation_job_init_resp = Some(aggregation_job_resp);
test_case
}

Expand Down Expand Up @@ -345,7 +349,7 @@ async fn aggregation_job_init_authorization_dap_auth_token() {
.run_async(&test_case.handler)
.await;

assert_eq!(response.status(), Some(Status::Ok));
assert_eq!(response.status(), Some(Status::Created));
}

#[rstest::rstest]
Expand Down Expand Up @@ -420,12 +424,14 @@ async fn aggregation_job_init_unexpected_taskprov_extension() {
&test_case.handler,
)
.await;
assert_eq!(response.status(), Some(Status::Ok));

let want_aggregation_job_resp = AggregationJobResp::new(Vec::from([PrepareResp::new(
report_id,
PrepareStepResult::Reject(ReportError::InvalidMessage),
)]));
assert_eq!(response.status(), Some(Status::Created));

let want_aggregation_job_resp = AggregationJobResp::Finished {
prepare_resps: Vec::from([PrepareResp::new(
report_id,
PrepareStepResult::Reject(ReportError::InvalidMessage),
)]),
};
let got_aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await;
assert_eq!(want_aggregation_job_resp, got_aggregation_job_resp);
}
Expand Down Expand Up @@ -589,19 +595,23 @@ async fn aggregation_job_intolerable_clock_skew() {
&test_case.handler,
)
.await;
assert_eq!(response.status(), Some(Status::Ok));
assert_eq!(response.status(), Some(Status::Created));

let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await;
let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await;
let prepare_resps = assert_matches!(
aggregation_job_resp,
AggregationJobResp::Finished { prepare_resps } => prepare_resps
);
assert_eq!(
aggregation_job_init_resp.prepare_resps().len(),
prepare_resps.len(),
test_case.aggregation_job_init_req.prepare_inits().len(),
);
assert_matches!(
aggregation_job_init_resp.prepare_resps()[0].result(),
prepare_resps[0].result(),
&PrepareStepResult::Continue { .. }
);
assert_matches!(
aggregation_job_init_resp.prepare_resps()[1].result(),
prepare_resps[1].result(),
&PrepareStepResult::Reject(ReportError::ReportTooEarly)
);
}
Expand Down
14 changes: 7 additions & 7 deletions aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,11 @@ impl VdafOps {
aggregation_job_writer.put(aggregation_job, report_aggregations_to_write)?;
let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?;
Ok((
AggregationJobResp::new(
prep_resps_by_agg_job
AggregationJobResp::Finished {
prepare_resps: prep_resps_by_agg_job
.remove(&aggregation_job_id)
.unwrap_or_default(),
),
},
counters,
))
}
Expand Down Expand Up @@ -334,7 +334,7 @@ pub mod test_util {
) -> AggregationJobResp {
let mut test_conn = post_aggregation_job(task, aggregation_job_id, request, handler).await;

assert_eq!(test_conn.status(), Some(Status::Ok));
assert_eq!(test_conn.status(), Some(Status::Accepted));
assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE));
decode_response_body::<AggregationJobResp>(&mut test_conn).await
}
Expand Down Expand Up @@ -582,14 +582,14 @@ mod tests {
// Validate response.
assert_eq!(
first_continue_response,
AggregationJobResp::new(
test_case
AggregationJobResp::Finished {
prepare_resps: test_case
.first_continue_request
.prepare_steps()
.iter()
.map(|step| PrepareResp::new(*step.report_id(), PrepareStepResult::Finished))
.collect()
)
}
);

test_case.first_continue_response = Some(first_continue_response);
Expand Down
Loading

0 comments on commit 40b7763

Please sign in to comment.