From 88c1962c4571ec5aa6cd59761e95b18fba6bf98f Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 24 Aug 2023 15:17:40 -0700 Subject: [PATCH] DAP-05: Specialize to exactly two aggregators This commit takes changes to various messages so that we handle exactly one leader and exactly one helper, as specified in [DAP-05][1]. These changes are voluminous but mostly mechanical. This commit also adds variants for `Poplar1` to a couple of enums and some macros, because later changes will need it. [1]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/05/ Part of #1669 --- Cargo.lock | 1 + aggregator/src/aggregator.rs | 157 +++--- .../src/aggregator/aggregation_job_creator.rs | 406 ++++++------- .../src/aggregator/aggregation_job_driver.rs | 531 ++++++++---------- .../src/aggregator/collection_job_driver.rs | 11 +- .../src/aggregator/collection_job_tests.rs | 49 +- aggregator/src/aggregator/http_handlers.rs | 292 +++++----- aggregator/src/aggregator/taskprov_tests.rs | 52 +- aggregator/src/bin/janus_cli.rs | 10 +- aggregator_api/src/models.rs | 8 +- aggregator_api/src/routes.rs | 12 +- aggregator_api/src/tests.rs | 14 +- aggregator_core/src/datastore.rs | 45 +- aggregator_core/src/datastore/tests.rs | 121 ++-- aggregator_core/src/task.rs | 206 +++---- aggregator_core/src/taskprov.rs | 6 +- client/Cargo.toml | 3 +- client/src/lib.rs | 129 ++--- collector/src/lib.rs | 393 +++++-------- core/Cargo.toml | 5 +- core/src/task.rs | 193 ++++--- db/00000000000001_initial_schema.up.sql | 29 +- docs/samples/tasks.yaml | 12 +- integration_tests/src/client.rs | 11 +- integration_tests/src/daphne.rs | 8 +- integration_tests/src/janus.rs | 6 +- integration_tests/src/lib.rs | 12 +- integration_tests/tests/common/mod.rs | 22 +- integration_tests/tests/daphne.rs | 15 +- .../src/bin/janus_interop_aggregator.rs | 3 +- .../src/bin/janus_interop_client.rs | 7 +- interop_binaries/src/lib.rs | 4 +- interop_binaries/tests/end_to_end.rs | 4 +- messages/src/lib.rs | 413 +++++++++----- 34 files changed, 1574 insertions(+), 1616 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d99657dd5..05a8fb679 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1986,6 +1986,7 @@ dependencies = [ "derivative", "http", "http-api-problem", + "itertools", "janus_core", "janus_messages", "mockito", diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index f7074c53c..c14a9182e 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -41,7 +41,7 @@ use janus_core::test_util::dummy_vdaf; use janus_core::{ hpke::{self, HpkeApplicationInfo, HpkeKeypair, Label}, http::response_to_problem_details, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LENGTH}, time::{Clock, DurationExt, IntervalExt, TimeExt}, }; use janus_messages::{ @@ -50,8 +50,8 @@ use janus_messages::{ taskprov::TaskConfig, AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, - BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeCiphertext, - HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, + BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeConfig, + HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Report, ReportIdChecksum, ReportShare, ReportShareError, Role, TaskId, }; @@ -67,6 +67,8 @@ use prio::{ codec::{Decode, Encode, ParameterizedDecode}, vdaf::{ self, + poplar1::Poplar1, + prg::PrgSha3, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded}, }, }; @@ -543,7 +545,7 @@ impl Aggregator { // have to use the peer aggregator's collector config rather than the main task. let collector_hpke_config = if self.cfg.taskprov_config.enabled && taskprov_task_config.is_some() { - let (peer_aggregator, _) = self + let (peer_aggregator, _, _) = self .taskprov_authorize_request( &Role::Leader, task_id, @@ -634,7 +636,7 @@ impl Aggregator { task_config: &TaskConfig, aggregator_auth_token: Option<&AuthenticationToken>, ) -> Result<(), Error> { - let (peer_aggregator, aggregator_urls) = self + let (peer_aggregator, leader_url, helper_url) = self .taskprov_authorize_request(peer_role, task_id, task_config, aggregator_auth_token) .await?; @@ -665,7 +667,8 @@ impl Aggregator { let task = taskprov::Task::new( *task_id, - aggregator_urls, + leader_url, + helper_url, task_config.query_config().query().try_into()?, vdaf_instance, our_role, @@ -715,13 +718,13 @@ impl Aggregator { task_id: &TaskId, task_config: &TaskConfig, aggregator_auth_token: Option<&AuthenticationToken>, - ) -> Result<(&PeerAggregator, Vec), Error> { + ) -> Result<(&PeerAggregator, Url, Url), Error> { let aggregator_urls = task_config .aggregator_endpoints() .iter() .map(|url| url.try_into()) .collect::, _>>()?; - if aggregator_urls.len() < 2 { + if aggregator_urls.len() != 2 { return Err(Error::UnrecognizedMessage( Some(*task_id), "taskprov configuration is missing one or both aggregators", @@ -754,7 +757,11 @@ impl Aggregator { ?peer_aggregator, "taskprov: authorized request" ); - Ok((peer_aggregator, aggregator_urls)) + Ok(( + peer_aggregator, + aggregator_urls[Role::Leader.index().unwrap()].clone(), + aggregator_urls[Role::Helper.index().unwrap()].clone(), + )) } #[cfg(feature = "test-util")] @@ -834,6 +841,12 @@ impl TaskAggregator { VdafOps::Prio3FixedPoint64BitBoundedL2VecSum(Arc::new(vdaf), verify_key) } + VdafInstance::Poplar1 { bits } => { + let vdaf = Poplar1::new_sha3(*bits); + let verify_key = task.primary_vdaf_verify_key()?; + VdafOps::Poplar1(Arc::new(vdaf), verify_key) + } + #[cfg(feature = "test-util")] VdafInstance::Fake => VdafOps::Fake(Arc::new(dummy_vdaf::Vdaf::new())), @@ -1004,32 +1017,27 @@ impl TaskAggregator { /// VdafOps stores VDAF-specific operations for a TaskAggregator in a non-generic way. #[allow(clippy::enum_variant_names)] enum VdafOps { - Prio3Count(Arc, VerifyKey), - Prio3CountVec( - Arc, - VerifyKey, - ), - Prio3Sum(Arc, VerifyKey), - Prio3SumVec( - Arc, - VerifyKey, - ), - Prio3Histogram(Arc, VerifyKey), + Prio3Count(Arc, VerifyKey), + Prio3CountVec(Arc, VerifyKey), + Prio3Sum(Arc, VerifyKey), + Prio3SumVec(Arc, VerifyKey), + Prio3Histogram(Arc, VerifyKey), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint32BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint64BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), + Poplar1(Arc>, VerifyKey), #[cfg(feature = "test-util")] Fake(Arc), @@ -1047,7 +1055,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1055,7 +1063,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1063,7 +1071,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1071,7 +1079,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1079,7 +1087,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1089,7 +1097,7 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1099,7 +1107,7 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1109,7 +1117,15 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; + $body + } + + crate::aggregator::VdafOps::Poplar1(vdaf, verify_key) => { + let $vdaf = vdaf; + let $verify_key = verify_key; + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -1301,17 +1317,6 @@ impl VdafOps { C: Clock, Q: UploadableQueryType, { - // The leader's report is the first one. - // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 - if report.encrypted_input_shares().len() != 2 { - return Err(Arc::new(Error::UnrecognizedMessage( - Some(*task.id()), - "unexpected number of encrypted shares in report", - ))); - } - let leader_encrypted_input_share = - &report.encrypted_input_shares()[Role::Leader.index().unwrap()]; - let report_deadline = clock .now() .add(task.tolerable_clock_skew()) @@ -1380,7 +1385,7 @@ impl VdafOps { hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, task.role()), - leader_encrypted_input_share, + report.leader_encrypted_input_share(), &InputShareAad::new( *task.id(), report.metadata().clone(), @@ -1391,11 +1396,11 @@ impl VdafOps { }; let global_hpke_keypair = - global_hpke_keypairs.keypair(leader_encrypted_input_share.config_id()); + global_hpke_keypairs.keypair(report.leader_encrypted_input_share().config_id()); let task_hpke_keypair = task .hpke_keys() - .get(leader_encrypted_input_share.config_id()); + .get(report.leader_encrypted_input_share().config_id()); let decryption_result = match (task_hpke_keypair, global_hpke_keypair) { // Verify that the report's HPKE config ID is known. @@ -1403,7 +1408,7 @@ impl VdafOps { (None, None) => { return Err(Arc::new(Error::OutdatedHpkeConfig( *task.id(), - *leader_encrypted_input_share.config_id(), + *report.leader_encrypted_input_share().config_id(), ))); } (None, Some(global_hpke_keypair)) => try_hpke_open(&global_hpke_keypair), @@ -1453,16 +1458,13 @@ impl VdafOps { } }; - let helper_encrypted_input_share = - &report.encrypted_input_shares()[Role::Helper.index().unwrap()]; - let report = LeaderStoredReport::new( *task.id(), report.metadata().clone(), public_share, Vec::from(leader_plaintext_input_share.extensions()), leader_input_share, - helper_encrypted_input_share.clone(), + report.helper_encrypted_input_share().clone(), ); report_writer @@ -2577,10 +2579,8 @@ impl VdafOps { ), *report_count, spanned_interval, - Vec::::from([ - encrypted_leader_aggregate_share, - encrypted_helper_aggregate_share.clone(), - ]), + encrypted_leader_aggregate_share, + encrypted_helper_aggregate_share.clone(), ) .get_encoded(), )) @@ -3066,7 +3066,7 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key_with_id, HpkeApplicationInfo, HpkeKeypair, Label, }, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, test_util::install_test_trace_subscriber, time::{Clock, MockClock, TimeExt}, }; @@ -3132,7 +3132,8 @@ mod tests { Report::new( report_metadata, public_share.get_encoded(), - Vec::from([leader_ciphertext, helper_ciphertext]), + leader_ciphertext, + helper_ciphertext, ) } @@ -3275,29 +3276,6 @@ mod tests { assert_eq!(want_report_ids, got_report_ids); } - #[tokio::test] - async fn upload_wrong_number_of_encrypted_shares() { - install_test_trace_subscriber(); - - let (_, aggregator, clock, task, _, _ephemeral_datastore) = - setup_upload_test(default_aggregator_config()).await; - let report = create_report(&task, clock.now()); - let report = Report::new( - report.metadata().clone(), - report.public_share().to_vec(), - Vec::from([report.encrypted_input_shares()[0].clone()]), - ); - - assert_matches!( - aggregator - .handle_upload(task.id(), &report.get_encoded()) - .await - .unwrap_err() - .as_ref(), - Error::UnrecognizedMessage(_, _) - ); - } - #[tokio::test] async fn upload_wrong_hpke_config_id() { install_test_trace_subscriber(); @@ -3314,16 +3292,15 @@ mod tests { let report = Report::new( report.metadata().clone(), report.public_share().to_vec(), - Vec::from([ - HpkeCiphertext::new( - unused_hpke_config_id, - report.encrypted_input_shares()[0] - .encapsulated_key() - .to_vec(), - report.encrypted_input_shares()[0].payload().to_vec(), - ), - report.encrypted_input_shares()[1].clone(), - ]), + HpkeCiphertext::new( + unused_hpke_config_id, + report + .leader_encrypted_input_share() + .encapsulated_key() + .to_vec(), + report.leader_encrypted_input_share().payload().to_vec(), + ), + report.helper_encrypted_input_share().clone(), ); assert_matches!(aggregator.handle_upload(task.id(), &report.get_encoded()).await.unwrap_err().as_ref(), Error::OutdatedHpkeConfig(task_id, config_id) => { @@ -3409,7 +3386,7 @@ mod tests { let task = task.clone(); Box::pin(async move { tx.put_collection_job(&CollectionJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index 7c5bda43e..b370c80f6 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -12,7 +12,7 @@ use janus_aggregator_core::{ task::{self, Task}, }; use janus_core::{ - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, time::{Clock, DurationExt as _, TimeExt as _}, }; use janus_messages::{ @@ -264,33 +264,33 @@ impl AggregationJobCreator { match (task.query_type(), task.vdaf()) { (task::QueryType::TimeInterval, VdafInstance::Prio3Count) => { let vdaf = Arc::new(Prio3::new_count(2)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3CountVec { length }) => { let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, 1, *length)?); self.create_aggregation_jobs_for_time_interval_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3SumVecMultithreaded >(task, vdaf).await } (task::QueryType::TimeInterval, VdafInstance::Prio3Sum { bits }) => { let vdaf = Arc::new(Prio3::new_sum(2, *bits)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3SumVec { bits, length }) => { let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, *bits, *length)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { length }) => { let vdaf = Arc::new(Prio3::new_histogram(2, *length)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -303,7 +303,7 @@ impl AggregationJobCreator { Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( 2, *length, )?); - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) .await } @@ -316,7 +316,7 @@ impl AggregationJobCreator { Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( 2, *length, )?); - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) .await } @@ -329,7 +329,7 @@ impl AggregationJobCreator { Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( 2, *length, )?); - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) .await } @@ -344,7 +344,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -360,7 +360,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3SumVecMultithreaded >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -376,7 +376,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Sum, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -392,7 +392,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3SumVecMultithreaded, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -408,7 +408,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Histogram, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -428,7 +428,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded>, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -448,7 +448,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded>, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -468,7 +468,7 @@ impl AggregationJobCreator { let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded>, >(task, vdaf, max_batch_size, batch_time_window_size).await } @@ -657,7 +657,7 @@ mod tests { test_util::noop_meter, }; use janus_core::{ - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, test_util::{dummy_vdaf, install_test_trace_subscriber}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; @@ -749,7 +749,7 @@ mod tests { Box::pin(async move { let (leader_aggregations, leader_batches) = read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, @@ -757,7 +757,7 @@ mod tests { .await; let (helper_aggregations, helper_batches) = read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, @@ -860,22 +860,23 @@ mod tests { .unwrap(); // Verify. - let (agg_jobs, batches) = job_creator - .datastore - .run_tx(|tx| { - let task = task.clone(); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - _, - >(tx, task.id()) - .await) + let (agg_jobs, batches) = + job_creator + .datastore + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { + Ok(read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + TimeInterval, + Prio3Count, + _, + >(tx, task.id()) + .await) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); let mut seen_report_ids = HashSet::new(); for (agg_job, report_ids) in &agg_jobs { // Jobs are created in round 0 @@ -957,22 +958,23 @@ mod tests { .unwrap(); // Verify -- we haven't received enough reports yet, so we don't create anything. - let (agg_jobs, batches) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - _, - >(tx, task.id()) - .await) + let (agg_jobs, batches) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + TimeInterval, + Prio3Count, + _, + >(tx, task.id()) + .await) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); assert!(agg_jobs.is_empty()); assert!(batches.is_empty()); @@ -996,22 +998,23 @@ mod tests { .unwrap(); // Verify -- the additional report we wrote allows an aggregation job to be created. - let (agg_jobs, batches) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - _, - >(tx, task.id()) - .await) + let (agg_jobs, batches) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + TimeInterval, + Prio3Count, + _, + >(tx, task.id()) + .await) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); assert_eq!(agg_jobs.len(), 1); let report_ids: HashSet<_> = agg_jobs.into_iter().next().unwrap().1.into_iter().collect(); assert_eq!( @@ -1074,16 +1077,14 @@ mod tests { tx.put_client_report(&dummy_vdaf::Vdaf::new(), report) .await?; } - tx.put_batch( - &Batch::::new( - *task.id(), - batch_identifier, - (), - BatchState::Closed, - 0, - Interval::from_time(&report_time).unwrap(), - ), - ) + tx.put_batch(&Batch::::new( + *task.id(), + batch_identifier, + (), + BatchState::Closed, + 0, + Interval::from_time(&report_time).unwrap(), + )) .await?; Ok(()) }) @@ -1106,22 +1107,23 @@ mod tests { .unwrap(); // Verify. - let (agg_jobs, batches) = job_creator - .datastore - .run_tx(|tx| { - let task = task.clone(); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - _, - >(tx, task.id()) - .await) + let (agg_jobs, batches) = + job_creator + .datastore + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { + Ok(read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + TimeInterval, + Prio3Count, + _, + >(tx, task.id()) + .await) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); let mut seen_report_ids = HashSet::new(); for (agg_job, report_ids) in &agg_jobs { // Job immediately finished since all reports are in a closed batch. @@ -1228,25 +1230,26 @@ mod tests { .unwrap(); // Verify. - let (outstanding_batches, (agg_jobs, batches)) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(( - tx.get_outstanding_batches(task.id(), &None).await?, - read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - _, - >(tx, task.id()) - .await, - )) + let (outstanding_batches, (agg_jobs, batches)) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(( + tx.get_outstanding_batches(task.id(), &None).await?, + read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + _, + >(tx, task.id()) + .await, + )) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); // Verify outstanding batches. let mut total_max_size = 0; @@ -1388,25 +1391,26 @@ mod tests { .unwrap(); // Verify. - let (outstanding_batches, (agg_jobs, batches)) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(( - tx.get_outstanding_batches(task.id(), &None).await?, - read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - _, - >(tx, task.id()) - .await, - )) + let (outstanding_batches, (agg_jobs, batches)) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(( + tx.get_outstanding_batches(task.id(), &None).await?, + read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + _, + >(tx, task.id()) + .await, + )) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); // Verify outstanding batches and aggregation jobs. assert_eq!(outstanding_batches.len(), 0); @@ -1503,25 +1507,26 @@ mod tests { .unwrap(); // Verify. - let (outstanding_batches, (agg_jobs, _batches)) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(( - tx.get_outstanding_batches(task.id(), &None).await?, - read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - _, - >(tx, task.id()) - .await, - )) + let (outstanding_batches, (agg_jobs, _batches)) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(( + tx.get_outstanding_batches(task.id(), &None).await?, + read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + _, + >(tx, task.id()) + .await, + )) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); // Verify sizes of batches and aggregation jobs. let mut outstanding_batch_sizes = outstanding_batches @@ -1559,25 +1564,26 @@ mod tests { .unwrap(); // Verify. - let (outstanding_batches, (agg_jobs, _batches)) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(( - tx.get_outstanding_batches(task.id(), &None).await?, - read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - _, - >(tx, task.id()) - .await, - )) + let (outstanding_batches, (agg_jobs, _batches)) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(( + tx.get_outstanding_batches(task.id(), &None).await?, + read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + _, + >(tx, task.id()) + .await, + )) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); let batch_ids: HashSet<_> = outstanding_batches .iter() .map(|outstanding_batch| *outstanding_batch.id()) @@ -1682,25 +1688,26 @@ mod tests { .unwrap(); // Verify. - let (outstanding_batches, (agg_jobs, _batches)) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(( - tx.get_outstanding_batches(task.id(), &None).await?, - read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - _, - >(tx, task.id()) - .await, - )) + let (outstanding_batches, (agg_jobs, _batches)) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(( + tx.get_outstanding_batches(task.id(), &None).await?, + read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + _, + >(tx, task.id()) + .await, + )) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); // Verify sizes of batches and aggregation jobs. let mut outstanding_batch_sizes = outstanding_batches @@ -1745,25 +1752,26 @@ mod tests { .unwrap(); // Verify. - let (outstanding_batches, (agg_jobs, _batches)) = job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(( - tx.get_outstanding_batches(task.id(), &None).await?, - read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - _, - >(tx, task.id()) - .await, - )) + let (outstanding_batches, (agg_jobs, _batches)) = + job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok(( + tx.get_outstanding_batches(task.id(), &None).await?, + read_aggregate_info_for_task::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + _, + >(tx, task.id()) + .await, + )) + }) }) - }) - .await - .unwrap(); + .await + .unwrap(); let batch_ids: HashSet<_> = outstanding_batches .iter() .map(|outstanding_batch| *outstanding_batch.id()) @@ -1892,7 +1900,7 @@ mod tests { tx.get_outstanding_batches(task.id(), &Some(time_bucket_start_2)) .await?, read_aggregate_info_for_task::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, FixedSize, Prio3Count, _, diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index ce8cf7d3f..8db01020a 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -890,7 +890,7 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, report_id::ReportIdChecksumExt, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, test_util::{install_test_trace_subscriber, run_vdaf, runtime::TestRuntimeManager}, time::{Clock, IntervalExt, MockClock, TimeExt}, Runtime, @@ -936,10 +936,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) .build(); let time = clock @@ -948,8 +945,7 @@ mod tests { .unwrap(); let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -961,7 +957,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -982,7 +978,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -996,34 +992,31 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - )) - .await?; - - tx.put_batch( - &Batch::::new( + tx.put_report_aggregation( + &ReportAggregation::::new( *task.id(), - batch_identifier, - (), - BatchState::Closing, - 1, - Interval::from_time(&time).unwrap(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, ), ) .await?; + tx.put_batch(&Batch::::new( + *task.id(), + batch_identifier, + (), + BatchState::Closing, + 1, + Interval::from_time(&time).unwrap(), + )) + .await?; + let collection_job = - CollectionJob::::new( + CollectionJob::::new( *task.id(), random(), batch_identifier, @@ -1128,7 +1121,7 @@ mod tests { } let want_aggregation_job = - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), @@ -1138,7 +1131,7 @@ mod tests { AggregationJobState::Finished, AggregationJobRound::from(2), ); - let want_report_aggregation = ReportAggregation::::new( + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -1147,7 +1140,7 @@ mod tests { None, ReportAggregationState::Finished, ); - let want_batch = Batch::::new( + let want_batch = Batch::::new( *task.id(), batch_identifier, (), @@ -1166,7 +1159,7 @@ mod tests { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1217,10 +1210,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) .build(); let time = clock @@ -1229,8 +1219,7 @@ mod tests { .unwrap(); let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1242,7 +1231,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1250,7 +1239,7 @@ mod tests { Vec::new(), transcript.input_shares.clone(), ); - let repeated_extension_report = generate_report::( + let repeated_extension_report = generate_report::( *task.id(), ReportMetadata::new(random(), time), helper_hpke_keypair.config(), @@ -1279,7 +1268,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -1293,58 +1282,53 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - )) - .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *repeated_extension_report.metadata().id(), - *repeated_extension_report.metadata().time(), - 1, - None, - ReportAggregationState::Start, - )) - .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - missing_report_id, - time, - 2, - None, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) .await?; - - tx.put_batch( - &Batch::::new( + tx.put_report_aggregation( + &ReportAggregation::::new( *task.id(), - batch_identifier, - (), - BatchState::Closing, + aggregation_job_id, + *repeated_extension_report.metadata().id(), + *repeated_extension_report.metadata().time(), 1, - Interval::from_time(&time).unwrap(), + None, + ReportAggregationState::Start, + ), + ) + .await?; + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + missing_report_id, + time, + 2, + None, + ReportAggregationState::Start, ), ) .await?; + tx.put_batch(&Batch::::new( + *task.id(), + batch_identifier, + (), + BatchState::Closing, + 1, + Interval::from_time(&time).unwrap(), + )) + .await?; + Ok(tx .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) .await? @@ -1435,7 +1419,7 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), @@ -1447,7 +1431,7 @@ mod tests { ); let leader_prep_state = transcript.leader_prep_state(0).clone(); let prep_msg = transcript.prepare_messages[0].clone(); - let want_report_aggregation = ReportAggregation::::new( + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -1457,7 +1441,7 @@ mod tests { ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), ); let want_repeated_extension_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *repeated_extension_report.metadata().id(), @@ -1467,7 +1451,7 @@ mod tests { ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), ); let want_missing_report_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, missing_report_id, @@ -1476,7 +1460,7 @@ mod tests { None, ReportAggregationState::Failed(ReportShareError::ReportDropped), ); - let want_batch = Batch::::new( + let want_batch = Batch::::new( *task.id(), batch_identifier, (), @@ -1501,7 +1485,7 @@ mod tests { ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1584,10 +1568,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) .build(); let report_metadata = ReportMetadata::new( @@ -1597,8 +1578,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1610,7 +1590,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1629,7 +1609,7 @@ mod tests { tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, FixedSize, Prio3Count, >::new( @@ -1643,32 +1623,29 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - )) - .await?; - - tx.put_batch( - &Batch::::new( + tx.put_report_aggregation( + &ReportAggregation::::new( *task.id(), - batch_id, - (), - BatchState::Open, - 1, - Interval::from_time(report.metadata().time()).unwrap(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, ), ) .await?; + tx.put_batch(&Batch::::new( + *task.id(), + batch_id, + (), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) + .await?; + Ok(tx .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) .await? @@ -1758,18 +1735,16 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -1781,7 +1756,7 @@ mod tests { Some(transcript.prepare_messages[0].clone()), ), ); - let want_batch = Batch::::new( + let want_batch = Batch::::new( *task.id(), batch_id, (), @@ -1796,7 +1771,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1840,10 +1815,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) .build(); let time = clock .now() @@ -1864,8 +1836,7 @@ mod tests { ) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1877,7 +1848,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1909,7 +1880,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -1923,45 +1894,40 @@ mod tests { AggregationJobRound::from(1), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), - )) - .await?; - - tx.put_batch( - &Batch::::new( + tx.put_report_aggregation( + &ReportAggregation::::new( *task.id(), - active_batch_identifier, - (), - BatchState::Closing, - 1, - Interval::from_time(report.metadata().time()).unwrap(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), ), ) .await?; - tx.put_batch( - &Batch::::new( - *task.id(), - other_batch_identifier, - (), - BatchState::Closing, - 1, - Interval::EMPTY, - ), - ) + + tx.put_batch(&Batch::::new( + *task.id(), + active_batch_identifier, + (), + BatchState::Closing, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) + .await?; + tx.put_batch(&Batch::::new( + *task.id(), + other_batch_identifier, + (), + BatchState::Closing, + 1, + Interval::EMPTY, + )) .await?; let collection_job = - CollectionJob::::new( + CollectionJob::::new( *task.id(), random(), collection_identifier, @@ -2056,7 +2022,7 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), @@ -2066,7 +2032,7 @@ mod tests { AggregationJobState::Finished, AggregationJobRound::from(2), ); - let want_report_aggregation = ReportAggregation::::new( + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -2081,7 +2047,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let want_batch_aggregations = Vec::from([BatchAggregation::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -2095,7 +2061,7 @@ mod tests { Interval::from_time(report.metadata().time()).unwrap(), ReportIdChecksum::for_report_id(report.metadata().id()), )]); - let want_active_batch = Batch::::new( + let want_active_batch = Batch::::new( *task.id(), active_batch_identifier, (), @@ -2103,7 +2069,7 @@ mod tests { 0, Interval::from_time(report.metadata().time()).unwrap(), ); - let want_other_batch = Batch::::new( + let want_other_batch = Batch::::new( *task.id(), other_batch_identifier, (), @@ -2128,7 +2094,7 @@ mod tests { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2146,7 +2112,7 @@ mod tests { .unwrap(); let batch_aggregations = TimeInterval::get_batch_aggregations_for_collection_identifier::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, _, >( @@ -2236,10 +2202,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) .build(); let report_metadata = ReportMetadata::new( random(), @@ -2248,8 +2211,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -2261,7 +2223,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2291,7 +2253,7 @@ mod tests { tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, FixedSize, Prio3Count, >::new( @@ -2305,34 +2267,31 @@ mod tests { AggregationJobRound::from(1), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), - )) - .await?; - - tx.put_batch( - &Batch::::new( + tx.put_report_aggregation( + &ReportAggregation::::new( *task.id(), - batch_id, - (), - BatchState::Closing, - 1, - Interval::from_time(report.metadata().time()).unwrap(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), ), ) .await?; + tx.put_batch(&Batch::::new( + *task.id(), + batch_id, + (), + BatchState::Closing, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) + .await?; + let collection_job = - CollectionJob::::new( + CollectionJob::::new( *task.id(), random(), batch_id, @@ -2426,18 +2385,16 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(2), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(2), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -2446,22 +2403,21 @@ mod tests { None, ReportAggregationState::Finished, ); - let want_batch_aggregations = Vec::from([BatchAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - >::new( - *task.id(), - batch_id, - (), - 0, - BatchAggregationState::Aggregating, - Some(leader_aggregate_share), - 1, - Interval::from_time(report.metadata().time()).unwrap(), - ReportIdChecksum::for_report_id(report.metadata().id()), - )]); - let want_batch = Batch::::new( + let want_batch_aggregations = + Vec::from([ + BatchAggregation::::new( + *task.id(), + batch_id, + (), + 0, + BatchAggregationState::Aggregating, + Some(leader_aggregate_share), + 1, + Interval::from_time(report.metadata().time()).unwrap(), + ReportIdChecksum::for_report_id(report.metadata().id()), + ), + ]); + let want_batch = Batch::::new( *task.id(), batch_id, (), @@ -2486,7 +2442,7 @@ mod tests { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2504,7 +2460,7 @@ mod tests { .unwrap(); let batch_aggregations = FixedSize::get_batch_aggregations_for_collection_identifier::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, _, >(tx, &task, &vdaf, &batch_id, &()) @@ -2572,8 +2528,7 @@ mod tests { .unwrap(); let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -2584,7 +2539,7 @@ mod tests { ); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2594,18 +2549,16 @@ mod tests { ); let aggregation_job_id = random(); - let aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - ); - let report_aggregation = ReportAggregation::::new( + let aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ); + let report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -2630,16 +2583,14 @@ mod tests { tx.put_aggregation_job(&aggregation_job).await?; tx.put_report_aggregation(&report_aggregation).await?; - tx.put_batch( - &Batch::::new( - *task.id(), - batch_identifier, - (), - BatchState::Open, - 1, - Interval::from_time(report.metadata().time()).unwrap(), - ), - ) + tx.put_batch(&Batch::::new( + *task.id(), + batch_identifier, + (), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) .await?; Ok(tx @@ -2669,7 +2620,7 @@ mod tests { // longer be acquired. let want_aggregation_job = aggregation_job.with_state(AggregationJobState::Abandoned); let want_report_aggregation = report_aggregation; - let want_batch = Batch::::new( + let want_batch = Batch::::new( *task.id(), batch_identifier, (), @@ -2684,7 +2635,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2779,15 +2730,11 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); @@ -2799,7 +2746,7 @@ mod tests { let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time); let transcript = run_vdaf(&vdaf, verify_key.as_bytes(), &(), report_metadata.id(), &0); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2821,7 +2768,7 @@ mod tests { tx.put_client_report(&vdaf, &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -2837,7 +2784,7 @@ mod tests { .await?; tx.put_report_aggregation( - &ReportAggregation::::new( + &ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -2849,16 +2796,14 @@ mod tests { ) .await?; - tx.put_batch( - &Batch::::new( - *task.id(), - batch_identifier, - (), - BatchState::Open, - 1, - Interval::from_time(report.metadata().time()).unwrap(), - ), - ) + tx.put_batch(&Batch::::new( + *task.id(), + batch_identifier, + (), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) .await?; Ok(()) @@ -2976,7 +2921,7 @@ mod tests { .unwrap(); assert_eq!( got_aggregation_job, - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), @@ -2989,7 +2934,7 @@ mod tests { ); assert_eq!( got_batch, - Batch::::new( + Batch::::new( *task.id(), batch_identifier, (), diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index 43611c828..5336502d2 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -557,7 +557,6 @@ mod tests { use rand::random; use std::{str, sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; - use url::Url; async fn setup_collection_job_test_case( server: &mut mockito::Server, @@ -571,10 +570,7 @@ mod tests { ) { let time_precision = Duration::from_seconds(500); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .with_time_precision(time_precision) .with_min_batch_size(10) .build(); @@ -712,10 +708,7 @@ mod tests { let time_precision = Duration::from_seconds(500); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .with_time_precision(time_precision) .with_min_batch_size(10) .build(); diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 81768fb4f..b167ccc8f 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -1,10 +1,4 @@ -use crate::aggregator::{ - http_handlers::{ - aggregator_handler, - test_util::{decode_response_body, take_problem_details}, - }, - Config, -}; +use crate::aggregator::{http_handlers::aggregator_handler, Config}; use http::StatusCode; use janus_aggregator_core::{ datastore::{ @@ -42,7 +36,6 @@ use serde_json::json; use std::sync::Arc; use trillium::{Handler, KnownHeaderName, Status}; use trillium_testing::{ - assert_headers, prelude::{post, put}, TestConn, }; @@ -353,10 +346,23 @@ async fn collection_job_success_fixed_size() { } let mut test_conn = test_case.post_collection_job(&collection_job_id).await; + assert_eq!(test_conn.status(), Some(Status::Ok)); - assert_headers!(&test_conn, "content-type" => (Collection::::MEDIA_TYPE)); + assert_eq!( + test_conn + .response_headers() + .get(KnownHeaderName::ContentType) + .unwrap(), + Collection::::MEDIA_TYPE + ); + let body_bytes = test_conn + .take_response_body() + .unwrap() + .into_bytes() + .await + .unwrap(); + let collect_resp = Collection::::get_decoded(body_bytes.as_ref()).unwrap(); - let collect_resp: Collection = decode_response_body(&mut test_conn).await; assert_eq!( collect_resp.report_count(), test_case.task.min_batch_size() + 1 @@ -367,13 +373,12 @@ async fn collection_job_success_fixed_size() { .align_to_time_precision(test_case.task.time_precision()) .unwrap(), ); - assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_config().unwrap(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[0], + collect_resp.leader_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_fixed_size(batch_id), @@ -391,7 +396,7 @@ async fn collection_job_success_fixed_size() { test_case.task.collector_hpke_config().unwrap(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[1], + collect_resp.helper_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_fixed_size(batch_id), @@ -417,7 +422,23 @@ async fn collection_job_success_fixed_size() { .await; assert_eq!(test_conn.status(), Some(Status::BadRequest)); assert_eq!( - take_problem_details(&mut test_conn).await, + test_conn + .response_headers() + .get(KnownHeaderName::ContentType) + .unwrap(), + "application/problem+json" + ); + let problem_details: serde_json::Value = serde_json::from_slice( + &test_conn + .take_response_body() + .unwrap() + .into_bytes() + .await + .unwrap(), + ) + .unwrap(); + assert_eq!( + problem_details, json!({ "status": StatusCode::BAD_REQUEST.as_u16(), "type": "urn:ietf:params:ppm:dap:error:batchInvalid", diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index a2e9d81fe..e89ccb2c0 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -686,7 +686,7 @@ mod tests { HpkeApplicationInfo, HpkeKeypair, Label, }, report_id::ReportIdChecksumExt, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LENGTH}, test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; @@ -976,7 +976,8 @@ mod tests { let task_id = *task.id(); let task = taskprov::Task::new( task_id, - task.aggregator_endpoints().to_vec(), + task.leader_aggregator_endpoint().clone(), + task.helper_aggregator_endpoint().clone(), *task.query_type(), task.vdaf().clone(), *task.role(), @@ -1169,7 +1170,8 @@ mod tests { .unwrap(), ), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) @@ -1185,26 +1187,6 @@ mod tests { ) .await; - // should reject a report with only one share with the unrecognizedMessage type. - let bad_report = Report::new( - report.metadata().clone(), - report.public_share().to_vec(), - Vec::from([report.encrypted_input_shares()[0].clone()]), - ); - let mut test_conn = put(task.report_upload_uri().unwrap().path()) - .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) - .with_request_body(bad_report.get_encoded()) - .run_async(&handler) - .await; - check_response( - &mut test_conn, - Status::BadRequest, - "unrecognizedMessage", - "The message type for a response was incorrect or the payload was malformed.", - task.id(), - ) - .await; - // should reject a report using the wrong HPKE config for the leader, and reply with // the error type outdatedConfig. let unused_hpke_config_id = (0..) @@ -1214,16 +1196,15 @@ mod tests { let bad_report = Report::new( report.metadata().clone(), report.public_share().to_vec(), - Vec::from([ - HpkeCiphertext::new( - unused_hpke_config_id, - report.encrypted_input_shares()[0] - .encapsulated_key() - .to_vec(), - report.encrypted_input_shares()[0].payload().to_vec(), - ), - report.encrypted_input_shares()[1].clone(), - ]), + HpkeCiphertext::new( + unused_hpke_config_id, + report + .leader_encrypted_input_share() + .encapsulated_key() + .to_vec(), + report.leader_encrypted_input_share().payload().to_vec(), + ), + report.helper_encrypted_input_share().clone(), ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) @@ -1249,7 +1230,8 @@ mod tests { let bad_report = Report::new( ReportMetadata::new(*report.metadata().id(), bad_report_time), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) @@ -1326,7 +1308,8 @@ mod tests { .unwrap(), ), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ) .get_encoded(), ) @@ -1976,6 +1959,7 @@ mod tests { let mut saw_conflicting_aggregation_job = false; let mut saw_non_conflicting_aggregation_job = false; let mut saw_new_aggregation_job = false; + for aggregation_job in aggregation_jobs { if aggregation_job.eq(&conflicting_aggregation_job) { saw_conflicting_aggregation_job = true; @@ -2473,8 +2457,7 @@ mod tests { .build(); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); // report_share_0 is a "happy path" report. @@ -2585,7 +2568,7 @@ mod tests { tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -2600,7 +2583,7 @@ mod tests { )) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2612,7 +2595,7 @@ mod tests { ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2624,7 +2607,7 @@ mod tests { ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2637,7 +2620,7 @@ mod tests { ) .await?; - tx.put_aggregate_share_job::( + tx.put_aggregate_share_job::( &AggregateShareJob::new( *task.id(), Interval::new( @@ -2688,31 +2671,32 @@ mod tests { ); // Validate datastore. - let (aggregation_job, report_aggregations) = - datastore - .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); - Box::pin(async move { - let aggregation_job = tx - .get_aggregation_job::( + let (aggregation_job, report_aggregations) = datastore + .run_tx(|tx| { + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::( task.id(), &aggregation_job_id, ) - .await.unwrap().unwrap(); - let report_aggregations = tx - .get_report_aggregations_for_aggregation_job( - vdaf.as_ref(), - &Role::Helper, - task.id(), - &aggregation_job_id, - ) - .await - .unwrap(); - Ok((aggregation_job, report_aggregations)) - }) + .await + .unwrap() + .unwrap(); + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Helper, + task.id(), + &aggregation_job_id, + ) + .await + .unwrap(); + Ok((aggregation_job, report_aggregations)) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert_eq!( aggregation_job, @@ -2789,8 +2773,7 @@ mod tests { ); let vdaf = Prio3::new_count(2).unwrap(); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); // report_share_0 is a "happy path" report. @@ -2892,7 +2875,7 @@ mod tests { ) .unwrap(); let second_batch_want_batch_aggregations = - empty_batch_aggregations::( + empty_batch_aggregations::( &task, BATCH_AGGREGATION_SHARD_COUNT, &second_batch_identifier, @@ -2929,7 +2912,7 @@ mod tests { tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -2944,57 +2927,52 @@ mod tests { )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_0.id(), - *report_metadata_0.time(), - 0, - None, - ReportAggregationState::Waiting(prep_state_0, None), - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_0.id(), + *report_metadata_0.time(), + 0, + None, + ReportAggregationState::Waiting(prep_state_0, None), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_1.id(), - *report_metadata_1.time(), - 1, - None, - ReportAggregationState::Waiting(prep_state_1, None), - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_1.id(), + *report_metadata_1.time(), + 1, + None, + ReportAggregationState::Waiting(prep_state_1, None), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_2.id(), - *report_metadata_2.time(), - 2, - None, - ReportAggregationState::Waiting(prep_state_2, None), - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_2.id(), + *report_metadata_2.time(), + 2, + None, + ReportAggregationState::Waiting(prep_state_2, None), + ), + ) .await?; for batch_identifier in [first_batch_identifier, second_batch_identifier] { - tx.put_batch( - &Batch::::new( - *task.id(), - batch_identifier, - (), - BatchState::Closed, - 0, - batch_identifier, - ), - ) + tx.put_batch(&Batch::::new( + *task.id(), + batch_identifier, + (), + BatchState::Closed, + 0, + batch_identifier, + )) .await .unwrap() } @@ -3042,7 +3020,7 @@ mod tests { (task.clone(), vdaf.clone(), report_metadata_0.clone()); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, _, >( @@ -3066,7 +3044,7 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::::new( *agg.task_id(), *agg.batch_identifier(), (), @@ -3114,7 +3092,7 @@ mod tests { (task.clone(), vdaf.clone(), report_metadata_2.clone()); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, _, >( @@ -3249,7 +3227,7 @@ mod tests { tx.put_report_share(task.id(), &report_share_5).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -3264,44 +3242,41 @@ mod tests { )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_3.id(), - *report_metadata_3.time(), - 3, - None, - ReportAggregationState::Waiting(prep_state_3, None), - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_3.id(), + *report_metadata_3.time(), + 3, + None, + ReportAggregationState::Waiting(prep_state_3, None), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_4.id(), - *report_metadata_4.time(), - 4, - None, - ReportAggregationState::Waiting(prep_state_4, None), - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_4.id(), + *report_metadata_4.time(), + 4, + None, + ReportAggregationState::Waiting(prep_state_4, None), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_5.id(), - *report_metadata_5.time(), - 5, - None, - ReportAggregationState::Waiting(prep_state_5, None), - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_5.id(), + *report_metadata_5.time(), + 5, + None, + ReportAggregationState::Waiting(prep_state_5, None), + ), + ) .await?; Ok(()) @@ -3340,7 +3315,7 @@ mod tests { (task.clone(), vdaf.clone(), report_metadata_0.clone()); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, _, >( @@ -3364,7 +3339,7 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::::new( *agg.task_id(), *agg.batch_identifier(), (), @@ -3417,7 +3392,7 @@ mod tests { (task.clone(), vdaf.clone(), report_metadata_2.clone()); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, Prio3Count, _, >( @@ -4412,13 +4387,12 @@ mod tests { assert_eq!(collect_resp.report_count(), 12); assert_eq!(collect_resp.interval(), &batch_interval); - assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_config().unwrap(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[0], + collect_resp.leader_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_time_interval(batch_interval), @@ -4436,7 +4410,7 @@ mod tests { test_case.task.collector_hpke_config().unwrap(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[1], + collect_resp.helper_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_time_interval(batch_interval), diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 000d867f3..3db8c8b59 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -30,7 +30,7 @@ use janus_core::{ HpkeKeypair, Label, }, report_id::ReportIdChecksumExt, - task::PRIO3_VERIFY_KEY_LENGTH, + task::VERIFY_KEY_LENGTH, taskprov::TASKPROV_HEADER, test_util::{install_test_trace_subscriber, run_vdaf, VdafTranscript}, time::{Clock, DurationExt, MockClock, TimeExt}, @@ -159,10 +159,8 @@ async fn setup_taskprov_test() -> TaskprovTestCase { let task = janus_aggregator_core::taskprov::Task::new( task_id, - Vec::from([ - url::Url::parse("https://leader.example.com/").unwrap(), - url::Url::parse("https://helper.example.com/").unwrap(), - ]), + url::Url::parse("https://leader.example.com/").unwrap(), + url::Url::parse("https://helper.example.com/").unwrap(), QueryType::FixedSize { max_batch_size: max_batch_size as u64, batch_time_window_size: None, @@ -696,36 +694,32 @@ async fn taskprov_aggregate_continue() { tx.put_report_share(task.id(), &report_share).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - TestVdaf, - >::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) - .await?; - - tx.put_report_aggregation::( - &ReportAggregation::new( + tx.put_aggregation_job( + &AggregationJob::::new( *task.id(), aggregation_job_id, - *report_metadata.id(), - *report_metadata.time(), - 0, - None, - ReportAggregationState::Waiting(prep_state, None), + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), ), ) .await?; - tx.put_aggregate_share_job::( + tx.put_report_aggregation::(&ReportAggregation::new( + *task.id(), + aggregation_job_id, + *report_metadata.id(), + *report_metadata.time(), + 0, + None, + ReportAggregationState::Waiting(prep_state, None), + )) + .await?; + + tx.put_aggregate_share_job::( &AggregateShareJob::new( *task.id(), batch_id, diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index 0c3780a43..821b05821 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -717,9 +717,8 @@ mod tests { // YAML contains no task ID, VDAF verify keys, aggregator auth tokens, collector auth tokens // or HPKE keys. let serialized_task_yaml = r#" -- aggregator_endpoints: - - https://leader - - https://helper +- leader_aggregator_endpoint: https://leader + helper_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 @@ -740,9 +739,8 @@ mod tests { aggregator_auth_tokens: [] collector_auth_tokens: [] hpke_keys: [] -- aggregator_endpoints: - - https://leader - - https://helper +- leader_aggregator_endpoint: https://leader + helper_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 diff --git a/aggregator_api/src/models.rs b/aggregator_api/src/models.rs index 84e103c0b..b8721f498 100644 --- a/aggregator_api/src/models.rs +++ b/aggregator_api/src/models.rs @@ -135,11 +135,11 @@ impl TryFrom<&Task> for TaskResp { // https://github.com/divviup/janus/issues/1524 // Return the aggregator endpoint URL for the role opposite our own - let peer_aggregator_endpoint = task.aggregator_endpoints()[match task.role() { - Role::Leader => 1, - Role::Helper => 0, + let peer_aggregator_endpoint = match task.role() { + Role::Leader => task.helper_aggregator_endpoint(), + Role::Helper => task.leader_aggregator_endpoint(), _ => return Err("illegal aggregator role in task"), - }] + } .clone(); if task.vdaf_verify_keys().len() != 1 { diff --git a/aggregator_api/src/routes.rs b/aggregator_api/src/routes.rs index 4ccf326c4..8a2e5db4a 100644 --- a/aggregator_api/src/routes.rs +++ b/aggregator_api/src/routes.rs @@ -93,9 +93,9 @@ pub(super) async fn post_task( // TODO(#1524): clean this up with `aggregator_core::task::Task` changes // unwrap safety: this fake URL is valid let fake_aggregator_url = Url::parse("http://never-used.example.com").unwrap(); - let aggregator_endpoints = match req.role { - Role::Leader => Vec::from([fake_aggregator_url, req.peer_aggregator_endpoint]), - Role::Helper => Vec::from([req.peer_aggregator_endpoint, fake_aggregator_url]), + let (leader_aggregator_endpoint, helper_aggregator_endpoint) = match req.role { + Role::Leader => (fake_aggregator_url, req.peer_aggregator_endpoint), + Role::Helper => (req.peer_aggregator_endpoint, fake_aggregator_url), _ => unreachable!(), }; @@ -155,7 +155,8 @@ pub(super) async fn post_task( let task = Arc::new( Task::new( task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, /* query_type */ req.query_type, /* vdaf */ req.vdaf, /* role */ req.role, @@ -182,7 +183,8 @@ pub(super) async fn post_task( if let Some(existing_task) = tx.get_task(task.id()).await? { // Check whether the existing task in the DB corresponds to the incoming task, ignoring // those fields that are randomly generated. - if existing_task.aggregator_endpoints() == task.aggregator_endpoints() + if existing_task.leader_aggregator_endpoint() == task.leader_aggregator_endpoint() + && existing_task.helper_aggregator_endpoint() == task.helper_aggregator_endpoint() && existing_task.query_type() == task.query_type() && existing_task.vdaf() == task.vdaf() && existing_task.vdaf_verify_keys() == task.vdaf_verify_keys() diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index 073b9b2d5..422172771 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -321,8 +321,8 @@ async fn post_task_helper_no_optional_fields() { // Verify that the task written to the datastore matches the request... assert_eq!( // The other aggregator endpoint in the datastore task is fake - req.peer_aggregator_endpoint, - got_task.aggregator_endpoints()[0] + &req.peer_aggregator_endpoint, + got_task.leader_aggregator_endpoint() ); assert_eq!(&req.query_type, got_task.query_type()); assert_eq!(&req.vdaf, got_task.vdaf()); @@ -522,8 +522,8 @@ async fn post_task_leader_all_optional_fields() { // Verify that the task written to the datastore matches the request... assert_eq!( // The other aggregator endpoint in the datastore task is fake - req.peer_aggregator_endpoint, - got_task.aggregator_endpoints()[1] + &req.peer_aggregator_endpoint, + got_task.helper_aggregator_endpoint() ); assert_eq!(&req.query_type, got_task.query_type()); assert_eq!(&req.vdaf, got_task.vdaf()); @@ -1731,10 +1731,8 @@ fn post_task_req_serialization() { fn task_resp_serialization() { let task = Task::new( TaskId::from([0u8; 32]), - Vec::from([ - "https://leader.com/".parse().unwrap(), - "https://helper.com/".parse().unwrap(), - ]), + "https://leader.com/".parse().unwrap(), + "https://helper.com/".parse().unwrap(), QueryType::FixedSize { max_batch_size: 999, batch_time_window_size: None, diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 434da2141..387001a68 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -526,20 +526,15 @@ impl Transaction<'_, C> { /// Writes a task into the datastore. #[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)] pub async fn put_task(&self, task: &Task) -> Result<(), Error> { - let endpoints: Vec<_> = task - .aggregator_endpoints() - .iter() - .map(Url::as_str) - .collect(); - // Main task insert. let stmt = self .prepare_cached( "INSERT INTO tasks ( - task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + task_id, aggregator_role, leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, + task_expiration, report_expiry_age, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT DO NOTHING", ) .await?; @@ -549,7 +544,10 @@ impl Transaction<'_, C> { &[ /* task_id */ &task.id().as_ref(), /* aggregator_role */ &AggregatorRole::from_role(*task.role())?, - /* aggregator_endpoints */ &endpoints, + /* leader_aggregator_endpoint */ + &task.leader_aggregator_endpoint().as_str(), + /* helper_aggregator_endpoint */ + &task.helper_aggregator_endpoint().as_str(), /* query_type */ &Json(task.query_type()), /* vdaf */ &Json(task.vdaf()), /* max_batch_query_count */ @@ -746,9 +744,9 @@ impl Transaction<'_, C> { let params: &[&(dyn ToSql + Sync)] = &[&task_id.as_ref()]; let stmt = self .prepare_cached( - "SELECT aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config + "SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, + query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, + min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config FROM tasks WHERE task_id = $1", ) .await?; @@ -818,9 +816,10 @@ impl Transaction<'_, C> { pub async fn get_tasks(&self) -> Result, Error> { let stmt = self .prepare_cached( - "SELECT task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config + "SELECT task_id, aggregator_role, leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, + task_expiration, report_expiry_age, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config FROM tasks", ) .await?; @@ -955,11 +954,10 @@ impl Transaction<'_, C> { ) -> Result { // Scalar task parameters. let aggregator_role: AggregatorRole = row.get("aggregator_role"); - let endpoints = row - .get::<_, Vec>("aggregator_endpoints") - .into_iter() - .map(|endpoint| Ok(Url::parse(&endpoint)?)) - .collect::>()?; + let leader_aggregator_endpoint = + row.get::<_, String>("leader_aggregator_endpoint").parse()?; + let helper_aggregator_endpoint = + row.get::<_, String>("helper_aggregator_endpoint").parse()?; let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?; @@ -1056,7 +1054,8 @@ impl Transaction<'_, C> { let task = Task::new_without_validation( *task_id, - endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, aggregator_role.as_role(), diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 4f5b3a348..37e9775c9 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -28,7 +28,7 @@ use janus_core::{ hpke::{ self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, test_util::{ dummy_vdaf::{self, AggregateShare, AggregationParam}, install_test_trace_subscriber, run_vdaf, @@ -1420,7 +1420,7 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore let task_id = *task.id(); async move { tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -1446,36 +1446,33 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore .await?; // Write an aggregation job that is finished. We don't want to retrieve this one. - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - random(), - (), - (), - Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)).unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(1), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + random(), + (), + (), + Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ), + ) .await?; // Write an expired aggregation job. We don't want to retrieve this one, either. - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; // Write an aggregation job for a task that we are taking on the helper role for. @@ -1487,20 +1484,18 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore ) .build(); tx.put_task(&helper_task).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *helper_task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *helper_task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await }) }) @@ -1739,7 +1734,7 @@ async fn aggregation_job_not_found(ephemeral_datastore: EphemeralDatastore) { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( &random(), &random(), ) @@ -1753,7 +1748,7 @@ async fn aggregation_job_not_found(ephemeral_datastore: EphemeralDatastore) { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.update_aggregation_job::( + tx.update_aggregation_job::( &AggregationJob::new( random(), random(), @@ -1884,12 +1879,12 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let report_id = random(); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); + let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); let leader_prep_state = vdaf_transcript.leader_prep_state(0); for (ord, state) in [ - ReportAggregationState::::Start, + ReportAggregationState::::Start, ReportAggregationState::Waiting( leader_prep_state.clone(), Some(vdaf_transcript.prepare_messages[0].clone()), @@ -1920,7 +1915,7 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { Box::pin(async move { tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, >::new( @@ -2259,7 +2254,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let report_id = random(); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); + let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); let task = TaskBuilder::new( @@ -2280,25 +2275,23 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme ); Box::pin(async move { tx.put_task(&task).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; let mut want_report_aggregations = Vec::new(); for (ord, state) in [ - ReportAggregationState::::Start, + ReportAggregationState::::Start, ReportAggregationState::Waiting(prep_state.clone(), Some(prep_msg)), ReportAggregationState::Finished, ReportAggregationState::Failed(ReportShareError::VdafPrepError), diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 41f5c6d06..467868750 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -14,11 +14,7 @@ use janus_messages::{ }; use rand::{distributions::Standard, random, thread_rng, Rng}; use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; -use std::{ - array::TryFromSliceError, - collections::HashMap, - fmt::{self, Formatter}, -}; +use std::{array::TryFromSliceError, collections::HashMap}; use url::Url; /// Errors that methods and functions in this module may return. @@ -101,10 +97,12 @@ impl TryFrom<&SecretBytes> for VerifyKey { pub struct Task { /// Unique identifier for the task. task_id: TaskId, - /// URLs relative to which aggregator API endpoints are found. The first - /// entry is the leader's. - #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - aggregator_endpoints: Vec, + /// URL relative to which the Leader's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + leader_aggregator_endpoint: Url, + /// URL relative to which the Helper's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + helper_aggregator_endpoint: Url, /// The query type this task uses to generate batches. query_type: QueryType, /// The VDAF this task executes. @@ -145,7 +143,8 @@ impl Task { #[allow(clippy::too_many_arguments)] pub fn new>( task_id: TaskId, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -163,7 +162,8 @@ impl Task { ) -> Result { let task = Self::new_without_validation( task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, role, @@ -188,7 +188,8 @@ impl Task { #[allow(clippy::too_many_arguments)] pub(crate) fn new_without_validation>( task_id: TaskId, - mut aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -204,13 +205,6 @@ impl Task { collector_auth_tokens: Vec, hpke_keys: I, ) -> Self { - // Ensure provided aggregator endpoints end with a slash, as we will be joining additional - // path segments into these endpoints & the Url::join implementation is persnickety about - // the slash at the end of the path. - for url in &mut aggregator_endpoints { - url_ensure_trailing_slash(url); - } - // Compute hpke_configs mapping cfg.id -> (cfg, key). let hpke_keys: HashMap = hpke_keys .into_iter() @@ -219,7 +213,11 @@ impl Task { Self { task_id, - aggregator_endpoints, + // Ensure provided aggregator endpoints end with a slash, as we will be joining + // additional path segments into these endpoints & the Url::join implementation is + // persnickety about the slash at the end of the path. + leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint), + helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint), query_type, vdaf, role, @@ -239,10 +237,6 @@ impl Task { /// Validates using criteria common to all tasks regardless of their provenance. pub(crate) fn validate_common(&self) -> Result<(), Error> { - // DAP currently only supports configurations of exactly two aggregators. - if self.aggregator_endpoints.len() != 2 { - return Err(Error::InvalidParameter("aggregator_endpoints")); - } if !self.role.is_aggregator() { return Err(Error::InvalidParameter("role")); } @@ -302,9 +296,14 @@ impl Task { &self.task_id } - /// Retrieves the aggregator endpoints associated with this task in natural order. - pub fn aggregator_endpoints(&self) -> &[Url] { - &self.aggregator_endpoints + /// Retrieves the Leader's aggregator endpoint associated with this task. + pub fn leader_aggregator_endpoint(&self) -> &Url { + &self.leader_aggregator_endpoint + } + + /// Retrieves the Helper's aggregator endpoint associated with this task. + pub fn helper_aggregator_endpoint(&self) -> &Url { + &self.helper_aggregator_endpoint } /// Retrieves the query type associated with this task. @@ -402,12 +401,6 @@ impl Task { } } - /// Returns the [`Url`] relative to which the server performing `role` serves its API. - pub fn aggregator_url(&self, role: &Role) -> Result<&Url, Error> { - let index = role.index().ok_or(Error::InvalidParameter(role.as_str()))?; - Ok(&self.aggregator_endpoints[index]) - } - /// Returns the [`AuthenticationToken`] currently used by this aggregator to authenticate itself /// to other aggregators. pub fn primary_aggregator_auth_token(&self) -> &AuthenticationToken { @@ -463,14 +456,14 @@ impl Task { /// Returns the URI at which reports may be uploaded for this task. pub fn report_upload_uri(&self) -> Result { Ok(self - .aggregator_url(&Role::Leader)? + .leader_aggregator_endpoint() .join(&format!("{}/reports", self.tasks_path()))?) } /// Returns the URI at which the helper resource for the specified aggregation job ID can be /// accessed. pub fn aggregation_job_uri(&self, aggregation_job_id: &AggregationJobId) -> Result { - Ok(self.aggregator_url(&Role::Helper)?.join(&format!( + Ok(self.helper_aggregator_endpoint().join(&format!( "{}/aggregation_jobs/{aggregation_job_id}", self.tasks_path() ))?) @@ -479,34 +472,27 @@ impl Task { /// Returns the URI at which the helper aggregate shares resource can be accessed. pub fn aggregate_shares_uri(&self) -> Result { Ok(self - .aggregator_url(&Role::Helper)? + .helper_aggregator_endpoint() .join(&format!("{}/aggregate_shares", self.tasks_path()))?) } /// Returns the URI at which the leader resource for the specified collection job ID can be /// accessed. pub fn collection_job_uri(&self, collection_job_id: &CollectionJobId) -> Result { - Ok(self.aggregator_url(&Role::Leader)?.join(&format!( + Ok(self.leader_aggregator_endpoint().join(&format!( "{}/collection_jobs/{collection_job_id}", self.tasks_path() ))?) } } -fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for url in urls { - list.entry(&format!("{url}")); - } - list.finish() -} - /// SerializedTask is an intermediate representation for tasks being serialized via the Serialize & /// Deserialize traits. #[derive(Clone, Serialize, Deserialize)] pub struct SerializedTask { task_id: Option, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -587,7 +573,8 @@ impl Serialize for Task { SerializedTask { task_id: Some(self.task_id), - aggregator_endpoints: self.aggregator_endpoints.clone(), + leader_aggregator_endpoint: self.leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint: self.helper_aggregator_endpoint.clone(), query_type: self.query_type, vdaf: self.vdaf.clone(), role: self.role, @@ -628,7 +615,8 @@ impl TryFrom for Task { Task::new( task_id, - serialized_task.aggregator_endpoints, + serialized_task.leader_aggregator_endpoint, + serialized_task.helper_aggregator_endpoint, serialized_task.query_type, serialized_task.vdaf, serialized_task.role, @@ -655,7 +643,6 @@ impl<'de> Deserialize<'de> for Task { } } -// This is public to allow use in integration tests. #[cfg(feature = "test-util")] #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] pub mod test_util { @@ -665,7 +652,7 @@ pub mod test_util { }; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair}, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LENGTH}, time::DurationExt, }; use janus_messages::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId, Time}; @@ -681,7 +668,7 @@ pub mod test_util { // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, + _ => VERIFY_KEY_LENGTH, } } @@ -725,10 +712,8 @@ pub mod test_util { Self( Task::new( task_id, - Vec::from([ - "https://leader.endpoint".parse().unwrap(), - "https://helper.endpoint".parse().unwrap(), - ]), + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), query_type, vdaf, role, @@ -748,22 +733,35 @@ pub mod test_util { ) } + /// Gets the leader aggregator endpoint for the eventual task. + pub fn leader_aggregator_endpoint(&self) -> &Url { + self.0.leader_aggregator_endpoint() + } + + /// Gets the helper aggregator endpoint for the eventual task. + pub fn helper_aggregator_endpoint(&self) -> &Url { + self.0.helper_aggregator_endpoint() + } + /// Associates the eventual task with the given task ID. pub fn with_id(self, task_id: TaskId) -> Self { Self(Task { task_id, ..self.0 }) } - /// Associates the eventual task with the given aggregator endpoints. - pub fn with_aggregator_endpoints(self, aggregator_endpoints: Vec) -> Self { + /// Associates the eventual task with the given aggregator endpoint for the Leader. + pub fn with_leader_aggregator_endpoint(self, leader_aggregator_endpoint: Url) -> Self { Self(Task { - aggregator_endpoints, + leader_aggregator_endpoint, ..self.0 }) } - /// Retrieves the aggregator endpoints associated with this task builder. - pub fn aggregator_endpoints(&self) -> &[Url] { - self.0.aggregator_endpoints() + /// Associates the eventual task with the given aggregator endpoint for the Helper. + pub fn with_helper_aggregator_endpoint(self, helper_aggregator_endpoint: Url) -> Self { + Self(Task { + helper_aggregator_endpoint, + ..self.0 + }) } /// Associates the eventual task with the given aggregator role. @@ -884,7 +882,7 @@ mod tests { use assert_matches::assert_matches; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VERIFY_KEY_LENGTH}, test_util::roundtrip_encoding, time::DurationExt, }; @@ -895,7 +893,6 @@ mod tests { use rand::random; use serde_json::json; use serde_test::{assert_de_tokens, assert_tokens, Token}; - use url::Url; #[test] fn task_serialization() { @@ -919,14 +916,12 @@ mod tests { // As leader, we receive an error if no collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LENGTH].into())]), 0, None, None, @@ -943,14 +938,12 @@ mod tests { // As leader, we receive no error if a collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LENGTH].into())]), 0, None, None, @@ -967,14 +960,12 @@ mod tests { // As helper, we receive no error if no collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LENGTH].into())]), 0, None, None, @@ -991,14 +982,12 @@ mod tests { // As helper, we receive an error if a collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LENGTH].into())]), 0, None, None, @@ -1017,14 +1006,12 @@ mod tests { fn aggregator_endpoints_end_in_slash() { let task = Task::new( random(), - Vec::from([ - "http://leader_endpoint/foo/bar".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint/foo/bar".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LENGTH].into())]), 0, None, None, @@ -1039,11 +1026,12 @@ mod tests { .unwrap(); assert_eq!( - task.aggregator_endpoints, - Vec::from([ - "http://leader_endpoint/foo/bar/".parse().unwrap(), - "http://helper_endpoint/".parse().unwrap() - ]) + task.leader_aggregator_endpoint, + "http://leader_endpoint/foo/bar/".parse().unwrap(), + ); + assert_eq!( + task.helper_aggregator_endpoint, + "http://helper_endpoint/".parse().unwrap(), ); } @@ -1066,10 +1054,8 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("https://leader.com/prefix/").unwrap(), - Url::parse("https://helper.com/prefix/").unwrap(), - ])) + .with_leader_aggregator_endpoint("https://leader.com/prefix/".parse().unwrap()) + .with_helper_aggregator_endpoint("https://helper.com/prefix/".parse().unwrap()) .build(), ), ] { @@ -1097,10 +1083,8 @@ mod tests { assert_tokens( &Task::new( TaskId::from([0; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, @@ -1141,16 +1125,15 @@ mod tests { &[ Token::Struct { name: "SerializedTask", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Some, Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::UnitVariant { name: "QueryType", @@ -1287,10 +1270,8 @@ mod tests { assert_tokens( &Task::new( TaskId::from([255; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, @@ -1331,16 +1312,15 @@ mod tests { &[ Token::Struct { name: "SerializedTask", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Some, Token::Str("__________________________________________8"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::StructVariant { name: "QueryType", diff --git a/aggregator_core/src/taskprov.rs b/aggregator_core/src/taskprov.rs index fb8facfe6..edf745c68 100644 --- a/aggregator_core/src/taskprov.rs +++ b/aggregator_core/src/taskprov.rs @@ -272,7 +272,8 @@ impl Task { #[allow(clippy::too_many_arguments)] pub fn new( task_id: TaskId, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -286,7 +287,8 @@ impl Task { ) -> Result { let task = Self(task::Task::new_without_validation( task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, role, diff --git a/client/Cargo.toml b/client/Cargo.toml index 977d39680..f889e9824 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,11 +14,12 @@ backoff = { version = "0.4.0", features = ["tokio"] } derivative = "2.2.0" http = "0.2.9" http-api-problem = "0.57.0" +itertools.workspace = true janus_core.workspace = true janus_messages.workspace = true prio.workspace = true rand = "0.8" -reqwest = { version = "0.11.19", default-features = false, features = ["rustls-tls", "json"] } +reqwest = { version = "0.11.18", default-features = false, features = ["rustls-tls", "json"] } thiserror.workspace = true tokio.workspace = true tracing = "0.1.37" diff --git a/client/src/lib.rs b/client/src/lib.rs index 20d820c2e..eb508a5be 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -4,6 +4,7 @@ use backoff::ExponentialBackoff; use derivative::Derivative; use http::header::CONTENT_TYPE; use http_api_problem::HttpApiProblem; +use itertools::Itertools; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, http::response_to_problem_details, @@ -12,18 +13,15 @@ use janus_core::{ time::{Clock, TimeExt}, }; use janus_messages::{ - Duration, HpkeCiphertext, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, - Report, ReportId, ReportMetadata, Role, TaskId, + Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, + ReportMetadata, Role, TaskId, }; use prio::{ codec::{Decode, Encode}, vdaf, }; use rand::random; -use std::{ - fmt::{self, Formatter}, - io::Cursor, -}; +use std::io::Cursor; use url::Url; #[derive(Debug, thiserror::Error)] @@ -60,10 +58,12 @@ static CLIENT_USER_AGENT: &str = concat!( pub struct ClientParameters { /// Unique identifier for the task. task_id: TaskId, - /// URLs relative to which aggregator API endpoints are found. The first - /// entry is the leader's. - #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - aggregator_endpoints: Vec, + /// URL relative to which the Leader's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + leader_aggregator_endpoint: Url, + /// URL relative to which the Helper's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + helper_aggregator_endpoint: Url, /// The time precision of the task. This value is shared by all parties in the protocol, and is /// used to compute report timestamps. time_precision: Duration, @@ -73,10 +73,16 @@ pub struct ClientParameters { impl ClientParameters { /// Creates a new set of client task parameters. - pub fn new(task_id: TaskId, aggregator_endpoints: Vec, time_precision: Duration) -> Self { + pub fn new( + task_id: TaskId, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, + time_precision: Duration, + ) -> Self { Self::new_with_backoff( task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, time_precision, http_request_exponential_backoff(), ) @@ -85,35 +91,32 @@ impl ClientParameters { /// Creates a new set of client task parameters with non-default HTTP request retry parameters. pub fn new_with_backoff( task_id: TaskId, - mut aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, time_precision: Duration, http_request_retry_parameters: ExponentialBackoff, ) -> Self { - // Ensure provided aggregator endpoints end with a slash, as we will be joining additional - // path segments into these endpoints & the Url::join implementation is persnickety about - // the slash at the end of the path. - for url in &mut aggregator_endpoints { - url_ensure_trailing_slash(url); - } - Self { task_id, - aggregator_endpoints, + leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint), + helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint), time_precision, http_request_retry_parameters, } } - /// The URL relative to which the API endpoints for the aggregator may be - /// found, if the role is an aggregator, or an error otherwise. + /// The URL relative to which the API endpoints for the aggregator may be found, if the role is + /// an aggregator, or an error otherwise. fn aggregator_endpoint(&self, role: &Role) -> Result<&Url, Error> { - Ok(&self.aggregator_endpoints[role - .index() - .ok_or(Error::InvalidParameter("role is not an aggregator"))?]) + match role { + Role::Leader => Ok(&self.leader_aggregator_endpoint), + Role::Helper => Ok(&self.helper_aggregator_endpoint), + _ => Err(Error::InvalidParameter("role is not an aggregator")), + } } - /// URL from which the HPKE configuration for the server filling `role` may - /// be fetched per draft-gpew-priv-ppm §4.3.1 + /// URL from which the HPKE configuration for the server filling `role` may be fetched per + /// draft-gpew-priv-ppm §4.3.1 fn hpke_config_endpoint(&self, role: &Role) -> Result { Ok(self.aggregator_endpoint(role)?.join("hpke_config")?) } @@ -121,21 +124,13 @@ impl ClientParameters { // URI to which reports may be uploaded for the provided task. fn reports_resource_uri(&self, task_id: &TaskId) -> Result { Ok(self - .aggregator_endpoint(&Role::Leader)? + .leader_aggregator_endpoint .join(&format!("tasks/{task_id}/reports"))?) } } -fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for url in urls { - list.entry(&format!("{url}")); - } - list.finish() -} - -/// Fetches HPKE configuration from the specified aggregator using the -/// aggregator endpoints in the provided [`ClientParameters`]. +/// Fetches HPKE configuration from the specified aggregator using the aggregator endpoints in the +/// provided [`ClientParameters`]. #[tracing::instrument(err)] pub async fn aggregator_hpke_config( client_parameters: &ClientParameters, @@ -228,14 +223,14 @@ impl, C: Clock> Client { let report_metadata = ReportMetadata::new(report_id, time); let encoded_public_share = public_share.get_encoded(); - let encrypted_input_shares: Vec = [ + let (leader_encrypted_input_share, helper_encrypted_input_share) = [ (&self.leader_hpke_config, &Role::Leader), (&self.helper_hpke_config, &Role::Helper), ] .into_iter() .zip(input_shares) .map(|((hpke_config, receiver_role), input_share)| { - Ok(hpke::seal( + hpke::seal( hpke_config, &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, receiver_role), &PlaintextInputShare::new( @@ -249,14 +244,16 @@ impl, C: Clock> Client { encoded_public_share.clone(), ) .get_encoded(), - )?) + ) }) - .collect::>()?; + .collect_tuple() + .expect("iterator to yield two items"); // expect safety: iterator contains two items. Ok(Report::new( report_metadata, encoded_public_share, - encrypted_input_shares, + leader_encrypted_input_share?, + helper_encrypted_input_share?, )) } @@ -309,14 +306,15 @@ mod tests { use url::Url; fn setup_client>( - server: &mut mockito::Server, + server: &mockito::Server, vdaf_client: V, ) -> Client { let server_url = Url::parse(&server.url()).unwrap(); Client::new( ClientParameters::new_with_backoff( random(), - Vec::from([server_url.clone(), server_url]), + server_url.clone(), + server_url, Duration::from_seconds(1), test_http_request_exponential_backoff(), ), @@ -332,19 +330,18 @@ mod tests { fn aggregator_endpoints_end_in_slash() { let client_parameters = ClientParameters::new( random(), - Vec::from([ - "http://leader_endpoint/foo/bar".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint/foo/bar".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), Duration::from_seconds(1), ); assert_eq!( - client_parameters.aggregator_endpoints, - Vec::from([ - "http://leader_endpoint/foo/bar/".parse().unwrap(), - "http://helper_endpoint/".parse().unwrap() - ]) + client_parameters.leader_aggregator_endpoint, + "http://leader_endpoint/foo/bar/".parse().unwrap() + ); + assert_eq!( + client_parameters.helper_aggregator_endpoint, + "http://helper_endpoint/".parse().unwrap() ); } @@ -352,7 +349,7 @@ mod tests { async fn upload_prio3_count() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -373,9 +370,9 @@ mod tests { #[tokio::test] async fn upload_prio3_invalid_measurement() { install_test_trace_subscriber(); - let mut server = mockito::Server::new_async().await; + let server = mockito::Server::new_async().await; let vdaf = Prio3::new_sum(2, 16).unwrap(); - let client = setup_client(&mut server, vdaf); + let client = setup_client(&server, vdaf); // 65536 is too big for a 16 bit sum and will be rejected by the VDAF. // Make sure we get the right error variant but otherwise we aren't @@ -387,7 +384,7 @@ mod tests { async fn upload_prio3_http_status_code() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -414,7 +411,7 @@ mod tests { async fn upload_problem_details() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -455,8 +452,12 @@ mod tests { async fn upload_bad_time_precision() { install_test_trace_subscriber(); - let client_parameters = - ClientParameters::new(random(), Vec::new(), Duration::from_seconds(0)); + let client_parameters = ClientParameters::new( + random(), + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), + Duration::from_seconds(0), + ); let client = Client::new( client_parameters, Prio3::new_count(2).unwrap(), @@ -472,9 +473,9 @@ mod tests { #[test] fn report_timestamp() { install_test_trace_subscriber(); - let mut server = mockito::Server::new(); + let server = mockito::Server::new(); let vdaf = Prio3::new_count(2).unwrap(); - let mut client = setup_client(&mut server, vdaf); + let mut client = setup_client(&server, vdaf); client.parameters.time_precision = Duration::from_seconds(100); client.clock = MockClock::new(Time::from_seconds_since_epoch(101)); diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 11683a4c4..63d2fedc1 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -113,8 +113,6 @@ pub enum Error { Codec(#[from] prio::codec::CodecError), #[error("aggregate share decoding error")] AggregateShareDecode, - #[error("expected two aggregate shares, got {0}")] - AggregateShareCount(usize), #[error("VDAF error: {0}")] Vdaf(#[from] prio::vdaf::VdafError), #[error("HPKE error: {0}")] @@ -180,17 +178,14 @@ impl CollectorParameters { /// Creates a new set of collector task parameters. pub fn new( task_id: TaskId, - mut leader_endpoint: Url, + leader_endpoint: Url, authentication: AuthenticationToken, hpke_config: HpkeConfig, hpke_private_key: HpkePrivateKey, ) -> CollectorParameters { - // Ensure the provided leader endpoint ends with a slash. - url_ensure_trailing_slash(&mut leader_endpoint); - CollectorParameters { task_id, - leader_endpoint, + leader_endpoint: url_ensure_trailing_slash(leader_endpoint), authentication, hpke_config, hpke_private_key, @@ -241,8 +236,10 @@ struct CollectionJob where Q: QueryType, { - /// The collection job ID. - collection_job_id: CollectionJobId, + /// The URL provided by the leader aggregator, where the collect response will be available + /// upon completion. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + collection_job_url: Url, /// The collect request's query. query: Query, /// The aggregation parameter used in this collect request. @@ -252,12 +249,12 @@ where impl CollectionJob { fn new( - collection_job_id: CollectionJobId, + collection_job_url: Url, query: Query, aggregation_parameter: P, ) -> CollectionJob { CollectionJob { - collection_job_id, + collection_job_url, query, aggregation_parameter, } @@ -393,8 +390,7 @@ impl Collector { ) -> Result, Error> { let collect_request = CollectionReq::new(query.clone(), aggregation_parameter.get_encoded()); - let collection_job_id = random(); - let collection_job_url = self.parameters.collection_job_uri(collection_job_id)?; + let collection_job_url = self.parameters.collection_job_uri(random())?; let response_res = retry_http_request( self.parameters.http_request_retry_parameters.clone(), @@ -433,7 +429,7 @@ impl Collector { }; Ok(CollectionJob::new( - collection_job_id, + collection_job_url, query, aggregation_parameter.clone(), )) @@ -446,14 +442,13 @@ impl Collector { &self, job: &CollectionJob, ) -> Result, Error> { - let collection_job_url = self.parameters.collection_job_uri(job.collection_job_id)?; let response_res = retry_http_request( self.parameters.http_request_retry_parameters.clone(), || async { let (auth_header, auth_value) = self.parameters.authentication.request_authentication(); self.http_client - .post(collection_job_url.clone()) + .post(job.collection_job_url.clone()) .header(auth_header, auth_value) .send() .await @@ -501,41 +496,40 @@ impl Collector { } let collect_response = CollectionMessage::::get_decoded(&response.bytes().await?)?; - if collect_response.encrypted_aggregate_shares().len() != 2 { - return Err(Error::AggregateShareCount( - collect_response.encrypted_aggregate_shares().len(), - )); - } - let aggregate_shares_bytes = collect_response - .encrypted_aggregate_shares() - .iter() - .zip(&[Role::Leader, Role::Helper]) - .map(|(encrypted_aggregate_share, role)| { - hpke::open( - &self.parameters.hpke_config, - &self.parameters.hpke_private_key, - &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, role, &Role::Collector), - encrypted_aggregate_share, - &AggregateShareAad::new( - self.parameters.task_id, - BatchSelector::::new(Q::batch_identifier_for_collection( - &job.query, - &collect_response, - )), - ) - .get_encoded(), - ) - }); - let aggregate_shares = aggregate_shares_bytes - .map(|bytes| { - V::AggregateShare::get_decoded_with_param( - &(&self.vdaf_collector, &job.aggregation_parameter), - &bytes?, + let aggregate_shares = [ + ( + Role::Leader, + collect_response.leader_encrypted_aggregate_share(), + ), + ( + Role::Helper, + collect_response.helper_encrypted_aggregate_share(), + ), + ] + .into_iter() + .map(|(role, encrypted_aggregate_share)| { + let bytes = hpke::open( + &self.parameters.hpke_config, + &self.parameters.hpke_private_key, + &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, &role, &Role::Collector), + encrypted_aggregate_share, + &AggregateShareAad::new( + self.parameters.task_id, + BatchSelector::::new(Q::batch_identifier_for_collection( + &job.query, + &collect_response, + )), ) - .map_err(|_err| Error::AggregateShareDecode) - }) - .collect::, Error>>()?; + .get_encoded(), + )?; + V::AggregateShare::get_decoded_with_param( + &(&self.vdaf_collector, &job.aggregation_parameter), + &bytes, + ) + .map_err(|_err| Error::AggregateShareDecode) + }) + .collect::, Error>>()?; let report_count = collect_response .report_count() @@ -634,6 +628,29 @@ impl Collector { } } +#[cfg(feature = "test-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] +pub mod test_util { + use crate::{Collection, Collector, Error}; + use janus_messages::{query_type::QueryType, Query}; + use prio::vdaf; + + pub async fn collect_with_rewritten_url( + collector: &Collector, + query: Query, + aggregation_parameter: &V::AggregationParam, + host: &str, + port: u16, + ) -> Result, Error> { + let mut job = collector + .start_collection(query, aggregation_parameter) + .await?; + job.collection_job_url.set_host(Some(host))?; + job.collection_job_url.set_port(Some(port)).unwrap(); + collector.poll_until_complete(&job).await + } +} + #[cfg(test)] mod tests { use crate::{ @@ -711,30 +728,20 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::::from([ - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &transcript.aggregate_shares[0].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &transcript.aggregate_shares[1].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &transcript.aggregate_shares[0].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &transcript.aggregate_shares[1].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ) } @@ -749,30 +756,20 @@ mod tests { PartialBatchSelector::new_fixed_size(batch_id), 1, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - Vec::::from([ - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &transcript.aggregate_shares[0].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &transcript.aggregate_shares[1].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &transcript.aggregate_shares[0].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &transcript.aggregate_shares[1].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ) } @@ -857,24 +854,20 @@ mod tests { let job = job.unwrap(); assert_eq!(job.query.batch_interval(), &batch_interval); - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mocked_collect_error = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(500) .expect(1) .create_async() .await; let mocked_collect_accepted = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(202) .expect(2) .create_async() .await; let mocked_collect_complete = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .match_header(auth_header, auth_value.as_str()) .with_status(200) .with_header( @@ -946,12 +939,8 @@ mod tests { assert_eq!(job.query.batch_interval(), &batch_interval); mocked_collect_start_success.assert_async().await; - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mocked_collect_complete = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1018,12 +1007,8 @@ mod tests { mocked_collect_start_success.assert_async().await; - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mocked_collect_complete = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1099,12 +1084,8 @@ mod tests { mocked_collect_start_success.assert_async().await; - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mocked_collect_complete = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1173,12 +1154,8 @@ mod tests { mocked_collect_start_success.assert_async().await; - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mocked_collect_complete = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1257,12 +1234,8 @@ mod tests { let job = job.unwrap(); assert_eq!(job.query.batch_interval(), &batch_interval); - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mocked_collect_complete = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .match_header(AUTHORIZATION.as_str(), "Bearer AAAAAAAAAAAAAAAA") .with_status(200) .with_header( @@ -1428,12 +1401,8 @@ mod tests { mock_collect_start.assert_async().await; mock_collection_job_server_error.assert_async().await; - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mock_collection_job_server_error_details = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(500) .with_header("Content-Type", "application/problem+json") .with_body("{\"type\": \"http://example.com/test_server_error\"}") @@ -1453,7 +1422,7 @@ mod tests { .await; let mock_collection_job_bad_request = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(400) .with_header("Content-Type", "application/problem+json") .with_body(concat!( @@ -1476,7 +1445,7 @@ mod tests { mock_collection_job_bad_request.assert_async().await; let mock_collection_job_bad_message_bytes = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1492,33 +1461,8 @@ mod tests { mock_collection_job_bad_message_bytes.assert_async().await; - let mock_collection_job_bad_share_count = server - .mock("POST", collection_job_path.as_str()) - .with_status(200) - .with_header( - CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, - ) - .with_body( - CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 0, - batch_interval, - Vec::new(), - ) - .get_encoded(), - ) - .expect_at_least(1) - .create_async() - .await; - - let error = collector.poll_once(&job).await.unwrap_err(); - assert_matches!(error, Error::AggregateShareCount(0)); - - mock_collection_job_bad_share_count.assert_async().await; - let mock_collection_job_bad_ciphertext = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1529,18 +1473,16 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - HpkeCiphertext::new( - *collector.parameters.hpke_config.id(), - Vec::new(), - Vec::new(), - ), - HpkeCiphertext::new( - *collector.parameters.hpke_config.id(), - Vec::new(), - Vec::new(), - ), - ]), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), ) .get_encoded(), ) @@ -1561,33 +1503,23 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - b"bad", - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - b"bad", - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + b"bad", + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + b"bad", + &associated_data.get_encoded(), + ) + .unwrap(), ); let mock_collection_job_bad_shares = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1607,38 +1539,28 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) - .get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &AggregateShare::from(OutputShare::from(Vec::from([ - Field64::from(0), - Field64::from(0), - ]))) + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) .get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &AggregateShare::from(OutputShare::from(Vec::from([ + Field64::from(0), + Field64::from(0), + ]))) + .get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ); let mock_collection_job_wrong_length = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1655,7 +1577,7 @@ mod tests { mock_collection_job_wrong_length.assert_async().await; let mock_collection_job_always_fail = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(500) .expect_at_least(3) .create_async() @@ -1697,12 +1619,8 @@ mod tests { .unwrap(); mock_collect_start.assert_async().await; - let collection_job_path = format!( - "/tasks/{}/collection_jobs/{}", - collector.parameters.task_id, job.collection_job_id - ); let mock_collect_poll_no_retry_after = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(202) .expect(1) .create_async() @@ -1714,7 +1632,7 @@ mod tests { mock_collect_poll_no_retry_after.assert_async().await; let mock_collect_poll_retry_after_60s = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(202) .with_header("Retry-After", "60") .expect(1) @@ -1727,7 +1645,7 @@ mod tests { mock_collect_poll_retry_after_60s.assert_async().await; let mock_collect_poll_retry_after_date_time = server - .mock("POST", collection_job_path.as_str()) + .mock("POST", job.collection_job_url.path()) .with_status(202) .with_header("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT") .expect(1) @@ -1761,13 +1679,14 @@ mod tests { collector.parameters.task_id ); + let collection_job_url = format!("{}{collection_job_path}", server.url()); let batch_interval = Interval::new( Time::from_seconds_since_epoch(1_000_000), Duration::from_seconds(3600), ) .unwrap(); let job = CollectionJob::new( - collection_job_id, + collection_job_url.parse().unwrap(), Query::new_time_interval(batch_interval), (), ); diff --git a/core/Cargo.toml b/core/Cargo.toml index 7443470ea..2d90048ed 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -49,7 +49,7 @@ k8s-openapi = { workspace = true, optional = true } lazy_static = { version = "1", optional = true } prio.workspace = true rand = "0.8" -reqwest = { version = "0.11.19", default-features = false, features = ["rustls-tls", "json"] } +reqwest = { version = "0.11.18", default-features = false, features = ["rustls-tls", "json"] } ring = "0.16.20" serde.workspace = true serde_json = { workspace = true, optional = true } @@ -64,13 +64,12 @@ tracing = "0.1.37" tracing-log = { version = "0.1.3", optional = true } tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"], optional = true } trillium.workspace = true +url = "2.4.0" [dev-dependencies] fixed = "1.23" hex = { version = "0.4", features = ["serde"] } # ensure this remains compatible with the non-dev dependency janus_core = { path = ".", features = ["test-util"] } -kube.workspace = true mockito = "1.1.0" rstest.workspace = true serde_test.workspace = true -url = "2.4.0" diff --git a/core/src/task.rs b/core/src/task.rs index f7de447eb..80e280e63 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -3,16 +3,16 @@ use derivative::Derivative; use http::header::AUTHORIZATION; use janus_messages::taskprov; use rand::{distributions::Standard, prelude::Distribution}; -use reqwest::Url; use ring::constant_time; use serde::{de::Error, Deserialize, Deserializer, Serialize}; use std::str; +use url::Url; /// HTTP header where auth tokens are provided in messages between participants. pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token"; -/// The length of the verify key parameter for Prio3 VDAF instantiations. -pub const PRIO3_VERIFY_KEY_LENGTH: usize = 16; +/// The length of the verify key parameter for Prio3 & Poplar1 VDAF instantiations. +pub const VERIFY_KEY_LENGTH: usize = 16; /// Identifiers for supported VDAFs, corresponding to definitions in /// [draft-irtf-cfrg-vdaf-03][1] and implementations in [`prio::vdaf::prio3`]. @@ -67,9 +67,8 @@ impl VdafInstance { | VdafInstance::FakeFailsPrepInit | VdafInstance::FakeFailsPrepStep => 0, - // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's - // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, + // All "real" VDAFs use a verify key of length 16 currently. + _ => VERIFY_KEY_LENGTH, } } } @@ -101,35 +100,41 @@ impl TryFrom<&taskprov::VdafType> for VdafInstance { #[macro_export] macro_rules! vdaf_dispatch_impl_base { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match base $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match base $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count => { type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3CountVec { length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3Sum { bits } => { type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3SumVec { bits, length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; + $body + } + + ::janus_core::task::VdafInstance::Poplar1 { bits } => { + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -138,12 +143,12 @@ macro_rules! vdaf_dispatch_impl_base { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match base $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match base $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_count(2)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -151,14 +156,14 @@ macro_rules! vdaf_dispatch_impl_base { // Prio3CountVec is implemented as a 1-bit sum vec let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, 1, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3Sum { bits } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *bits)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -166,14 +171,21 @@ macro_rules! vdaf_dispatch_impl_base { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, *bits, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3Histogram { length } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; + $body + } + + ::janus_core::task::VdafInstance::Poplar1 { bits } => { + let $vdaf = ::prio::vdaf::poplar1::Poplar1::new_sha3(*bits); + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -187,13 +199,13 @@ macro_rules! vdaf_dispatch_impl_base { #[macro_export] macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI16<::fixed::types::extra::U15>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -201,7 +213,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI32<::fixed::types::extra::U31>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -209,7 +221,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI64<::fixed::types::extra::U63>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -218,7 +230,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => { let $vdaf = @@ -228,7 +240,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI16<::fixed::types::extra::U15>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -240,7 +252,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI32<::fixed::types::extra::U31>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -252,7 +264,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI64<::fixed::types::extra::U63>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } @@ -266,23 +278,23 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { #[macro_export] macro_rules! vdaf_dispatch_impl_test_util { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match test_util $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match test_util $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Fake => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepInit => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepStep => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -291,12 +303,12 @@ macro_rules! vdaf_dispatch_impl_test_util { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match test_util $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match test_util $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Fake => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new(); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -306,26 +318,30 @@ macro_rules! vdaf_dispatch_impl_test_util { ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( "FakeFailsPrepInit failed at prep_init".to_string(), )) - } + }, ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepStep => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result< - ::prio::vdaf::PrepareTransition<::janus_core::test_util::dummy_vdaf::Vdaf, 0, 16>, - ::prio::vdaf::VdafError, - > { - ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( - "FakeFailsPrepStep failed at prep_step".to_string(), - )) - } - ); + || -> Result< + ::prio::vdaf::PrepareTransition< + ::janus_core::test_util::dummy_vdaf::Vdaf, + 0, + 16, + >, + ::prio::vdaf::VdafError, + > { + ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( + "FakeFailsPrepStep failed at prep_step".to_string(), + )) + }, + ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -339,26 +355,27 @@ macro_rules! vdaf_dispatch_impl_test_util { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -366,26 +383,27 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -398,20 +416,21 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -419,20 +438,21 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -445,20 +465,21 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -466,20 +487,21 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -492,14 +514,15 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -507,14 +530,15 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -541,21 +565,21 @@ macro_rules! vdaf_dispatch_impl { /// # } /// # fn test() -> Result<(), prio::vdaf::VdafError> { /// # let vdaf = janus_core::task::VdafInstance::Prio3Count; -/// vdaf_dispatch!(&vdaf, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { -/// handle_request_generic::(&vdaf) +/// vdaf_dispatch!(&vdaf, (vdaf, VdafType, VERIFY_KEY_LEN) => { +/// handle_request_generic::(&vdaf) /// }) /// # } /// ``` #[macro_export] macro_rules! vdaf_dispatch { // Provide the dispatched type only, don't construct a VDAF instance. - ($vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ($vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { + ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) }; // Construct a VDAF instance, and provide that to the block as well. - ($vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ($vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { + ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) }; } @@ -710,15 +734,16 @@ impl Distribution for Standard { } } -/// Modifies a [`Url`] in place to ensure it ends with a slash. +/// Returns the given [`Url`], possibly modified to end with a slash. /// /// Aggregator endpoint URLs should end with a slash if they will be used with [`Url::join`], /// because that method will drop the last path component of the base URL if it does not end with a /// slash. -pub fn url_ensure_trailing_slash(url: &mut Url) { +pub fn url_ensure_trailing_slash(mut url: Url) -> Url { if !url.as_str().ends_with('/') { url.set_path(&format!("{}/", url.path())); } + url } #[cfg(test)] diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index e28bfc016..bd9f03634 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -75,19 +75,20 @@ CREATE TABLE taskprov_collector_auth_tokens( -- Corresponds to a DAP task, containing static data associated with the task. CREATE TABLE tasks( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification - aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task - aggregator_endpoints TEXT[] NOT NULL, -- aggregator HTTPS endpoints, leader first - query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters - vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters - max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected - task_expiration TIMESTAMP, -- the time after which client reports are no longer accepted - report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. - min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected - time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds - tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds - collector_hpke_config BYTEA -- the HPKE config of the collector (encoded HpkeConfig message) + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only + task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification + aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task + leader_aggregator_endpoint TEXT NOT NULL, -- Leader's API endpoint + helper_aggregator_endpoint TEXT NOT NULL, -- Helper's API endpoint + query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters + vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters + max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected + task_expiration TIMESTAMP, -- the time after which client reports are no longer accepted + report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. + min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected + time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds + tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds + collector_hpke_config BYTEA -- the HPKE config of the collector (encoded HpkeConfig message) ); CREATE INDEX task_id_index ON tasks(task_id); @@ -325,4 +326,4 @@ CREATE TABLE outstanding_batches( CONSTRAINT outstanding_batches_unique_task_id_batch_id UNIQUE(task_id, batch_id), CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE ); -CREATE INDEX outstanding_batches_task_and_time_bucket_index ON outstanding_batches (task_id, time_bucket_start); \ No newline at end of file +CREATE INDEX outstanding_batches_task_and_time_bucket_index ON outstanding_batches (task_id, time_bucket_start); diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index 78cd97eb0..d04e6a51e 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -6,10 +6,9 @@ # DAP's recommendation. task_id: "G9YKXjoEjfoU7M_fi_o2H0wmzavRb2sBFHeykeRhDMk" - # HTTPS endpoints of the leader and helper aggregators, in a list. - aggregator_endpoints: - - "https://example.com/" - - "https://example.net/" + # HTTPS endpoints of the leader and helper aggregators. + leader_aggregator_endpoint: "https://example.com/" + helper_aggregator_endpoint: "https://example.net/" # The DAP query type. See below for an example of a fixed-size task query_type: TimeInterval @@ -102,9 +101,8 @@ private_key: wFRYwiypcHC-mkGP1u3XQgIvtnlkQlUfZjgtM_zRsnI - task_id: "D-hCKPuqL2oTf7ZVRVyMP5VGt43EAEA8q34mDf6p1JE" - aggregator_endpoints: - - "https://example.org/" - - "https://example.com/" + leader_aggregator_endpoint: "https://example.org/" + helper_aggregator_endpoint: "https://example.com/" # For tasks using the fixed size query type, an additional `max_batch_size` # parameter must be provided. query_type: !FixedSize diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index 28f00218c..84097a8f4 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -216,12 +216,13 @@ where (leader_port, helper_port): (u16, u16), vdaf: V, ) -> Result, janus_client::Error> { - let aggregator_endpoints = task_parameters + let (leader_aggregator_endpoint, helper_aggregator_endpoint) = task_parameters .endpoint_fragments .port_forwarded_endpoints(leader_port, helper_port); let client_parameters = ClientParameters::new( task_parameters.task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, task_parameters.time_precision, ); let http_client = default_http_client()?; @@ -267,13 +268,13 @@ where let container = ContainerLogsDropGuard::new(container); let host_port = container.get_host_port_ipv4(8080); let http_client = reqwest::Client::new(); - let aggregator_endpoints = task_parameters + let (leader_aggregator_endpoint, helper_aggregator_endpoint) = task_parameters .endpoint_fragments .container_network_endpoints(); ClientImplementation::Container(Box::new(ContainerClientImplementation { _container: container, - leader: aggregator_endpoints[Role::Leader.index().unwrap()].clone(), - helper: aggregator_endpoints[Role::Helper.index().unwrap()].clone(), + leader: leader_aggregator_endpoint, + helper: helper_aggregator_endpoint, task_id: task_parameters.task_id, time_precision: task_parameters.time_precision, vdaf, diff --git a/integration_tests/src/daphne.rs b/integration_tests/src/daphne.rs index 6a4d894b7..cf6367cc9 100644 --- a/integration_tests/src/daphne.rs +++ b/integration_tests/src/daphne.rs @@ -26,15 +26,17 @@ impl<'a> Daphne<'a> { /// Create and start a new hermetic Daphne test instance in the given Docker network, configured /// to service the given task. The aggregator port is also exposed to the host. pub async fn new(container_client: &'a Cli, network: &str, task: &Task) -> Daphne<'a> { - let image_name_and_tag = match task.role() { + let (endpoint, image_name_and_tag) = match task.role() { Role::Leader => panic!("A leader container image for Daphne is not yet available"), - Role::Helper => DAPHNE_HELPER_IMAGE_NAME_AND_TAG, + Role::Helper => ( + task.helper_aggregator_endpoint(), + DAPHNE_HELPER_IMAGE_NAME_AND_TAG, + ), Role::Collector | Role::Client => unreachable!(), }; let (image_name, image_tag) = image_name_and_tag.rsplit_once(':').unwrap(); // Start the Daphne test container running. - let endpoint = task.aggregator_url(task.role()).unwrap(); let runnable_image = RunnableImage::from(GenericImage::new(image_name, image_tag)) .with_network(network) .with_container_name(endpoint.host_str().unwrap()); diff --git a/integration_tests/src/janus.rs b/integration_tests/src/janus.rs index 9a0e1582e..11e6d1c98 100644 --- a/integration_tests/src/janus.rs +++ b/integration_tests/src/janus.rs @@ -23,7 +23,11 @@ impl<'a> Janus<'a> { /// to service the given task. The aggregator port is also exposed to the host. pub async fn new(container_client: &'a Cli, network: &str, task: &Task) -> Janus<'a> { // Start the Janus interop aggregator container running. - let endpoint = task.aggregator_url(task.role()).unwrap(); + let endpoint = match task.role() { + Role::Leader => task.leader_aggregator_endpoint(), + Role::Helper => task.helper_aggregator_endpoint(), + _ => panic!("unexpected task role"), + }; let container = container_client.run( RunnableImage::from(Aggregator::default()) .with_network(network) diff --git a/integration_tests/src/lib.rs b/integration_tests/src/lib.rs index 7a7b720da..96c79da26 100644 --- a/integration_tests/src/lib.rs +++ b/integration_tests/src/lib.rs @@ -44,19 +44,19 @@ impl EndpointFragments { .unwrap() } - pub fn port_forwarded_endpoints(&self, leader_port: u16, helper_port: u16) -> Vec { - Vec::from([ + pub fn port_forwarded_endpoints(&self, leader_port: u16, helper_port: u16) -> (Url, Url) { + ( self.port_forwarded_leader_endpoint(leader_port), Url::parse(&format!( "http://127.0.0.1:{helper_port}{}", self.helper_endpoint_path )) .unwrap(), - ]) + ) } - pub fn container_network_endpoints(&self) -> Vec { - Vec::from([ + pub fn container_network_endpoints(&self) -> (Url, Url) { + ( Url::parse(&format!( "http://{}:8080{}", self.leader_endpoint_host, self.leader_endpoint_path @@ -67,6 +67,6 @@ impl EndpointFragments { self.helper_endpoint_host, self.helper_endpoint_path )) .unwrap(), - ]) + ) } } diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index bc97c6030..a3d9296db 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -1,7 +1,9 @@ use backoff::{future::retry, ExponentialBackoffBuilder}; use itertools::Itertools; use janus_aggregator_core::task::{test_util::TaskBuilder, QueryType}; -use janus_collector::{Collection, Collector, CollectorParameters}; +use janus_collector::{ + test_util::collect_with_rewritten_url, Collection, Collector, CollectorParameters, +}; use janus_core::{ hpke::test_util::generate_test_hpke_config_and_private_key, retries::test_http_request_exponential_backoff, @@ -21,6 +23,7 @@ use prio::vdaf::{self, prio3::Prio3}; use rand::{random, thread_rng, Rng}; use std::{iter, time::Duration as StdDuration}; use tokio::time::{self, sleep}; +use url::Url; /// Returns a tuple of [`TaskParameters`], a task builder for the leader, and a task builder for the /// helper. @@ -37,7 +40,12 @@ pub fn test_task_builders( }; let collector_keypair = generate_test_hpke_config_and_private_key(); let leader_task = TaskBuilder::new(query_type, vdaf.clone(), Role::Leader) - .with_aggregator_endpoints(endpoint_fragments.container_network_endpoints()) + .with_leader_aggregator_endpoint( + Url::parse(&format!("http://leader-{endpoint_random_value}:8080/")).unwrap(), + ) + .with_helper_aggregator_endpoint( + Url::parse(&format!("http://helper-{endpoint_random_value}:8080/")).unwrap(), + ) .with_min_batch_size(46) .with_collector_hpke_config(collector_keypair.config().clone()); let helper_task = leader_task @@ -74,6 +82,8 @@ pub async fn collect_generic<'a, V, Q>( collector: &Collector, query: Query, aggregation_parameter: &V::AggregationParam, + host: &str, + port: u16, ) -> Result, janus_collector::Error> where V: vdaf::Client<16> + vdaf::Collector + InteropClientEncoding, @@ -90,7 +100,9 @@ where retry(backoff, || { let query = query.clone(); async move { - match collector.collect(query, aggregation_parameter).await { + match collect_with_rewritten_url(collector, query, aggregation_parameter, host, port) + .await + { Ok(collection) => Ok(collection), Err( error @ janus_collector::Error::Http { @@ -162,6 +174,8 @@ pub async fn submit_measurements_and_verify_aggregate_generic( &collector, Query::new_time_interval(batch_interval), &test_case.aggregation_parameter, + "127.0.0.1", + leader_port, ) .await .unwrap(); @@ -180,6 +194,8 @@ pub async fn submit_measurements_and_verify_aggregate_generic( &collector, Query::new_fixed_size(FixedSizeQuery::CurrentBatch), &test_case.aggregation_parameter, + "127.0.0.1", + leader_port, ) .await; match collection_res { diff --git a/integration_tests/tests/daphne.rs b/integration_tests/tests/daphne.rs index 2c6ee5f28..fa53888a2 100644 --- a/integration_tests/tests/daphne.rs +++ b/integration_tests/tests/daphne.rs @@ -6,7 +6,6 @@ use janus_core::{ }; use janus_integration_tests::{client::ClientBackend, daphne::Daphne, janus::Janus}; use janus_interop_binaries::test_util::generate_network_name; -use janus_messages::Role; mod common; @@ -26,9 +25,10 @@ async fn daphne_janus() { let [leader_task, helper_task]: [Task; 2] = [leader_task, helper_task] .into_iter() .map(|task| { - let mut endpoints = task.aggregator_endpoints().to_vec(); - endpoints[Role::Leader.index().unwrap()].set_path("/v04/"); - task.with_aggregator_endpoints(endpoints).build() + let mut leader_aggregator_endpoint = task.leader_aggregator_endpoint().clone(); + leader_aggregator_endpoint.set_path("/v04/"); + task.with_leader_aggregator_endpoint(leader_aggregator_endpoint) + .build() }) .collect::>() .try_into() @@ -63,9 +63,10 @@ async fn janus_daphne() { let [leader_task, helper_task]: [Task; 2] = [leader_task, helper_task] .into_iter() .map(|task| { - let mut endpoints = task.aggregator_endpoints().to_vec(); - endpoints[Role::Helper.index().unwrap()].set_path("/v04/"); - task.with_aggregator_endpoints(endpoints).build() + let mut helper_aggregator_endpoint = task.helper_aggregator_endpoint().clone(); + helper_aggregator_endpoint.set_path("/v04/"); + task.with_helper_aggregator_endpoint(helper_aggregator_endpoint) + .build() }) .collect::>() .try_into() diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index c08f7602c..cd7c0a16d 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -88,7 +88,8 @@ async fn handle_add_task( let task = Task::new( request.task_id, - Vec::from([request.leader, request.helper]), + request.leader, + request.helper, query_type, vdaf, request.role.into(), diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index cca595006..c8b1dbe2d 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -79,11 +79,8 @@ async fn handle_upload_generic>( .context("invalid base64url content in \"task_id\"")?; let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; let time_precision = Duration::from_seconds(request.time_precision); - let client_parameters = ClientParameters::new( - task_id, - Vec::::from([request.leader, request.helper]), - time_precision, - ); + let client_parameters = + ClientParameters::new(task_id, request.leader, request.helper, time_precision); let leader_hpke_config = janus_client::aggregator_hpke_config( &client_parameters, diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index 926b8535a..d65cf60cc 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -287,8 +287,8 @@ impl From for AggregatorAddTaskRequest { }; Self { task_id: *task.id(), - leader: task.aggregator_url(&Role::Leader).unwrap().clone(), - helper: task.aggregator_url(&Role::Helper).unwrap().clone(), + leader: task.leader_aggregator_endpoint().clone(), + helper: task.helper_aggregator_endpoint().clone(), vdaf: task.vdaf().clone().into(), leader_authentication_token: String::from_utf8( task.primary_aggregator_auth_token().as_ref().to_vec(), diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index 325db2025..f0012a88d 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -2,7 +2,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use futures::future::join_all; use janus_core::{ - task::PRIO3_VERIFY_KEY_LENGTH, + task::VERIFY_KEY_LENGTH, test_util::{install_test_trace_subscriber, testcontainers::container_client}, time::{Clock, RealClock, TimeExt}, }; @@ -110,7 +110,7 @@ async fn run( let task_id: TaskId = random(); let aggregator_auth_token = URL_SAFE_NO_PAD.encode(random::<[u8; 16]>()); let collector_auth_token = URL_SAFE_NO_PAD.encode(random::<[u8; 16]>()); - let vdaf_verify_key = rand::random::<[u8; PRIO3_VERIFY_KEY_LENGTH]>(); + let vdaf_verify_key = rand::random::<[u8; VERIFY_KEY_LENGTH]>(); let task_id_encoded = URL_SAFE_NO_PAD.encode(task_id.get_encoded()); let vdaf_verify_key_encoded = URL_SAFE_NO_PAD.encode(vdaf_verify_key); diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 88aad22d5..816565b1f 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -1306,7 +1306,8 @@ impl Decode for PlaintextInputShare { pub struct Report { metadata: ReportMetadata, public_share: Vec, - encrypted_input_shares: Vec, + leader_encrypted_input_share: HpkeCiphertext, + helper_encrypted_input_share: HpkeCiphertext, } impl Report { @@ -1317,12 +1318,14 @@ impl Report { pub fn new( metadata: ReportMetadata, public_share: Vec, - encrypted_input_shares: Vec, + leader_encrypted_input_share: HpkeCiphertext, + helper_encrypted_input_share: HpkeCiphertext, ) -> Self { Self { metadata, public_share, - encrypted_input_shares, + leader_encrypted_input_share, + helper_encrypted_input_share, } } @@ -1331,13 +1334,19 @@ impl Report { &self.metadata } + /// Retrieve the public share from this report. pub fn public_share(&self) -> &[u8] { &self.public_share } - /// Get this report's encrypted input shares. - pub fn encrypted_input_shares(&self) -> &[HpkeCiphertext] { - &self.encrypted_input_shares + /// Retrieve the encrypted leader input share from this report. + pub fn leader_encrypted_input_share(&self) -> &HpkeCiphertext { + &self.leader_encrypted_input_share + } + + /// Retrieve the encrypted helper input share from this report. + pub fn helper_encrypted_input_share(&self) -> &HpkeCiphertext { + &self.helper_encrypted_input_share } } @@ -1345,17 +1354,16 @@ impl Encode for Report { fn encode(&self, bytes: &mut Vec) { self.metadata.encode(bytes); encode_u32_items(bytes, &(), &self.public_share); - encode_u32_items(bytes, &(), &self.encrypted_input_shares); + self.leader_encrypted_input_share.encode(bytes); + self.helper_encrypted_input_share.encode(bytes); } fn encoded_len(&self) -> Option { let mut length = self.metadata.encoded_len()?; length += 4; length += self.public_share.len(); - length += 4; - for encrypted_input_share in self.encrypted_input_shares.iter() { - length += encrypted_input_share.encoded_len()?; - } + length += self.leader_encrypted_input_share.encoded_len()?; + length += self.helper_encrypted_input_share.encoded_len()?; Some(length) } } @@ -1364,12 +1372,14 @@ impl Decode for Report { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let metadata = ReportMetadata::decode(bytes)?; let public_share = decode_u32_items(&(), bytes)?; - let encrypted_input_shares = decode_u32_items(&(), bytes)?; + let leader_encrypted_input_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_input_share = HpkeCiphertext::decode(bytes)?; Ok(Self { metadata, public_share, - encrypted_input_shares, + leader_encrypted_input_share, + helper_encrypted_input_share, }) } } @@ -1673,7 +1683,8 @@ pub struct Collection { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, - encrypted_aggregate_shares: Vec, + leader_encrypted_agg_share: HpkeCiphertext, + helper_encrypted_agg_share: HpkeCiphertext, } impl Collection { @@ -1685,34 +1696,41 @@ impl Collection { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, - encrypted_aggregate_shares: Vec, + leader_encrypted_agg_share: HpkeCiphertext, + helper_encrypted_agg_share: HpkeCiphertext, ) -> Self { Self { partial_batch_selector, report_count, interval, - encrypted_aggregate_shares, + leader_encrypted_agg_share, + helper_encrypted_agg_share, } } - /// Gets the batch selector associated with this collection. + /// Retrieves the batch selector associated with this collection. pub fn partial_batch_selector(&self) -> &PartialBatchSelector { &self.partial_batch_selector } - /// Gets the number of reports that were aggregated into this collection. + /// Retrieves the number of reports that were aggregated into this collection. pub fn report_count(&self) -> u64 { self.report_count } - /// Gets the interval spanned by the reports aggregated into this collection. + /// Retrieves the interval spanned by the reports aggregated into this collection. pub fn interval(&self) -> &Interval { &self.interval } - /// Gets the encrypted aggregate shares associated with this collection. - pub fn encrypted_aggregate_shares(&self) -> &[HpkeCiphertext] { - &self.encrypted_aggregate_shares + /// Retrieves the leader encrypted aggregate share associated with this collection. + pub fn leader_encrypted_aggregate_share(&self) -> &HpkeCiphertext { + &self.leader_encrypted_agg_share + } + + /// Retrieves the helper encrypted aggregate share associated with this collection. + pub fn helper_encrypted_aggregate_share(&self) -> &HpkeCiphertext { + &self.helper_encrypted_agg_share } } @@ -1721,18 +1739,18 @@ impl Encode for Collection { self.partial_batch_selector.encode(bytes); self.report_count.encode(bytes); self.interval.encode(bytes); - encode_u32_items(bytes, &(), &self.encrypted_aggregate_shares); + self.leader_encrypted_agg_share.encode(bytes); + self.helper_encrypted_agg_share.encode(bytes); } fn encoded_len(&self) -> Option { - let mut length = self.partial_batch_selector.encoded_len()? - + self.report_count.encoded_len()? - + self.interval.encoded_len()?; - length += 4; - for encrypted_aggregate_share in self.encrypted_aggregate_shares.iter() { - length += encrypted_aggregate_share.encoded_len()?; - } - Some(length) + Some( + self.partial_batch_selector.encoded_len()? + + self.report_count.encoded_len()? + + self.interval.encoded_len()? + + self.leader_encrypted_agg_share.encoded_len()? + + self.helper_encrypted_agg_share.encoded_len()?, + ) } } @@ -1741,13 +1759,15 @@ impl Decode for Collection { let partial_batch_selector = PartialBatchSelector::decode(bytes)?; let report_count = u64::decode(bytes)?; let interval = Interval::decode(bytes)?; - let encrypted_aggregate_shares = decode_u32_items(&(), bytes)?; + let leader_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; Ok(Self { partial_batch_selector, report_count, interval, - encrypted_aggregate_shares, + leader_encrypted_agg_share, + helper_encrypted_agg_share, }) } } @@ -3249,7 +3269,16 @@ mod tests { Time::from_seconds_since_epoch(12345), ), Vec::new(), - Vec::new(), + HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), ), concat!( concat!( @@ -3262,9 +3291,33 @@ mod tests { "00000000", // length ), concat!( - // encrypted_input_shares - "00000000", // length - ) + // leader_encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435" // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + concat!( + // helper_encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), ), ), ( @@ -3274,18 +3327,16 @@ mod tests { Time::from_seconds_since_epoch(54321), ), Vec::from("3210"), - Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), - ]), + HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), ), concat!( concat!( @@ -3299,33 +3350,31 @@ mod tests { "33323130", // opaque data ), concat!( - // encrypted_input_shares - "00000022", // length + // leader_encrypted_input_share + "2A", // config_id concat!( - "2A", // config_id - concat!( - // encapsulated_context - "0006", // length - "303132333435" // opaque data - ), - concat!( - // payload - "00000006", // length - "353433323130", // opaque data - ), + // encapsulated_context + "0006", // length + "303132333435" // opaque data ), concat!( - "0D", // config_id - concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data - ), - concat!( - // payload - "00000004", // length - "61626664", // opaque data - ), + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + concat!( + // helper_encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data ), ), ), @@ -3582,7 +3631,16 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 0, interval, - encrypted_aggregate_shares: Vec::new(), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3596,8 +3654,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "00000000", // length + // leader_encrypted_agg_share + "0A", // config_id + concat!( + // encapsulated_context + "0004", // length + "30313233", // opaque data + ), + concat!( + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3606,18 +3688,16 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 23, interval, - encrypted_aggregate_shares: Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(10), - Vec::from("0123"), - Vec::from("4567"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("01234"), - Vec::from("567"), - ), - ]), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3631,34 +3711,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "0000001E", // length + // leader_encrypted_agg_share + "0A", // config_id concat!( - "0A", // config_id - concat!( - // encapsulated_context - "0004", // length - "30313233", // opaque data - ), - concat!( - // payload - "00000004", // length - "34353637", // opaque data - ), + // encapsulated_context + "0004", // length + "30313233", // opaque data ), concat!( - "0C", // config_id - concat!( - // encapsulated_context - "0005", // length - "3031323334", // opaque data - ), - concat!( - // payload - "00000003", // length - "353637", // opaque data - ), - ) + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3673,7 +3751,16 @@ mod tests { )), report_count: 0, interval, - encrypted_aggregate_shares: Vec::new(), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3688,8 +3775,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "00000000", // length + // leader_encrypted_agg_share + "0A", // config_id + concat!( + // encapsulated_context + "0004", // length + "30313233", // opaque data + ), + concat!( + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3700,18 +3811,16 @@ mod tests { )), report_count: 23, interval, - encrypted_aggregate_shares: Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(10), - Vec::from("0123"), - Vec::from("4567"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("01234"), - Vec::from("567"), - ), - ]), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3726,34 +3835,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "0000001E", // length + // leader_encrypted_agg_share + "0A", // config_id concat!( - "0A", // config_id - concat!( - // encapsulated_context - "0004", // length - "30313233", // opaque data - ), - concat!( - // payload - "00000004", // length - "34353637", // opaque data - ), + // encapsulated_context + "0004", // length + "30313233", // opaque data ), concat!( - "0C", // config_id - concat!( - // encapsulated_context - "0005", // length - "3031323334", // opaque data - ), - concat!( - // payload - "00000003", // length - "353637", // opaque data - ), - ) + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ),