From 40b7763356be6b42075ca049d19203f60683db5a Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Mon, 25 Nov 2024 16:16:11 -0800 Subject: [PATCH] Message & request changes for async aggregation/collection. Implementation of asynchronous aggregation is forthcoming. With this change, we use the Pending/Finished statuses for aggregation/collection as-specified, but both the Leader & Helper support only synchronous aggregation. --- aggregator/src/aggregator.rs | 100 +++--- .../src/aggregator/aggregate_init_tests.rs | 44 ++- .../aggregator/aggregation_job_continue.rs | 14 +- .../src/aggregator/aggregation_job_driver.rs | 20 +- .../aggregation_job_driver/tests.rs | 138 ++++---- .../src/aggregator/collection_job_tests.rs | 94 ++++-- aggregator/src/aggregator/http_handlers.rs | 74 +++-- .../tests/aggregation_job_continue.rs | 33 +- .../tests/aggregation_job_init.rs | 56 ++-- .../http_handlers/tests/collection_job.rs | 76 +++-- .../http_handlers/tests/helper_e2e.rs | 21 +- aggregator/src/aggregator/taskprov_tests.rs | 72 ++-- collector/src/lib.rs | 308 +++++++++++------- messages/src/batch_mode.rs | 26 +- messages/src/lib.rs | 198 +++++------ messages/src/tests/aggregation.rs | 89 ++--- messages/src/tests/collection.rs | 29 +- tools/src/bin/dap_decode.rs | 18 +- 18 files changed, 844 insertions(+), 566 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 00023e46e..91910b491 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -68,8 +68,8 @@ use janus_messages::{ taskprov::TaskConfig, AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobStep, - BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, ExtensionType, HpkeConfig, - HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, + BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp, Duration, ExtensionType, + HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareResp, PrepareStepResult, Report, ReportError, ReportIdChecksum, ReportShare, Role, TaskId, }; @@ -529,14 +529,14 @@ impl Aggregator { } /// Handle a collection job creation request. Only supported by the leader. `req_bytes` is an - /// encoded [`CollectionReq`]. + /// encoded [`CollectionJobReq`]. Returns an encoded [`CollectionJobResp`] on success. async fn handle_create_collection_job( &self, task_id: &TaskId, collection_job_id: &CollectionJobId, req_bytes: &[u8], auth_token: Option, - ) -> Result<(), Error> { + ) -> Result, Error> { let task_aggregator = self .task_aggregators .get(task_id) @@ -566,7 +566,7 @@ impl Aggregator { task_id: &TaskId, collection_job_id: &CollectionJobId, auth_token: Option, - ) -> Result>, Error> { + ) -> Result, Error> { let task_aggregator = self .task_aggregators .get(task_id) @@ -1034,7 +1034,7 @@ impl TaskAggregator { datastore: &Datastore, collection_job_id: &CollectionJobId, req_bytes: &[u8], - ) -> Result<(), Error> { + ) -> Result, Error> { self.vdaf_ops .handle_create_collection_job( datastore, @@ -1049,7 +1049,7 @@ impl TaskAggregator { &self, datastore: &Datastore, collection_job_id: &CollectionJobId, - ) -> Result>, Error> { + ) -> Result, Error> { self.vdaf_ops .handle_get_collection_job(datastore, Arc::clone(&self.task), collection_job_id) .await @@ -1834,19 +1834,20 @@ impl VdafOps { } // This is a repeated request. Send the same response we computed last time. - return Ok(Some(AggregationJobResp::new( - tx.get_report_aggregations_for_aggregation_job( - vdaf, - &Role::Helper, - task_id, - aggregation_job_id, - ) - .await? - .iter() - .filter_map(ReportAggregation::last_prep_resp) - .cloned() - .collect(), - ))); + return Ok(Some(AggregationJobResp::Finished { + prepare_resps: tx + .get_report_aggregations_for_aggregation_job( + vdaf, + &Role::Helper, + task_id, + aggregation_job_id, + ) + .await? + .iter() + .filter_map(ReportAggregation::last_prep_resp) + .cloned() + .collect(), + })); } Ok(None) @@ -2358,11 +2359,11 @@ impl VdafOps { let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?; Ok(( - AggregationJobResp::new( - prep_resps_by_agg_job + AggregationJobResp::Finished { + prepare_resps: prep_resps_by_agg_job .remove(aggregation_job.id()) .unwrap_or_default(), - ), + }, counters, )) }) @@ -2473,13 +2474,13 @@ impl VdafOps { } } return Ok(( - AggregationJobResp::new( - report_aggregations + AggregationJobResp::Finished { + prepare_resps: report_aggregations .iter() .filter_map(ReportAggregation::last_prep_resp) .cloned() .collect(), - ), + }, TaskAggregationCounter::default(), )); } else if aggregation_job.step().increment() != req.step() { @@ -2574,7 +2575,7 @@ impl VdafOps { task: Arc, collection_job_id: &CollectionJobId, collection_req_bytes: &[u8], - ) -> Result<(), Error> { + ) -> Result, Error> { match task.batch_mode() { task::BatchMode::TimeInterval => { vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { @@ -2612,19 +2613,19 @@ impl VdafOps { vdaf: Arc, collection_job_id: &CollectionJobId, req_bytes: &[u8], - ) -> Result<(), Error> + ) -> Result, Error> where A::AggregationParam: 'static + Send + Sync + PartialEq + Eq + Hash, A::AggregateShare: Send + Sync, { let req = - Arc::new(CollectionReq::::get_decoded(req_bytes).map_err(Error::MessageDecode)?); + Arc::new(CollectionJobReq::::get_decoded(req_bytes).map_err(Error::MessageDecode)?); let aggregation_param = Arc::new( A::AggregationParam::get_decoded(req.aggregation_parameter()) .map_err(Error::MessageDecode)?, ); - Ok(datastore + datastore .run_tx("collect", move |tx| { let (task, vdaf, collection_job_id, req, aggregation_param) = ( Arc::clone(&task), @@ -2715,7 +2716,11 @@ impl VdafOps { Ok(()) }) }) - .await?) + .await?; + + CollectionJobResp::::Processing + .get_encoded() + .map_err(Error::MessageEncode) } /// Handle GET requests to the leader's `tasks/{task-id}/collection_jobs/{collection-job-id}` @@ -2727,7 +2732,7 @@ impl VdafOps { datastore: &Datastore, task: Arc, collection_job_id: &CollectionJobId, - ) -> Result>, Error> { + ) -> Result, Error> { match task.batch_mode() { task::BatchMode::TimeInterval => { vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { @@ -2765,7 +2770,7 @@ impl VdafOps { task: Arc, vdaf: Arc, collection_job_id: &CollectionJobId, - ) -> Result>, Error> + ) -> Result, Error> where A::AggregationParam: Send + Sync, A::AggregateShare: Send + Sync, @@ -2790,7 +2795,7 @@ impl VdafOps { match collection_job.state() { CollectionJobState::Start => { debug!(%collection_job_id, task_id = %task.id(), "collection job has not run yet"); - Ok(None) + Ok(CollectionJobResp::::Processing) } CollectionJobState::Finished { @@ -2839,19 +2844,15 @@ impl VdafOps { .map_err(Error::MessageEncode)?, )?; - Ok(Some( - Collection::::new( - PartialBatchSelector::new( - B::partial_batch_identifier(collection_job.batch_identifier()).clone(), - ), - *report_count, - *client_timestamp_interval, - encrypted_leader_aggregate_share, - encrypted_helper_aggregate_share.clone(), - ) - .get_encoded() - .map_err(Error::MessageEncode)?, - )) + Ok(CollectionJobResp::::Finished { + partial_batch_selector: PartialBatchSelector::new( + B::partial_batch_identifier(collection_job.batch_identifier()).clone(), + ), + report_count: *report_count, + interval: *client_timestamp_interval, + leader_encrypted_agg_share: encrypted_leader_aggregate_share, + helper_encrypted_agg_share: encrypted_helper_aggregate_share.clone(), + }) } CollectionJobState::Abandoned => Err(Error::AbandonedCollectionJob( @@ -2863,6 +2864,11 @@ impl VdafOps { Err(Error::DeletedCollectionJob(*task.id(), *collection_job_id)) } } + .and_then(|collection_job_resp| { + collection_job_resp + .get_encoded() + .map_err(Error::MessageEncode) + }) } #[tracing::instrument(skip(self, datastore, task), fields(task_id = ?task.id()), err(level = Level::DEBUG))] diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 0f66ec5a0..85b44b414 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -212,19 +212,23 @@ async fn setup_aggregate_init_test_for_vdaf< &test_case.handler, ) .await; - assert_eq!(response.status(), Some(Status::Ok)); + assert_eq!(response.status(), Some(Status::Created)); - let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await; + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + let prepare_resps = assert_matches!( + &aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); assert_eq!( - aggregation_job_init_resp.prepare_resps().len(), + prepare_resps.len(), test_case.aggregation_job_init_req.prepare_inits().len(), ); assert_matches!( - aggregation_job_init_resp.prepare_resps()[0].result(), + prepare_resps[0].result(), &PrepareStepResult::Continue { .. } ); - test_case.aggregation_job_init_resp = Some(aggregation_job_init_resp); + test_case.aggregation_job_init_resp = Some(aggregation_job_resp); test_case } @@ -345,7 +349,7 @@ async fn aggregation_job_init_authorization_dap_auth_token() { .run_async(&test_case.handler) .await; - assert_eq!(response.status(), Some(Status::Ok)); + assert_eq!(response.status(), Some(Status::Created)); } #[rstest::rstest] @@ -420,12 +424,14 @@ async fn aggregation_job_init_unexpected_taskprov_extension() { &test_case.handler, ) .await; - assert_eq!(response.status(), Some(Status::Ok)); - - let want_aggregation_job_resp = AggregationJobResp::new(Vec::from([PrepareResp::new( - report_id, - PrepareStepResult::Reject(ReportError::InvalidMessage), - )])); + assert_eq!(response.status(), Some(Status::Created)); + + let want_aggregation_job_resp = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + report_id, + PrepareStepResult::Reject(ReportError::InvalidMessage), + )]), + }; let got_aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; assert_eq!(want_aggregation_job_resp, got_aggregation_job_resp); } @@ -589,19 +595,23 @@ async fn aggregation_job_intolerable_clock_skew() { &test_case.handler, ) .await; - assert_eq!(response.status(), Some(Status::Ok)); + assert_eq!(response.status(), Some(Status::Created)); - let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await; + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + let prepare_resps = assert_matches!( + aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); assert_eq!( - aggregation_job_init_resp.prepare_resps().len(), + prepare_resps.len(), test_case.aggregation_job_init_req.prepare_inits().len(), ); assert_matches!( - aggregation_job_init_resp.prepare_resps()[0].result(), + prepare_resps[0].result(), &PrepareStepResult::Continue { .. } ); assert_matches!( - aggregation_job_init_resp.prepare_resps()[1].result(), + prepare_resps[1].result(), &PrepareStepResult::Reject(ReportError::ReportTooEarly) ); } diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index eff2fa4ee..c9ae5089d 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -286,11 +286,11 @@ impl VdafOps { aggregation_job_writer.put(aggregation_job, report_aggregations_to_write)?; let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?; Ok(( - AggregationJobResp::new( - prep_resps_by_agg_job + AggregationJobResp::Finished { + prepare_resps: prep_resps_by_agg_job .remove(&aggregation_job_id) .unwrap_or_default(), - ), + }, counters, )) } @@ -334,7 +334,7 @@ pub mod test_util { ) -> AggregationJobResp { let mut test_conn = post_aggregation_job(task, aggregation_job_id, request, handler).await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Accepted)); assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); decode_response_body::(&mut test_conn).await } @@ -582,14 +582,14 @@ mod tests { // Validate response. assert_eq!( first_continue_response, - AggregationJobResp::new( - test_case + AggregationJobResp::Finished { + prepare_resps: test_case .first_continue_request .prepare_steps() .iter() .map(|step| PrepareResp::new(*step.report_id(), PrepareStepResult::Finished)) .collect() - ) + } ); test_case.first_continue_response = Some(first_continue_response); diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 3a3d74624..ff042d8ea 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -543,7 +543,9 @@ where // If there are no prepare inits to send (because every report aggregation was filtered // by the block above), don't send a request to the Helper at all and process an // artificial aggregation job response instead, which will finish the aggregation job. - AggregationJobResp::new(Vec::new()) + AggregationJobResp::Finished { + prepare_resps: Vec::new(), + } }; let aggregation_job = Arc::unwrap_or_clone(aggregation_job); @@ -770,17 +772,27 @@ where A::PrepareState: Send + Sync + Encode, A::PublicShare: Send + Sync, { + let prepare_resps = match helper_resp { + // TODO(#3436): implement asynchronous aggregation + AggregationJobResp::Processing => { + return Err(Error::Internal( + "asynchronous aggregation not yet implemented".into(), + )) + } + AggregationJobResp::Finished { prepare_resps } => prepare_resps, + }; + // Handle response, computing the new report aggregations to be stored. let expected_report_aggregation_count = report_aggregations_to_write.len() + stepped_aggregations.len(); - if stepped_aggregations.len() != helper_resp.prepare_resps().len() { + if stepped_aggregations.len() != prepare_resps.len() { return Err(Error::Internal( "missing, duplicate, out-of-order, or unexpected prepare steps in response" .to_string(), )); } for (stepped_aggregation, helper_prep_resp) in - stepped_aggregations.iter().zip(helper_resp.prepare_resps()) + stepped_aggregations.iter().zip(&prepare_resps) { if stepped_aggregation.report_aggregation.report_id() != helper_prep_resp.report_id() { return Err(Error::Internal( @@ -810,7 +822,7 @@ where ); let ctx = vdaf_application_context(&task_id); - stepped_aggregations.into_par_iter().zip(helper_resp.prepare_resps()).try_for_each( + stepped_aggregations.into_par_iter().zip(prepare_resps).try_for_each( |(stepped_aggregation, helper_prep_resp)| { let _entered = span.enter(); diff --git a/aggregator/src/aggregator/aggregation_job_driver/tests.rs b/aggregator/src/aggregator/aggregation_job_driver/tests.rs index 23c4b8639..b6e01f067 100644 --- a/aggregator/src/aggregator/aggregation_job_driver/tests.rs +++ b/aggregator/src/aggregator/aggregation_job_driver/tests.rs @@ -168,12 +168,14 @@ async fn aggregation_job_driver() { "PUT", AggregationJobInitializeReq::::MEDIA_TYPE, AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Continue { - message: transcript.helper_prepare_transitions[0].message.clone(), - }, - )])) + AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + } .get_encoded() .unwrap(), ), @@ -181,10 +183,12 @@ async fn aggregation_job_driver() { "POST", AggregationJobContinueReq::MEDIA_TYPE, AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Finished, - )])) + AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Finished, + )]), + } .get_encoded() .unwrap(), ), @@ -526,12 +530,14 @@ async fn step_time_interval_aggregation_job_init_single_step() { transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Continue { - message: transcript.helper_prepare_transitions[0].message.clone(), - }, - )])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + }; let mocked_aggregate_failure = server .mock( "PUT", @@ -885,12 +891,14 @@ async fn step_time_interval_aggregation_job_init_two_steps() { transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Continue { - message: transcript.helper_prepare_transitions[0].message.clone(), - }, - )])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + }; let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( @@ -1229,24 +1237,26 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { ), ]), ); - let helper_response = AggregationJobResp::new(Vec::from([ - PrepareResp::new( - *gc_eligible_report.metadata().id(), - PrepareStepResult::Continue { - message: gc_eligible_transcript.helper_prepare_transitions[0] - .message - .clone(), - }, - ), - PrepareResp::new( - *gc_ineligible_report.metadata().id(), - PrepareStepResult::Continue { - message: gc_ineligible_transcript.helper_prepare_transitions[0] - .message - .clone(), - }, - ), - ])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([ + PrepareResp::new( + *gc_eligible_report.metadata().id(), + PrepareStepResult::Continue { + message: gc_eligible_transcript.helper_prepare_transitions[0] + .message + .clone(), + }, + ), + PrepareResp::new( + *gc_ineligible_report.metadata().id(), + PrepareStepResult::Continue { + message: gc_ineligible_transcript.helper_prepare_transitions[0] + .message + .clone(), + }, + ), + ]), + }; let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_init = server .mock( @@ -1515,12 +1525,14 @@ async fn step_leader_selected_aggregation_job_init_single_step() { transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Continue { - message: transcript.helper_prepare_transitions[0].message.clone(), - }, - )])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + }; let mocked_aggregate_failure = server .mock( "PUT", @@ -1797,12 +1809,14 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Continue { - message: transcript.helper_prepare_transitions[0].message.clone(), - }, - )])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + }; let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( @@ -2083,10 +2097,12 @@ async fn step_time_interval_aggregation_job_continue() { transcript.leader_prepare_transitions[1].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Finished, - )])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Finished, + )]), + }; let mocked_aggregate_failure = server .mock( "POST", @@ -2386,10 +2402,12 @@ async fn step_leader_selected_aggregation_job_continue() { transcript.leader_prepare_transitions[1].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( - *report.metadata().id(), - PrepareStepResult::Finished, - )])); + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Finished, + )]), + }; let mocked_aggregate_failure = server .mock( "POST", diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 87c46d060..4dd49520a 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -3,6 +3,7 @@ use crate::aggregator::{ test_util::BATCH_AGGREGATION_SHARD_COUNT, Config, }; +use assert_matches::assert_matches; use http::StatusCode; use janus_aggregator_core::{ datastore::{ @@ -28,8 +29,8 @@ use janus_core::{ }; use janus_messages::{ batch_mode::{BatchMode as BatchModeTrait, LeaderSelected, TimeInterval}, - AggregateShareAad, AggregationJobStep, BatchId, BatchSelector, Collection, CollectionJobId, - CollectionReq, Interval, Query, Role, Time, + AggregateShareAad, AggregationJobStep, BatchId, BatchSelector, CollectionJobId, + CollectionJobReq, CollectionJobResp, Interval, Query, Role, Time, }; use prio::{ codec::{Decode, Encode}, @@ -59,7 +60,7 @@ impl CollectionJobTestCase { pub(super) async fn put_collection_job_with_auth_token( &self, collection_job_id: &CollectionJobId, - request: &CollectionReq, + request: &CollectionJobReq, auth_token: Option<&AuthenticationToken>, ) -> TestConn { let mut test_conn = put(self @@ -75,7 +76,7 @@ impl CollectionJobTestCase { test_conn .with_request_header( KnownHeaderName::ContentType, - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_request_body(request.get_encoded().unwrap()) .run_async(&self.handler) @@ -85,7 +86,7 @@ impl CollectionJobTestCase { pub(super) async fn put_collection_job( &self, collection_job_id: &CollectionJobId, - request: &CollectionReq, + request: &CollectionJobReq, ) -> TestConn { self.put_collection_job_with_auth_token( collection_job_id, @@ -329,7 +330,7 @@ async fn collection_job_success_leader_selected() { let leader_aggregate_share = dummy::AggregateShare(0); let helper_aggregate_share = dummy::AggregateShare(1); let aggregation_param = dummy::AggregationParam::default(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_leader_selected(), aggregation_param.get_encoded().unwrap(), ); @@ -337,13 +338,27 @@ async fn collection_job_success_leader_selected() { for _ in 0..2 { let collection_job_id: CollectionJobId = random(); - let test_conn = test_case + let mut test_conn = test_case .put_collection_job(&collection_job_id, &request) .await; assert_eq!(test_conn.status(), Some(Status::Created)); + assert_headers!( + &test_conn, + "content-type" => (CollectionJobResp::::MEDIA_TYPE) + ); + let collect_resp: CollectionJobResp = + decode_response_body(&mut test_conn).await; + assert_eq!(collect_resp, CollectionJobResp::::Processing); - let test_conn = test_case.get_collection_job(&collection_job_id).await; - assert_eq!(test_conn.status(), Some(Status::Accepted)); + let mut test_conn = test_case.get_collection_job(&collection_job_id).await; + assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_headers!( + &test_conn, + "content-type" => (CollectionJobResp::::MEDIA_TYPE) + ); + let collect_resp: CollectionJobResp = + decode_response_body(&mut test_conn).await; + assert_eq!(collect_resp, CollectionJobResp::::Processing); // Update the collection job with the aggregate shares. collection job should now be complete. let batch_id = test_case @@ -408,19 +423,38 @@ async fn collection_job_success_leader_selected() { } let mut test_conn = test_case.get_collection_job(&collection_job_id).await; - assert_headers!(&test_conn, "content-type" => (Collection::::MEDIA_TYPE)); - - let collect_resp: Collection = decode_response_body(&mut test_conn).await; - assert_eq!( - collect_resp.report_count(), - test_case.task.min_batch_size() + 1 + assert_headers!(&test_conn, "content-type" => (CollectionJobResp::::MEDIA_TYPE)); + + let collect_resp: CollectionJobResp = + decode_response_body(&mut test_conn).await; + let ( + report_count, + interval, + leader_encrypted_aggregate_share, + helper_encrypted_aggregate_share, + ) = assert_matches!( + collect_resp, + CollectionJobResp::Finished { + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + .. + } => ( + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share + ) ); - assert_eq!(collect_resp.interval(), &spanned_interval); + + assert_eq!(report_count, test_case.task.min_batch_size() + 1); + assert_eq!(interval, spanned_interval); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_keypair(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - collect_resp.leader_encrypted_aggregate_share(), + &leader_encrypted_aggregate_share, &AggregateShareAad::new( *test_case.task.id(), aggregation_param.get_encoded().unwrap(), @@ -438,7 +472,7 @@ async fn collection_job_success_leader_selected() { let decrypted_helper_aggregate_share = hpke::open( test_case.task.collector_hpke_keypair(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - collect_resp.helper_encrypted_aggregate_share(), + &helper_encrypted_aggregate_share, &AggregateShareAad::new( *test_case.task.id(), aggregation_param.get_encoded().unwrap(), @@ -483,7 +517,7 @@ async fn collection_job_put_idempotence_time_interval() { .await; let collection_job_id = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -535,7 +569,7 @@ async fn collection_job_put_idempotence_time_interval_varied_collection_id() { .await; let collection_job_ids = HashSet::from(random::<[CollectionJobId; 2]>()); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -592,7 +626,7 @@ async fn collection_job_put_idempotence_time_interval_mutate_time_interval() { .await; let collection_job_id = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -608,7 +642,7 @@ async fn collection_job_put_idempotence_time_interval_mutate_time_interval() { .await; assert_eq!(response.status(), Some(Status::Created)); - let mutated_request = CollectionReq::new( + let mutated_request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(test_case.task.time_precision().as_seconds()), @@ -633,7 +667,7 @@ async fn collection_job_put_idempotence_time_interval_mutate_aggregation_param() .await; let collection_job_id = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -649,7 +683,7 @@ async fn collection_job_put_idempotence_time_interval_mutate_aggregation_param() .await; assert_eq!(response.status(), Some(Status::Created)); - let mutated_request = CollectionReq::new( + let mutated_request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -672,7 +706,7 @@ async fn collection_job_put_idempotence_leader_selected() { setup_leader_selected_current_batch_collection_job_test_case().await; let collection_job_id = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_leader_selected(), dummy::AggregationParam(0).get_encoded().unwrap(), ); @@ -723,7 +757,7 @@ async fn collection_job_put_idempotence_leader_selected_mutate_aggregation_param let (test_case, _, _, _) = setup_leader_selected_current_batch_collection_job_test_case().await; let collection_job_id = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_leader_selected(), dummy::AggregationParam(0).get_encoded().unwrap(), ); @@ -734,7 +768,7 @@ async fn collection_job_put_idempotence_leader_selected_mutate_aggregation_param assert_eq!(response.status(), Some(Status::Created)); - let mutated_request = CollectionReq::new( + let mutated_request = CollectionJobReq::new( Query::new_leader_selected(), dummy::AggregationParam(1).get_encoded().unwrap(), ); @@ -752,7 +786,7 @@ async fn collection_job_put_idempotence_leader_selected_no_extra_reports() { let collection_job_id_1 = random(); let collection_job_id_2 = random(); - let request: Arc> = Arc::new(CollectionReq::new( + let request: Arc> = Arc::new(CollectionJobReq::new( Query::new_leader_selected(), dummy::AggregationParam(0).get_encoded().unwrap(), )); @@ -765,7 +799,7 @@ async fn collection_job_put_idempotence_leader_selected_no_extra_reports() { // Fetch the first collection job, to advance the current batch. let response = test_case.get_collection_job(&collection_job_id_1).await; - assert_eq!(response.status(), Some(Status::Accepted)); + assert_eq!(response.status(), Some(Status::Ok)); // Create the second collection job. let response = test_case @@ -776,7 +810,7 @@ async fn collection_job_put_idempotence_leader_selected_no_extra_reports() { // Fetch the second collection job, to advance the current batch. There are now no outstanding // batches left. let response = test_case.get_collection_job(&collection_job_id_2).await; - assert_eq!(response.status(), Some(Status::Accepted)); + assert_eq!(response.status(), Some(Status::Ok)); // Re-send the collection job creation requests to confirm they are still idempotent. let response = test_case diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 4aeb1efad..ac3ebcce7 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -17,8 +17,8 @@ use janus_core::{ use janus_messages::{ batch_mode::TimeInterval, codec::Decode, problem_type::DapProblemType, taskprov::TaskConfig, AggregateShare, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, - AggregationJobInitializeReq, AggregationJobResp, Collection, CollectionJobId, CollectionReq, - HpkeConfigList, Report, TaskId, + AggregationJobInitializeReq, AggregationJobResp, CollectionJobId, CollectionJobReq, + CollectionJobResp, HpkeConfigList, Report, TaskId, }; use opentelemetry::{ metrics::{Counter, Meter}, @@ -199,10 +199,11 @@ const CORS_PREFLIGHT_CACHE_AGE: u32 = 24 * 60 * 60; /// Wrapper around a type that implements [`Encode`]. It acts as a Trillium handler, encoding the /// inner object and sending it as the response body, setting the Content-Type header to the -/// provided media type, and setting the status to 200. +/// provided media type, and setting the status to the specified value (or 200 if unspecified). struct EncodedBody { object: T, media_type: &'static str, + status: Status, } impl EncodedBody @@ -210,7 +211,15 @@ where T: Encode, { fn new(object: T, media_type: &'static str) -> Self { - Self { object, media_type } + Self { + object, + media_type, + status: Status::Ok, + } + } + + fn with_status(self, status: Status) -> Self { + Self { status, ..self } } } @@ -223,7 +232,9 @@ where match self.object.get_encoded() { Ok(encoded) => conn .with_response_header(KnownHeaderName::ContentType, self.media_type) - .ok(encoded), + .with_status(self.status) + .with_body(encoded) + .halt(), Err(e) => Error::MessageEncode(e).run(conn).await, } } @@ -545,7 +556,7 @@ async fn aggregation_jobs_put( .await .ok_or(Error::ClientDisconnected)??; - Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE)) + Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE).with_status(Status::Created)) } /// API handler for the "/tasks/.../aggregation_jobs/..." POST endpoint. @@ -570,7 +581,7 @@ async fn aggregation_jobs_post( .await .ok_or(Error::ClientDisconnected)??; - Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE)) + Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE).with_status(Status::Accepted)) } /// API handler for the "/tasks/.../aggregation_jobs/..." DELETE endpoint. @@ -598,22 +609,29 @@ async fn aggregation_jobs_delete( async fn collection_jobs_put( conn: &mut Conn, (State(aggregator), BodyBytes(body)): (State>>, BodyBytes), -) -> Result { - validate_content_type(conn, CollectionReq::::MEDIA_TYPE)?; +) -> Result<(), Error> { + validate_content_type(conn, CollectionJobReq::::MEDIA_TYPE)?; let task_id = parse_task_id(conn)?; let collection_job_id = parse_collection_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; - conn.cancel_on_disconnect(aggregator.handle_create_collection_job( - &task_id, - &collection_job_id, - &body, - auth_token, - )) - .await - .ok_or(Error::ClientDisconnected)??; + let response_bytes = conn + .cancel_on_disconnect(aggregator.handle_create_collection_job( + &task_id, + &collection_job_id, + &body, + auth_token, + )) + .await + .ok_or(Error::ClientDisconnected)??; - Ok(Status::Created) + conn.response_headers_mut().insert( + KnownHeaderName::ContentType, + CollectionJobResp::::MEDIA_TYPE, + ); + conn.set_status(Status::Created); + conn.set_body(response_bytes); + Ok(()) } /// API handler for the "/tasks/.../collection_jobs/..." GET endpoint. @@ -624,7 +642,7 @@ async fn collection_jobs_get( let task_id = parse_task_id(conn)?; let collection_job_id = parse_collection_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; - let response_opt = conn + let response_bytes = conn .cancel_on_disconnect(aggregator.handle_get_collection_job( &task_id, &collection_job_id, @@ -632,17 +650,13 @@ async fn collection_jobs_get( )) .await .ok_or(Error::ClientDisconnected)??; - match response_opt { - Some(response_bytes) => { - conn.response_headers_mut().insert( - KnownHeaderName::ContentType, - Collection::::MEDIA_TYPE, - ); - conn.set_status(Status::Ok); - conn.set_body(response_bytes); - } - None => conn.set_status(Status::Accepted), - } + + conn.response_headers_mut().insert( + KnownHeaderName::ContentType, + CollectionJobResp::::MEDIA_TYPE, + ); + conn.set_status(Status::Ok); + conn.set_body(response_bytes); Ok(()) } diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs index 500c5a044..9d9fcc328 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs @@ -9,6 +9,7 @@ use crate::aggregator::{ BATCH_AGGREGATION_SHARD_COUNT, }, }; +use assert_matches::assert_matches; use futures::future::try_join_all; use janus_aggregator_core::{ batch_mode::CollectableBatchMode, @@ -268,13 +269,15 @@ async fn aggregate_continue() { // Validate response. assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([ - PrepareResp::new(*report_metadata_0.id(), PrepareStepResult::Finished), - PrepareResp::new( - *report_metadata_2.id(), - PrepareStepResult::Reject(ReportError::BatchCollected), - ) - ])) + AggregationJobResp::Finished { + prepare_resps: Vec::from([ + PrepareResp::new(*report_metadata_0.id(), PrepareStepResult::Finished), + PrepareResp::new( + *report_metadata_2.id(), + PrepareStepResult::Reject(ReportError::BatchCollected), + ) + ]) + } ); // Validate datastore. @@ -1133,9 +1136,11 @@ async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { let resp = post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; - assert_eq!(resp.prepare_resps().len(), 1); + let prepare_resps = + assert_matches!(resp, AggregationJobResp::Finished{prepare_resps} => prepare_resps); + assert_eq!(prepare_resps.len(), 1); assert_eq!( - resp.prepare_resps()[0], + prepare_resps[0], PrepareResp::new( *report_metadata.id(), PrepareStepResult::Reject(ReportError::VdafPrepError), @@ -1247,10 +1252,12 @@ async fn aggregate_continue_prep_step_fails() { post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([PrepareResp::new( - *report_metadata.id(), - PrepareStepResult::Reject(ReportError::VdafPrepError), - )]),) + AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Reject(ReportError::VdafPrepError), + )]) + } ); // Check datastore state. diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs index 858bf7f95..05ddc0074 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs @@ -533,17 +533,21 @@ async fn aggregate_init() { for _ in 0..2 { let mut test_conn = put_aggregation_job(&task, &aggregation_job_id, &request, &handler).await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!( &test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE) ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); // Validate response. - assert_eq!(aggregate_resp.prepare_resps().len(), 10); + assert_eq!(prepare_resps.len(), 10); - let prepare_step_0 = aggregate_resp.prepare_resps().first().unwrap(); + let prepare_step_0 = prepare_resps.first().unwrap(); assert_eq!( prepare_step_0.report_id(), prepare_init_0.report_share().metadata().id() @@ -552,7 +556,7 @@ async fn aggregate_init() { assert_eq!(message, &transcript_0.helper_prepare_transitions[0].message); }); - let prepare_step_1 = aggregate_resp.prepare_resps().get(1).unwrap(); + let prepare_step_1 = prepare_resps.get(1).unwrap(); assert_eq!( prepare_step_1.report_id(), prepare_init_1.report_share().metadata().id() @@ -562,7 +566,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::HpkeDecryptError) ); - let prepare_step_2 = aggregate_resp.prepare_resps().get(2).unwrap(); + let prepare_step_2 = prepare_resps.get(2).unwrap(); assert_eq!( prepare_step_2.report_id(), prepare_init_2.report_share().metadata().id() @@ -572,7 +576,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::InvalidMessage) ); - let prepare_step_3 = aggregate_resp.prepare_resps().get(3).unwrap(); + let prepare_step_3 = prepare_resps.get(3).unwrap(); assert_eq!( prepare_step_3.report_id(), prepare_init_3.report_share().metadata().id() @@ -582,7 +586,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::HpkeUnknownConfigId) ); - let prepare_step_4 = aggregate_resp.prepare_resps().get(4).unwrap(); + let prepare_step_4 = prepare_resps.get(4).unwrap(); assert_eq!( prepare_step_4.report_id(), prepare_init_4.report_share().metadata().id() @@ -592,7 +596,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::ReportReplayed) ); - let prepare_step_5 = aggregate_resp.prepare_resps().get(5).unwrap(); + let prepare_step_5 = prepare_resps.get(5).unwrap(); assert_eq!( prepare_step_5.report_id(), prepare_init_5.report_share().metadata().id() @@ -602,7 +606,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::BatchCollected) ); - let prepare_step_6 = aggregate_resp.prepare_resps().get(6).unwrap(); + let prepare_step_6 = prepare_resps.get(6).unwrap(); assert_eq!( prepare_step_6.report_id(), prepare_init_6.report_share().metadata().id() @@ -612,7 +616,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::InvalidMessage), ); - let prepare_step_7 = aggregate_resp.prepare_resps().get(7).unwrap(); + let prepare_step_7 = prepare_resps.get(7).unwrap(); assert_eq!( prepare_step_7.report_id(), prepare_init_7.report_share().metadata().id() @@ -622,7 +626,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::InvalidMessage), ); - let prepare_step_8 = aggregate_resp.prepare_resps().get(8).unwrap(); + let prepare_step_8 = prepare_resps.get(8).unwrap(); assert_eq!( prepare_step_8.report_id(), prepare_init_8.report_share().metadata().id() @@ -632,7 +636,7 @@ async fn aggregate_init() { &PrepareStepResult::Reject(ReportError::InvalidMessage), ); - let prepare_step_9 = aggregate_resp.prepare_resps().get(9).unwrap(); + let prepare_step_9 = prepare_resps.get(9).unwrap(); assert_eq!( prepare_step_9.report_id(), prepare_init_9.report_share().metadata().id() @@ -787,10 +791,14 @@ async fn aggregate_init_batch_already_collected() { .run_async(&handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - let prepare_step = aggregate_resp.prepare_resps().first().unwrap(); + let prepare_step = prepare_resps.first().unwrap(); assert_eq!( prepare_step.report_id(), prepare_init.report_share().metadata().id() @@ -841,17 +849,21 @@ async fn aggregate_init_prep_init_failed() { // Send request, and parse response. let aggregation_job_id: AggregationJobId = random(); let mut test_conn = put_aggregation_job(&task, &aggregation_job_id, &request, &handler).await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!( &test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE) ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); // Validate response. - assert_eq!(aggregate_resp.prepare_resps().len(), 1); + assert_eq!(prepare_resps.len(), 1); - let prepare_step = aggregate_resp.prepare_resps().first().unwrap(); + let prepare_step = prepare_resps.first().unwrap(); assert_eq!( prepare_step.report_id(), prepare_init.report_share().metadata().id() @@ -901,17 +913,21 @@ async fn aggregate_init_prep_step_failed() { let aggregation_job_id: AggregationJobId = random(); let mut test_conn = put_aggregation_job(&task, &aggregation_job_id, &request, &handler).await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!( &test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE) ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); // Validate response. - assert_eq!(aggregate_resp.prepare_resps().len(), 1); + assert_eq!(prepare_resps.len(), 1); - let prepare_step = aggregate_resp.prepare_resps().first().unwrap(); + let prepare_step = prepare_resps.first().unwrap(); assert_eq!( prepare_step.report_id(), prepare_init.report_share().metadata().id() diff --git a/aggregator/src/aggregator/http_handlers/tests/collection_job.rs b/aggregator/src/aggregator/http_handlers/tests/collection_job.rs index 75458fe50..d7c7971be 100644 --- a/aggregator/src/aggregator/http_handlers/tests/collection_job.rs +++ b/aggregator/src/aggregator/http_handlers/tests/collection_job.rs @@ -2,6 +2,7 @@ use crate::aggregator::{ collection_job_tests::setup_collection_job_test_case, http_handlers::test_util::{decode_response_body, take_problem_details, HttpHandlerTest}, }; +use assert_matches::assert_matches; use janus_aggregator_core::{ batch_mode::AccumulableBatchMode, datastore::models::{CollectionJob, CollectionJobState}, @@ -12,8 +13,8 @@ use janus_core::{ vdaf::VdafInstance, }; use janus_messages::{ - batch_mode::TimeInterval, AggregateShareAad, BatchSelector, Collection, CollectionJobId, - CollectionReq, Duration, Interval, Query, Role, Time, + batch_mode::TimeInterval, AggregateShareAad, BatchSelector, CollectionJobId, CollectionJobReq, + CollectionJobResp, Duration, Interval, Query, Role, Time, }; use prio::{ codec::{Decode, Encode}, @@ -35,7 +36,7 @@ async fn collection_job_put_request_to_helper() { .await; let collection_job_id: CollectionJobId = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -70,7 +71,7 @@ async fn collection_job_put_request_invalid_batch_interval() { .await; let collection_job_id: CollectionJobId = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -106,7 +107,7 @@ async fn collection_job_put_request_invalid_aggregation_parameter() { .await; let collection_job_id: CollectionJobId = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -152,7 +153,7 @@ async fn collection_job_put_request_invalid_batch_size() { datastore.put_aggregator_task(&leader_task).await.unwrap(); let collection_job_id: CollectionJobId = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -168,7 +169,7 @@ async fn collection_job_put_request_invalid_batch_size() { .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_request_body(request.get_encoded().unwrap()) .run_async(&handler) @@ -200,7 +201,7 @@ async fn collection_job_put_request_unauthenticated() { ) .unwrap(); let collection_job_id: CollectionJobId = random(); - let req = CollectionReq::new( + let req = CollectionJobReq::new( Query::new_time_interval(batch_interval), dummy::AggregationParam::default().get_encoded().unwrap(), ); @@ -275,7 +276,7 @@ async fn collection_job_get_request_unauthenticated_collection_jobs() { .unwrap(); let collection_job_id: CollectionJobId = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval(batch_interval), dummy::AggregationParam::default().get_encoded().unwrap(), ); @@ -360,12 +361,12 @@ async fn collection_job_success_time_interval() { let helper_aggregate_share = dummy::AggregateShare(1); let collection_job_id: CollectionJobId = random(); - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval(batch_interval), aggregation_param.get_encoded().unwrap(), ); - let test_conn = test_case + let mut test_conn = test_case .put_collection_job(&collection_job_id, &request) .await; @@ -397,9 +398,21 @@ async fn collection_job_success_time_interval() { assert_eq!(want_collection_job, got_collection_job); assert_eq!(test_conn.status(), Some(Status::Created)); + assert_headers!( + &test_conn, + "content-type" => (CollectionJobResp::::MEDIA_TYPE) + ); + let collect_resp: CollectionJobResp = decode_response_body(&mut test_conn).await; + assert_eq!(collect_resp, CollectionJobResp::::Processing); - let test_conn = test_case.get_collection_job(&collection_job_id).await; - assert_eq!(test_conn.status(), Some(Status::Accepted)); + let mut test_conn = test_case.get_collection_job(&collection_job_id).await; + assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_headers!( + &test_conn, + "content-type" => (CollectionJobResp::::MEDIA_TYPE) + ); + let collect_resp: CollectionJobResp = decode_response_body(&mut test_conn).await; + assert_eq!(collect_resp, CollectionJobResp::::Processing); // Update the collection job with the aggregate shares and some aggregation jobs. collection // job should now be complete. @@ -457,17 +470,32 @@ async fn collection_job_success_time_interval() { assert_eq!(test_conn.status(), Some(Status::Ok)); assert_headers!( &test_conn, - "content-type" => (Collection::::MEDIA_TYPE) + "content-type" => (CollectionJobResp::::MEDIA_TYPE) + ); + let collect_resp: CollectionJobResp = decode_response_body(&mut test_conn).await; + let ( + report_count, + interval, + leader_encrypted_aggregate_share, + helper_encrypted_aggregate_share, + ) = assert_matches!( + collect_resp, + CollectionJobResp::Finished{ + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + .. + } => (report_count, interval, leader_encrypted_agg_share, helper_encrypted_agg_share) ); - let collect_resp: Collection = decode_response_body(&mut test_conn).await; - assert_eq!(collect_resp.report_count(), 12); - assert_eq!(collect_resp.interval(), &batch_interval); + assert_eq!(report_count, 12); + assert_eq!(interval, batch_interval); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_keypair(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - collect_resp.leader_encrypted_aggregate_share(), + &leader_encrypted_aggregate_share, &AggregateShareAad::new( *test_case.task.id(), aggregation_param.get_encoded().unwrap(), @@ -485,7 +513,7 @@ async fn collection_job_success_time_interval() { let decrypted_helper_aggregate_share = hpke::open( test_case.task.collector_hpke_keypair(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - collect_resp.helper_encrypted_aggregate_share(), + &helper_encrypted_aggregate_share, &AggregateShareAad::new( *test_case.task.id(), aggregation_param.get_encoded().unwrap(), @@ -532,7 +560,7 @@ async fn collection_job_put_request_batch_queried_multiple_times() { .await; // Sending this request will consume a query for [0, time_precision). - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval(interval), dummy::AggregationParam(0).get_encoded().unwrap(), ); @@ -542,7 +570,7 @@ async fn collection_job_put_request_batch_queried_multiple_times() { assert_eq!(test_conn.status(), Some(Status::Created)); // This request will not be allowed due to the query count already being consumed. - let invalid_request = CollectionReq::new( + let invalid_request = CollectionJobReq::new( Query::new_time_interval(interval), dummy::AggregationParam(1).get_encoded().unwrap(), ); @@ -570,7 +598,7 @@ async fn collection_job_put_request_batch_overlap() { .await; // Sending this request will consume a query for [0, 2 * time_precision). - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -586,7 +614,7 @@ async fn collection_job_put_request_batch_overlap() { assert_eq!(test_conn.status(), Some(Status::Created)); // This request will not be allowed due to overlapping with the previous request. - let invalid_request = CollectionReq::new( + let invalid_request = CollectionJobReq::new( Query::new_time_interval(interval), dummy::AggregationParam(1).get_encoded().unwrap(), ); @@ -634,7 +662,7 @@ async fn delete_collection_job() { assert_eq!(test_conn.status(), Some(Status::NotFound)); // Create a collection job - let request = CollectionReq::new( + let request = CollectionJobReq::new( Query::new_time_interval(batch_interval), dummy::AggregationParam::default().get_encoded().unwrap(), ); diff --git a/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs b/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs index d8b10a4c1..1e5799e30 100644 --- a/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs +++ b/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs @@ -10,6 +10,7 @@ use prio::{ vdaf::dummy, }; use rand::random; +use trillium::Status; use trillium_testing::assert_status; use crate::aggregator::{ @@ -94,29 +95,37 @@ async fn helper_aggregation_report_share_replay() { // Make aggregation job initialization requests, and check the prepare step results. let mut test_conn = put_aggregation_job(&task, &aggregation_job_id_1, &agg_init_req_1, &handler).await; - assert_status!(test_conn, 200); + assert_status!(test_conn, Status::Created); let agg_init_resp_1 = AggregationJobResp::get_decoded(take_response_body(&mut test_conn).await.as_ref()).unwrap(); + let prepare_resps_1 = assert_matches!( + agg_init_resp_1, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); assert_matches!( - agg_init_resp_1.prepare_resps()[0].result(), + prepare_resps_1[0].result(), PrepareStepResult::Continue { .. } ); assert_matches!( - agg_init_resp_1.prepare_resps()[1].result(), + prepare_resps_1[1].result(), PrepareStepResult::Continue { .. } ); let mut test_conn = put_aggregation_job(&task, &aggregation_job_id_2, &agg_init_req_2, &handler).await; - assert_status!(test_conn, 200); + assert_status!(test_conn, Status::Created); let agg_init_resp_2 = AggregationJobResp::get_decoded(take_response_body(&mut test_conn).await.as_ref()).unwrap(); + let prepare_resps_2 = assert_matches!( + agg_init_resp_2, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); assert_matches!( - agg_init_resp_2.prepare_resps()[0].result(), + prepare_resps_2[0].result(), PrepareStepResult::Reject(ReportError::ReportReplayed) ); assert_matches!( - agg_init_resp_2.prepare_resps()[1].result(), + prepare_resps_2[1].result(), PrepareStepResult::Continue { .. } ); diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 71993bb00..104a59772 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -321,15 +321,19 @@ async fn taskprov_aggregate_init() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok), "{name}"); + assert_eq!(test_conn.status(), Some(Status::Created), "{name}"); assert_headers!( &test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE) ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - assert_eq!(aggregate_resp.prepare_resps().len(), 1, "{}", name); - let prepare_step = aggregate_resp.prepare_resps().first().unwrap(); + assert_eq!(prepare_resps.len(), 1, "{}", name); + let prepare_step = prepare_resps.first().unwrap(); assert_eq!( prepare_step.report_id(), report_share.metadata().id(), @@ -424,15 +428,19 @@ async fn taskprov_aggregate_init_missing_extension() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!( &test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE) ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_resps().first().unwrap(); + assert_eq!(prepare_resps.len(), 1); + let prepare_step = prepare_resps.first().unwrap(); assert_eq!(prepare_step.report_id(), report_share.metadata().id(),); assert_eq!( prepare_step.result(), @@ -507,15 +515,19 @@ async fn taskprov_aggregate_init_malformed_extension() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!( &test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE) ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregate_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_resps().first().unwrap(); + assert_eq!(prepare_resps.len(), 1); + let prepare_step = prepare_resps.first().unwrap(); assert_eq!(prepare_step.report_id(), report_share.metadata().id(),); assert_eq!( prepare_step.result(), @@ -928,7 +940,7 @@ async fn taskprov_aggregate_continue() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Accepted)); assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; @@ -936,10 +948,12 @@ async fn taskprov_aggregate_continue() { // authorization of the request. assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([PrepareResp::new( - *report_share.metadata().id(), - PrepareStepResult::Finished - )])) + AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report_share.metadata().id(), + PrepareStepResult::Finished + )]) + } ); } @@ -1097,12 +1111,16 @@ async fn end_to_end() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); - let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resps.len(), 1); + let prepare_resp = &prepare_resps[0]; assert_eq!(prepare_resp.report_id(), report_share.metadata().id()); let message = assert_matches!( prepare_resp.result(), @@ -1137,12 +1155,16 @@ async fn end_to_end() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Accepted)); assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); - let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resps.len(), 1); + let prepare_resp = &prepare_resps[0]; assert_eq!(prepare_resp.report_id(), report_share.metadata().id()); assert_matches!(prepare_resp.result(), PrepareStepResult::Finished); @@ -1239,12 +1261,16 @@ async fn end_to_end_sumvec_hmac() { .run_async(&test.handler) .await; - assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_eq!(test_conn.status(), Some(Status::Created)); assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + let prepare_resps = assert_matches!( + aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); - assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); - let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resps.len(), 1); + let prepare_resp = &prepare_resps[0]; assert_eq!(prepare_resp.report_id(), report_share.metadata().id()); let message = assert_matches!(prepare_resp.result(), PrepareStepResult::Continue { message } => message.clone()); assert_eq!(message, transcript.helper_prepare_transitions[0].message); diff --git a/collector/src/lib.rs b/collector/src/lib.rs index b0923dd09..6853d445d 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -78,8 +78,8 @@ use janus_core::{ }; use janus_messages::{ batch_mode::{BatchMode, TimeInterval}, - AggregateShareAad, BatchSelector, Collection as CollectionMessage, CollectionJobId, - CollectionReq, PartialBatchSelector, Query, Role, TaskId, + AggregateShareAad, BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp, + PartialBatchSelector, Query, Role, TaskId, }; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, @@ -486,7 +486,7 @@ impl Collector { aggregation_parameter: &V::AggregationParam, ) -> Result, Error> { let collect_request = - CollectionReq::new(query.clone(), aggregation_parameter.get_encoded()?) + CollectionJobReq::new(query.clone(), aggregation_parameter.get_encoded()?) .get_encoded()?; let collection_job_url = self.collection_job_uri(collection_job_id)?; @@ -495,7 +495,7 @@ impl Collector { let (auth_header, auth_value) = self.authentication.request_authentication(); self.http_client .put(collection_job_url.clone()) - .header(CONTENT_TYPE, CollectionReq::::MEDIA_TYPE) + .header(CONTENT_TYPE, CollectionJobReq::::MEDIA_TYPE) .body(collect_request.clone()) .header(auth_header, auth_value) .send() @@ -551,21 +551,12 @@ impl Collector { let response = match response_res { // Successful response. Ok(response) => { - let status = response.status(); - match status { - StatusCode::OK => response, - StatusCode::ACCEPTED => { - let retry_after_opt = response - .headers() - .get(RETRY_AFTER) - .map(RetryAfter::try_from) - .transpose()?; - return Ok(PollResult::NotReady(retry_after_opt)); - } - _ => { - return Err(Error::Http(Box::new(HttpErrorResponse::from(status)))); - } + if response.status() != StatusCode::OK { + return Err(Error::Http(Box::new(HttpErrorResponse::from( + response.status(), + )))); } + response } // HTTP-level error. @@ -579,21 +570,45 @@ impl Collector { .headers() .get(CONTENT_TYPE) .ok_or(Error::BadContentType(None))?; - if content_type != CollectionMessage::::MEDIA_TYPE { + if content_type != CollectionJobResp::::MEDIA_TYPE { return Err(Error::BadContentType(Some(content_type.clone()))); } - let collect_response = CollectionMessage::::get_decoded(response.body())?; + let collect_response = CollectionJobResp::::get_decoded(response.body())?; + let ( + partial_batch_selector, + report_count, + interval, + leader_encrypted_aggregate_share, + helper_encrypted_aggregate_share, + ) = match &collect_response { + CollectionJobResp::Finished { + partial_batch_selector, + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + } => ( + partial_batch_selector, + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + ), + + CollectionJobResp::Processing => { + let retry_after_opt = response + .headers() + .get(RETRY_AFTER) + .map(RetryAfter::try_from) + .transpose()?; + return Ok(PollResult::NotReady(retry_after_opt)); + } + }; let aggregate_shares = [ - ( - Role::Leader, - collect_response.leader_encrypted_aggregate_share(), - ), - ( - Role::Helper, - collect_response.helper_encrypted_aggregate_share(), - ), + (Role::Leader, leader_encrypted_aggregate_share), + (Role::Helper, helper_encrypted_aggregate_share), ] .into_iter() .map(|(role, encrypted_aggregate_share)| { @@ -607,7 +622,7 @@ impl Collector { BatchSelector::::new(B::batch_identifier_for_collection( &job.query, &collect_response, - )), + )?), ) .get_encoded()?, )?; @@ -619,23 +634,18 @@ impl Collector { }) .collect::, Error>>()?; - let report_count = collect_response - .report_count() - .try_into() - .map_err(|_| Error::ReportCountOverflow)?; - let aggregate_result = - self.vdaf - .unshard(&job.aggregation_parameter, aggregate_shares, report_count)?; + let aggregate_result = self.vdaf.unshard( + &job.aggregation_parameter, + aggregate_shares, + usize::try_from(*report_count).map_err(|_| Error::ReportCountOverflow)?, + )?; Ok(PollResult::CollectionResult(Collection { - partial_batch_selector: collect_response.partial_batch_selector().clone(), - report_count: collect_response.report_count(), + partial_batch_selector: partial_batch_selector.clone(), + report_count: *report_count, interval: ( - Utc.from_utc_datetime(&collect_response.interval().start().as_naive_date_time()?), - collect_response - .interval() - .duration() - .as_chrono_duration()?, + Utc.from_utc_datetime(&interval.start().as_naive_date_time()?), + interval.duration().as_chrono_duration()?, ), aggregate_result, })) @@ -770,9 +780,9 @@ mod tests { use janus_messages::{ batch_mode::{LeaderSelected, TimeInterval}, problem_type::DapProblemType, - AggregateShareAad, BatchId, BatchSelector, Collection as CollectionMessage, - CollectionJobId, CollectionReq, Duration, HpkeCiphertext, Interval, PartialBatchSelector, - Query, Role, TaskId, Time, + AggregateShareAad, BatchId, BatchSelector, CollectionJobId, CollectionJobReq, + CollectionJobResp, Duration, HpkeCiphertext, Interval, PartialBatchSelector, Query, Role, + TaskId, Time, }; use mockito::Matcher; use prio::{ @@ -819,31 +829,31 @@ mod tests { collector: &Collector, aggregation_parameter: &V::AggregationParam, batch_interval: Interval, - ) -> CollectionMessage { + ) -> CollectionJobResp { let associated_data = AggregateShareAad::new( collector.task_id, aggregation_parameter.get_encoded().unwrap(), BatchSelector::new_time_interval(batch_interval), ); - CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 1, - batch_interval, - hpke::seal( + CollectionJobResp::Finished { + partial_batch_selector: PartialBatchSelector::new_time_interval(), + report_count: 1, + interval: batch_interval, + leader_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), &transcript.leader_aggregate_share.get_encoded().unwrap(), &associated_data.get_encoded().unwrap(), ) .unwrap(), - hpke::seal( + helper_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), &transcript.helper_aggregate_share.get_encoded().unwrap(), &associated_data.get_encoded().unwrap(), ) .unwrap(), - ) + } } fn build_collect_response_fixed< @@ -854,31 +864,32 @@ mod tests { collector: &Collector, aggregation_parameter: &V::AggregationParam, batch_id: BatchId, - ) -> CollectionMessage { + ) -> CollectionJobResp { let associated_data = AggregateShareAad::new( collector.task_id, aggregation_parameter.get_encoded().unwrap(), BatchSelector::new_leader_selected(batch_id), ); - CollectionMessage::new( - PartialBatchSelector::new_leader_selected(batch_id), - 1, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - hpke::seal( + CollectionJobResp::Finished { + partial_batch_selector: PartialBatchSelector::new_leader_selected(batch_id), + report_count: 1, + interval: Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + leader_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), &transcript.leader_aggregate_share.get_encoded().unwrap(), &associated_data.get_encoded().unwrap(), ) .unwrap(), - hpke::seal( + helper_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), &transcript.helper_aggregate_share.get_encoded().unwrap(), &associated_data.get_encoded().unwrap(), ) .unwrap(), - ) + } } #[test] @@ -924,6 +935,7 @@ mod tests { Duration::from_seconds(3600), ) .unwrap(); + let processing_collect_resp = CollectionJobResp::::Processing; let collect_resp = build_collect_response_time(&transcript, &collector, &(), batch_interval); let matcher = collection_uri_regex_matcher(&collector.task_id); @@ -932,7 +944,7 @@ mod tests { .mock("PUT", matcher.clone()) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(500) .expect(1) @@ -942,7 +954,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .match_header(auth_header, auth_value.as_str()) .with_status(201) @@ -972,7 +984,12 @@ mod tests { .await; let mocked_collect_accepted = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(processing_collect_resp.get_encoded().unwrap()) .expect(2) .create_async() .await; @@ -982,7 +999,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) @@ -1032,7 +1049,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(201) .expect(1) @@ -1055,7 +1072,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) @@ -1100,7 +1117,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(201) .expect(1) @@ -1124,7 +1141,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) @@ -1179,7 +1196,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(201) .expect(1) @@ -1203,7 +1220,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) @@ -1243,7 +1260,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(201) .expect(1) @@ -1266,7 +1283,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) @@ -1323,7 +1340,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .match_header(AUTHORIZATION.as_str(), "Bearer AAAAAAAAAAAAAAAA") .with_status(201) @@ -1349,7 +1366,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) @@ -1385,7 +1402,7 @@ mod tests { .mock("PUT", matcher.clone()) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(500) .expect_at_least(1) @@ -1412,7 +1429,7 @@ mod tests { .mock("PUT", matcher.clone()) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(500) .with_header("Content-Type", "application/problem+json") @@ -1437,7 +1454,7 @@ mod tests { .mock("PUT", matcher) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(400) .with_header("Content-Type", "application/problem+json") @@ -1476,7 +1493,7 @@ mod tests { .mock("PUT", matcher.clone()) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(201) .expect(1) @@ -1559,7 +1576,7 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(b"") .expect_at_least(1) @@ -1576,24 +1593,24 @@ mod tests { .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body( - CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 1, - batch_interval, - HpkeCiphertext::new( + CollectionJobResp::Finished { + partial_batch_selector: PartialBatchSelector::new_time_interval(), + report_count: 1, + interval: batch_interval, + leader_encrypted_agg_share: HpkeCiphertext::new( *collector.hpke_keypair.config().id(), Vec::new(), Vec::new(), ), - HpkeCiphertext::new( + helper_encrypted_agg_share: HpkeCiphertext::new( *collector.hpke_keypair.config().id(), Vec::new(), Vec::new(), ), - ) + } .get_encoded() .unwrap(), ) @@ -1611,31 +1628,31 @@ mod tests { ().get_encoded().unwrap(), BatchSelector::new_time_interval(batch_interval), ); - let collect_resp = CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 1, - batch_interval, - hpke::seal( + let collect_resp = CollectionJobResp::Finished { + partial_batch_selector: PartialBatchSelector::new_time_interval(), + report_count: 1, + interval: batch_interval, + leader_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), b"bad", &associated_data.get_encoded().unwrap(), ) .unwrap(), - hpke::seal( + helper_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), b"bad", &associated_data.get_encoded().unwrap(), ) .unwrap(), - ); + }; let mock_collection_job_bad_shares = server .mock("GET", collection_job_path.as_str()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect_at_least(1) @@ -1647,11 +1664,11 @@ mod tests { mock_collection_job_bad_shares.assert_async().await; - let collect_resp = CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 1, - batch_interval, - hpke::seal( + let collect_resp = CollectionJobResp::Finished { + partial_batch_selector: PartialBatchSelector::new_time_interval(), + report_count: 1, + interval: batch_interval, + leader_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) @@ -1660,7 +1677,7 @@ mod tests { &associated_data.get_encoded().unwrap(), ) .unwrap(), - hpke::seal( + helper_encrypted_agg_share: hpke::seal( collector.hpke_keypair.config(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), &AggregateShare::from(OutputShare::from(Vec::from([ @@ -1672,13 +1689,13 @@ mod tests { &associated_data.get_encoded().unwrap(), ) .unwrap(), - ); + }; let mock_collection_job_wrong_length = server .mock("GET", collection_job_path.as_str()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect_at_least(1) @@ -1716,7 +1733,7 @@ mod tests { .mock("PUT", matcher.clone()) .match_header( CONTENT_TYPE.as_str(), - CollectionReq::::MEDIA_TYPE, + CollectionJobReq::::MEDIA_TYPE, ) .with_status(201) .expect(1) @@ -1737,9 +1754,16 @@ mod tests { "/tasks/{}/collection_jobs/{}", collector.task_id, job.collection_job_id ); + let collect_resp = CollectionJobResp::::Processing; + let mock_collect_poll_no_retry_after = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; @@ -1751,8 +1775,13 @@ mod tests { let mock_collect_poll_retry_after_60s = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", "60") + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; @@ -1764,8 +1793,13 @@ mod tests { let mock_collect_poll_retry_after_date_time = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT") + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; @@ -1806,17 +1840,29 @@ mod tests { (), ); + let collect_resp = CollectionJobResp::::Processing; + let mock_collect_poll_retry_after_1s = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", "1") + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; let mock_collect_poll_retry_after_10s = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", "10") + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; @@ -1832,22 +1878,37 @@ mod tests { let near_future_formatted = near_future.format("%a, %d %b %Y %H:%M:%S GMT").to_string(); let mock_collect_poll_retry_after_near_future = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", &near_future_formatted) + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; let mock_collect_poll_retry_after_past = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", "Mon, 01 Jan 1900 00:00:00 GMT") + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; let mock_collect_poll_retry_after_far_future = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) .with_header("Retry-After", "Wed, 01 Jan 3000 00:00:00 GMT") + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect(1) .create_async() .await; @@ -1870,7 +1931,12 @@ mod tests { std::time::Duration::from_millis(10); let mock_collect_poll_no_retry_after = server .mock("GET", collection_job_path.as_str()) - .with_status(202) + .with_status(200) + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(collect_resp.get_encoded().unwrap()) .expect_at_least(1) .create_async() .await; @@ -1965,6 +2031,7 @@ mod tests { Duration::from_seconds(3600), ) .unwrap(); + let processing_collect_resp = CollectionJobResp::::Processing; let collect_resp = build_collect_response_time(&transcript, &collector, &(), batch_interval); @@ -1986,19 +2053,24 @@ mod tests { .await; let mocked_collect_accepted = server .mock("GET", collection_job_path.as_str()) + .with_status(200) .match_header(CONTENT_LENGTH.as_str(), "0") - .with_status(202) + .with_header( + CONTENT_TYPE.as_str(), + CollectionJobResp::::MEDIA_TYPE, + ) + .with_body(processing_collect_resp.get_encoded().unwrap()) .expect(2) .create_async() .await; let mocked_collect_complete = server .mock("GET", collection_job_path.as_str()) + .with_status(200) .match_header(auth_header, auth_value.as_str()) .match_header(CONTENT_LENGTH.as_str(), "0") - .with_status(200) .with_header( CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, + CollectionJobResp::::MEDIA_TYPE, ) .with_body(collect_resp.get_encoded().unwrap()) .expect(1) diff --git a/messages/src/batch_mode.rs b/messages/src/batch_mode.rs index b94c68dbd..0d7ee901f 100644 --- a/messages/src/batch_mode.rs +++ b/messages/src/batch_mode.rs @@ -1,4 +1,4 @@ -use crate::{BatchId, Collection, Interval, Query}; +use crate::{BatchId, CollectionJobResp, Error, Interval, Query}; use anyhow::anyhow; use num_enum::TryFromPrimitive; use prio::codec::{CodecError, Decode, Encode}; @@ -55,8 +55,8 @@ pub trait BatchMode: Clone + Debug + PartialEq + Eq + Send + Sync + 'static { /// Retrieves the batch identifier associated with an ongoing collection. fn batch_identifier_for_collection( query: &Query, - collect_resp: &Collection, - ) -> Self::BatchIdentifier; + collection_job_resp: &CollectionJobResp, + ) -> Result; } /// Represents the `time-interval` DAP batch mode. @@ -76,9 +76,9 @@ impl BatchMode for TimeInterval { fn batch_identifier_for_collection( query: &Query, - _: &Collection, - ) -> Self::BatchIdentifier { - *query.batch_interval() + _: &CollectionJobResp, + ) -> Result { + Ok(*query.batch_interval()) } } @@ -101,9 +101,17 @@ impl BatchMode for LeaderSelected { fn batch_identifier_for_collection( _: &Query, - collect_resp: &Collection, - ) -> Self::BatchIdentifier { - *collect_resp.partial_batch_selector().batch_identifier() + collection_job_resp: &CollectionJobResp, + ) -> Result { + match collection_job_resp { + CollectionJobResp::Processing => Err(Error::InvalidParameter( + "collection job resp in Processing state", + )), + CollectionJobResp::Finished { + partial_batch_selector, + .. + } => Ok(*partial_batch_selector.batch_identifier()), + } } } diff --git a/messages/src/lib.rs b/messages/src/lib.rs index b27925640..fa5b0bf4f 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -1527,15 +1527,15 @@ impl Decode for Query { /// aggregate shares for a given batch. #[derive(Clone, Educe, PartialEq, Eq)] #[educe(Debug)] -pub struct CollectionReq { +pub struct CollectionJobReq { query: Query, #[educe(Debug(ignore))] aggregation_parameter: Vec, } -impl CollectionReq { +impl CollectionJobReq { /// The media type associated with this protocol message. - pub const MEDIA_TYPE: &'static str = "application/dap-collect-req"; + pub const MEDIA_TYPE: &'static str = "application/dap-collection-job-req"; /// Constructs a new collect request from its components. pub fn new(query: Query, aggregation_parameter: Vec) -> Self { @@ -1556,7 +1556,7 @@ impl CollectionReq { } } -impl Encode for CollectionReq { +impl Encode for CollectionJobReq { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { self.query.encode(bytes)?; encode_u32_items(bytes, &(), &self.aggregation_parameter) @@ -1567,7 +1567,7 @@ impl Encode for CollectionReq { } } -impl Decode for CollectionReq { +impl Decode for CollectionJobReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let query = Query::decode(bytes)?; let aggregation_parameter = decode_u32_items(&(), bytes)?; @@ -1704,95 +1704,86 @@ impl Distribution for Standard { /// DAP protocol message representing a leader's response to the collector's request to provide /// aggregate shares for a given query. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Collection { - partial_batch_selector: PartialBatchSelector, - report_count: u64, - interval: Interval, - leader_encrypted_agg_share: HpkeCiphertext, - helper_encrypted_agg_share: HpkeCiphertext, -} - -impl Collection { - /// The media type associated with this protocol message. - pub const MEDIA_TYPE: &'static str = "application/dap-collection"; - - /// Constructs a new collection. - pub fn new( +pub enum CollectionJobResp { + Processing, + Finished { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, leader_encrypted_agg_share: HpkeCiphertext, helper_encrypted_agg_share: HpkeCiphertext, - ) -> Self { - Self { - partial_batch_selector, - report_count, - interval, - leader_encrypted_agg_share, - helper_encrypted_agg_share, - } - } - - /// Retrieves the batch selector associated with this collection. - pub fn partial_batch_selector(&self) -> &PartialBatchSelector { - &self.partial_batch_selector - } - - /// Retrieves the number of reports that were aggregated into this collection. - pub fn report_count(&self) -> u64 { - self.report_count - } - - /// Retrieves the interval spanned by the reports aggregated into this collection. - pub fn interval(&self) -> &Interval { - &self.interval - } - - /// Retrieves the leader encrypted aggregate share associated with this collection. - pub fn leader_encrypted_aggregate_share(&self) -> &HpkeCiphertext { - &self.leader_encrypted_agg_share - } + }, +} - /// Retrieves the helper encrypted aggregate share associated with this collection. - pub fn helper_encrypted_aggregate_share(&self) -> &HpkeCiphertext { - &self.helper_encrypted_agg_share - } +impl CollectionJobResp { + /// The media type associated with this protocol message. + pub const MEDIA_TYPE: &'static str = "application/dap-collection-job-resp"; } -impl Encode for Collection { +impl Encode for CollectionJobResp { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - self.partial_batch_selector.encode(bytes)?; - self.report_count.encode(bytes)?; - self.interval.encode(bytes)?; - self.leader_encrypted_agg_share.encode(bytes)?; - self.helper_encrypted_agg_share.encode(bytes) + // The encoding includes an implicit discriminator byte called CollectionJobStatus in the + // DAP specification. + match self { + Self::Processing => 0u8.encode(bytes), + Self::Finished { + partial_batch_selector, + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + } => { + 1u8.encode(bytes)?; + partial_batch_selector.encode(bytes)?; + report_count.encode(bytes)?; + interval.encode(bytes)?; + leader_encrypted_agg_share.encode(bytes)?; + helper_encrypted_agg_share.encode(bytes) + } + } } fn encoded_len(&self) -> Option { - Some( - self.partial_batch_selector.encoded_len()? - + self.report_count.encoded_len()? - + self.interval.encoded_len()? - + self.leader_encrypted_agg_share.encoded_len()? - + self.helper_encrypted_agg_share.encoded_len()?, - ) + Some(match self { + Self::Processing => 1, + Self::Finished { + partial_batch_selector, + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + } => { + 1 + partial_batch_selector.encoded_len()? + + report_count.encoded_len()? + + interval.encoded_len()? + + leader_encrypted_agg_share.encoded_len()? + + helper_encrypted_agg_share.encoded_len()? + } + }) } } -impl Decode for Collection { +impl Decode for CollectionJobResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { - let partial_batch_selector = PartialBatchSelector::decode(bytes)?; - let report_count = u64::decode(bytes)?; - let interval = Interval::decode(bytes)?; - let leader_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; - let helper_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; - - Ok(Self { - partial_batch_selector, - report_count, - interval, - leader_encrypted_agg_share, - helper_encrypted_agg_share, + let val = u8::decode(bytes)?; + Ok(match val { + 0 => Self::Processing, + 1 => { + let partial_batch_selector = PartialBatchSelector::decode(bytes)?; + let report_count = u64::decode(bytes)?; + let interval = Interval::decode(bytes)?; + let leader_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; + + Self::Finished { + partial_batch_selector, + report_count, + interval, + leader_encrypted_agg_share, + helper_encrypted_agg_share, + } + } + _ => return Err(CodecError::UnexpectedValue), }) } } @@ -2491,43 +2482,54 @@ impl Decode for AggregationJobContinueReq { /// DAP protocol message representing the response to an aggregation job initialization or /// continuation request. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct AggregationJobResp { - prepare_resps: Vec, +pub enum AggregationJobResp { + Processing, + Finished { prepare_resps: Vec }, } impl AggregationJobResp { /// The media type associated with this protocol message. pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-resp"; - - /// Constructs a new aggregate continuation response from its components. - pub fn new(prepare_resps: Vec) -> Self { - Self { prepare_resps } - } - - /// Gets the prepare responses associated with this aggregate continuation response. - pub fn prepare_resps(&self) -> &[PrepareResp] { - &self.prepare_resps - } } impl Encode for AggregationJobResp { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - encode_u32_items(bytes, &(), &self.prepare_resps) + // The encoding includes an implicit discriminator byte called AggregationJobStatus in the + // DAP specification. + match self { + Self::Processing => 0u8.encode(bytes), + Self::Finished { prepare_resps } => { + 1u8.encode(bytes)?; + encode_u32_items(bytes, &(), prepare_resps) + } + } } fn encoded_len(&self) -> Option { - let mut length = 4; - for prepare_resp in self.prepare_resps.iter() { - length += prepare_resp.encoded_len()?; - } - Some(length) + Some(match self { + Self::Processing => 1, + Self::Finished { prepare_resps } => { + let mut len = 5; + for prepare_resp in prepare_resps { + len += prepare_resp.encoded_len()?; + } + len + } + }) } } impl Decode for AggregationJobResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { - let prepare_resps = decode_u32_items(&(), bytes)?; - Ok(Self { prepare_resps }) + let val = u8::decode(bytes)?; + Ok(match val { + 0 => Self::Processing, + 1 => { + let prepare_resps = decode_u32_items(&(), bytes)?; + Self::Finished { prepare_resps } + } + _ => return Err(CodecError::UnexpectedValue), + }) } } diff --git a/messages/src/tests/aggregation.rs b/messages/src/tests/aggregation.rs index 1ff9e84c2..502ffa18a 100644 --- a/messages/src/tests/aggregation.rs +++ b/messages/src/tests/aggregation.rs @@ -693,53 +693,64 @@ fn roundtrip_aggregation_job_continue_req() { #[test] fn roundtrip_aggregation_job_resp() { - roundtrip_encoding(&[( - AggregationJobResp { - prepare_resps: Vec::from([ - PrepareResp { - report_id: ReportId::from([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - ]), - result: PrepareStepResult::Continue { - message: PingPongMessage::Continue { - prep_msg: Vec::from("01234"), - prep_share: Vec::from("56789"), + roundtrip_encoding(&[ + ( + AggregationJobResp::Processing, + concat!( + "00", // status + ), + ), + ( + AggregationJobResp::Finished { + prepare_resps: Vec::from([ + PrepareResp { + report_id: ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + result: PrepareStepResult::Continue { + message: PingPongMessage::Continue { + prep_msg: Vec::from("01234"), + prep_share: Vec::from("56789"), + }, }, }, - }, - PrepareResp { - report_id: ReportId::from([ - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, - ]), - result: PrepareStepResult::Finished, - }, - ]), - }, - concat!(concat!( - // prepare_steps - "00000039", // length + PrepareResp { + report_id: ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + result: PrepareStepResult::Finished, + }, + ]), + }, concat!( - "0102030405060708090A0B0C0D0E0F10", // report_id - "00", // prepare_step_result + "01", // status concat!( - "00000013", // ping pong message length - "01", // ping pong message type + // prepare_steps + "00000039", // length concat!( - // prep_msg - "00000005", // prep_msg length - "3031323334", // opaque data + "0102030405060708090A0B0C0D0E0F10", // report_id + "00", // prepare_step_result + concat!( + "00000013", // ping pong message length + "01", // ping pong message type + concat!( + // prep_msg + "00000005", // prep_msg length + "3031323334", // opaque data + ), + concat!( + // prep_share + "00000005", // prep_share length + "3536373839", // opaque data + ) + ), ), concat!( - // prep_share - "00000005", // prep_share length - "3536373839", // opaque data + "100F0E0D0C0B0A090807060504030201", // report_id + "01", // prepare_step_result ) ), ), - concat!( - "100F0E0D0C0B0A090807060504030201", // report_id - "01", // prepare_step_result - ) - ),), - )]) + ), + ]) } diff --git a/messages/src/tests/collection.rs b/messages/src/tests/collection.rs index 94a55dfb5..25840b587 100644 --- a/messages/src/tests/collection.rs +++ b/messages/src/tests/collection.rs @@ -1,16 +1,17 @@ use crate::{ roundtrip_encoding, AggregateShare, AggregateShareAad, AggregateShareReq, BatchId, - BatchSelector, Collection, CollectionReq, Duration, HpkeCiphertext, HpkeConfigId, Interval, - LeaderSelected, PartialBatchSelector, Query, ReportIdChecksum, TaskId, Time, TimeInterval, + BatchSelector, CollectionJobReq, CollectionJobResp, Duration, HpkeCiphertext, HpkeConfigId, + Interval, LeaderSelected, PartialBatchSelector, Query, ReportIdChecksum, TaskId, Time, + TimeInterval, }; use prio::codec::Decode; #[test] -fn roundtrip_collection_req() { +fn roundtrip_collection_job_req() { // TimeInterval. roundtrip_encoding(&[ ( - CollectionReq:: { + CollectionJobReq:: { query: Query { query_body: Interval::new( Time::from_seconds_since_epoch(54321), @@ -39,7 +40,7 @@ fn roundtrip_collection_req() { ), ), ( - CollectionReq:: { + CollectionJobReq:: { query: Query { query_body: Interval::new( Time::from_seconds_since_epoch(48913), @@ -72,7 +73,7 @@ fn roundtrip_collection_req() { // LeaderSelected. roundtrip_encoding(&[ ( - CollectionReq:: { + CollectionJobReq:: { query: Query { query_body: () }, aggregation_parameter: Vec::new(), }, @@ -90,7 +91,7 @@ fn roundtrip_collection_req() { ), ), ( - CollectionReq:: { + CollectionJobReq:: { query: Query { query_body: () }, aggregation_parameter: Vec::from("012345"), }, @@ -144,7 +145,7 @@ fn roundtrip_partial_batch_selector() { } #[test] -fn roundtrip_collection() { +fn roundtrip_collection_job_resp() { let interval = Interval { start: Time::from_seconds_since_epoch(54321), duration: Duration::from_seconds(12345), @@ -152,7 +153,7 @@ fn roundtrip_collection() { // TimeInterval. roundtrip_encoding(&[ ( - Collection { + CollectionJobResp::Finished { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 0, interval, @@ -168,6 +169,7 @@ fn roundtrip_collection() { ), }, concat!( + "01", // status concat!( // partial_batch_selector "01", // batch_mode @@ -211,7 +213,7 @@ fn roundtrip_collection() { ), ), ( - Collection { + CollectionJobResp::Finished { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 23, interval, @@ -227,6 +229,7 @@ fn roundtrip_collection() { ), }, concat!( + "01", // status concat!( // partial_batch_selector "01", // batch_mode @@ -274,7 +277,7 @@ fn roundtrip_collection() { // LeaderSelected. roundtrip_encoding(&[ ( - Collection { + CollectionJobResp::Finished { partial_batch_selector: PartialBatchSelector::new_leader_selected(BatchId::from( [3u8; 32], )), @@ -292,6 +295,7 @@ fn roundtrip_collection() { ), }, concat!( + "01", // status concat!( // partial_batch_selector "02", // batch_mode @@ -335,7 +339,7 @@ fn roundtrip_collection() { ), ), ( - Collection { + CollectionJobResp::Finished { partial_batch_selector: PartialBatchSelector::new_leader_selected(BatchId::from( [4u8; 32], )), @@ -353,6 +357,7 @@ fn roundtrip_collection() { ), }, concat!( + "01", // status concat!( // partial_batch_selector "02", // batch_mode diff --git a/tools/src/bin/dap_decode.rs b/tools/src/bin/dap_decode.rs index 3d6d72d71..350009cac 100644 --- a/tools/src/bin/dap_decode.rs +++ b/tools/src/bin/dap_decode.rs @@ -3,7 +3,7 @@ use clap::{Parser, ValueEnum}; use janus_messages::{ batch_mode::{LeaderSelected, TimeInterval}, AggregateShare, AggregateShareReq, AggregationJobContinueReq, AggregationJobInitializeReq, - AggregationJobResp, Collection, CollectionReq, HpkeConfig, HpkeConfigList, Report, + AggregationJobResp, CollectionJobReq, CollectionJobResp, HpkeConfig, HpkeConfigList, Report, }; use prio::codec::Decode; use std::{ @@ -82,22 +82,22 @@ fn decode_dap_message(message_file: &str, media_type: &MediaType) -> Result { - if let Ok(decoded) = CollectionReq::::get_decoded(&message_buf) { - let message: CollectionReq = decoded; + if let Ok(decoded) = CollectionJobReq::::get_decoded(&message_buf) { + let message: CollectionJobReq = decoded; Box::new(message) } else { - let message: CollectionReq = - CollectionReq::::get_decoded(&message_buf)?; + let message: CollectionJobReq = + CollectionJobReq::::get_decoded(&message_buf)?; Box::new(message) } } MediaType::Collection => { - if let Ok(decoded) = Collection::::get_decoded(&message_buf) { - let message: Collection = decoded; + if let Ok(decoded) = CollectionJobResp::::get_decoded(&message_buf) { + let message: CollectionJobResp = decoded; Box::new(message) } else { - let message: Collection = - Collection::::get_decoded(&message_buf)?; + let message: CollectionJobResp = + CollectionJobResp::::get_decoded(&message_buf)?; Box::new(message) } }