From a155d817904e0fc27569aee80adca0c97ae7166b Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 29 Mar 2023 14:18:02 -0700 Subject: [PATCH] DAP-05 ping-pong topology This change implements the DAP-05 ping-pong topology in which aggregators take turns preprocessing prepare shares into prepare messages. While this topology first appeared in DAP-05, this implementation follows DAP-06. This change depends on the implementation of the VDAF ping-pong topology added to crate `prio` in [1], which in turn conforms to the specification in VDAF-07. In the ping-pong topology, each DAP-layer step of aggregation now spans two VDAF rounds. An aggregator will use the prepare message it gets from its peer to advance by one VDAF round, and then can use the prepare share it just computed along with the peer's prepare share to advance by another. This incurs some changes to what intermediate values are stored by aggregators. In the case where a leader is continuing/waiting, it will have computed a prepare state, a prepare message for the current round and a prepare share for the next round. The naive implementation would store all three objects in the database, significantly increasing the per-report storage use. To mitigate this, the leader stores a `prio::topology::ping_pong::PingPongTransition`, which will contain a prepare state and a prepare message (both of which are generally much smaller than prepare shares), from which the next prepare state and importantly prepare share can be recomputed. On the helper side, there's no way around storing the prepare share: we store the most recently computed `PrepareResp` so that we can handle aggregation jobs idempotently. But to avoid storing prepare messages twice, the continuing/waiting helper stores just a prepare state and a `PingPongMessage`. The main benefit of this change is to reduce how many round trips between aggregators are needed to prepare reports. Quite a few tests used Prio3 but depended on having the leader or helper in the `Waiting` state after running aggregation initialization. Accordingly, those tests are changed to run Poplar1, which now takes 2 rounds. [1]: https://github.com/divviup/libprio-rs/pull/683 Co-authored-by: Tim Geoghegan Part of #1669 --- Cargo.lock | 4 +- Cargo.toml | 3 +- aggregator/Cargo.toml | 2 +- aggregator/src/aggregator.rs | 285 +-- aggregator/src/aggregator/accumulator.rs | 2 +- .../src/aggregator/aggregate_init_tests.rs | 265 ++- .../aggregator/aggregation_job_continue.rs | 366 ++-- .../src/aggregator/aggregation_job_creator.rs | 158 +- .../src/aggregator/aggregation_job_driver.rs | 1596 ++++++++++++----- .../src/aggregator/aggregation_job_writer.rs | 4 +- aggregator/src/aggregator/error.rs | 71 +- aggregator/src/aggregator/http_handlers.rs | 1585 ++++++++-------- aggregator/src/aggregator/taskprov_tests.rs | 210 ++- aggregator_api/Cargo.toml | 2 +- aggregator_core/Cargo.toml | 2 +- aggregator_core/src/datastore.rs | 122 +- aggregator_core/src/datastore/models.rs | 81 +- aggregator_core/src/datastore/tests.rs | 194 +- collector/src/lib.rs | 8 +- core/Cargo.toml | 2 +- core/src/test_util/mod.rs | 291 ++- db/00000000000001_initial_schema.up.sql | 8 +- integration_tests/Cargo.toml | 4 +- interop_binaries/Cargo.toml | 2 +- messages/Cargo.toml | 4 +- messages/src/lib.rs | 842 ++++++--- tools/Cargo.toml | 2 +- 27 files changed, 3921 insertions(+), 2194 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2f59ae56a..6c612dff8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2937,9 +2937,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prio" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7591b152d20a8a992f8b3a5daf6bc9e38e7fb347e3694ed9238eddc7e57332" +checksum = "e2e546dc580118e2120309c8aa7bb0da8deabd4b848289c486e9429a27c05594" dependencies = [ "aes", "bitvec", diff --git a/Cargo.toml b/Cargo.toml index 032738078..233d17380 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ version = "0.6.0" [workspace.dependencies] anyhow = "1" +base64 = "0.21.3" # Disable default features to disable compatibility with the old `time` crate, and we also don't # (yet) need other default features. # https://docs.rs/chrono/latest/chrono/#duration @@ -42,7 +43,7 @@ janus_messages = { version = "0.6", path = "messages" } k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] } opentelemetry = { version = "0.20", features = ["metrics"] } -prio = { version = "0.15.0", features = ["multithreaded"] } +prio = { version = "0.15.1", features = ["multithreaded", "experimental"] } serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.106" serde_test = "1.0.175" diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index 358720736..3dca700ff 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -35,7 +35,7 @@ test-util = [ async-trait = "0.1" anyhow.workspace = true backoff = { version = "0.4.0", features = ["tokio"] } -base64 = "0.21.4" +base64.workspace = true bytes = "1.5.0" chrono.workspace = true clap = { version = "4.4.2", features = ["derive", "env"] } diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 90a64ad92..879ad9d03 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -52,7 +52,7 @@ use janus_messages::{ AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, - PrepareStep, PrepareStepResult, Report, ReportIdChecksum, ReportShare, ReportShareError, Role, + PrepareError, PrepareResp, PrepareStepResult, Report, ReportIdChecksum, ReportShare, Role, TaskId, }; use opentelemetry::{ @@ -65,6 +65,7 @@ use prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded; use prio::vdaf::{PrepareTransition, VdafError}; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, + topology::ping_pong::{PingPongState, PingPongTopology}, vdaf::{ self, poplar1::Poplar1, @@ -87,6 +88,8 @@ use tokio::{sync::Mutex, try_join}; use tracing::{debug, info, trace_span, warn}; use url::Url; +use self::{accumulator::Accumulator, error::handle_ping_pong_error}; + pub mod accumulator; #[cfg(test)] mod aggregate_init_tests; @@ -131,6 +134,9 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "decrypt_failure", "input_share_decode_failure", "public_share_decode_failure", + "prepare_message_decode_failure", + "leader_prep_share_decode_failure", + "helper_prep_share_decode_failure", "continue_mismatch", "accumulate_failure", "finish_mismatch", @@ -138,6 +144,7 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "plaintext_input_share_decode_failure", "duplicate_extension", "missing_client_report", + "missing_prepare_message", ] { aggregate_step_failure_counter.add(0, &[KeyValue::new("type", failure_type)]); } @@ -387,6 +394,7 @@ impl Aggregator { &self.datastore, &self.global_hpke_keypairs, &self.aggregate_step_failure_counter, + self.cfg.batch_aggregation_shard_count, aggregation_job_id, req_bytes, ) @@ -931,6 +939,7 @@ impl TaskAggregator { datastore: &Datastore, global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, + batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, req_bytes: &[u8], ) -> Result { @@ -940,6 +949,7 @@ impl TaskAggregator { global_hpke_keypairs, aggregate_step_failure_counter, Arc::clone(&self.task), + batch_aggregation_shard_count, aggregation_job_id, req_bytes, ) @@ -1223,6 +1233,7 @@ impl VdafOps { global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, task: Arc, + batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, req_bytes: &[u8], ) -> Result { @@ -1235,6 +1246,7 @@ impl VdafOps { vdaf, aggregate_step_failure_counter, task, + batch_aggregation_shard_count, aggregation_job_id, verify_key, req_bytes, @@ -1250,6 +1262,7 @@ impl VdafOps { vdaf, aggregate_step_failure_counter, task, + batch_aggregation_shard_count, aggregation_job_id, verify_key, req_bytes, @@ -1551,6 +1564,7 @@ impl VdafOps { vdaf: &A, aggregate_step_failure_counter: &Counter, task: Arc, + batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, verify_key: &VerifyKey, req_bytes: &[u8], @@ -1573,9 +1587,9 @@ impl VdafOps { // If two ReportShare messages have the same report ID, then the helper MUST abort with // error "unrecognizedMessage". (§4.4.4.1) - let mut seen_report_ids = HashSet::with_capacity(req.report_shares().len()); - for share in req.report_shares() { - if !seen_report_ids.insert(share.metadata().id()) { + let mut seen_report_ids = HashSet::with_capacity(req.prepare_inits().len()); + for prepare_init in req.prepare_inits() { + if !seen_report_ids.insert(*prepare_init.report_share().metadata().id()) { return Err(Error::UnrecognizedMessage( Some(*task.id()), "aggregate request contains duplicate report IDs", @@ -1589,53 +1603,71 @@ impl VdafOps { let mut interval_per_batch_identifier: HashMap = HashMap::new(); let agg_param = A::AggregationParam::get_decoded(req.aggregation_parameter())?; - for (ord, report_share) in req.report_shares().iter().enumerate() { + + let mut accumulator = Accumulator::::new( + Arc::clone(&task), + batch_aggregation_shard_count, + agg_param.clone(), + ); + + for (ord, prepare_init) in req.prepare_inits().iter().enumerate() { // Compute intervals for each batch identifier included in this aggregation job. let batch_identifier = Q::to_batch_identifier( &task, req.batch_selector().batch_identifier(), - report_share.metadata().time(), + prepare_init.report_share().metadata().time(), )?; match interval_per_batch_identifier.entry(batch_identifier) { Entry::Occupied(mut entry) => { - *entry.get_mut() = entry.get().merged_with(report_share.metadata().time())?; + *entry.get_mut() = entry + .get() + .merged_with(prepare_init.report_share().metadata().time())?; } Entry::Vacant(entry) => { - entry.insert(Interval::from_time(report_share.metadata().time())?); + entry.insert(Interval::from_time( + prepare_init.report_share().metadata().time(), + )?); } } - let task_hpke_keypair = task - .hpke_keys() - .get(report_share.encrypted_input_share().config_id()); - - let global_hpke_keypair = - global_hpke_keypairs.keypair(report_share.encrypted_input_share().config_id()); - // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) let try_hpke_open = |hpke_keypair: &HpkeKeypair| { hpke::open( hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - report_share.encrypted_input_share(), + prepare_init.report_share().encrypted_input_share(), &InputShareAad::new( *task.id(), - report_share.metadata().clone(), - report_share.public_share().to_vec(), + prepare_init.report_share().metadata().clone(), + prepare_init.report_share().public_share().to_vec(), ) .get_encoded(), ) }; + let global_hpke_keypair = global_hpke_keypairs.keypair( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), + ); + + let task_hpke_keypair = task.hpke_keys().get( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), + ); + let check_keypairs = if task_hpke_keypair.is_none() && global_hpke_keypair.is_none() { info!( - config_id = %report_share.encrypted_input_share().config_id(), + config_id = %prepare_init.report_share().encrypted_input_share().config_id(), "Helper encrypted input share references unknown HPKE config ID" ); aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "unknown_hpke_config_id")]); - Err(ReportShareError::HpkeUnknownConfigId) + Err(PrepareError::HpkeUnknownConfigId) } else { Ok(()) }; @@ -1657,21 +1689,21 @@ impl VdafOps { .map_err(|error| { info!( task_id = %task.id(), - metadata = ?report_share.metadata(), + metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decrypt helper's report share" ); aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "decrypt_failure")]); - ReportShareError::HpkeDecryptError + PrepareError::HpkeDecryptError }) }); let plaintext_input_share = plaintext.and_then(|plaintext| { let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext).map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's plaintext input share"); + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decode helper's plaintext input share"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "plaintext_input_share_decode_failure")]); - ReportShareError::UnrecognizedMessage + PrepareError::UnrecognizedMessage })?; // Check for repeated extensions. let mut extension_types = HashSet::new(); @@ -1679,9 +1711,9 @@ impl VdafOps { .extensions() .iter() .all(|extension| extension_types.insert(extension.extension_type())) { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), "Received report share with duplicate extensions"); + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), "Received report share with duplicate extensions"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "duplicate_extension")]); - return Err(ReportShareError::UnrecognizedMessage) + return Err(PrepareError::UnrecognizedMessage) } Ok(plaintext_input_share) }); @@ -1689,16 +1721,16 @@ impl VdafOps { let input_share = plaintext_input_share.and_then(|plaintext_input_share| { A::InputShare::get_decoded_with_param(&(vdaf, Role::Helper.index().unwrap()), plaintext_input_share.payload()) .map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's input share"); + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decode helper's input share"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "input_share_decode_failure")]); - ReportShareError::UnrecognizedMessage + PrepareError::UnrecognizedMessage }) }); - let public_share = A::PublicShare::get_decoded_with_param(vdaf, report_share.public_share()).map_err(|error|{ - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode public share"); + let public_share = A::PublicShare::get_decoded_with_param(vdaf, prepare_init.report_share().public_share()).map_err(|error|{ + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decode public share"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "public_share_decode_failure")]); - ReportShareError::UnrecognizedMessage + PrepareError::UnrecognizedMessage }); let shares = input_share.and_then(|input_share| Ok((public_share?, input_share))); @@ -1707,88 +1739,98 @@ impl VdafOps { // associated with the task and computes the first state transition. [...] If either // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) let init_rslt = shares.and_then(|(public_share, input_share)| { - trace_span!("VDAF preparation") - .in_scope(|| { - vdaf.prepare_init( - verify_key.as_bytes(), - Role::Helper.index().unwrap(), - &agg_param, - report_share.metadata().id().as_ref(), - &public_share, - &input_share, - ) - }) + trace_span!("VDAF preparation").in_scope(|| { + vdaf.helper_initialized( + verify_key.as_bytes(), + &agg_param, + /* report ID is used as VDAF nonce */ + prepare_init.report_share().metadata().id().as_ref(), + &public_share, + &input_share, + prepare_init.message(), + ) + .and_then(|transition| transition.evaluate(vdaf)) .map_err(|error| { - info!( - task_id = %task.id(), - report_id = %report_share.metadata().id(), - ?error, - "Couldn't prepare_init report share" - ); - aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_init_failure")]); - ReportShareError::VdafPrepError + handle_ping_pong_error( + task.id(), + Role::Helper, + prepare_init.report_share().metadata().id(), + error, + aggregate_step_failure_counter, + ) }) + }) }); - report_share_data.push(match init_rslt { - Ok((prep_state, prep_share)) => { + let (report_aggregation_state, prepare_step_result) = match init_rslt { + Ok((PingPongState::Continued(prep_state), outgoing_message)) => { + // Helper is not finished. Await the next message from the Leader to advance to + // the next round. saw_continue = true; - - let encoded_prep_share = prep_share.get_encoded(); - ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - Some(PrepareStep::new( - *report_share.metadata().id(), - PrepareStepResult::Continued(encoded_prep_share), - )), - ReportAggregationState::::Waiting(prep_state, None), - ), + ( + ReportAggregationState::WaitingHelper(prep_state), + PrepareStepResult::Continue { + message: outgoing_message, + }, ) } + Ok((PingPongState::Finished(output_share), outgoing_message)) => { + // Helper finished. Unlike the Leader, the Helper does not wait for confirmation + // that the Leader finished before accumulating its output share. + accumulator.update( + req.batch_selector().batch_identifier(), + prepare_init.report_share().metadata().id(), + prepare_init.report_share().metadata().time(), + &output_share, + )?; + ( + ReportAggregationState::Finished, + PrepareStepResult::Continue { + message: outgoing_message, + }, + ) + } + Err(prepare_error) => ( + ReportAggregationState::Failed(prepare_error), + PrepareStepResult::Reject(prepare_error), + ), + }; - Err(err) => ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - Some(PrepareStep::new( - *report_share.metadata().id(), - PrepareStepResult::Failed(err), - )), - ReportAggregationState::::Failed(err), - ), + report_share_data.push(ReportShareData::new( + prepare_init.report_share().clone(), + ReportAggregation::::new( + *task.id(), + *aggregation_job_id, + *prepare_init.report_share().metadata().id(), + *prepare_init.report_share().metadata().time(), + ord.try_into()?, + Some(PrepareResp::new( + *prepare_init.report_share().metadata().id(), + prepare_step_result, + )), + report_aggregation_state, ), - }); + )); } // Store data to datastore. let req = Arc::new(req); let min_client_timestamp = req - .report_shares() + .prepare_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|prepare_init| *prepare_init.report_share().metadata().time()) .min() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let max_client_timestamp = req - .report_shares() + .prepare_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|prepare_init| *prepare_init.report_share().metadata().time()) .max() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let client_timestamp_interval = Interval::new( - *min_client_timestamp, + min_client_timestamp, max_client_timestamp - .difference(min_client_timestamp)? + .difference(&min_client_timestamp)? .add(&Duration::from_seconds(1))?, )?; let aggregation_job = Arc::new( @@ -1808,17 +1850,21 @@ impl VdafOps { .with_last_request_hash(request_hash), ); let interval_per_batch_identifier = Arc::new(interval_per_batch_identifier); + let accumulator = Arc::new(accumulator); Ok(datastore - .run_tx_with_name("aggregate_init", |tx| { + .run_tx_with_name("aggregate_init", |tx| { let vdaf = vdaf.clone(); let task = Arc::clone(&task); let req = Arc::clone(&req); let aggregation_job = Arc::clone(&aggregation_job); let mut report_share_data = report_share_data.clone(); let interval_per_batch_identifier = Arc::clone(&interval_per_batch_identifier); + let accumulator = Arc::clone(&accumulator); Box::pin(async move { + let unwritable_reports = accumulator.flush_to_datastore(tx, &vdaf).await?; + for report_share_data in &mut report_share_data { // Verify that we haven't seen this report ID and aggregation parameter // before in another aggregation job, and that the report isn't for a batch @@ -1843,27 +1889,26 @@ impl VdafOps { )?; if report_aggregation_exists { - report_share_data.report_aggregation = report_share_data - .report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportReplayed, - )) - .with_last_prep_step(Some(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed), - ))); - } else if !conflicting_aggregate_share_jobs.is_empty() { - report_share_data.report_aggregation = report_share_data - .report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::BatchCollected, - )) - .with_last_prep_step(Some(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), - ))); + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + PrepareError::ReportReplayed)) + .with_last_prep_resp(Some(PrepareResp::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Reject(PrepareError::ReportReplayed)) + )); + } else if !conflicting_aggregate_share_jobs.is_empty() || + unwritable_reports.contains(report_share_data.report_aggregation.report_id()) { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + PrepareError::BatchCollected)) + .with_last_prep_resp(Some(PrepareResp::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Reject(PrepareError::BatchCollected)) + )); } } @@ -1912,12 +1957,12 @@ impl VdafOps { .report_aggregation .clone() .with_state(ReportAggregationState::Failed( - ReportShareError::ReportReplayed, + PrepareError::ReportReplayed, )) - .with_last_prep_step(Some(PrepareStep::new( + .with_last_prep_resp(Some(PrepareResp::new( *rsd.report_share.metadata().id(), - PrepareStepResult::Failed( - ReportShareError::ReportReplayed, + PrepareStepResult::Reject( + PrepareError::ReportReplayed, ), ))); } @@ -1932,7 +1977,11 @@ impl VdafOps { let task = Arc::clone(&task); let aggregation_job = Arc::clone(&aggregation_job); async move { - match tx.get_batch::(task.id(), batch_identifier, aggregation_job.aggregation_parameter()).await? { + match tx.get_batch::( + task.id(), + batch_identifier, + aggregation_job.aggregation_parameter(), + ).await? { Some(batch) => { let interval = batch.client_timestamp_interval().merge(interval)?; tx.update_batch(&batch.with_client_timestamp_interval(interval)).await?; @@ -2019,7 +2068,7 @@ impl VdafOps { &Role::Helper, task.id(), &aggregation_job_id, - ), + ) )?; let helper_aggregation_job = helper_aggregation_job.ok_or_else(|| { diff --git a/aggregator/src/aggregator/accumulator.rs b/aggregator/src/aggregator/accumulator.rs index 0bdf2b9ba..374185c1f 100644 --- a/aggregator/src/aggregator/accumulator.rs +++ b/aggregator/src/aggregator/accumulator.rs @@ -40,7 +40,7 @@ pub struct Accumulator< aggregations: HashMap>, } -#[derive(Debug)] +#[derive(Clone, Debug)] struct BatchData< const SEED_SIZE: usize, Q: AccumulableQueryType, diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 59ffefdbe..303263e41 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -1,6 +1,9 @@ use crate::aggregator::{ - http_handlers::aggregator_handler, tests::generate_helper_report_share, Config, + http_handlers::{aggregator_handler, test_util::decode_response_body}, + tests::generate_helper_report_share, + Config, }; +use assert_matches::assert_matches; use janus_aggregator_core::{ datastore::{ test_util::{ephemeral_datastore, EphemeralDatastore}, @@ -15,92 +18,156 @@ use janus_core::{ time::{Clock, MockClock, TimeExt as _}, }; use janus_messages::{ - query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, PartialBatchSelector, - ReportMetadata, ReportShare, Role, + query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, + PartialBatchSelector, PrepareInit, PrepareStepResult, ReportMetadata, Role, +}; +use prio::{ + codec::Encode, + idpf::IdpfInput, + vdaf::{ + self, + poplar1::{Poplar1, Poplar1AggregationParam}, + xof::XofShake128, + }, }; -use prio::codec::Encode; use rand::random; use std::sync::Arc; use trillium::{Handler, KnownHeaderName, Status}; use trillium_testing::{prelude::put, TestConn}; -pub(super) struct ReportShareGenerator { +pub(super) struct PrepareInitGenerator +where + V: vdaf::Vdaf, +{ clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, - vdaf: dummy_vdaf::Vdaf, + vdaf: V, + aggregation_param: V::AggregationParam, } -impl ReportShareGenerator { +impl PrepareInitGenerator +where + V: vdaf::Vdaf + vdaf::Aggregator + vdaf::Client<16>, +{ pub(super) fn new( clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, + vdaf: V, + aggregation_param: V::AggregationParam, ) -> Self { Self { clock, task, + vdaf, aggregation_param, - vdaf: dummy_vdaf::Vdaf::new(), } } - fn with_vdaf(mut self, vdaf: dummy_vdaf::Vdaf) -> Self { - self.vdaf = vdaf; - self - } - - pub(super) fn next(&self) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { - self.next_with_metadata(ReportMetadata::new( - random(), - self.clock - .now() - .to_batch_interval_start(self.task.time_precision()) - .unwrap(), - )) + pub(super) fn next( + &self, + measurement: &V::Measurement, + ) -> (PrepareInit, VdafTranscript) { + self.next_with_metadata( + ReportMetadata::new( + random(), + self.clock + .now() + .to_batch_interval_start(self.task.time_precision()) + .unwrap(), + ), + measurement, + ) } pub(super) fn next_with_metadata( &self, report_metadata: ReportMetadata, - ) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { + measurement: &V::Measurement, + ) -> (PrepareInit, VdafTranscript) { let transcript = run_vdaf( &self.vdaf, self.task.primary_vdaf_verify_key().unwrap().as_bytes(), &self.aggregation_param, report_metadata.id(), - &(), + measurement, ); - let report_share = generate_helper_report_share::( + let report_share = generate_helper_report_share::( *self.task.id(), report_metadata, self.task.current_hpke_key().config(), &transcript.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript.helper_input_share, ); - - (report_share, transcript) + ( + PrepareInit::new( + report_share, + transcript.leader_prepare_transitions[0].message.clone(), + ), + transcript, + ) } } -pub(super) struct AggregationJobInitTestCase { +pub(super) struct AggregationJobInitTestCase< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator, +> { pub(super) clock: MockClock, pub(super) task: Task, - pub(super) report_share_generator: ReportShareGenerator, - pub(super) report_shares: Vec, + pub(super) prepare_init_generator: PrepareInitGenerator, + pub(super) prepare_inits: Vec, pub(super) aggregation_job_id: AggregationJobId, aggregation_job_init_req: AggregationJobInitializeReq, - pub(super) aggregation_param: dummy_vdaf::AggregationParam, + aggregation_job_init_resp: Option, + pub(super) aggregation_param: V::AggregationParam, pub(super) handler: Box, pub(super) datastore: Arc>, _ephemeral_datastore: EphemeralDatastore, } -pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { - let test_case = setup_aggregate_init_test_without_sending_request().await; +pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy_vdaf::Vdaf> { + setup_aggregate_init_test_for_vdaf( + dummy_vdaf::Vdaf::new(), + VdafInstance::Fake, + dummy_vdaf::AggregationParam(0), + (), + ) + .await +} + +async fn setup_poplar1_aggregate_init_test( +) -> AggregationJobInitTestCase<16, Poplar1> { + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + setup_aggregate_init_test_for_vdaf( + Poplar1::new_shake128(1), + VdafInstance::Poplar1 { bits: 1 }, + aggregation_param, + IdpfInput::from_bools(&[true]), + ) + .await +} - let response = put_aggregation_job( +async fn setup_aggregate_init_test_for_vdaf< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator + vdaf::Client<16>, +>( + vdaf: V, + vdaf_instance: VdafInstance, + aggregation_param: V::AggregationParam, + measurement: V::Measurement, +) -> AggregationJobInitTestCase { + let mut test_case = setup_aggregate_init_test_without_sending_request( + vdaf, + vdaf_instance, + aggregation_param, + measurement, + ) + .await; + + let mut response = put_aggregation_job( &test_case.task, &test_case.aggregation_job_id, &test_case.aggregation_job_init_req, @@ -109,13 +176,32 @@ pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { .await; assert_eq!(response.status(), Some(Status::Ok)); + let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await; + assert_eq!( + aggregation_job_init_resp.prepare_resps().len(), + test_case.aggregation_job_init_req.prepare_inits().len(), + ); + assert_matches!( + aggregation_job_init_resp.prepare_resps()[0].result(), + &PrepareStepResult::Continue { .. } + ); + + test_case.aggregation_job_init_resp = Some(aggregation_job_init_resp); test_case } -async fn setup_aggregate_init_test_without_sending_request() -> AggregationJobInitTestCase { +async fn setup_aggregate_init_test_without_sending_request< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator + vdaf::Client<16>, +>( + vdaf: V, + vdaf_instance: VdafInstance, + aggregation_param: V::AggregationParam, + measurement: V::Measurement, +) -> AggregationJobInitTestCase { install_test_trace_subscriber(); - let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper).build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); @@ -131,30 +217,29 @@ async fn setup_aggregate_init_test_without_sending_request() -> AggregationJobIn .await .unwrap(); - let aggregation_param = dummy_vdaf::AggregationParam(0); - - let report_share_generator = - ReportShareGenerator::new(clock.clone(), task.clone(), aggregation_param); + let prepare_init_generator = + PrepareInitGenerator::new(clock.clone(), task.clone(), vdaf, aggregation_param.clone()); - let report_shares = Vec::from([ - report_share_generator.next().0, - report_share_generator.next().0, + let prepare_inits = Vec::from([ + prepare_init_generator.next(&measurement).0, + prepare_init_generator.next(&measurement).0, ]); let aggregation_job_id = random(); let aggregation_job_init_req = AggregationJobInitializeReq::new( aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - report_shares.clone(), + prepare_inits.clone(), ); AggregationJobInitTestCase { clock, task, - report_shares, - report_share_generator, + prepare_inits, + prepare_init_generator, aggregation_job_id, aggregation_job_init_req, + aggregation_job_init_resp: None, aggregation_param, handler: Box::new(handler), datastore, @@ -184,7 +269,13 @@ pub(crate) async fn put_aggregation_job( #[tokio::test] async fn aggregation_job_init_authorization_dap_auth_token() { - let test_case = setup_aggregate_init_test_without_sending_request().await; + let test_case = setup_aggregate_init_test_without_sending_request( + dummy_vdaf::Vdaf::new(), + VdafInstance::Fake, + dummy_vdaf::AggregationParam(0), + (), + ) + .await; // Find a DapAuthToken among the task's aggregator auth tokens let (auth_header, auth_value) = test_case .task @@ -216,7 +307,13 @@ async fn aggregation_job_init_authorization_dap_auth_token() { #[case::not_base64("Bearer: ")] #[tokio::test] async fn aggregation_job_init_malformed_authorization_header(#[case] header_value: &'static str) { - let test_case = setup_aggregate_init_test_without_sending_request().await; + let test_case = setup_aggregate_init_test_without_sending_request( + dummy_vdaf::Vdaf::new(), + VdafInstance::Fake, + dummy_vdaf::AggregationParam(0), + (), + ) + .await; let response = put(test_case .task @@ -254,7 +351,7 @@ async fn aggregation_job_mutation_aggregation_job() { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(1).get_encoded(), PartialBatchSelector::new_time_interval(), - test_case.report_shares, + test_case.prepare_inits, ); let response = put_aggregation_job( @@ -273,28 +370,28 @@ async fn aggregation_job_mutation_report_shares() { // Put the aggregation job again, mutating the associated report shares' metadata such that // uniqueness constraints on client_reports are violated - for mutated_report_shares in [ + for mutated_prepare_inits in [ // Omit a report share that was included previously - Vec::from(&test_case.report_shares[0..test_case.report_shares.len() - 1]), + Vec::from(&test_case.prepare_inits[0..test_case.prepare_inits.len() - 1]), // Include a different report share than was included previously [ - &test_case.report_shares[0..test_case.report_shares.len() - 1], - &[test_case.report_share_generator.next().0], + &test_case.prepare_inits[0..test_case.prepare_inits.len() - 1], + &[test_case.prepare_init_generator.next(&()).0], ] .concat(), // Include an extra report share than was included previously [ - test_case.report_shares.as_slice(), - &[test_case.report_share_generator.next().0], + test_case.prepare_inits.as_slice(), + &[test_case.prepare_init_generator.next(&()).0], ] .concat(), // Reverse the order of the reports - test_case.report_shares.into_iter().rev().collect(), + test_case.prepare_inits.into_iter().rev().collect(), ] { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( test_case.aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - mutated_report_shares, + mutated_prepare_inits, ); let response = put_aggregation_job( &test_case.task, @@ -309,20 +406,22 @@ async fn aggregation_job_mutation_report_shares() { #[tokio::test] async fn aggregation_job_mutation_report_aggregations() { - let test_case = setup_aggregate_init_test().await; - - // Generate some new reports using the existing reports' metadata, but varying the input shares - // such that the prepare state computed during aggregation initializaton won't match the first - // aggregation job. - let mutated_report_shares_generator = test_case - .report_share_generator - .with_vdaf(dummy_vdaf::Vdaf::new().with_input_share(dummy_vdaf::InputShare(1))); - let mutated_report_shares = test_case - .report_shares + // We must run Poplar1 in this test so that the aggregation job won't finish on the first step + let test_case = setup_poplar1_aggregate_init_test().await; + + // Generate some new reports using the existing reports' metadata, but varying the measurement + // values such that the prepare state computed during aggregation initializaton won't match the + // first aggregation job. + let mutated_prepare_inits = test_case + .prepare_inits .iter() .map(|s| { - mutated_report_shares_generator - .next_with_metadata(s.metadata().clone()) + test_case + .prepare_init_generator + .next_with_metadata( + s.report_share().metadata().clone(), + &IdpfInput::from_bools(&[false]), + ) .0 }) .collect(); @@ -330,8 +429,9 @@ async fn aggregation_job_mutation_report_aggregations() { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( test_case.aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - mutated_report_shares, + mutated_prepare_inits, ); + let response = put_aggregation_job( &test_case.task, &test_case.aggregation_job_id, @@ -341,3 +441,24 @@ async fn aggregation_job_mutation_report_aggregations() { .await; assert_eq!(response.status(), Some(Status::Conflict)); } + +#[tokio::test] +async fn aggregation_job_init_two_round_vdaf_idempotence() { + // We must run Poplar1 in this test so that the aggregation job won't finish on the first step + let test_case = setup_poplar1_aggregate_init_test().await; + + // Send the aggregation job init request again. We should get an identical response back. + let mut response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &test_case.aggregation_job_init_req, + &test_case.handler, + ) + .await; + + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + assert_eq!( + aggregation_job_resp, + test_case.aggregation_job_init_resp.unwrap() + ); +} diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 31bbb066b..0b7e5c3b9 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -1,5 +1,6 @@ //! Implements portions of aggregation job continuation for the helper. +use super::error::handle_ping_pong_error; use crate::aggregator::{accumulator::Accumulator, Error, VdafOps}; use futures::future::try_join_all; use janus_aggregator_core::{ @@ -13,16 +14,18 @@ use janus_aggregator_core::{ }; use janus_core::time::Clock; use janus_messages::{ - AggregationJobContinueReq, AggregationJobResp, PrepareStep, PrepareStepResult, ReportShareError, + AggregationJobContinueReq, AggregationJobResp, PrepareError, PrepareResp, PrepareStepResult, + Role, }; -use opentelemetry::{metrics::Counter, KeyValue}; +use opentelemetry::metrics::Counter; use prio::{ codec::{Encode, ParameterizedDecode}, - vdaf::{self, PrepareTransition}, + topology::ping_pong::{PingPongContinuedValue, PingPongState, PingPongTopology}, + vdaf, }; -use std::{io::Cursor, sync::Arc}; +use std::sync::Arc; use tokio::try_join; -use tracing::{info, trace_span}; +use tracing::trace_span; impl VdafOps { /// Step the helper's aggregation job to the next round of VDAF preparation using the round `n` @@ -69,13 +72,11 @@ impl VdafOps { if report_agg.report_id() != prep_step.report_id() { // This report was omitted by the leader because of a prior failure. Note that // the report was dropped (if it's not already in an error state) and continue. - if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { + if matches!(report_agg.state(), ReportAggregationState::WaitingHelper(_)) { *report_agg = report_agg .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportDropped, - )) - .with_last_prep_step(None); + .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)) + .with_last_prep_resp(None); } continue; } @@ -94,108 +95,120 @@ impl VdafOps { if !conflicting_aggregate_share_jobs.is_empty() { *report_aggregation = report_aggregation .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::BatchCollected, - )) - .with_last_prep_step(Some(PrepareStep::new( + .with_state(ReportAggregationState::Failed(PrepareError::BatchCollected)) + .with_last_prep_resp(Some(PrepareResp::new( *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), + PrepareStepResult::Reject(PrepareError::BatchCollected), ))); continue; } let prep_state = match report_aggregation.state() { - ReportAggregationState::Waiting(prep_state, _) => prep_state, - _ => { + ReportAggregationState::WaitingHelper(prep_state) => prep_state, + ReportAggregationState::WaitingLeader(_) => { return Err(datastore::Error::User( - Error::UnrecognizedMessage( - Some(*task.id()), - "leader sent prepare step for non-WAITING report aggregation", + Error::Internal( + "helper encountered unexpected ReportAggregationState::WaitingLeader" + .to_string(), ) .into(), - )); + )) } - }; - - // Parse preparation message out of prepare step received from leader. - let prep_msg = match prep_step.result() { - PrepareStepResult::Continued(payload) => A::PrepareMessage::decode_with_param( - prep_state, - &mut Cursor::new(payload.as_ref()), - )?, _ => { return Err(datastore::Error::User( Error::UnrecognizedMessage( Some(*task.id()), - "leader sent non-Continued prepare step", + "leader sent prepare step for non-WAITING report aggregation", ) .into(), )); } }; - // Compute the next transition. - let prepare_step_res = trace_span!("VDAF preparation") - .in_scope(|| vdaf.prepare_next(prep_state.clone(), prep_msg)); - match prepare_step_res { - Ok(PrepareTransition::Continue(prep_state, prep_share)) => { - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Waiting(prep_state, None)) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Continued(prep_share.get_encoded()), - ))); - } - - Ok(PrepareTransition::Finish(output_share)) => { - accumulator.update( - helper_aggregation_job.partial_batch_identifier(), - prep_step.report_id(), - report_aggregation.time(), - &output_share, - )?; - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Finished) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Finished, - ))); - } - - Err(error) => { - info!( - task_id = %task.id(), - job_id = %helper_aggregation_job.id(), - report_id = %prep_step.report_id(), - ?error, "Prepare step failed", - ); - aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_step_failure")]); - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::VdafPrepError, - )) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError), - ))) - } - }; + let (report_aggregation_state, prepare_step_result, output_share) = + trace_span!("VDAF preparation") + .in_scope(|| { + // Continue with the incoming message. + vdaf.helper_continued( + PingPongState::Continued(prep_state.clone()), + helper_aggregation_job.aggregation_parameter(), + prep_step.message(), + ) + .and_then( + |continued_value| match continued_value { + PingPongContinuedValue::WithMessage { transition } => { + let (new_state, message) = + transition.evaluate(vdaf.as_ref())?; + let (report_aggregation_state, output_share) = match new_state { + // Helper did not finish. Store the new state and await the + // next message from the Leader to advance preparation. + PingPongState::Continued(prep_state) => ( + ReportAggregationState::WaitingHelper(prep_state), + None, + ), + // Helper finished. Commit the output share. + PingPongState::Finished(output_share) => { + (ReportAggregationState::Finished, Some(output_share)) + } + }; + + Ok(( + report_aggregation_state, + // Helper has an outgoing message for Leader + PrepareStepResult::Continue { message }, + output_share, + )) + } + PingPongContinuedValue::FinishedNoMessage { output_share } => Ok(( + ReportAggregationState::Finished, + PrepareStepResult::Finished, + Some(output_share), + )), + }, + ) + }) + .map_err(|error| { + handle_ping_pong_error( + task.id(), + Role::Leader, + prep_step.report_id(), + error, + &aggregate_step_failure_counter, + ) + }) + .unwrap_or_else(|prepare_error| { + ( + ReportAggregationState::Failed(prepare_error), + PrepareStepResult::Reject(prepare_error), + None, + ) + }); + + *report_aggregation = report_aggregation + .clone() + .with_state(report_aggregation_state) + .with_last_prep_resp(Some(PrepareResp::new( + *prep_step.report_id(), + prepare_step_result, + ))); + if let Some(output_share) = output_share { + accumulator.update( + helper_aggregation_job.partial_batch_identifier(), + prep_step.report_id(), + report_aggregation.time(), + &output_share, + )?; + } } for report_agg in report_aggregations_iter { // This report was omitted by the leader because of a prior failure. Note that the // report was dropped (if it's not already in an error state) and continue. - if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { + if matches!(report_agg.state(), ReportAggregationState::WaitingHelper(_)) { *report_agg = report_agg .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportDropped, - )) - .with_last_prep_step(None); + .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)) + .with_last_prep_resp(None); } } @@ -206,26 +219,24 @@ impl VdafOps { if unwritable_reports.contains(report_aggregation.report_id()) { *report_aggregation = report_aggregation .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::BatchCollected, - )) - .with_last_prep_step(Some(PrepareStep::new( + .with_state(ReportAggregationState::Failed(PrepareError::BatchCollected)) + .with_last_prep_resp(Some(PrepareResp::new( *report_aggregation.report_id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), + PrepareStepResult::Reject(PrepareError::BatchCollected), ))); } } let saw_continue = report_aggregations.iter().any(|report_agg| { matches!( - report_agg.last_prep_step().map(PrepareStep::result), - Some(PrepareStepResult::Continued(_)) + report_agg.last_prep_resp().map(PrepareResp::result), + Some(PrepareStepResult::Continue { .. }) ) }); let saw_finish = report_aggregations.iter().any(|report_agg| { matches!( - report_agg.last_prep_step().map(PrepareStep::result), - Some(PrepareStepResult::Finished) + report_agg.last_prep_resp().map(PrepareResp::result), + Some(PrepareStepResult::Finished { .. }) ) }); let helper_aggregation_job = helper_aggregation_job @@ -252,7 +263,7 @@ impl VdafOps { try_join_all( report_aggregations .iter() - .map(|ra| tx.update_report_aggregation(ra)) + .map(|report_agg| tx.update_report_aggregation(report_agg)), ), )?; @@ -260,16 +271,16 @@ impl VdafOps { } /// Constructs an AggregationJobResp from a given set of Helper report aggregations. - pub(super) fn aggregation_job_resp_for< - const SEED_SIZE: usize, - A: vdaf::Aggregator, - >( + pub(super) fn aggregation_job_resp_for( report_aggregations: impl IntoIterator>, - ) -> AggregationJobResp { + ) -> AggregationJobResp + where + A: vdaf::Aggregator, + { AggregationJobResp::new( report_aggregations .into_iter() - .filter_map(|ra| ra.last_prep_step().cloned()) + .filter_map(|v| v.last_prep_resp().cloned()) .collect(), ) } @@ -366,7 +377,7 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - aggregate_init_tests::ReportShareGenerator, + aggregate_init_tests::PrepareInitGenerator, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, post_aggregation_job_expecting_status, @@ -386,23 +397,33 @@ mod tests { test_util::noop_meter, }; use janus_core::{ - task::VdafInstance, - test_util::{dummy_vdaf, install_test_trace_subscriber}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, + test_util::install_test_trace_subscriber, time::{IntervalExt, MockClock}, }; use janus_messages::{ query_type::TimeInterval, AggregationJobContinueReq, AggregationJobId, AggregationJobResp, - AggregationJobRound, Interval, PrepareStep, PrepareStepResult, Role, + AggregationJobRound, Interval, PrepareContinue, PrepareResp, PrepareStepResult, Role, + }; + use prio::{ + idpf::IdpfInput, + vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, + xof::XofShake128, + Aggregator, + }, }; - use prio::codec::Encode; use rand::random; use std::sync::Arc; use trillium::{Handler, Status}; - struct AggregationJobContinueTestCase { + struct AggregationJobContinueTestCase< + const VERIFY_KEY_LENGTH: usize, + V: Aggregator, + > { task: Task, datastore: Arc>, - report_generator: ReportShareGenerator, + prepare_init_generator: PrepareInitGenerator, aggregation_job_id: AggregationJobId, first_continue_request: AggregationJobContinueReq, first_continue_response: Option, @@ -412,61 +433,86 @@ mod tests { /// Set up a helper with an aggregation job in round 0 #[allow(clippy::unit_arg)] - async fn setup_aggregation_job_continue_test() -> AggregationJobContinueTestCase { + async fn setup_aggregation_job_continue_test( + ) -> AggregationJobContinueTestCase> { // Prepare datastore & request. install_test_trace_subscriber(); let aggregation_job_id = random(); - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let meter = noop_meter(); let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let report_generator = ReportShareGenerator::new( + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let prepare_init_generator = PrepareInitGenerator::new( clock.clone(), task.clone(), - dummy_vdaf::AggregationParam::default(), + Poplar1::new_shake128(1), + aggregation_param.clone(), ); - let report = report_generator.next(); + let (prepare_init, transcript) = + prepare_init_generator.next(&IdpfInput::from_bools(&[true])); datastore .run_tx(|tx| { - let (task, report) = (task.clone(), report.clone()); + let (task, aggregation_param, prepare_init, transcript) = ( + task.clone(), + aggregation_param.clone(), + prepare_init.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await.unwrap(); - tx.put_report_share(task.id(), &report.0).await.unwrap(); + tx.put_report_share(task.id(), prepare_init.report_share()) + .await + .unwrap(); - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::from_time(prepare_init.report_share().metadata().time()).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation::>( + &ReportAggregation::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam::default(), - (), - Interval::from_time(report.0.metadata().time()).unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *prepare_init.report_share().metadata().id(), + *prepare_init.report_share().metadata().time(), + 0, + None, + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), ), ) .await .unwrap(); - let (prep_state, _) = report.1.helper_prep_state(0); - tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( - *task.id(), - aggregation_job_id, - *report.0.metadata().id(), - *report.0.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(*prep_state, None), - )) - .await - .unwrap(); - Ok(()) }) }) @@ -475,9 +521,9 @@ mod tests { let first_continue_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( - *report.0.metadata().id(), - PrepareStepResult::Continued(report.1.prepare_messages[0].get_encoded()), + Vec::from([PrepareContinue::new( + *prepare_init.report_share().metadata().id(), + transcript.leader_prepare_transitions[1].message.clone(), )]), ); @@ -494,7 +540,7 @@ mod tests { AggregationJobContinueTestCase { task, datastore, - report_generator, + prepare_init_generator, aggregation_job_id, first_continue_request, first_continue_response: None, @@ -503,10 +549,10 @@ mod tests { } } - /// Set up a helper with an aggregation job in round 1 + /// Set up a helper with an aggregation job in round 1. #[allow(clippy::unit_arg)] - async fn setup_aggregation_job_continue_round_recovery_test() -> AggregationJobContinueTestCase - { + async fn setup_aggregation_job_continue_round_recovery_test( + ) -> AggregationJobContinueTestCase> { let mut test_case = setup_aggregation_job_continue_test().await; let first_continue_response = post_aggregation_job_and_decode( @@ -525,7 +571,7 @@ mod tests { .first_continue_request .prepare_steps() .iter() - .map(|step| PrepareStep::new(*step.report_id(), PrepareStepResult::Finished)) + .map(|step| PrepareResp::new(*step.report_id(), PrepareStepResult::Finished)) .collect() ) ); @@ -581,23 +627,25 @@ mod tests { async fn aggregation_job_continue_round_recovery_mutate_continue_request() { let test_case = setup_aggregation_job_continue_round_recovery_test().await; - let unrelated_report = test_case.report_generator.next(); + let (unrelated_prepare_init, unrelated_transcript) = test_case + .prepare_init_generator + .next(&IdpfInput::from_bools(&[false])); let (before_aggregation_job, before_report_aggregations) = test_case .datastore .run_tx(|tx| { - let (task_id, unrelated_report, aggregation_job_id) = ( + let (task_id, unrelated_prepare_init, aggregation_job_id) = ( *test_case.task.id(), - unrelated_report.clone(), + unrelated_prepare_init.clone(), test_case.aggregation_job_id, ); Box::pin(async move { - tx.put_report_share(&task_id, &unrelated_report.0) + tx.put_report_share(&task_id, unrelated_prepare_init.report_share()) .await .unwrap(); let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -605,8 +653,8 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_shake128(1), &Role::Helper, &task_id, &aggregation_job_id, @@ -624,9 +672,11 @@ mod tests { // ID. let modified_request = AggregationJobContinueReq::new( test_case.first_continue_request.round(), - Vec::from([PrepareStep::new( - *unrelated_report.0.metadata().id(), - PrepareStepResult::Continued(unrelated_report.1.prepare_messages[0].get_encoded()), + Vec::from([PrepareContinue::new( + *unrelated_prepare_init.report_share().metadata().id(), + unrelated_transcript.leader_prepare_transitions[1] + .message + .clone(), )]), ); @@ -647,7 +697,7 @@ mod tests { (*test_case.task.id(), test_case.aggregation_job_id); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -655,8 +705,8 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_shake128(1), &Role::Helper, &task_id, &aggregation_job_id, @@ -689,7 +739,7 @@ mod tests { // round mismatch error instead of tripping the check for a request to continue // to round 0. let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index fe0fbb2db..28ecc5a2b 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -673,7 +673,7 @@ impl AggregationJobCreator { #[cfg(test)] mod tests { use super::AggregationJobCreator; - use futures::{future::try_join_all, TryFutureExt}; + use futures::future::try_join_all; use janus_aggregator_core::{ datastore::{ models::{AggregationJob, AggregationJobState, Batch, BatchState, LeaderStoredReport}, @@ -690,6 +690,7 @@ mod tests { time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; use janus_messages::{ + codec::ParameterizedDecode, query_type::{FixedSize, TimeInterval}, AggregationJobRound, Interval, ReportId, Role, TaskId, Time, }; @@ -775,13 +776,14 @@ mod tests { .run_tx(|tx| { let (leader_task, helper_task) = (leader_task.clone(), helper_task.clone()); Box::pin(async move { + let vdaf = Prio3Count::new_count(2).unwrap(); let (leader_aggregations, leader_batches) = read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, leader_task.id()) + >(tx, &vdaf, leader_task.id()) .await; let (helper_aggregations, helper_batches) = read_aggregate_info_for_task::< @@ -789,7 +791,7 @@ mod tests { TimeInterval, Prio3Count, _, - >(tx, helper_task.id()) + >(tx, &vdaf, helper_task.id()) .await; Ok(( leader_aggregations, @@ -888,23 +890,24 @@ mod tests { .unwrap(); // Verify. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = task.clone(); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, &Prio3Count::new_count(2).unwrap(), task.id()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); let mut seen_report_ids = HashSet::new(); for (agg_job, report_ids) in &agg_jobs { // Jobs are created in round 0 @@ -986,23 +989,24 @@ mod tests { .unwrap(); // Verify -- we haven't received enough reports yet, so we don't create anything. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, &Prio3Count::new_count(2).unwrap(), task.id()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert!(agg_jobs.is_empty()); assert!(batches.is_empty()); @@ -1026,23 +1030,24 @@ mod tests { .unwrap(); // Verify -- the additional report we wrote allows an aggregation job to be created. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, &Prio3Count::new_count(2).unwrap(), task.id()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert_eq!(agg_jobs.len(), 1); let report_ids: HashSet<_> = agg_jobs.into_iter().next().unwrap().1.into_iter().collect(); assert_eq!( @@ -1135,23 +1140,24 @@ mod tests { .unwrap(); // Verify. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = task.clone(); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, &Prio3Count::new_count(2).unwrap(), task.id()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); let mut seen_report_ids = HashSet::new(); for (agg_job, report_ids) in &agg_jobs { // Job immediately finished since all reports are in a closed batch. @@ -1271,7 +1277,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -1432,7 +1440,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -1548,7 +1558,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -1605,7 +1617,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -1729,7 +1743,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -1793,7 +1809,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -1932,7 +1950,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, &Prio3Count::new_count(2).unwrap(), task.id() + ) .await, )) }) @@ -2060,18 +2080,20 @@ mod tests { /// Test helper function that reads all aggregation jobs & batches for a given task ID, /// returning the aggregation jobs, the report IDs included in the aggregation job, and the /// batches. Report IDs are returned in the order they are included in the aggregation job. - async fn read_aggregate_info_for_task< - const SEED_SIZE: usize, - Q: AccumulableQueryType, - A: vdaf::Aggregator, - C: Clock, - >( + async fn read_aggregate_info_for_task( tx: &Transaction<'_, C>, + vdaf: &A, task_id: &TaskId, ) -> ( Vec<(AggregationJob, Vec)>, Vec>, - ) { + ) + where + Q: AccumulableQueryType, + A: vdaf::Aggregator, + C: Clock, + for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, + { try_join!( try_join_all( tx.get_aggregation_jobs_for_task(task_id) @@ -2081,18 +2103,18 @@ mod tests { .map(|agg_job| async { let agg_job_id = *agg_job.id(); tx.get_report_aggregations_for_aggregation_job( - &dummy_vdaf::Vdaf::new(), + vdaf, &Role::Leader, task_id, &agg_job_id, ) - .map_ok(move |report_aggs| { + .await + .map(|report_aggs| { ( agg_job, report_aggs.into_iter().map(|ra| *ra.report_id()).collect(), ) }) - .await }), ), tx.get_batches_for_task(task_id), diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 93bdbc056..3034add63 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -21,8 +21,8 @@ use janus_core::{time::Clock, vdaf_dispatch}; use janus_messages::{ query_type::{FixedSize, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, - PartialBatchSelector, PrepareStep, PrepareStepResult, ReportId, ReportShare, ReportShareError, - Role, + PartialBatchSelector, PrepareContinue, PrepareError, PrepareInit, PrepareResp, + PrepareStepResult, ReportId, ReportShare, Role, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter, Unit}, @@ -30,7 +30,8 @@ use opentelemetry::{ }; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, - vdaf::{self, PrepareTransition}, + topology::ping_pong::{PingPongContinuedValue, PingPongState, PingPongTopology}, + vdaf, }; use reqwest::Method; use std::{ @@ -41,6 +42,8 @@ use std::{ use tokio::try_join; use tracing::{info, trace_span, warn}; +use super::error::handle_ping_pong_error; + #[derive(Derivative)] #[derivative(Debug)] pub struct AggregationJobDriver { @@ -234,7 +237,12 @@ impl AggregationJobDriver { for report_aggregation in &report_aggregations { match report_aggregation.state() { ReportAggregationState::Start => saw_start = true, - ReportAggregationState::Waiting(_, _) => saw_waiting = true, + ReportAggregationState::WaitingLeader(_) => saw_waiting = true, + ReportAggregationState::WaitingHelper(_) => { + return Err(anyhow!( + "Leader encountered unexpected ReportAggregationState::WaitingHelper" + )); + } ReportAggregationState::Finished => saw_finished = true, ReportAggregationState::Failed(_) => (), // ignore failed aggregations } @@ -316,7 +324,7 @@ impl AggregationJobDriver { // Compute report shares to send to helper, and decrypt our input shares & initialize // preparation state. let mut report_aggregations_to_write = Vec::new(); - let mut report_shares = Vec::new(); + let mut prepare_inits = Vec::new(); let mut stepped_aggregations = Vec::new(); for report_aggregation in report_aggregations { // Look up report. @@ -326,9 +334,10 @@ impl AggregationJobDriver { info!(report_id = %report_aggregation.report_id(), "Attempted to aggregate missing report (most likely garbage collected)"); self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "missing_client_report")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::ReportDropped), - )); + report_aggregations_to_write.push( + report_aggregation + .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)), + ); continue; }; @@ -343,44 +352,52 @@ impl AggregationJobDriver { self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "duplicate_extension")]); report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), + ReportAggregationState::Failed(PrepareError::UnrecognizedMessage), )); continue; } // Initialize the leader's preparation state from the input share. - let prepare_init_res = trace_span!("VDAF preparation").in_scope(|| { - vdaf.prepare_init( + match trace_span!("VDAF preparation").in_scope(|| { + vdaf.leader_initialized( verify_key.as_bytes(), - Role::Leader.index().unwrap(), aggregation_job.aggregation_parameter(), + // DAP report ID is used as VDAF nonce report.metadata().id().as_ref(), report.public_share(), report.leader_input_share(), ) - }); - let (prep_state, prep_share) = match prepare_init_res { - Ok(prep_state_and_share) => prep_state_and_share, - Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't initialize leader's preparation state"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_init_failure")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + .map_err(|ping_pong_error| { + handle_ping_pong_error( + task.id(), + Role::Leader, + report.metadata().id(), + ping_pong_error, + &self.aggregate_step_failure_counter, + ) + }) + }) { + Ok((ping_pong_state, ping_pong_message)) => { + prepare_inits.push(PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + ping_pong_message, )); + stepped_aggregations.push(SteppedAggregation { + report_aggregation, + leader_state: ping_pong_state, + }); + } + Err(prep_error) => { + report_aggregations_to_write.push( + report_aggregation.with_state(ReportAggregationState::Failed(prep_error)), + ); continue; } - }; - - report_shares.push(ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), - )); - stepped_aggregations.push(SteppedAggregation { - report_aggregation, - leader_transition: PrepareTransition::Continue(prep_state, prep_share), - }); + } } // Construct request, send it to the helper, and process the response. @@ -389,7 +406,7 @@ impl AggregationJobDriver { let req = AggregationJobInitializeReq::::new( aggregation_job.aggregation_parameter().get_encoded(), PartialBatchSelector::new(aggregation_job.partial_batch_identifier().clone()), - report_shares, + prepare_inits, ); let resp_bytes = send_request_to_helper( @@ -411,9 +428,9 @@ impl AggregationJobDriver { lease, task, aggregation_job, - stepped_aggregations, + &stepped_aggregations, report_aggregations_to_write, - resp.prepare_steps(), + resp.prepare_resps(), ) .await } @@ -444,60 +461,43 @@ impl AggregationJobDriver { // Visit the report aggregations, ignoring any that have already failed; compute our own // next step & transitions to send to the helper. let mut report_aggregations_to_write = Vec::new(); - let mut prepare_steps = Vec::new(); + let mut prepare_continues = Vec::new(); let mut stepped_aggregations = Vec::new(); for report_aggregation in report_aggregations { - if let ReportAggregationState::Waiting(prep_state, prep_msg) = - report_aggregation.state() - { - let prep_msg = match prep_msg.as_ref() { - Some(prep_msg) => prep_msg, - None => { - // This error indicates programmer/system error (i.e. it cannot possibly be - // the fault of our co-aggregator). We still record this failure against a - // single report, rather than failing the entire request, to minimize impact - // if we ever encounter this bug. - info!(report_id = %report_aggregation.report_id(), "Report aggregation is missing prepare message"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "missing_prepare_message")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::VdafPrepError), - )); - continue; - } - }; - - // Step our own state. - let prepare_step_res = trace_span!("VDAF preparation") - .in_scope(|| vdaf.prepare_next(prep_state.clone(), prep_msg.clone())); - let leader_transition = match prepare_step_res { - Ok(leader_transition) => leader_transition, + if let ReportAggregationState::WaitingLeader(transition) = report_aggregation.state() { + let (prep_state, message) = match transition.evaluate(vdaf.as_ref()) { + Ok((state, message)) => (state, message), Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Prepare step failed"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_step_failure")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::VdafPrepError), - )); + let prepare_error = handle_ping_pong_error( + task.id(), + Role::Leader, + report_aggregation.report_id(), + error, + &self.aggregate_step_failure_counter, + ); + report_aggregations_to_write.push( + report_aggregation + .with_state(ReportAggregationState::Failed(prepare_error)), + ); continue; } }; - prepare_steps.push(PrepareStep::new( + prepare_continues.push(PrepareContinue::new( *report_aggregation.report_id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + message, )); stepped_aggregations.push(SteppedAggregation { - report_aggregation, - leader_transition, - }) + report_aggregation: report_aggregation.clone(), + leader_state: prep_state.clone(), + }); } } // Construct request, send it to the helper, and process the response. // TODO(#235): abandon work immediately on "terminal" failures from helper, or other // unexpected cases such as unknown/unexpected content type. - let req = AggregationJobContinueReq::new(aggregation_job.round(), prepare_steps); + let req = AggregationJobContinueReq::new(aggregation_job.round(), prepare_continues); let resp_bytes = send_request_to_helper( &self.http_client, @@ -518,9 +518,9 @@ impl AggregationJobDriver { lease, task, aggregation_job, - stepped_aggregations, + &stepped_aggregations, report_aggregations_to_write, - resp.prepare_steps(), + resp.prepare_resps(), ) .await } @@ -538,9 +538,9 @@ impl AggregationJobDriver { lease: Arc>, task: Arc, aggregation_job: AggregationJob, - stepped_aggregations: Vec>, + stepped_aggregations: &[SteppedAggregation], mut report_aggregations_to_write: Vec>, - helper_prep_steps: &[PrepareStep], + helper_prep_resps: &[PrepareResp], ) -> Result<()> where A: 'static, @@ -552,7 +552,7 @@ impl AggregationJobDriver { A::PrepareState: Send + Sync + Encode, { // Handle response, computing the new report aggregations to be stored. - if stepped_aggregations.len() != helper_prep_steps.len() { + if stepped_aggregations.len() != helper_prep_resps.len() { return Err(anyhow!( "missing, duplicate, out-of-order, or unexpected prepare steps in response" )); @@ -562,98 +562,126 @@ impl AggregationJobDriver { self.batch_aggregation_shard_count, aggregation_job.aggregation_parameter().clone(), ); - for (stepped_aggregation, helper_prep_step) in - stepped_aggregations.into_iter().zip(helper_prep_steps) + for (stepped_aggregation, helper_prep_resp) in + stepped_aggregations.iter().zip(helper_prep_resps) { - let (report_aggregation, leader_transition) = ( - stepped_aggregation.report_aggregation, - stepped_aggregation.leader_transition, - ); - if helper_prep_step.report_id() != report_aggregation.report_id() { + if helper_prep_resp.report_id() != stepped_aggregation.report_aggregation.report_id() { return Err(anyhow!( "missing, duplicate, out-of-order, or unexpected prepare steps in response" )); } - let new_state = match helper_prep_step.result() { - PrepareStepResult::Continued(payload) => { - // If the leader continued too, combine the leader's prepare share with the - // helper's to compute next round's prepare message. Prepare to store the - // leader's new state & the prepare message. If the leader didn't continue, - // transition to INVALID. - if let PrepareTransition::Continue(leader_prep_state, leader_prep_share) = - leader_transition - { - let leader_prep_state = leader_prep_state.clone(); - let helper_prep_share = - A::PrepareShare::get_decoded_with_param(&leader_prep_state, payload) - .context("couldn't decode helper's prepare message"); - let prep_msg = helper_prep_share.and_then(|helper_prep_share| { - vdaf.prepare_shares_to_prepare_message( - aggregation_job.aggregation_parameter(), - [leader_prep_share.clone(), helper_prep_share], - ) - .context( - "couldn't preprocess leader & helper prepare shares into \ - prepare message", + let new_state = match helper_prep_resp.result() { + PrepareStepResult::Continue { + message: helper_prep_msg, + } => { + let state_and_message = vdaf + .leader_continued( + stepped_aggregation.leader_state.clone(), + aggregation_job.aggregation_parameter(), + helper_prep_msg, + ) + .map_err(|ping_pong_error| { + handle_ping_pong_error( + task.id(), + Role::Leader, + stepped_aggregation.report_aggregation.report_id(), + ping_pong_error, + &self.aggregate_step_failure_counter, ) }); - match prep_msg { - Ok(prep_msg) => { - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) - } - Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't compute prepare message"); + + match state_and_message { + Ok(PingPongContinuedValue::WithMessage { transition }) => { + // Leader did not finish. Store our state and outgoing message for the + // next round. + // n.b. it's possible we finished and recovered an output share at the + // VDAF level (i.e., state may be PingPongState::Finished) but we cannot + // finish at the DAP layer and commit the output share until we get + // confirmation from the Helper that they finished, too. + ReportAggregationState::WaitingLeader(transition) + } + Ok(PingPongContinuedValue::FinishedNoMessage { output_share }) => { + // We finished and have no outgoing message, meaning the Helper was + // already finished. Commit the output share. + if let Err(err) = accumulator.update( + aggregation_job.partial_batch_identifier(), + stepped_aggregation.report_aggregation.report_id(), + stepped_aggregation.report_aggregation.time(), + &output_share, + ) { + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + ?err, + "Could not update batch aggregation", + ); self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_message_failure")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + .add(1, &[KeyValue::new("type", "accumulate_failure")]); + ReportAggregationState::::Failed( + PrepareError::VdafPrepError, + ) + } else { + ReportAggregationState::Finished } } - } else { - warn!(report_id = %report_aggregation.report_id(), "Helper continued but leader did not"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "continue_mismatch")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + Err(prepare_error) => ReportAggregationState::Failed(prepare_error), } } PrepareStepResult::Finished => { - // If the leader finished too, we are done; prepare to store the output share. - // If the leader didn't finish too, we transition to INVALID. - if let PrepareTransition::Finish(out_share) = leader_transition { - match accumulator.update( + if let PingPongState::Finished(output_share) = &stepped_aggregation.leader_state + { + // Helper finished and we had already finished. Commit the output share. + if let Err(err) = accumulator.update( aggregation_job.partial_batch_identifier(), - report_aggregation.report_id(), - report_aggregation.time(), - &out_share, + stepped_aggregation.report_aggregation.report_id(), + stepped_aggregation.report_aggregation.time(), + output_share, ) { - Ok(_) => ReportAggregationState::Finished, - Err(error) => { - warn!(report_id = %report_aggregation.report_id(), ?error, "Could not update batch aggregation"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "accumulate_failure")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) - } + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + ?err, + "Could not update batch aggregation", + ); + self.aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "accumulate_failure")]); + ReportAggregationState::::Failed( + PrepareError::VdafPrepError, + ) + } else { + ReportAggregationState::Finished } } else { - warn!(report_id = %report_aggregation.report_id(), "Helper finished but leader did not"); + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + "Helper finished but Leader did not", + ); self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "finish_mismatch")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + ReportAggregationState::Failed(PrepareError::VdafPrepError) } } - PrepareStepResult::Failed(err) => { + PrepareStepResult::Reject(err) => { // If the helper failed, we move to FAILED immediately. // TODO(#236): is it correct to just record the transition error that the helper reports? - info!(report_id = %report_aggregation.report_id(), helper_error = ?err, "Helper couldn't step report aggregation"); + info!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + helper_error = ?err, + "Helper couldn't step report aggregation", + ); self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "helper_step_failure")]); ReportAggregationState::Failed(*err) } }; - report_aggregations_to_write.push(report_aggregation.with_state(new_state)); + report_aggregations_to_write.push( + stepped_aggregation + .report_aggregation + .clone() + .with_state(new_state), + ); } // Write everything back to storage. @@ -664,6 +692,7 @@ impl AggregationJobDriver { report_aggregations_to_write, )?; let aggregation_job_writer = Arc::new(aggregation_job_writer); + let accumulator = Arc::new(accumulator); datastore .run_tx_with_name("step_aggregation_job_2", |tx| { @@ -742,6 +771,7 @@ impl AggregationJobDriver { A::AggregateShare: Send + Sync, A::AggregationParam: Send + Sync + PartialEq + Eq, A::PrepareMessage: Send + Sync, + A::OutputShare: Send + Sync, for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, { let vdaf = Arc::new(vdaf); @@ -864,7 +894,7 @@ impl AggregationJobDriver { /// transition representing the next step for the leader. struct SteppedAggregation> { report_aggregation: ReportAggregation, - leader_transition: PrepareTransition, + leader_state: PingPongState, } #[cfg(test)] @@ -903,20 +933,22 @@ mod tests { query_type::{FixedSize, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, Duration, Extension, ExtensionType, FixedSizeQuery, HpkeConfig, - InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, - PrepareStepResult, Query, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, - Role, TaskId, Time, + InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareContinue, + PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query, ReportIdChecksum, + ReportMetadata, ReportShare, Role, TaskId, Time, }; use prio::{ codec::Encode, + idpf::IdpfInput, vdaf::{ self, + poplar1::{Poplar1, Poplar1AggregationParam}, prio3::{Prio3, Prio3Count}, + xof::XofShake128, Aggregator, }, }; use rand::random; - use reqwest::Url; use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; @@ -935,13 +967,13 @@ mod tests { let mut runtime_manager = TestRuntimeManager::new(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_shake128(1)); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -951,31 +983,41 @@ mod tests { let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &measurement, ); let agg_auth_token = task.primary_aggregator_auth_token().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let aggregation_job_id = random(); let collection_job = ds .run_tx(|tx| { - let (vdaf, task, report) = (vdaf.clone(), task.clone(), report.clone()); + let (vdaf, task, report, aggregation_param) = ( + vdaf.clone(), + task.clone(), + report.clone(), + aggregation_param.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; @@ -985,11 +1027,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -997,38 +1039,46 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::from_time(&time).unwrap(), )) .await?; - let collection_job = - CollectionJob::::new( - *task.id(), - random(), - Query::new_time_interval(batch_identifier), - (), - batch_identifier, - CollectionJobState::Start, - ); + let collection_job = CollectionJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( + *task.id(), + random(), + Query::new_time_interval(batch_identifier), + aggregation_param, + batch_identifier, + CollectionJobState::Start, + ); tx.put_collection_job(&collection_job).await?; Ok(collection_job) @@ -1038,15 +1088,16 @@ mod tests { .unwrap(); // Setup: prepare mocked HTTP responses. - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); let helper_responses = Vec::from([ ( "PUT", AggregationJobInitializeReq::::MEDIA_TYPE, AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, )])) .get_encoded(), ), @@ -1054,14 +1105,14 @@ mod tests { "POST", AggregationJobContinueReq::MEDIA_TYPE, AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), PrepareStepResult::Finished, )])) .get_encoded(), ), ]); - let mocked_aggregates = join_all(helper_responses.into_iter().map( + let mocked_aggregates = join_all(helper_responses.iter().map( |(req_method, req_content_type, resp_content_type, resp_body)| { server .mock( @@ -1074,7 +1125,7 @@ mod tests { "DAP-Auth-Token", str::from_utf8(agg_auth_token.as_ref()).unwrap(), ) - .match_header(CONTENT_TYPE.as_str(), req_content_type) + .match_header(CONTENT_TYPE.as_str(), *req_content_type) .with_status(200) .with_header(CONTENT_TYPE.as_str(), resp_content_type) .with_body(resp_body) @@ -1127,29 +1178,30 @@ mod tests { } let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, AggregationJobRound::from(2), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Finished, - ); - let want_batch = Batch::::new( + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let want_batch = Batch::>::new( *task.id(), batch_identifier, - (), + aggregation_param.clone(), BatchState::Closed, 0, Interval::from_time(&time).unwrap(), @@ -1165,7 +1217,7 @@ mod tests { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -1177,12 +1229,17 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? .unwrap(); let batch = tx - .get_batch(task.id(), &batch_identifier, &()) + .get_batch( + task.id(), + &batch_identifier, + aggregation_job.aggregation_parameter(), + ) .await? .unwrap(); let collection_job = tx @@ -1202,7 +1259,7 @@ mod tests { } #[tokio::test] - async fn step_time_interval_aggregation_job_init() { + async fn step_time_interval_aggregation_job_init_single_round() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; @@ -1216,7 +1273,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -1243,7 +1300,8 @@ mod tests { helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let repeated_extension_report = generate_report::( *task.id(), @@ -1254,7 +1312,8 @@ mod tests { Extension::new(ExtensionType::Tbd, Vec::new()), Extension::new(ExtensionType::Tbd, Vec::new()), ]), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let missing_report_id = random(); let aggregation_job_id = random(); @@ -1353,16 +1412,20 @@ mod tests { let leader_request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1432,11 +1495,9 @@ mod tests { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Finished, AggregationJobRound::from(1), ); - let leader_prep_state = transcript.leader_prep_state(0).clone(); - let prep_msg = transcript.prepare_messages[0].clone(); let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, @@ -1444,7 +1505,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + ReportAggregationState::Finished, ); let want_repeated_extension_report_aggregation = ReportAggregation::::new( @@ -1454,7 +1515,7 @@ mod tests { *repeated_extension_report.metadata().time(), 1, None, - ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), + ReportAggregationState::Failed(PrepareError::UnrecognizedMessage), ); let want_missing_report_report_aggregation = ReportAggregation::::new( @@ -1464,14 +1525,14 @@ mod tests { time, 2, None, - ReportAggregationState::Failed(ReportShareError::ReportDropped), + ReportAggregationState::Failed(PrepareError::ReportDropped), ); let want_batch = Batch::::new( *task.id(), batch_identifier, (), BatchState::Closing, - 1, + 0, Interval::from_time(&time).unwrap(), ); @@ -1503,6 +1564,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? @@ -1513,6 +1575,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &repeated_extension_report_id, ) .await? @@ -1523,6 +1586,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &missing_report_id, ) .await? @@ -1557,98 +1621,109 @@ mod tests { } #[tokio::test] - async fn step_fixed_size_aggregation_job_init() { + async fn step_time_interval_aggregation_job_init_two_rounds() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_shake128(1)); let task = TaskBuilder::new( - QueryType::FixedSize { - max_batch_size: 10, - batch_time_window_size: None, - }, - VdafInstance::Prio3Count, + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); - let report_metadata = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &measurement, ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); - let batch_id = random(); let aggregation_job_id = random(); let lease = ds .run_tx(|tx| { - let (vdaf, task, report) = (vdaf.clone(), task.clone(), report.clone()); + let (vdaf, task, report, aggregation_param) = ( + vdaf.clone(), + task.clone(), + report.clone(), + aggregation_param.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, + TimeInterval, + Poplar1, >::new( *task.id(), aggregation_job_id, + aggregation_param.clone(), (), - batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::InProgress, AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - ), - ) + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), - batch_id, - (), - BatchState::Open, + batch_identifier, + aggregation_param, + BatchState::Closing, 1, - Interval::from_time(report.metadata().time()).unwrap(), + Interval::from_time(&time).unwrap(), )) .await?; @@ -1668,31 +1743,23 @@ mod tests { // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregationJobInitializeReq::new( - ().get_encoded(), - PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + aggregation_param.get_encoded(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, )])); - let mocked_aggregate_failure = server - .mock( - "PUT", - task.aggregation_job_uri(&aggregation_job_id) - .unwrap() - .path(), - ) - .with_status(500) - .with_header("Content-Type", "application/problem+json") - .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") - .create_async() - .await; let mocked_aggregate_success = server .mock( "PUT", @@ -1706,7 +1773,7 @@ mod tests { ) .match_header( CONTENT_TYPE.as_str(), - AggregationJobInitializeReq::::MEDIA_TYPE, + AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded()) .with_status(200) @@ -1721,51 +1788,555 @@ mod tests { &noop_meter(), 32, ); - let error = aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease.clone())) - .await - .unwrap_err(); - assert_matches!( - error.downcast().unwrap(), - Error::Http { problem_details, dap_problem_type } => { - assert_eq!(problem_details.status.unwrap(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(dap_problem_type, Some(DapProblemType::UnauthorizedRequest)); - } - ); aggregation_job_driver .step_aggregation_job(ds.clone(), Arc::new(lease)) .await .unwrap(); // Verify. - mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = AggregationJob::::new( + let want_aggregation_job = + AggregationJob::>::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(1), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::WaitingLeader( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), + ); + let want_batch = Batch::>::new( *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), + batch_identifier, + aggregation_param, + BatchState::Closing, + 1, + Interval::from_time(&time).unwrap(), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, + + let (got_aggregation_job, got_report_aggregation, got_batch) = ds + .run_tx(|tx| { + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::>( + task.id(), + &aggregation_job_id, + ) + .await? + .unwrap(); + let report_aggregation = tx + .get_report_aggregation( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + aggregation_job.aggregation_parameter(), + &report_id, + ) + .await? + .unwrap(); + let batch = tx + .get_batch( + task.id(), + &batch_identifier, + aggregation_job.aggregation_parameter(), + ) + .await? + .unwrap(); + Ok((aggregation_job, report_aggregation, batch)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch, got_batch); + } + + #[tokio::test] + async fn step_fixed_size_aggregation_job_init_single_round() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + + let task = TaskBuilder::new( + QueryType::FixedSize { + max_batch_size: 10, + batch_time_window_size: None, + }, + VdafInstance::Prio3Count, + Role::Leader, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + + let transcript = run_vdaf( + vdaf.as_ref(), + verify_key.as_bytes(), + &(), + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.primary_aggregator_auth_token(); + let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); + let report = generate_report::( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + transcript.public_share.clone(), + Vec::new(), + &transcript.leader_input_share, + &transcript.helper_input_share, + ); + let batch_id = random(); + let aggregation_job_id = random(); + + let lease = ds + .run_tx(|tx| { + let (vdaf, task, report) = (vdaf.clone(), task.clone(), report.clone()); + Box::pin(async move { + tx.put_task(&task).await?; + tx.put_client_report(vdaf.borrow(), &report).await?; + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + >::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) + .await?; + + tx.put_batch(&Batch::::new( + *task.id(), + batch_id, + (), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) + .await?; + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await? + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. (first an error response, then a success) + // (This is fragile in that it expects the leader request to be deterministically encoded. + // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & + // verification -- but mockito does not expose this functionality at time of writing.) + let leader_request = AggregationJobInitializeReq::new( + ().get_encoded(), + PartialBatchSelector::new_fixed_size(batch_id), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), + )]), + ); + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )])); + let mocked_aggregate_failure = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .with_status(500) + .with_header("Content-Type", "application/problem+json") + .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") + .create_async() + .await; + let mocked_aggregate_success = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header( + "DAP-Auth-Token", + str::from_utf8(agg_auth_token.as_ref()).unwrap(), + ) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + &noop_meter(), + 32, + ); + let error = aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease.clone())) + .await + .unwrap_err(); + assert_matches!( + error.downcast().unwrap(), + Error::Http { problem_details, dap_problem_type } => { + assert_eq!(problem_details.status.unwrap(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(dap_problem_type, Some(DapProblemType::UnauthorizedRequest)); + } + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_failure.assert_async().await; + mocked_aggregate_success.assert_async().await; + + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( + *task.id(), + aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, None, - ReportAggregationState::Waiting( - transcript.leader_prep_state(0).clone(), - Some(transcript.prepare_messages[0].clone()), - ), + ReportAggregationState::Finished, + ); + let want_batch = Batch::::new( + *task.id(), + batch_id, + (), + BatchState::Open, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + ); + + let (got_aggregation_job, got_report_aggregation, got_batch) = ds + .run_tx(|tx| { + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::( + task.id(), + &aggregation_job_id, + ) + .await? + .unwrap(); + let report_aggregation = tx + .get_report_aggregation( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + aggregation_job.aggregation_parameter(), + &report_id, + ) + .await? + .unwrap(); + let batch = tx.get_batch(task.id(), &batch_id, &()).await?.unwrap(); + Ok((aggregation_job, report_aggregation, batch)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch, got_batch); + } + + #[tokio::test] + async fn step_fixed_size_aggregation_job_init_two_rounds() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(Poplar1::new_shake128(1)); + + let task = TaskBuilder::new( + QueryType::FixedSize { + max_batch_size: 10, + batch_time_window_size: None, + }, + VdafInstance::Poplar1 { bits: 1 }, + Role::Leader, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); + + let transcript = run_vdaf( + vdaf.as_ref(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + + let agg_auth_token = task.primary_aggregator_auth_token(); + let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); + let report = generate_report::>( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + transcript.public_share.clone(), + Vec::new(), + &transcript.leader_input_share, + &transcript.helper_input_share, + ); + let batch_id = random(); + let aggregation_job_id = random(); + + let lease = ds + .run_tx(|tx| { + let (vdaf, task, report, aggregation_param) = ( + vdaf.clone(), + task.clone(), + report.clone(), + aggregation_param.clone(), + ); + Box::pin(async move { + tx.put_task(&task).await?; + tx.put_client_report(vdaf.borrow(), &report).await?; + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) + .await?; + + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( + *task.id(), + batch_id, + aggregation_param.clone(), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) + .await?; + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await? + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. (first an error response, then a success) + // (This is fragile in that it expects the leader request to be deterministically encoded. + // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & + // verification -- but mockito does not expose this functionality at time of writing.) + let leader_request = AggregationJobInitializeReq::new( + aggregation_param.get_encoded(), + PartialBatchSelector::new_fixed_size(batch_id), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), + )]), ); - let want_batch = Batch::::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )])); + let mocked_aggregate_success = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header( + "DAP-Auth-Token", + str::from_utf8(agg_auth_token.as_ref()).unwrap(), + ) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + &noop_meter(), + 32, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_success.assert_async().await; + + let want_aggregation_job = + AggregationJob::>::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(1), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::WaitingLeader( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), + ); + let want_batch = Batch::>::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Open, 1, Interval::from_time(report.metadata().time()).unwrap(), @@ -1777,7 +2348,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -1789,11 +2360,19 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? .unwrap(); - let batch = tx.get_batch(task.id(), &batch_id, &()).await?.unwrap(); + let batch = tx + .get_batch( + task.id(), + &batch_id, + aggregation_job.aggregation_parameter(), + ) + .await? + .unwrap(); Ok((aggregation_job, report_aggregation, batch)) }) }) @@ -1814,14 +2393,14 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_shake128(1)); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock .now() @@ -1844,40 +2423,43 @@ mod tests { let report_metadata = ReportMetadata::new(random(), time); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate(&aggregation_param, [transcript.leader_output_share.clone()]) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; let (lease, want_collection_job) = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, aggregation_param, report, transcript) = ( vdaf.clone(), task.clone(), + aggregation_param.clone(), report.clone(), - leader_prep_state.clone(), - prep_msg.clone(), + transcript.clone(), ); Box::pin(async move { tx.put_task(&task).await?; @@ -1888,11 +2470,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -1900,47 +2482,65 @@ mod tests { AggregationJobRound::from(1), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::WaitingLeader( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), ), - ) + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), active_batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::from_time(report.metadata().time()).unwrap(), )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), other_batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::EMPTY, )) .await?; - let collection_job = - CollectionJob::::new( - *task.id(), - random(), - Query::new_time_interval(collection_identifier), - (), - collection_identifier, - CollectionJobState::Start, - ); + let collection_job = CollectionJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( + *task.id(), + random(), + Query::new_time_interval(collection_identifier), + aggregation_param, + collection_identifier, + CollectionJobState::Start, + ); tx.put_collection_job(&collection_job).await?; let lease = tx @@ -1962,12 +2562,12 @@ mod tests { // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + transcript.leader_prepare_transitions[1].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), PrepareStepResult::Finished, )])); @@ -2029,25 +2629,27 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, AggregationJobRound::from(2), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Finished, - ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let batch_interval_start = report .metadata() .time() @@ -2056,11 +2658,11 @@ mod tests { let want_batch_aggregations = Vec::from([BatchAggregation::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), Interval::new(batch_interval_start, *task.time_precision()).unwrap(), - (), + aggregation_param.clone(), 0, BatchAggregationState::Aggregating, Some(leader_aggregate_share), @@ -2068,22 +2670,24 @@ mod tests { Interval::from_time(report.metadata().time()).unwrap(), ReportIdChecksum::for_report_id(report.metadata().id()), )]); - let want_active_batch = Batch::::new( - *task.id(), - active_batch_identifier, - (), - BatchState::Closed, - 0, - Interval::from_time(report.metadata().time()).unwrap(), - ); - let want_other_batch = Batch::::new( - *task.id(), - other_batch_identifier, - (), - BatchState::Closing, - 1, - Interval::EMPTY, - ); + let want_active_batch = + Batch::>::new( + *task.id(), + active_batch_identifier, + aggregation_param.clone(), + BatchState::Closed, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + ); + let want_other_batch = + Batch::>::new( + *task.id(), + other_batch_identifier, + aggregation_param.clone(), + BatchState::Closing, + 1, + Interval::EMPTY, + ); let ( got_aggregation_job, @@ -2094,14 +2698,16 @@ mod tests { got_collection_job, ) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let task = task.clone(); - let report_metadata = report.metadata().clone(); - let collection_job_id = *want_collection_job.id(); - + let (vdaf, task, report_metadata, aggregation_param, collection_job_id) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + *want_collection_job.id(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2113,6 +2719,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), report_metadata.id(), ) .await? @@ -2120,7 +2727,7 @@ mod tests { let batch_aggregations = TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -2134,16 +2741,16 @@ mod tests { *task.time_precision(), ) .unwrap(), - &(), + &aggregation_param, ) .await .unwrap(); let got_active_batch = tx - .get_batch(task.id(), &active_batch_identifier, &()) + .get_batch(task.id(), &active_batch_identifier, &aggregation_param) .await? .unwrap(); let got_other_batch = tx - .get_batch(task.id(), &other_batch_identifier, &()) + .get_batch(task.id(), &other_batch_identifier, &aggregation_param) .await? .unwrap(); let got_collection_job = tx @@ -2171,7 +2778,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, *agg.state(), agg.aggregate_share().cloned(), @@ -2199,17 +2806,17 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_shake128(1)); let task = TaskBuilder::new( QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let report_metadata = ReportMetadata::new( random(), @@ -2220,40 +2827,43 @@ mod tests { ); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let batch_id = random(); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate(&aggregation_param, [transcript.leader_output_share.clone()]) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; let (lease, collection_job) = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, report, aggregation_param, transcript) = ( vdaf.clone(), task.clone(), report.clone(), - leader_prep_state.clone(), - prep_msg.clone(), + aggregation_param.clone(), + transcript.clone(), ); Box::pin(async move { tx.put_task(&task).await?; @@ -2262,11 +2872,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, FixedSize, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2274,38 +2884,52 @@ mod tests { AggregationJobRound::from(1), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::WaitingLeader( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), ), - ) + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::from_time(report.metadata().time()).unwrap(), )) .await?; - let collection_job = - CollectionJob::::new( - *task.id(), - random(), - Query::new_fixed_size(FixedSizeQuery::CurrentBatch), - (), - batch_id, - CollectionJobState::Start, - ); + let collection_job = CollectionJob::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( + *task.id(), + random(), + Query::new_fixed_size(FixedSizeQuery::CurrentBatch), + aggregation_param, + batch_id, + CollectionJobState::Start, + ); tx.put_collection_job(&collection_job).await?; let lease = tx @@ -2327,12 +2951,12 @@ mod tests { // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + transcript.leader_prepare_transitions[1].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), PrepareStepResult::Finished, )])); @@ -2393,42 +3017,46 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = AggregationJob::::new( + let want_aggregation_job = + AggregationJob::>::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(2), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let want_batch_aggregations = Vec::from([BatchAggregation::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( *task.id(), - aggregation_job_id, - (), batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(2), - ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), + aggregation_param.clone(), 0, - None, - ReportAggregationState::Finished, - ); - let want_batch_aggregations = - Vec::from([ - BatchAggregation::::new( - *task.id(), - batch_id, - (), - 0, - BatchAggregationState::Aggregating, - Some(leader_aggregate_share), - 1, - Interval::from_time(report.metadata().time()).unwrap(), - ReportIdChecksum::for_report_id(report.metadata().id()), - ), - ]); - let want_batch = Batch::::new( + BatchAggregationState::Aggregating, + Some(leader_aggregate_share), + 1, + Interval::from_time(report.metadata().time()).unwrap(), + ReportIdChecksum::for_report_id(report.metadata().id()), + )]); + let want_batch = Batch::>::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Closed, 0, Interval::from_time(report.metadata().time()).unwrap(), @@ -2443,14 +3071,16 @@ mod tests { got_collection_job, ) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let task = task.clone(); - let report_metadata = report.metadata().clone(); - let collection_job_id = *want_collection_job.id(); - + let (vdaf, task, report_metadata, aggregation_param, collection_job_id) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + *want_collection_job.id(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2462,6 +3092,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), report_metadata.id(), ) .await? @@ -2469,11 +3100,14 @@ mod tests { let batch_aggregations = FixedSize::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, - >(tx, &task, &vdaf, &batch_id, &()) + >(tx, &task, &vdaf, &batch_id, &aggregation_param) .await?; - let batch = tx.get_batch(task.id(), &batch_id, &()).await?.unwrap(); + let batch = tx + .get_batch(task.id(), &batch_id, &aggregation_param) + .await? + .unwrap(); let collection_job = tx .get_collection_job(vdaf.as_ref(), task.id(), &collection_job_id) .await? @@ -2497,7 +3131,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, *agg.state(), agg.aggregate_share().cloned(), @@ -2553,7 +3187,8 @@ mod tests { helper_hpke_keypair.config(), transcript.public_share, Vec::new(), - transcript.input_shares, + &transcript.leader_input_share, + &transcript.helper_input_share, ); let aggregation_job_id = random(); @@ -2655,6 +3290,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? @@ -2685,26 +3321,18 @@ mod tests { helper_hpke_config: &HpkeConfig, public_share: A::PublicShare, extensions: Vec, - input_shares: Vec, + leader_input_share: &A::InputShare, + helper_input_share: &A::InputShare, ) -> LeaderStoredReport where A: vdaf::Aggregator, A::InputShare: PartialEq, A::PublicShare: PartialEq, { - assert_eq!(input_shares.len(), 2); - let encrypted_helper_input_share = hpke::seal( helper_hpke_config, &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - &PlaintextInputShare::new( - Vec::new(), - input_shares - .get(Role::Helper.index().unwrap()) - .unwrap() - .get_encoded(), - ) - .get_encoded(), + &PlaintextInputShare::new(Vec::new(), helper_input_share.get_encoded()).get_encoded(), &InputShareAad::new(task_id, report_metadata.clone(), public_share.get_encoded()) .get_encoded(), ) @@ -2715,10 +3343,7 @@ mod tests { report_metadata, public_share, extensions, - input_shares - .get(Role::Leader.index().unwrap()) - .unwrap() - .clone(), + leader_input_share.clone(), encrypted_helper_input_share, ) } @@ -2738,7 +3363,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); @@ -2760,7 +3385,8 @@ mod tests { helper_hpke_keypair.config(), transcript.public_share, Vec::new(), - transcript.input_shares, + &transcript.leader_input_share, + &transcript.helper_input_share, ); // Set up fixtures in the database. diff --git a/aggregator/src/aggregator/aggregation_job_writer.rs b/aggregator/src/aggregator/aggregation_job_writer.rs index 65ce1d387..ae2d1e41b 100644 --- a/aggregator/src/aggregator/aggregation_job_writer.rs +++ b/aggregator/src/aggregator/aggregation_job_writer.rs @@ -14,7 +14,7 @@ use janus_aggregator_core::{ task::Task, }; use janus_core::time::{Clock, IntervalExt}; -use janus_messages::{AggregationJobId, Interval, ReportId, ReportShareError}; +use janus_messages::{AggregationJobId, Interval, PrepareError, ReportId}; use prio::{codec::Encode, vdaf}; use std::{ borrow::Cow, @@ -260,7 +260,7 @@ impl, +) -> PrepareError { + let peer_role = match role { + Role::Leader => Role::Helper, + Role::Helper => Role::Leader, + // panic safety: role should be passed to this function as a literal, so passing a role that + // isn't an aggregator is a straightforward programmer error and we want to fail noisily. + _ => panic!("invalid role"), + }; + let (error_desc, value) = match ping_pong_error { + PingPongError::VdafPrepareInit(_) => ( + "Couldn't helper_initialize report share".to_string(), + "prepare_init_failure".to_string(), + ), + PingPongError::VdafPrepareSharesToPrepareMessage(_) => ( + "Couldn't compute prepare message".to_string(), + "prepare_message_failure".to_string(), + ), + PingPongError::VdafPrepareNext(_) => ( + "Prepare next failed".to_string(), + "prepare_next_failure".to_string(), + ), + PingPongError::CodecPrepShare(_) => ( + format!("Couldn't decode {peer_role} prepare share"), + format!("{peer_role}_prep_share_decode_failure"), + ), + PingPongError::CodecPrepMessage(_) => ( + format!("Couldn't decode {peer_role} prepare message"), + format!("{peer_role}_prep_message_decode_failure"), + ), + ref error @ PingPongError::HostStateMismatch { .. } => ( + format!("{error}"), + format!("{role}_ping_pong_host_state_mismatch"), + ), + ref error @ PingPongError::PeerMessageMismatch { .. } => ( + format!("{error}"), + format!("{peer_role}_ping_pong_message_mismatch"), + ), + PingPongError::InternalError(desc) => ( + desc.to_string(), + "vdaf_ping_pong_internal_error".to_string(), + ), + }; + + info!( + task_id = %task_id, + report_id = %report_id, + ?ping_pong_error, + error_desc, + ); + + aggregate_step_failure_counter.add(1, &[KeyValue::new("type", value)]); + + // Per DAP, any occurrence of state Rejected() from a ping-pong routime is translated to + // VdafPrepError + PrepareError::VdafPrepError +} diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index d2787d0be..b3ee890bf 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -640,7 +640,9 @@ pub mod test_util { mod tests { use crate::{ aggregator::{ - aggregate_init_tests::{put_aggregation_job, setup_aggregate_init_test}, + aggregate_init_tests::{ + put_aggregation_job, setup_aggregate_init_test, PrepareInitGenerator, + }, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, }, @@ -696,15 +698,18 @@ mod tests { AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, HpkeConfigList, InputShareAad, Interval, - PartialBatchSelector, PrepareStep, PrepareStepResult, Query, Report, ReportId, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + PartialBatchSelector, PrepareContinue, PrepareError, PrepareInit, PrepareResp, + PrepareStepResult, Query, Report, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, + Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, - field::Field64, + idpf::IdpfInput, + topology::ping_pong::PingPongMessage, vdaf::{ - prio3::{Prio3, Prio3Count}, - AggregateShare, Aggregator, OutputShare, + poplar1::{Poplar1, Poplar1AggregationParam}, + xof::XofShake128, + Aggregator, }, }; use rand::random; @@ -1494,9 +1499,9 @@ mod tests { } #[tokio::test] - // Silence the unit_arg lint so that we can work with dummy_vdaf::Vdaf::InputShare values (whose - // type is ()). - #[allow(clippy::unit_arg)] + // Silence the unit_arg lint so that we can work with dummy_vdaf::Vdaf::{InputShare, + // Measurement} values (whose type is ()). + #[allow(clippy::unit_arg, clippy::let_unit_value)] async fn aggregate_init() { let (clock, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; @@ -1506,55 +1511,21 @@ mod tests { let vdaf = dummy_vdaf::Vdaf::new(); let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); - - // report_share_0 is a "happy path" report. - let report_metadata_0 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_0.id(), - &(), - ); - let report_share_0 = generate_helper_report_share::( - *task.id(), - report_metadata_0, - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], + let measurement = (); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + vdaf.clone(), + dummy_vdaf::AggregationParam(0), ); + // prepare_init_0 is a "happy path" report. + let (prepare_init_0, transcript_0) = prep_init_generator.next(&measurement); + // report_share_1 fails decryption. - let report_metadata_1 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_1.id(), - &(), - ); - let report_share_1 = generate_helper_report_share::( - *task.id(), - report_metadata_1.clone(), - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); - let encrypted_input_share = report_share_1.encrypted_input_share(); + let (prepare_init_1, transcript_1) = prep_init_generator.next(&measurement); + + let encrypted_input_share = prepare_init_1.report_share().encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); corrupted_payload[0] ^= 0xFF; let corrupted_input_share = HpkeCiphertext::new( @@ -1562,53 +1533,39 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_1 = ReportShare::new( - report_metadata_1, - encoded_public_share.clone(), - corrupted_input_share, - ); - // report_share_2 fails decoding due to an issue with the input share. - let report_metadata_2 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_2.id(), - &(), + let prepare_init_1 = PrepareInit::new( + ReportShare::new( + prepare_init_1.report_share().metadata().clone(), + transcript_1.public_share.get_encoded(), + corrupted_input_share, + ), + prepare_init_1.message().clone(), ); - let mut input_share_bytes = transcript.input_shares[1].get_encoded(); + + // prepare_init_2 fails decoding due to an issue with the input share. + let (prepare_init_2, transcript_2) = prep_init_generator.next(&measurement); + + let mut input_share_bytes = transcript_2.helper_input_share.get_encoded(); input_share_bytes.push(0); // can no longer be decoded. let report_share_2 = generate_helper_report_share_for_plaintext( - report_metadata_2.clone(), + prepare_init_2.report_share().metadata().clone(), hpke_key.config(), - encoded_public_share.clone(), + transcript_2.public_share.get_encoded(), &input_share_bytes, - &InputShareAad::new(*task.id(), report_metadata_2, encoded_public_share).get_encoded(), + &InputShareAad::new( + *task.id(), + prepare_init_2.report_share().metadata().clone(), + transcript_2.public_share.get_encoded(), + ) + .get_encoded(), ); - // report_share_3 has an unknown HPKE config ID. - let report_metadata_3 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_3.id(), - &(), - ); + let prepare_init_2 = PrepareInit::new(report_share_2, prepare_init_2.message().clone()); + + // prepare_init_3 has an unknown HPKE config ID. + let (prepare_init_3, transcript_3) = prep_init_generator.next(&measurement); + let wrong_hpke_config = loop { let hpke_config = generate_test_hpke_config_and_private_key().config().clone(); if task.hpke_keys().contains_key(hpke_config.id()) { @@ -1616,41 +1573,23 @@ mod tests { } break hpke_config; }; + let report_share_3 = generate_helper_report_share::( *task.id(), - report_metadata_3, + prepare_init_3.report_share().metadata().clone(), &wrong_hpke_config, - &transcript.public_share, + &transcript_3.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_3.helper_input_share, ); - // report_share_4 has already been aggregated in another aggregation job, with the same + let prepare_init_3 = PrepareInit::new(report_share_3, prepare_init_3.message().clone()); + + // prepare_init_4 has already been aggregated in another aggregation job, with the same // aggregation parameter. - let report_metadata_4 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_4.id(), - &(), - ); - let report_share_4 = generate_helper_report_share::( - *task.id(), - report_metadata_4, - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); + let (prepare_init_4, _) = prep_init_generator.next(&measurement); - // report_share_5 falls into a batch that has already been collected. + // prepare_init_5 falls into a batch that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( task.time_precision().as_seconds() / 2, )); @@ -1661,23 +1600,28 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_5 = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_5.id(), - &(), + &measurement, ); let report_share_5 = generate_helper_report_share::( *task.id(), report_metadata_5, hpke_key.config(), - &transcript.public_share, + &transcript_5.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_5.helper_input_share, + ); + + let prepare_init_5 = PrepareInit::new( + report_share_5, + transcript_5.leader_prepare_transitions[0].message.clone(), ); - // report_share_6 fails decoding due to an issue with the public share. + // prepare_init_6 fails decoding due to an issue with the public share. let public_share_6 = Vec::from([0]); let report_metadata_6 = ReportMetadata::new( random(), @@ -1686,22 +1630,27 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_6 = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_6.id(), - &(), + &measurement, ); let report_share_6 = generate_helper_report_share_for_plaintext( report_metadata_6.clone(), hpke_key.config(), public_share_6.clone(), - &transcript.input_shares[1].get_encoded(), + &transcript_6.helper_input_share.get_encoded(), &InputShareAad::new(*task.id(), report_metadata_6, public_share_6).get_encoded(), ); - // report_share_7 fails due to having repeated extensions. + let prepare_init_6 = PrepareInit::new( + report_share_6, + transcript_6.leader_prepare_transitions[0].message.clone(), + ); + + // prepare_init_7 fails due to having repeated extensions. let report_metadata_7 = ReportMetadata::new( random(), clock @@ -1709,56 +1658,40 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_7 = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_7.id(), - &(), + &measurement, ); let report_share_7 = generate_helper_report_share::( *task.id(), report_metadata_7, hpke_key.config(), - &transcript.public_share, + &transcript_7.public_share, Vec::from([ Extension::new(ExtensionType::Tbd, Vec::new()), Extension::new(ExtensionType::Tbd, Vec::new()), ]), - &transcript.input_shares[0], + &transcript_7.helper_input_share, ); - // report_share_8 has already been aggregated in another aggregation job, with a different - // aggregation parameter. - let report_metadata_8 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(1), - report_metadata_8.id(), - &(), - ); - let report_share_8 = generate_helper_report_share::( - *task.id(), - report_metadata_8, - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], + let prepare_init_7 = PrepareInit::new( + report_share_7, + transcript_7.leader_prepare_transitions[0].message.clone(), ); + // prepare_init_8 has already been aggregated in another aggregation job, with a different + // aggregation parameter. + let (prepare_init_8, transcript_8) = prep_init_generator.next(&measurement); + let (conflicting_aggregation_job, non_conflicting_aggregation_job) = datastore .run_tx(|tx| { let task = task.clone(); - let report_share_4 = report_share_4.clone(); - let report_share_5 = report_share_5.clone(); - let report_share_8 = report_share_8.clone(); + let report_share_4 = prepare_init_4.report_share().clone(); + let report_share_5 = prepare_init_5.report_share().clone(); + let report_share_8 = prepare_init_8.report_share().clone(); Box::pin(async move { tx.put_task(&task).await?; @@ -1856,15 +1789,15 @@ mod tests { dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), Vec::from([ - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), - report_share_6.clone(), - report_share_7.clone(), - report_share_8.clone(), + prepare_init_0.clone(), + prepare_init_1.clone(), + prepare_init_2.clone(), + prepare_init_3.clone(), + prepare_init_4.clone(), + prepare_init_5.clone(), + prepare_init_6.clone(), + prepare_init_7.clone(), + prepare_init_8.clone(), ]), ); @@ -1881,64 +1814,95 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 9); + assert_eq!(aggregate_resp.prepare_resps().len(), 9); - let prepare_step_0 = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step_0.report_id(), report_share_0.metadata().id()); - assert_matches!(prepare_step_0.result(), &PrepareStepResult::Continued(..)); + let prepare_step_0 = aggregate_resp.prepare_resps().get(0).unwrap(); + assert_eq!( + prepare_step_0.report_id(), + prepare_init_0.report_share().metadata().id() + ); + assert_matches!(prepare_step_0.result(), PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_0.helper_prepare_transitions[0].message); + }); - let prepare_step_1 = aggregate_resp.prepare_steps().get(1).unwrap(); - assert_eq!(prepare_step_1.report_id(), report_share_1.metadata().id()); + let prepare_step_1 = aggregate_resp.prepare_resps().get(1).unwrap(); + assert_eq!( + prepare_step_1.report_id(), + prepare_init_1.report_share().metadata().id() + ); assert_matches!( prepare_step_1.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + &PrepareStepResult::Reject(PrepareError::HpkeDecryptError) ); - let prepare_step_2 = aggregate_resp.prepare_steps().get(2).unwrap(); - assert_eq!(prepare_step_2.report_id(), report_share_2.metadata().id()); + let prepare_step_2 = aggregate_resp.prepare_resps().get(2).unwrap(); + assert_eq!( + prepare_step_2.report_id(), + prepare_init_2.report_share().metadata().id() + ); assert_matches!( prepare_step_2.result(), - &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage) + &PrepareStepResult::Reject(PrepareError::UnrecognizedMessage) ); - let prepare_step_3 = aggregate_resp.prepare_steps().get(3).unwrap(); - assert_eq!(prepare_step_3.report_id(), report_share_3.metadata().id()); + let prepare_step_3 = aggregate_resp.prepare_resps().get(3).unwrap(); + assert_eq!( + prepare_step_3.report_id(), + prepare_init_3.report_share().metadata().id() + ); assert_matches!( prepare_step_3.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeUnknownConfigId) + &PrepareStepResult::Reject(PrepareError::HpkeUnknownConfigId) ); - let prepare_step_4 = aggregate_resp.prepare_steps().get(4).unwrap(); - assert_eq!(prepare_step_4.report_id(), report_share_4.metadata().id()); + let prepare_step_4 = aggregate_resp.prepare_resps().get(4).unwrap(); + assert_eq!( + prepare_step_4.report_id(), + prepare_init_4.report_share().metadata().id() + ); assert_eq!( prepare_step_4.result(), - &PrepareStepResult::Failed(ReportShareError::ReportReplayed) + &PrepareStepResult::Reject(PrepareError::ReportReplayed) ); - let prepare_step_5 = aggregate_resp.prepare_steps().get(5).unwrap(); - assert_eq!(prepare_step_5.report_id(), report_share_5.metadata().id()); + let prepare_step_5 = aggregate_resp.prepare_resps().get(5).unwrap(); + assert_eq!( + prepare_step_5.report_id(), + prepare_init_5.report_share().metadata().id() + ); assert_eq!( prepare_step_5.result(), - &PrepareStepResult::Failed(ReportShareError::BatchCollected) + &PrepareStepResult::Reject(PrepareError::BatchCollected) ); - let prepare_step_6 = aggregate_resp.prepare_steps().get(6).unwrap(); - assert_eq!(prepare_step_6.report_id(), report_share_6.metadata().id()); + let prepare_step_6 = aggregate_resp.prepare_resps().get(6).unwrap(); + assert_eq!( + prepare_step_6.report_id(), + prepare_init_6.report_share().metadata().id() + ); assert_eq!( prepare_step_6.result(), - &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), + &PrepareStepResult::Reject(PrepareError::UnrecognizedMessage), ); - let prepare_step_7 = aggregate_resp.prepare_steps().get(7).unwrap(); - assert_eq!(prepare_step_7.report_id(), report_share_7.metadata().id()); + let prepare_step_7 = aggregate_resp.prepare_resps().get(7).unwrap(); + assert_eq!( + prepare_step_7.report_id(), + prepare_init_7.report_share().metadata().id() + ); assert_eq!( prepare_step_7.result(), - &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), + &PrepareStepResult::Reject(PrepareError::UnrecognizedMessage), ); - let prepare_step_8 = aggregate_resp.prepare_steps().get(8).unwrap(); - assert_eq!(prepare_step_8.report_id(), report_share_8.metadata().id()); - assert_matches!(prepare_step_8.result(), &PrepareStepResult::Continued(..)); + let prepare_step_8 = aggregate_resp.prepare_resps().get(8).unwrap(); + assert_eq!( + prepare_step_8.report_id(), + prepare_init_8.report_share().metadata().id() + ); + assert_matches!(prepare_step_8.result(), PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_8.helper_prepare_transitions[0].message); + }); // Check aggregation job in datastore. let aggregation_jobs = datastore @@ -1968,7 +1932,7 @@ mod tests { } else if aggregation_job.task_id().eq(task.id()) && aggregation_job.id().eq(&aggregation_job_id) && aggregation_job.partial_batch_identifier().eq(&()) - && aggregation_job.state().eq(&AggregationJobState::InProgress) + && aggregation_job.state().eq(&AggregationJobState::Finished) { saw_new_aggregation_job = true; } @@ -1988,6 +1952,10 @@ mod tests { let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); datastore.put_task(&task).await.unwrap(); + let vdaf = dummy_vdaf::Vdaf::new(); + let aggregation_param = dummy_vdaf::AggregationParam(0); + let prep_init_generator = + PrepareInitGenerator::new(clock.clone(), task.clone(), vdaf.clone(), aggregation_param); // Insert some global HPKE keys. // Same ID as the task to test having both keys to choose from. @@ -2028,59 +1996,20 @@ mod tests { .await .unwrap(); - let vdaf = dummy_vdaf::Vdaf::new(); let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); // This report was encrypted with a global HPKE config that has the same config // ID as the task's HPKE config. - let report_metadata_same_id = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_same_id.id(), - &(), - ); - let report_share_same_id = generate_helper_report_share::( - *task.id(), - report_metadata_same_id, - global_hpke_keypair_same_id.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); + let (prepare_init_same_id, transcript_same_id) = prep_init_generator.next(&()); // This report was encrypted with a global HPKE config that has the same config // ID as the task's HPKE config, but will fail to decrypt. - let report_metadata_same_id_corrupted = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_same_id_corrupted.id(), - &(), - ); - let report_share_same_id_corrupted = generate_helper_report_share::( - *task.id(), - report_metadata_same_id_corrupted.clone(), - global_hpke_keypair_same_id.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); - let encrypted_input_share = report_share_same_id_corrupted.encrypted_input_share(); + let (prepare_init_same_id_corrupted, transcript_same_id_corrupted) = + prep_init_generator.next(&()); + + let encrypted_input_share = prepare_init_same_id_corrupted + .report_share() + .encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); corrupted_payload[0] ^= 0xFF; let corrupted_input_share = HpkeCiphertext::new( @@ -2088,11 +2017,17 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_same_id_corrupted = ReportShare::new( - report_metadata_same_id_corrupted, - encoded_public_share.clone(), - corrupted_input_share, + + let prepare_init_same_id_corrupted = PrepareInit::new( + ReportShare::new( + prepare_init_same_id_corrupted + .report_share() + .metadata() + .clone(), + transcript_same_id_corrupted.public_share.get_encoded(), + corrupted_input_share, + ), + prepare_init_same_id_corrupted.message().clone(), ); // This report was encrypted with a global HPKE config that doesn't collide @@ -2104,7 +2039,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_different_id = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), @@ -2115,9 +2050,16 @@ mod tests { *task.id(), report_metadata_different_id, global_hpke_keypair_different_id.config(), - &transcript.public_share, + &transcript_different_id.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_different_id.helper_input_share, + ); + + let prepare_init_different_id = PrepareInit::new( + report_share_different_id, + transcript_different_id.leader_prepare_transitions[0] + .message + .clone(), ); // This report was encrypted with a global HPKE config that doesn't collide @@ -2129,7 +2071,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_different_id_corrupted = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), @@ -2140,9 +2082,9 @@ mod tests { *task.id(), report_metadata_different_id_corrupted.clone(), global_hpke_keypair_different_id.config(), - &transcript.public_share, + &transcript_different_id_corrupted.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_different_id_corrupted.helper_input_share, ); let encrypted_input_share = report_share_different_id_corrupted.encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); @@ -2152,11 +2094,17 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_different_id_corrupted = ReportShare::new( - report_metadata_different_id_corrupted, - encoded_public_share.clone(), - corrupted_input_share, + let encoded_public_share = transcript_different_id_corrupted.public_share.get_encoded(); + + let prepare_init_different_id_corrupted = PrepareInit::new( + ReportShare::new( + report_metadata_different_id_corrupted, + encoded_public_share.clone(), + corrupted_input_share, + ), + transcript_different_id_corrupted.leader_prepare_transitions[0] + .message + .clone(), ); let aggregation_job_id: AggregationJobId = random(); @@ -2164,10 +2112,10 @@ mod tests { dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), Vec::from([ - report_share_same_id.clone(), - report_share_different_id.clone(), - report_share_same_id_corrupted.clone(), - report_share_different_id_corrupted.clone(), + prepare_init_same_id.clone(), + prepare_init_different_id.clone(), + prepare_init_same_id_corrupted.clone(), + prepare_init_different_id_corrupted.clone(), ]), ); @@ -2177,46 +2125,53 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 4); + assert_eq!(aggregate_resp.prepare_resps().len(), 4); - let prepare_step_same_id = aggregate_resp.prepare_steps().get(0).unwrap(); + let prepare_step_same_id = aggregate_resp.prepare_resps().get(0).unwrap(); assert_eq!( prepare_step_same_id.report_id(), - report_share_same_id.metadata().id() - ); - assert_matches!( - prepare_step_same_id.result(), - &PrepareStepResult::Continued(..) + prepare_init_same_id.report_share().metadata().id() ); + assert_matches!(prepare_step_same_id.result(), PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_same_id.helper_prepare_transitions[0].message); + }); - let prepare_step_different_id = aggregate_resp.prepare_steps().get(1).unwrap(); + let prepare_step_different_id = aggregate_resp.prepare_resps().get(1).unwrap(); assert_eq!( prepare_step_different_id.report_id(), - report_share_different_id.metadata().id() + prepare_init_different_id.report_share().metadata().id() ); assert_matches!( prepare_step_different_id.result(), - &PrepareStepResult::Continued(..) + PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_different_id.helper_prepare_transitions[0].message); + } ); - let prepare_step_same_id_corrupted = aggregate_resp.prepare_steps().get(2).unwrap(); + let prepare_step_same_id_corrupted = aggregate_resp.prepare_resps().get(2).unwrap(); assert_eq!( prepare_step_same_id_corrupted.report_id(), - report_share_same_id_corrupted.metadata().id() + prepare_init_same_id_corrupted + .report_share() + .metadata() + .id(), ); assert_matches!( prepare_step_same_id_corrupted.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + &PrepareStepResult::Reject(PrepareError::HpkeDecryptError) ); - let prepare_step_different_id_corrupted = aggregate_resp.prepare_steps().get(3).unwrap(); + let prepare_step_different_id_corrupted = aggregate_resp.prepare_resps().get(3).unwrap(); assert_eq!( prepare_step_different_id_corrupted.report_id(), - report_share_different_id_corrupted.metadata().id() + prepare_init_different_id_corrupted + .report_share() + .metadata() + .id() ); assert_matches!( prepare_step_different_id_corrupted.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + &PrepareStepResult::Reject(PrepareError::HpkeDecryptError) ); } @@ -2230,24 +2185,23 @@ mod tests { // This report has the same ID as the previous one, but a different timestamp. let mutated_timestamp_report_metadata = ReportMetadata::new( - *test_case.report_shares[0].metadata().id(), + *test_case.prepare_inits[0].report_share().metadata().id(), test_case .clock .now() .add(test_case.task.time_precision()) .unwrap(), ); - let mutated_timestamp_report_share = test_case - .report_share_generator - .next_with_metadata(mutated_timestamp_report_metadata) - .0; + let (mutated_timestamp_prepare_init, _) = test_case + .prepare_init_generator + .next_with_metadata(mutated_timestamp_report_metadata, &()); // Send another aggregate job re-using the same report ID but with a different timestamp. It // should be flagged as a replay. let request = AggregationJobInitializeReq::new( other_aggregation_parameter.get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([mutated_timestamp_report_share.clone()]), + Vec::from([mutated_timestamp_prepare_init.clone()]), ); let mut test_conn = @@ -2255,16 +2209,19 @@ mod tests { assert_eq!(test_conn.status(), Some(Status::Ok)); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregate_resp.prepare_steps().len(), 1); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); assert_eq!( prepare_step.report_id(), - mutated_timestamp_report_share.metadata().id() + mutated_timestamp_prepare_init + .report_share() + .metadata() + .id(), ); assert_matches!( prepare_step.result(), - &PrepareStepResult::Failed(ReportShareError::ReportReplayed) + &PrepareStepResult::Reject(PrepareError::ReportReplayed) ); // The attempt to mutate the report share timestamp should not cause any change in the @@ -2282,8 +2239,14 @@ mod tests { .await .unwrap(); assert_eq!(client_reports.len(), 2); - assert_eq!(&client_reports[0], test_case.report_shares[0].metadata()); - assert_eq!(&client_reports[1], test_case.report_shares[1].metadata()); + assert_eq!( + &client_reports[0], + test_case.prepare_inits[0].report_share().metadata() + ); + assert_eq!( + &client_reports[1], + test_case.prepare_inits[1].report_share().metadata() + ); } #[tokio::test] @@ -2296,27 +2259,20 @@ mod tests { Role::Helper, ) .build(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + dummy_vdaf::AggregationParam(0), + ); + datastore.put_task(&task).await.unwrap(); - let hpke_key = task.current_hpke_key(); - let report_share = generate_helper_report_share::( - *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), - Vec::new(), - &dummy_vdaf::InputShare::default(), - ); + let (prepare_init, _) = prep_init_generator.next(&()); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([prepare_init.clone()]), ); // Send request, and parse response. @@ -2331,13 +2287,16 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 1); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); + assert_eq!( + prepare_step.report_id(), + prepare_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), - &PrepareStepResult::Failed(ReportShareError::VdafPrepError) + &PrepareStepResult::Reject(PrepareError::VdafPrepError) ); } @@ -2347,31 +2306,24 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::FakeFailsPrepInit, + VdafInstance::FakeFailsPrepStep, Role::Helper, ) .build(); - let hpke_key = task.current_hpke_key(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + dummy_vdaf::AggregationParam(0), + ); + datastore.put_task(&task).await.unwrap(); - let report_share = generate_helper_report_share::( - *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), - Vec::new(), - &dummy_vdaf::InputShare::default(), - ); + let (prepare_init, _) = prep_init_generator.next(&()); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([prepare_init.clone()]), ); let aggregation_job_id: AggregationJobId = random(); @@ -2385,46 +2337,40 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 1); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); + assert_eq!( + prepare_step.report_id(), + prepare_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), - &PrepareStepResult::Failed(ReportShareError::VdafPrepError) + &PrepareStepResult::Reject(PrepareError::VdafPrepError) ); } #[tokio::test] async fn aggregate_init_duplicated_report_id() { - let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; - - let task = TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::FakeFailsPrepInit, - Role::Helper, - ) - .build(); - datastore.put_task(&task).await.unwrap(); + let (clock, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; - let report_share = ReportShare::new( - ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - Vec::from("PUBLIC"), - HpkeCiphertext::new( - // bogus, but we never get far enough to notice - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + let task = + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + dummy_vdaf::AggregationParam(0), ); + datastore.put_task(&task).await.unwrap(); + + let (prepare_init, _) = prep_init_generator.next(&()); + let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone(), report_share]), + Vec::from([prepare_init.clone(), prepare_init]), ); let aggregation_job_id: AggregationJobId = random(); @@ -2451,14 +2397,17 @@ mod tests { let aggregation_job_id = random(); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::::new(1)); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(vec![measurement.clone()]).unwrap(); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( @@ -2471,19 +2420,19 @@ mod tests { let transcript_0 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, _) = transcript_0.helper_prep_state(0); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let helper_prep_state_0 = transcript_0.helper_prepare_transitions[0].prepare_state(); + let leader_prep_message_0 = &transcript_0.leader_prepare_transitions[1].message; + let report_share_0 = generate_helper_report_share::>( *task.id(), report_metadata_0.clone(), hpke_key.config(), &transcript_0.public_share, Vec::new(), - &transcript_0.input_shares[1], + &transcript_0.helper_input_share, ); // report_share_1 is omitted by the leader's request. @@ -2497,19 +2446,19 @@ mod tests { let transcript_1 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - let (prep_state_1, _) = transcript_1.helper_prep_state(0); - let report_share_1 = generate_helper_report_share::( + let helper_prep_state_1 = transcript_1.helper_prepare_transitions[0].prepare_state(); + let report_share_1 = generate_helper_report_share::>( *task.id(), report_metadata_1.clone(), hpke_key.config(), &transcript_1.public_share, Vec::new(), - &transcript_1.input_shares[1], + &transcript_1.helper_input_share, ); // report_share_2 falls into a batch that has already been collected. @@ -2526,19 +2475,19 @@ mod tests { let transcript_2 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, _) = transcript_2.helper_prep_state(0); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let helper_prep_state_2 = transcript_2.helper_prepare_transitions[0].prepare_state(); + let leader_prep_message_2 = &transcript_2.leader_prepare_transitions[1].message; + let report_share_2 = generate_helper_report_share::>( *task.id(), report_metadata_2.clone(), hpke_key.config(), &transcript_2.public_share, Vec::new(), - &transcript_2.input_shares[1], + &transcript_2.helper_input_share, ); datastore @@ -2549,16 +2498,18 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_state_0, prep_state_1, prep_state_2) = ( - prep_state_0.clone(), - prep_state_1.clone(), - prep_state_2.clone(), + let (helper_prep_state_0, helper_prep_state_1, helper_prep_state_2) = ( + helper_prep_state_0.clone(), + helper_prep_state_1.clone(), + helper_prep_state_2.clone(), ); let (report_metadata_0, report_metadata_1, report_metadata_2) = ( report_metadata_0.clone(), report_metadata_1.clone(), report_metadata_2.clone(), ); + let aggregation_param = aggregation_param.clone(); + let helper_aggregate_share = transcript_0.helper_aggregate_share.clone(); Box::pin(async move { tx.put_task(&task).await?; @@ -2570,11 +2521,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2583,7 +2534,7 @@ mod tests { )) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2591,11 +2542,11 @@ mod tests { *report_metadata_0.time(), 0, None, - ReportAggregationState::Waiting(prep_state_0, None), + ReportAggregationState::WaitingHelper(helper_prep_state_0), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2603,11 +2554,11 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Waiting(prep_state_1, None), + ReportAggregationState::WaitingHelper(helper_prep_state_1), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2615,12 +2566,12 @@ mod tests { *report_metadata_2.time(), 2, None, - ReportAggregationState::Waiting(prep_state_2, None), + ReportAggregationState::WaitingHelper(helper_prep_state_2), ), ) .await?; - tx.put_aggregate_share_job::( + tx.put_aggregate_share_job::>( &AggregateShareJob::new( *task.id(), Interval::new( @@ -2628,8 +2579,8 @@ mod tests { *task.time_precision(), ) .unwrap(), - (), - AggregateShare::from(OutputShare::from(Vec::from([Field64::from(7)]))), + aggregation_param.clone(), + helper_aggregate_share, 0, ReportIdChecksum::default(), ), @@ -2643,14 +2594,8 @@ mod tests { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([ - PrepareStep::new( - *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), - ), - PrepareStep::new( - *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), - ), + PrepareContinue::new(*report_metadata_0.id(), leader_prep_message_0.clone()), + PrepareContinue::new(*report_metadata_2.id(), leader_prep_message_2.clone()), ]), ); @@ -2662,10 +2607,10 @@ mod tests { assert_eq!( aggregate_resp, AggregationJobResp::new(Vec::from([ - PrepareStep::new(*report_metadata_0.id(), PrepareStepResult::Finished), - PrepareStep::new( + PrepareResp::new(*report_metadata_0.id(), PrepareStepResult::Finished), + PrepareResp::new( *report_metadata_2.id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), + PrepareStepResult::Reject(PrepareError::BatchCollected), ) ])) ); @@ -2676,7 +2621,7 @@ mod tests { let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2703,7 +2648,7 @@ mod tests { AggregationJob::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2721,7 +2666,7 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, - Some(PrepareStep::new( + Some(PrepareResp::new( *report_metadata_0.id(), PrepareStepResult::Finished )), @@ -2734,7 +2679,7 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Failed(ReportShareError::ReportDropped), + ReportAggregationState::Failed(PrepareError::ReportDropped), ), ReportAggregation::new( *task.id(), @@ -2742,11 +2687,11 @@ mod tests { *report_metadata_2.id(), *report_metadata_2.time(), 2, - Some(PrepareStep::new( + Some(PrepareResp::new( *report_metadata_2.id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected) + PrepareStepResult::Reject(PrepareError::BatchCollected) )), - ReportAggregationState::Failed(ReportShareError::BatchCollected), + ReportAggregationState::Failed(PrepareError::BatchCollected), ) ]) ); @@ -2758,7 +2703,7 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); @@ -2772,9 +2717,12 @@ mod tests { .unwrap(), ); - let vdaf = Prio3::new_count(2).unwrap(); + let vdaf = Poplar1::new(1); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(vec![measurement.clone()]).unwrap(); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( @@ -2787,20 +2735,19 @@ mod tests { let transcript_0 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, _) = transcript_0.helper_prep_state(0); - let out_share_0 = transcript_0.output_share(Role::Helper); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let helper_prep_state_0 = transcript_0.helper_prepare_transitions[0].prepare_state(); + let ping_pong_leader_message_0 = &transcript_0.leader_prepare_transitions[1].message; + let report_share_0 = generate_helper_report_share::>( *task.id(), report_metadata_0.clone(), hpke_key.config(), &transcript_0.public_share, Vec::new(), - &transcript_0.input_shares[1], + &transcript_0.helper_input_share, ); // report_share_1 is another "happy path" report to exercise in-memory accumulation of @@ -2815,20 +2762,19 @@ mod tests { let transcript_1 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - let (prep_state_1, _) = transcript_1.helper_prep_state(0); - let out_share_1 = transcript_1.output_share(Role::Helper); - let prep_msg_1 = transcript_1.prepare_messages[0].clone(); - let report_share_1 = generate_helper_report_share::( + let helper_prep_state_1 = transcript_1.helper_prepare_transitions[0].prepare_state(); + let ping_pong_leader_message_1 = &transcript_1.leader_prepare_transitions[1].message; + let report_share_1 = generate_helper_report_share::>( *task.id(), report_metadata_1.clone(), hpke_key.config(), &transcript_1.public_share, Vec::new(), - &transcript_1.input_shares[1], + &transcript_1.helper_input_share, ); // report_share_2 aggregates successfully, but into a distinct batch aggregation which has @@ -2843,19 +2789,19 @@ mod tests { let transcript_2 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, _) = transcript_2.helper_prep_state(0); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let helper_prep_state_2 = transcript_2.helper_prepare_transitions[0].prepare_state(); + let ping_pong_leader_message_2 = &transcript_2.leader_prepare_transitions[1].message; + let report_share_2 = generate_helper_report_share::>( *task.id(), report_metadata_2.clone(), hpke_key.config(), &transcript_2.public_share, Vec::new(), - &transcript_2.input_shares[1], + &transcript_2.helper_input_share, ); let first_batch_identifier = Interval::new( @@ -2875,11 +2821,11 @@ mod tests { ) .unwrap(); let second_batch_want_batch_aggregations = - empty_batch_aggregations::( + empty_batch_aggregations::>( &task, BATCH_AGGREGATION_SHARD_COUNT, &second_batch_identifier, - &(), + &aggregation_param, &[], ); @@ -2891,16 +2837,17 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_state_0, prep_state_1, prep_state_2) = ( - prep_state_0.clone(), - prep_state_1.clone(), - prep_state_2.clone(), + let (helper_prep_state_0, helper_prep_state_1, helper_prep_state_2) = ( + helper_prep_state_0.clone(), + helper_prep_state_1.clone(), + helper_prep_state_2.clone(), ); let (report_metadata_0, report_metadata_1, report_metadata_2) = ( report_metadata_0.clone(), report_metadata_1.clone(), report_metadata_2.clone(), ); + let aggregation_param = aggregation_param.clone(); let second_batch_want_batch_aggregations = second_batch_want_batch_aggregations.clone(); @@ -2914,11 +2861,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2927,48 +2874,55 @@ mod tests { )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_0.id(), - *report_metadata_0.time(), - 0, - None, - ReportAggregationState::Waiting(prep_state_0, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_0.id(), + *report_metadata_0.time(), + 0, + None, + ReportAggregationState::WaitingHelper(helper_prep_state_0), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_1.id(), - *report_metadata_1.time(), - 1, - None, - ReportAggregationState::Waiting(prep_state_1, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_1.id(), + *report_metadata_1.time(), + 1, + None, + ReportAggregationState::WaitingHelper(helper_prep_state_1), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_2.id(), - *report_metadata_2.time(), - 2, - None, - ReportAggregationState::Waiting(prep_state_2, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_2.id(), + *report_metadata_2.time(), + 2, + None, + ReportAggregationState::WaitingHelper(helper_prep_state_2), + )) .await?; for batch_identifier in [first_batch_identifier, second_batch_identifier] { - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), batch_identifier, - (), + aggregation_param.clone(), BatchState::Closed, 0, batch_identifier, @@ -2994,18 +2948,9 @@ mod tests { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([ - PrepareStep::new( - *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), - ), - PrepareStep::new( - *report_metadata_1.id(), - PrepareStepResult::Continued(prep_msg_1.get_encoded()), - ), - PrepareStep::new( - *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), - ), + PrepareContinue::new(*report_metadata_0.id(), ping_pong_leader_message_0.clone()), + PrepareContinue::new(*report_metadata_1.id(), ping_pong_leader_message_1.clone()), + PrepareContinue::new(*report_metadata_2.id(), ping_pong_leader_message_2.clone()), ]), ); @@ -3016,12 +2961,16 @@ mod tests { // Map the batch aggregation ordinal value to 0, as it may vary due to sharding. let first_batch_got_batch_aggregations: Vec<_> = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3035,7 +2984,7 @@ mod tests { *task.time_precision(), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3044,10 +2993,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, BatchAggregationState::Aggregating, agg.aggregate_share().cloned(), @@ -3059,7 +3008,13 @@ mod tests { .collect(); let aggregate_share = vdaf - .aggregate(&(), [out_share_0.clone(), out_share_1.clone()]) + .aggregate( + &aggregation_param, + [ + transcript_0.helper_output_share.clone(), + transcript_1.helper_output_share.clone(), + ], + ) .unwrap(); let checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) .updated_with(report_metadata_1.id()); @@ -3076,7 +3031,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, BatchAggregationState::Aggregating, Some(aggregate_share), @@ -3088,12 +3043,16 @@ mod tests { let second_batch_got_batch_aggregations = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_2) = - (task.clone(), vdaf.clone(), report_metadata_2.clone()); + let (task, vdaf, report_metadata_2, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_2.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3107,7 +3066,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3132,20 +3091,19 @@ mod tests { let transcript_3 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_3.id(), - &0, + &measurement, ); - let (prep_state_3, _) = transcript_3.helper_prep_state(0); - let out_share_3 = transcript_3.output_share(Role::Helper); - let prep_msg_3 = transcript_3.prepare_messages[0].clone(); - let report_share_3 = generate_helper_report_share::( + let helper_prep_state_3 = transcript_3.helper_prepare_transitions[0].prepare_state(); + let ping_pong_leader_message_3 = &transcript_3.leader_prepare_transitions[1].message; + let report_share_3 = generate_helper_report_share::>( *task.id(), report_metadata_3.clone(), hpke_key.config(), &transcript_3.public_share, Vec::new(), - &transcript_3.input_shares[1], + &transcript_3.helper_input_share, ); // report_share_4 gets aggregated into the second batch interval (which has already been @@ -3160,19 +3118,19 @@ mod tests { let transcript_4 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_4.id(), - &0, + &measurement, ); - let (prep_state_4, _) = transcript_4.helper_prep_state(0); - let prep_msg_4 = transcript_4.prepare_messages[0].clone(); - let report_share_4 = generate_helper_report_share::( + let helper_prep_state_4 = transcript_4.helper_prepare_transitions[0].prepare_state(); + let ping_pong_leader_message_4 = &transcript_4.leader_prepare_transitions[1].message; + let report_share_4 = generate_helper_report_share::>( *task.id(), report_metadata_4.clone(), hpke_key.config(), &transcript_4.public_share, Vec::new(), - &transcript_4.input_shares[1], + &transcript_4.helper_input_share, ); // report_share_5 also gets aggregated into the second batch interval (which has already @@ -3187,19 +3145,19 @@ mod tests { let transcript_5 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_5.id(), - &0, + &measurement, ); - let (prep_state_5, _) = transcript_5.helper_prep_state(0); - let prep_msg_5 = transcript_5.prepare_messages[0].clone(); - let report_share_5 = generate_helper_report_share::( + let helper_prep_state_5 = transcript_5.helper_prepare_transitions[0].prepare_state(); + let ping_pong_leader_message_5 = &transcript_5.leader_prepare_transitions[1].message; + let report_share_5 = generate_helper_report_share::>( *task.id(), report_metadata_5.clone(), hpke_key.config(), &transcript_5.public_share, Vec::new(), - &transcript_5.input_shares[1], + &transcript_5.helper_input_share, ); datastore @@ -3210,16 +3168,17 @@ mod tests { report_share_4.clone(), report_share_5.clone(), ); - let (prep_state_3, prep_state_4, prep_state_5) = ( - prep_state_3.clone(), - prep_state_4.clone(), - prep_state_5.clone(), + let (helper_prep_state_3, helper_prep_state_4, helper_prep_state_5) = ( + helper_prep_state_3.clone(), + helper_prep_state_4.clone(), + helper_prep_state_5.clone(), ); let (report_metadata_3, report_metadata_4, report_metadata_5) = ( report_metadata_3.clone(), report_metadata_4.clone(), report_metadata_5.clone(), ); + let aggregation_param = aggregation_param.clone(); Box::pin(async move { tx.put_report_share(task.id(), &report_share_3).await?; @@ -3229,11 +3188,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -3242,41 +3201,44 @@ mod tests { )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_3.id(), - *report_metadata_3.time(), - 3, - None, - ReportAggregationState::Waiting(prep_state_3, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_3.id(), + *report_metadata_3.time(), + 3, + None, + ReportAggregationState::WaitingHelper(helper_prep_state_3), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_4.id(), - *report_metadata_4.time(), - 4, - None, - ReportAggregationState::Waiting(prep_state_4, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_4.id(), + *report_metadata_4.time(), + 4, + None, + ReportAggregationState::WaitingHelper(helper_prep_state_4), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_5.id(), - *report_metadata_5.time(), - 5, - None, - ReportAggregationState::Waiting(prep_state_5, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_5.id(), + *report_metadata_5.time(), + 5, + None, + ReportAggregationState::WaitingHelper(helper_prep_state_5), + )) .await?; Ok(()) @@ -3288,18 +3250,9 @@ mod tests { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([ - PrepareStep::new( - *report_metadata_3.id(), - PrepareStepResult::Continued(prep_msg_3.get_encoded()), - ), - PrepareStep::new( - *report_metadata_4.id(), - PrepareStepResult::Continued(prep_msg_4.get_encoded()), - ), - PrepareStep::new( - *report_metadata_5.id(), - PrepareStepResult::Continued(prep_msg_5.get_encoded()), - ), + PrepareContinue::new(*report_metadata_3.id(), ping_pong_leader_message_3.clone()), + PrepareContinue::new(*report_metadata_4.id(), ping_pong_leader_message_4.clone()), + PrepareContinue::new(*report_metadata_5.id(), ping_pong_leader_message_5.clone()), ]), ); @@ -3311,12 +3264,16 @@ mod tests { // be the same) let merged_first_batch_aggregation = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3330,7 +3287,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3339,10 +3296,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, BatchAggregationState::Aggregating, agg.aggregate_share().cloned(), @@ -3356,8 +3313,14 @@ mod tests { let first_aggregate_share = vdaf .aggregate( - &(), - [out_share_0, out_share_1, out_share_3].into_iter().cloned(), + &aggregation_param, + [ + &transcript_0.helper_output_share, + &transcript_1.helper_output_share, + &transcript_3.helper_output_share, + ] + .into_iter() + .cloned(), ) .unwrap(); let first_checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) @@ -3376,7 +3339,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, BatchAggregationState::Aggregating, Some(first_aggregate_share), @@ -3388,12 +3351,16 @@ mod tests { let second_batch_got_batch_aggregations = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_2) = - (task.clone(), vdaf.clone(), report_metadata_2.clone()); + let (task, vdaf, report_metadata_2, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_2.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3407,7 +3374,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3421,12 +3388,28 @@ mod tests { } #[tokio::test] - async fn aggregate_continue_leader_sends_non_continue_transition() { + async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); + let report_id = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let transcript = run_vdaf( + &Poplar1::new_shake128(1), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), + ); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), @@ -3436,14 +3419,19 @@ mod tests { // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata) = (task.clone(), report_metadata.clone()); + let (task, aggregation_param, report_metadata, transcript) = ( + task.clone(), + aggregation_param.clone(), + report_metadata.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; tx.put_report_share( task.id(), &ReportShare::new( report_metadata.clone(), - Vec::from("Public Share"), + Vec::from("public share"), HpkeCiphertext::new( HpkeConfigId::from(42), Vec::from("012345"), @@ -3453,31 +3441,36 @@ mod tests { ) .await?; - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + tx.put_aggregation_job(&AggregationJob::< + 16, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + tx.put_report_aggregation( + &ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new( - Time::from_seconds_since_epoch(0), - Duration::from_seconds(1), - ) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *report_metadata.id(), + *report_metadata.time(), + 0, + None, + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), ), ) - .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( - *task.id(), - aggregation_job_id, - *report_metadata.id(), - *report_metadata.time(), - 0, - None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), - )) .await }) }) @@ -3487,21 +3480,25 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report_metadata.id(), - PrepareStepResult::Finished, + // An AggregationJobContinueReq should only ever contain Continue or Finished + PingPongMessage::Initialize { + prep_share: Vec::new(), + }, )]), ); - post_aggregation_job_expecting_error( - &task, - &aggregation_job_id, - &request, - &handler, - Status::BadRequest, - "urn:ietf:params:ppm:dap:error:unrecognizedMessage", - "The message type for a response was incorrect or the payload was malformed.", - ) - .await; + + let resp = + post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; + assert_eq!(resp.prepare_resps().len(), 1); + assert_eq!( + resp.prepare_resps()[0], + PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Reject(PrepareError::VdafPrepError), + ) + ); } #[tokio::test] @@ -3511,61 +3508,78 @@ mod tests { // Prepare parameters. let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::FakeFailsPrepStep, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); + let vdaf = Poplar1::new_shake128(1); + let report_id = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let transcript = run_vdaf( + &vdaf, + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), + ); let aggregation_job_id = random(); - let report_metadata = ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), + let report_metadata = ReportMetadata::new(report_id, Time::from_seconds_since_epoch(54321)); + let helper_report_share = generate_helper_report_share::>( + *task.id(), + report_metadata.clone(), + task.current_hpke_key().config(), + &transcript.public_share, + Vec::new(), + &transcript.helper_input_share, ); // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata) = (task.clone(), report_metadata.clone()); + let (task, aggregation_param, report_metadata, transcript, helper_report_share) = ( + task.clone(), + aggregation_param.clone(), + report_metadata.clone(), + transcript.clone(), + helper_report_share.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; - tx.put_report_share( - task.id(), - &ReportShare::new( - report_metadata.clone(), - Vec::from("public share"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), - ) + tx.put_report_share(task.id(), &helper_report_share).await?; + tx.put_aggregation_job(&AggregationJob::< + 16, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) .await?; - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation( + &ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new( - Time::from_seconds_since_epoch(0), - Duration::from_seconds(1), - ) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *report_metadata.id(), + *report_metadata.time(), + 0, + None, + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), ), ) - .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( - *task.id(), - aggregation_job_id, - *report_metadata.id(), - *report_metadata.time(), - 0, - None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), - )) .await }) }) @@ -3575,9 +3589,12 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report_metadata.id(), - PrepareStepResult::Continued(Vec::new()), + PingPongMessage::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -3585,19 +3602,20 @@ mod tests { post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *report_metadata.id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError), + PrepareStepResult::Reject(PrepareError::VdafPrepError), )]),) ); // Check datastore state. let (aggregation_job, report_aggregation) = datastore .run_tx(|tx| { - let (task, report_metadata) = (task.clone(), report_metadata.clone()); + let (vdaf, task, report_metadata) = + (vdaf.clone(), task.clone(), report_metadata.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::<16, TimeInterval, Poplar1>( task.id(), &aggregation_job_id, ) @@ -3606,10 +3624,11 @@ mod tests { .unwrap(); let report_aggregation = tx .get_report_aggregation( - &dummy_vdaf::Vdaf::default(), + &vdaf, &Role::Helper, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), report_metadata.id(), ) .await @@ -3626,7 +3645,7 @@ mod tests { AggregationJob::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -3643,11 +3662,11 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - Some(PrepareStep::new( + Some(PrepareResp::new( *report_metadata.id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError) + PrepareStepResult::Reject(PrepareError::VdafPrepError) )), - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), ) ); } @@ -3657,18 +3676,36 @@ mod tests { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); - let aggregation_job_id = random(); - let report_metadata = ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); + let report_id = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let transcript = run_vdaf( + &Poplar1::new_shake128(1), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), ); + let aggregation_job_id = random(); + let report_metadata = ReportMetadata::new(report_id, Time::from_seconds_since_epoch(54321)); // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata) = (task.clone(), report_metadata.clone()); + let (task, aggregation_param, report_metadata, transcript) = ( + task.clone(), + aggregation_param.clone(), + report_metadata.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; @@ -3685,31 +3722,36 @@ mod tests { ), ) .await?; - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + tx.put_aggregation_job(&AggregationJob::< + 16, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + tx.put_report_aggregation( + &ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new( - Time::from_seconds_since_epoch(0), - Duration::from_seconds(1), - ) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *report_metadata.id(), + *report_metadata.time(), + 0, + None, + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), ), ) - .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( - *task.id(), - aggregation_job_id, - *report_metadata.id(), - *report_metadata.time(), - 0, - None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), - )) .await }) }) @@ -3719,11 +3761,14 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( ReportId::from( [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], // not the same as above ), - PrepareStepResult::Continued(Vec::new()), + PingPongMessage::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -3744,25 +3789,55 @@ mod tests { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); + let report_id_0 = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let transcript_0 = run_vdaf( + &Poplar1::new_shake128(1), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &aggregation_param, + &report_id_0, + &IdpfInput::from_bools(&[false]), + ); + let report_metadata_0 = + ReportMetadata::new(report_id_0, Time::from_seconds_since_epoch(54321)); + let report_id_1 = random(); + let transcript_1 = run_vdaf( + &Poplar1::new_shake128(1), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &aggregation_param, + &report_id_1, + &IdpfInput::from_bools(&[false]), + ); + let report_metadata_1 = + ReportMetadata::new(report_id_1, Time::from_seconds_since_epoch(54321)); let aggregation_job_id = random(); - let report_metadata_0 = ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ); - let report_metadata_1 = ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(54321), - ); // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata_0, report_metadata_1) = ( + let ( + task, + aggregation_param, + report_metadata_0, + report_metadata_1, + transcript_0, + transcript_1, + ) = ( task.clone(), + aggregation_param.clone(), report_metadata_0.clone(), report_metadata_1.clone(), + transcript_0.clone(), + transcript_1.clone(), ); Box::pin(async move { @@ -3795,42 +3870,53 @@ mod tests { ) .await?; - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + tx.put_aggregation_job(&AggregationJob::< + 16, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + + tx.put_report_aggregation( + &ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new( - Time::from_seconds_since_epoch(0), - Duration::from_seconds(1), - ) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *report_metadata_0.id(), + *report_metadata_0.time(), + 0, + None, + ReportAggregationState::WaitingHelper( + transcript_0.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), ), ) .await?; - - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( - *task.id(), - aggregation_job_id, - *report_metadata_0.id(), - *report_metadata_0.time(), - 0, - None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), - )) - .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( - *task.id(), - aggregation_job_id, - *report_metadata_1.id(), - *report_metadata_1.time(), - 1, - None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), - )) + tx.put_report_aggregation( + &ReportAggregation::<16, Poplar1>::new( + *task.id(), + aggregation_job_id, + *report_metadata_1.id(), + *report_metadata_1.time(), + 1, + None, + ReportAggregationState::WaitingHelper( + transcript_1.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), + ), + ) .await }) }) @@ -3842,13 +3928,19 @@ mod tests { AggregationJobRound::from(1), Vec::from([ // Report IDs are in opposite order to what was stored in the datastore. - PrepareStep::new( + PrepareContinue::new( *report_metadata_1.id(), - PrepareStepResult::Continued(Vec::new()), + PingPongMessage::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), - PrepareStep::new( + PrepareContinue::new( *report_metadata_0.id(), - PrepareStepResult::Continued(Vec::new()), + PingPongMessage::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), ]), ); @@ -3919,7 +4011,7 @@ mod tests { *report_metadata.time(), 0, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), )) .await }) @@ -3930,9 +4022,12 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - PrepareStepResult::Continued(Vec::new()), + PingPongMessage::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); post_aggregation_job_expecting_error( diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 511d5db48..9ce14b4d2 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -44,16 +44,15 @@ use janus_messages::{ }, AggregateShare as AggregateShareMessage, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, - AggregationJobRound, BatchSelector, Duration, Interval, PartialBatchSelector, PrepareStep, - PrepareStepResult, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, + AggregationJobRound, BatchSelector, Duration, Interval, PartialBatchSelector, PrepareContinue, + PrepareInit, PrepareResp, PrepareStepResult, ReportIdChecksum, ReportMetadata, ReportShare, + Role, TaskId, Time, }; use prio::{ - field::Field64, - flp::types::Count, + idpf::IdpfInput, vdaf::{ - prio3::{Prio3, Prio3Count}, + poplar1::{Poplar1, Poplar1AggregationParam}, xof::XofShake128, - AggregateShare, OutputShare, }, }; use rand::random; @@ -66,7 +65,7 @@ use trillium_testing::{ prelude::{post, put}, }; -type TestVdaf = Prio3, XofShake128, 16>; +type TestVdaf = Poplar1; pub struct TaskprovTestCase { _ephemeral_datastore: EphemeralDatastore, @@ -81,6 +80,7 @@ pub struct TaskprovTestCase { task: Task, task_config: TaskConfig, task_id: TaskId, + aggregation_param: Poplar1AggregationParam, } async fn setup_taskprov_test() -> TaskprovTestCase { @@ -143,19 +143,29 @@ async fn setup_taskprov_test() -> TaskprovTestCase { TaskprovQuery::FixedSize { max_batch_size }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let mut task_config_encoded = vec![]; task_config.encode(&mut task_config_encoded); - // We use a real VDAF since taskprov doesn't have any allowance for a test VDAF. - let vdaf = Prio3Count::new_count(2).unwrap(); + // We use a real VDAF since taskprov doesn't have any allowance for a test VDAF, and we use + // Poplar1 so that the VDAF wil take more than one round, so we can exercise aggregation + // continuation. + let vdaf = Poplar1::new(1); let task_id = TaskId::try_from(digest(&SHA256, &task_config_encoded).as_ref()).unwrap(); let vdaf_instance = task_config.vdaf_config().vdaf_type().try_into().unwrap(); let vdaf_verify_key = peer_aggregator.derive_vdaf_verify_key(&task_id, &vdaf_instance); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); + let measurement = IdpfInput::from_bools(&[true]); let task = janus_aggregator_core::taskprov::Task::new( task_id, @@ -187,9 +197,9 @@ async fn setup_taskprov_test() -> TaskprovTestCase { let transcript = run_vdaf( &vdaf, vdaf_verify_key.as_ref().try_into().unwrap(), - &(), + &aggregation_param, report_metadata.id(), - &1, + &measurement, ); let report_share = generate_helper_report_share::( task_id, @@ -197,7 +207,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { global_hpke_key.config(), &transcript.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript.helper_input_share, ); TaskprovTestCase { @@ -213,6 +223,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { report_metadata, transcript, report_share, + aggregation_param, } } @@ -222,9 +233,14 @@ async fn taskprov_aggregate_init() { let batch_id = random(); let request = AggregationJobInitializeReq::new( - ().get_encoded(), + test.aggregation_param.get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -287,10 +303,10 @@ async fn taskprov_aggregate_init() { ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregate_resp.prepare_steps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); assert_eq!(prepare_step.report_id(), test.report_share.metadata().id()); - assert_matches!(prepare_step.result(), &PrepareStepResult::Continued(..)); + assert_matches!(prepare_step.result(), &PrepareStepResult::Continue { .. }); let (aggregation_jobs, got_task) = test .datastore @@ -328,7 +344,12 @@ async fn taskprov_opt_out_task_expired() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -378,7 +399,12 @@ async fn taskprov_opt_out_mismatched_task_id() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -404,7 +430,11 @@ async fn taskprov_opt_out_mismatched_task_id() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); @@ -452,7 +482,12 @@ async fn taskprov_opt_out_missing_aggregator() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -475,7 +510,11 @@ async fn taskprov_opt_out_missing_aggregator() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let another_task_config_encoded = another_task_config.get_encoded(); @@ -490,8 +529,7 @@ async fn taskprov_opt_out_missing_aggregator() { .request_authentication(); let mut test_conn = put(format!( - "/tasks/{another_task_id -}/aggregation_jobs/{aggregation_job_id}" + "/tasks/{another_task_id}/aggregation_jobs/{aggregation_job_id}" )) .with_request_header(auth.0, auth.1) .with_request_header( @@ -525,7 +563,12 @@ async fn taskprov_opt_out_peer_aggregator_wrong_role() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -551,7 +594,11 @@ async fn taskprov_opt_out_peer_aggregator_wrong_role() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let another_task_config_encoded = another_task_config.get_encoded(); @@ -566,8 +613,7 @@ async fn taskprov_opt_out_peer_aggregator_wrong_role() { .request_authentication(); let mut test_conn = put(format!( - "/tasks/{another_task_id -}/aggregation_jobs/{aggregation_job_id}" + "/tasks/{another_task_id}/aggregation_jobs/{aggregation_job_id}" )) .with_request_header(auth.0, auth.1) .with_request_header( @@ -602,7 +648,12 @@ async fn taskprov_opt_out_peer_aggregator_does_not_exist() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -628,7 +679,11 @@ async fn taskprov_opt_out_peer_aggregator_does_not_exist() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let another_task_config_encoded = another_task_config.get_encoded(); @@ -643,8 +698,7 @@ async fn taskprov_opt_out_peer_aggregator_does_not_exist() { .request_authentication(); let mut test_conn = put(format!( - "/tasks/{another_task_id -}/aggregation_jobs/{aggregation_job_id}" + "/tasks/{another_task_id}/aggregation_jobs/{aggregation_job_id}" )) .with_request_header(auth.0, auth.1) .with_request_header( @@ -678,15 +732,13 @@ async fn taskprov_aggregate_continue() { let aggregation_job_id = random(); let batch_id = random(); - let (prep_state, _) = test.transcript.helper_prep_state(0); - let prep_msg = test.transcript.prepare_messages[0].clone(); - test.datastore .run_tx(|tx| { let task = test.task.clone(); let report_share = test.report_share.clone(); - let prep_state = prep_state.clone(); let report_metadata = test.report_metadata.clone(); + let transcript = test.transcript.clone(); + let aggregation_param = test.aggregation_param.clone(); Box::pin(async move { // Aggregate continue is only possible if the task has already been inserted. @@ -698,7 +750,7 @@ async fn taskprov_aggregate_continue() { &AggregationJob::::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -715,7 +767,11 @@ async fn taskprov_aggregate_continue() { *report_metadata.time(), 0, None, - ReportAggregationState::Waiting(prep_state, None), + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), )) .await?; @@ -723,8 +779,8 @@ async fn taskprov_aggregate_continue() { &AggregateShareJob::new( *task.id(), batch_id, - (), - AggregateShare::from(OutputShare::from(Vec::from([Field64::from(7)]))), + aggregation_param, + transcript.helper_aggregate_share, 0, ReportIdChecksum::default(), ), @@ -737,9 +793,11 @@ async fn taskprov_aggregate_continue() { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *test.report_metadata.id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + test.transcript.leader_prepare_transitions[1] + .message + .clone(), )]), ); @@ -801,11 +859,11 @@ async fn taskprov_aggregate_continue() { assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - // We'll only validate the response. Taskprov doesn't touch functionality beyond the authorization - // of the request. + // We'll only validate the response. Taskprov doesn't touch functionality beyond the + // authorization of the request. assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *test.report_metadata.id(), PrepareStepResult::Finished )])) @@ -823,6 +881,8 @@ async fn taskprov_aggregate_share() { let interval = Interval::new(Time::from_seconds_since_epoch(6000), *task.time_precision()) .unwrap(); + let aggregation_param = test.aggregation_param.clone(); + let transcript = test.transcript.clone(); Box::pin(async move { tx.put_task(&task).await?; @@ -830,7 +890,7 @@ async fn taskprov_aggregate_share() { tx.put_batch(&Batch::<16, FixedSize, TestVdaf>::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Closed, 0, interval, @@ -841,12 +901,10 @@ async fn taskprov_aggregate_share() { tx.put_batch_aggregation(&BatchAggregation::<16, FixedSize, TestVdaf>::new( *task.id(), batch_id, - (), + aggregation_param, 0, BatchAggregationState::Aggregating, - Some(AggregateShare::from(OutputShare::from(Vec::from([ - Field64::from(7), - ])))), + Some(transcript.helper_aggregate_share), 1, interval, ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), @@ -861,7 +919,7 @@ async fn taskprov_aggregate_share() { let request = AggregateShareReq::new( BatchSelector::new_fixed_size(batch_id), - ().get_encoded(), + test.aggregation_param.get_encoded(), 1, ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), ); @@ -910,8 +968,6 @@ async fn taskprov_aggregate_share() { .run_async(&test.handler) .await; - println!("{:?}", test_conn); - assert_eq!(test_conn.status(), Some(Status::Ok)); assert_headers!( &test_conn, @@ -943,9 +999,14 @@ async fn end_to_end() { let aggregation_job_id = random(); let aggregation_job_init_request = AggregationJobInitializeReq::new( - ().get_encoded(), + test.aggregation_param.get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let mut test_conn = put(test @@ -970,23 +1031,25 @@ async fn end_to_end() { assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregation_job_resp.prepare_steps().len(), 1); - let prepare_step = &aggregation_job_resp.prepare_steps()[0]; - assert_eq!(prepare_step.report_id(), test.report_metadata.id()); - let encoded_prep_share = assert_matches!( - prepare_step.result(), - PrepareStepResult::Continued(payload) => payload.clone() + assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); + let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resp.report_id(), test.report_metadata.id()); + let message = assert_matches!( + prepare_resp.result(), + PrepareStepResult::Continue { message } => message.clone() ); assert_eq!( - encoded_prep_share, - test.transcript.helper_prep_state(0).1.get_encoded() + message, + test.transcript.helper_prepare_transitions[0].message, ); let aggregation_job_continue_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *test.report_metadata.id(), - PrepareStepResult::Continued(test.transcript.prepare_messages[0].get_encoded()), + test.transcript.leader_prepare_transitions[1] + .message + .clone(), )]), ); @@ -1013,15 +1076,15 @@ async fn end_to_end() { assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregation_job_resp.prepare_steps().len(), 1); - let prepare_step = &aggregation_job_resp.prepare_steps()[0]; - assert_eq!(prepare_step.report_id(), test.report_metadata.id()); - assert_matches!(prepare_step.result(), PrepareStepResult::Finished); + assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); + let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resp.report_id(), test.report_metadata.id()); + assert_matches!(prepare_resp.result(), PrepareStepResult::Finished); let checksum = ReportIdChecksum::for_report_id(test.report_metadata.id()); let aggregate_share_request = AggregateShareReq::new( BatchSelector::new_fixed_size(batch_id), - ().get_encoded(), + test.aggregation_param.get_encoded(), 1, checksum, ); @@ -1056,5 +1119,8 @@ async fn end_to_end() { .get_encoded(), ) .unwrap(); - assert_eq!(plaintext, test.transcript.aggregate_shares[1].get_encoded()); + assert_eq!( + plaintext, + test.transcript.helper_aggregate_share.get_encoded() + ); } diff --git a/aggregator_api/Cargo.toml b/aggregator_api/Cargo.toml index 3db907dbd..60ba4da6f 100644 --- a/aggregator_api/Cargo.toml +++ b/aggregator_api/Cargo.toml @@ -11,7 +11,7 @@ version.workspace = true [dependencies] anyhow.workspace = true async-trait = "0.1" -base64 = "0.21.4" +base64.workspace = true janus_aggregator_core.workspace = true janus_core.workspace = true janus_messages.workspace = true diff --git a/aggregator_core/Cargo.toml b/aggregator_core/Cargo.toml index 7b35e2369..7d8a74154 100644 --- a/aggregator_core/Cargo.toml +++ b/aggregator_core/Cargo.toml @@ -16,7 +16,7 @@ test-util = ["dep:hex", "dep:sqlx", "dep:testcontainers", "janus_core/test-util" anyhow.workspace = true async-trait = "0.1" backoff = { version = "0.4.0", features = ["tokio"] } -base64 = "0.21.4" +base64.workspace = true bytes = "1.5.0" chrono = "0.4" deadpool = { version = "0.9.5", features = ["rt_tokio_1"] } diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index b864e6a2f..6e4d54dea 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -13,7 +13,6 @@ use crate::{ taskprov::{self, PeerAggregator}, SecretBytes, }; -use anyhow::anyhow; use chrono::NaiveDateTime; use futures::future::try_join_all; use janus_core::{ @@ -24,7 +23,7 @@ use janus_core::{ use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregationJobId, BatchId, CollectionJobId, Duration, Extension, HpkeCiphertext, HpkeConfig, - HpkeConfigId, Interval, PrepareStep, Query, ReportId, ReportIdChecksum, ReportMetadata, + HpkeConfigId, Interval, PrepareResp, Query, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, }; use opentelemetry::{ @@ -34,6 +33,7 @@ use opentelemetry::{ use postgres_types::{FromSql, Json, ToSql}; use prio::{ codec::{decode_u16_items, encode_u16_items, CodecError, Decode, Encode, ParameterizedDecode}, + topology::ping_pong::PingPongTransition, vdaf, }; use rand::random; @@ -2097,6 +2097,7 @@ impl Transaction<'_, C> { role: &Role, task_id: &TaskId, aggregation_job_id: &AggregationJobId, + aggregation_param: &A::AggregationParam, report_id: &ReportId, ) -> Result>, Error> where @@ -2106,9 +2107,9 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT report_aggregations.client_timestamp, report_aggregations.ord, - report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.error_code, - report_aggregations.last_prep_step + report_aggregations.state, report_aggregations.helper_prep_state, + report_aggregations.leader_prep_transition, report_aggregations.error_code, + report_aggregations.last_prep_resp FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2162,8 +2163,8 @@ impl Transaction<'_, C> { "SELECT report_aggregations.client_report_id, report_aggregations.client_timestamp, report_aggregations.ord, report_aggregations.state, - report_aggregations.prep_state, report_aggregations.prep_msg, - report_aggregations.error_code, report_aggregations.last_prep_step + report_aggregations.helper_prep_state, report_aggregations.leader_prep_transition, + report_aggregations.error_code, report_aggregations.last_prep_resp FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2216,9 +2217,9 @@ impl Transaction<'_, C> { "SELECT aggregation_jobs.aggregation_job_id, report_aggregations.client_report_id, report_aggregations.client_timestamp, report_aggregations.ord, - report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.error_code, - report_aggregations.last_prep_step + report_aggregations.state, report_aggregations.helper_prep_state, + report_aggregations.leader_prep_transition, report_aggregations.error_code, + report_aggregations.last_prep_resp FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2262,10 +2263,8 @@ impl Transaction<'_, C> { let time = Time::from_naive_date_time(&row.get("client_timestamp")); let ord: u64 = row.get_bigint_and_convert("ord")?; let state: ReportAggregationStateCode = row.get("state"); - let prep_state_bytes: Option> = row.get("prep_state"); - let prep_msg_bytes: Option> = row.get("prep_msg"); let error_code: Option = row.get("error_code"); - let last_prep_step_bytes: Option> = row.get("last_prep_step"); + let last_prep_resp_bytes: Option> = row.get("last_prep_resp"); let error_code = match error_code { Some(c) => { @@ -2279,31 +2278,49 @@ impl Transaction<'_, C> { None => None, }; - let last_prep_step = last_prep_step_bytes - .map(|bytes| PrepareStep::get_decoded(&bytes)) + let last_prep_resp = last_prep_resp_bytes + .map(|bytes| PrepareResp::get_decoded(&bytes)) .transpose()?; let agg_state = match state { ReportAggregationStateCode::Start => ReportAggregationState::Start, ReportAggregationStateCode::Waiting => { - let agg_index = role.index().ok_or_else(|| { - Error::User(anyhow!("unexpected role: {}", role.as_str()).into()) - })?; - let prep_state = A::PrepareState::get_decoded_with_param( - &(vdaf, agg_index), - &prep_state_bytes.ok_or_else(|| { - Error::DbState( - "report aggregation in state WAITING but prep_state is NULL" - .to_string(), - ) - })?, - )?; - let prep_msg = prep_msg_bytes - .map(|bytes| A::PrepareMessage::get_decoded_with_param(&prep_state, &bytes)) - .transpose()?; - - ReportAggregationState::Waiting(prep_state, prep_msg) + match role { + Role::Leader => { + let leader_prep_transition_bytes = row + .get::<_, Option>>("leader_prep_transition") + .ok_or_else(|| { + Error::DbState( + "report aggregation in state WAITING but leader_prep_transition is NULL" + .to_string(), + ) + })?; + let ping_pong_transition = PingPongTransition::get_decoded_with_param( + &(vdaf, 0 /* leader */), + &leader_prep_transition_bytes, + )?; + + ReportAggregationState::WaitingLeader(ping_pong_transition) + } + Role::Helper => { + let helper_prep_state_bytes = row + .get::<_, Option>>("helper_prep_state") + .ok_or_else(|| { + Error::DbState( + "report aggregation in state WAITING but helper_prep_state is NULL" + .to_string(), + ) + })?; + let prepare_state = A::PrepareState::get_decoded_with_param( + &(vdaf, 1 /* helper */), + &helper_prep_state_bytes, + )?; + + ReportAggregationState::WaitingHelper(prepare_state) + } + _ => panic!("unexpected role"), + } } ReportAggregationStateCode::Finished => ReportAggregationState::Finished, @@ -2323,7 +2340,7 @@ impl Transaction<'_, C> { *report_id, time, ord, - last_prep_step, + last_prep_resp, agg_state, )) } @@ -2341,15 +2358,15 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); - let encoded_last_prep_step = report_aggregation - .last_prep_step() - .map(PrepareStep::get_encoded); + let encoded_last_prep_resp: Option> = report_aggregation + .last_prep_resp() + .map(PrepareResp::get_encoded); let stmt = self .prepare_cached( "INSERT INTO report_aggregations - (aggregation_job_id, client_report_id, client_timestamp, ord, state, prep_state, - prep_msg, error_code, last_prep_step) + (aggregation_job_id, client_report_id, client_timestamp, ord, state, + helper_prep_state, leader_prep_transition, error_code, last_prep_resp) SELECT aggregation_jobs.id, $3, $4, $5, $6, $7, $8, $9, $10 FROM aggregation_jobs JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2368,10 +2385,10 @@ impl Transaction<'_, C> { /* client_timestamp */ &report_aggregation.time().as_naive_date_time()?, /* ord */ &TryInto::::try_into(report_aggregation.ord())?, /* state */ &report_aggregation.state().state_code(), - /* prep_state */ &encoded_state_values.prep_state, - /* prep_msg */ &encoded_state_values.prep_msg, - /* error_code */ &encoded_state_values.report_share_err, - /* last_prep_step */ &encoded_last_prep_step, + /* helper_prep_state */ &encoded_state_values.helper_prep_state, + /* leader_prep_transition */ &encoded_state_values.leader_prep_transition, + /* error_code */ &encoded_state_values.prepare_err, + /* last_prep_resp */ &encoded_last_prep_resp, /* now */ &self.clock.now().as_naive_date_time()?, ], ) @@ -2391,14 +2408,14 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); - let encoded_last_prep_step = report_aggregation - .last_prep_step() - .map(PrepareStep::get_encoded); + let encoded_last_prep_resp: Option> = report_aggregation + .last_prep_resp() + .map(PrepareResp::get_encoded); let stmt = self .prepare_cached( "UPDATE report_aggregations - SET state = $1, prep_state = $2, prep_msg = $3, error_code = $4, last_prep_step = $5 + SET state = $1, helper_prep_state = $2, leader_prep_transition = $3, error_code = $4, last_prep_resp = $5 FROM aggregation_jobs, tasks WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id AND aggregation_jobs.task_id = tasks.id @@ -2416,10 +2433,11 @@ impl Transaction<'_, C> { &[ /* state */ &report_aggregation.state().state_code(), - /* prep_state */ &encoded_state_values.prep_state, - /* prep_msg */ &encoded_state_values.prep_msg, - /* error_code */ &encoded_state_values.report_share_err, - /* last_prep_step */ &encoded_last_prep_step, + /* helper_prep_state */ &encoded_state_values.helper_prep_state, + /* leader_prep_transition */ + &encoded_state_values.leader_prep_transition, + /* error_code */ &encoded_state_values.prepare_err, + /* last_prep_resp */ &encoded_last_prep_resp, /* aggregation_job_id */ &report_aggregation.aggregation_job_id().as_ref(), /* task_id */ &report_aggregation.task_id().as_ref(), @@ -4427,7 +4445,7 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p + "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p WHERE p.id = a.peer_aggregator_id) AS peer_id, ord, type, token FROM taskprov_aggregator_auth_tokens AS a ORDER BY ord ASC", @@ -4437,7 +4455,7 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p + "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p WHERE p.id = a.peer_aggregator_id) AS peer_id, ord, type, token FROM taskprov_collector_auth_tokens AS a ORDER BY ord ASC", diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index a1092c967..e11f4cf62 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -13,8 +13,8 @@ use janus_core::{ use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregationJobId, AggregationJobRound, BatchId, CollectionJobId, Duration, Extension, - HpkeCiphertext, Interval, PrepareStep, Query, ReportId, ReportIdChecksum, ReportMetadata, - ReportShareError, Role, TaskId, Time, + HpkeCiphertext, Interval, PrepareError, PrepareResp, Query, ReportId, ReportIdChecksum, + ReportMetadata, Role, TaskId, Time, }; use postgres_protocol::types::{ range_from_sql, range_to_sql, timestamp_from_sql, timestamp_to_sql, Range, RangeBound, @@ -22,6 +22,7 @@ use postgres_protocol::types::{ use postgres_types::{accepts, to_sql_checked, FromSql, ToSql}; use prio::{ codec::Encode, + topology::ping_pong::PingPongTransition, vdaf::{self, Aggregatable}, }; use rand::{distributions::Standard, prelude::Distribution}; @@ -588,7 +589,7 @@ pub struct ReportAggregation, + last_prep_resp: Option, state: ReportAggregationState, } @@ -600,7 +601,7 @@ impl> ReportAggregati report_id: ReportId, time: Time, ord: u64, - last_prep_step: Option, + last_prep_resp: Option, state: ReportAggregationState, ) -> Self { Self { @@ -609,7 +610,7 @@ impl> ReportAggregati report_id, time, ord, - last_prep_step, + last_prep_resp, state, } } @@ -644,16 +645,16 @@ impl> ReportAggregati self.ord } - /// Returns the last preparation step returned by the Helper, if any. - pub fn last_prep_step(&self) -> Option<&PrepareStep> { - self.last_prep_step.as_ref() + /// Returns the last preparation response returned by the Helper, if any. + pub fn last_prep_resp(&self) -> Option<&PrepareResp> { + self.last_prep_resp.as_ref() } /// Returns a new [`ReportAggregation`] corresponding to this report aggregation updated to - /// have the given last preparation step. - pub fn with_last_prep_step(self, last_prep_step: Option) -> Self { + /// have the given last preparation response. + pub fn with_last_prep_resp(self, last_prep_resp: Option) -> Self { Self { - last_prep_step, + last_prep_resp, ..self } } @@ -688,7 +689,7 @@ where && self.report_id == other.report_id && self.time == other.time && self.ord == other.ord - && self.last_prep_step == other.last_prep_step + && self.last_prep_resp == other.last_prep_resp && self.state == other.state } } @@ -709,16 +710,19 @@ where /// ReportAggregationState represents the state of a single report aggregation. It corresponds /// to the REPORT_AGGREGATION_STATE enum in the schema, along with the state-specific data. -#[derive(Clone, Derivative)] -#[derivative(Debug)] +#[derive(Clone, Debug, Derivative)] pub enum ReportAggregationState> { Start, - Waiting( - #[derivative(Debug = "ignore")] A::PrepareState, - #[derivative(Debug = "ignore")] Option, + WaitingLeader( + /// Most recent transition for this report aggregation. + PingPongTransition, + ), + WaitingHelper( + /// Helper's current preparation state + A::PrepareState, ), Finished, - Failed(ReportShareError), + Failed(PrepareError), } impl> @@ -727,7 +731,9 @@ impl> pub fn state_code(&self) -> ReportAggregationStateCode { match self { ReportAggregationState::Start => ReportAggregationStateCode::Start, - ReportAggregationState::Waiting(_, _) => ReportAggregationStateCode::Waiting, + ReportAggregationState::WaitingLeader(_) | ReportAggregationState::WaitingHelper(_) => { + ReportAggregationStateCode::Waiting + } ReportAggregationState::Finished => ReportAggregationStateCode::Finished, ReportAggregationState::Failed(_) => ReportAggregationStateCode::Failed, } @@ -742,29 +748,32 @@ impl> { match self { ReportAggregationState::Start => EncodedReportAggregationStateValues::default(), - ReportAggregationState::Waiting(prep_state, prep_msg) => { + ReportAggregationState::WaitingLeader(transition) => { EncodedReportAggregationStateValues { - prep_state: Some(prep_state.get_encoded()), - prep_msg: prep_msg.as_ref().map(Encode::get_encoded), + leader_prep_transition: Some(transition.get_encoded()), ..Default::default() } } - ReportAggregationState::Finished => EncodedReportAggregationStateValues::default(), - ReportAggregationState::Failed(report_share_err) => { + ReportAggregationState::WaitingHelper(prepare_state) => { EncodedReportAggregationStateValues { - report_share_err: Some(*report_share_err as i16), + helper_prep_state: Some(prepare_state.get_encoded()), ..Default::default() } } + ReportAggregationState::Finished => EncodedReportAggregationStateValues::default(), + ReportAggregationState::Failed(prepare_err) => EncodedReportAggregationStateValues { + prepare_err: Some(*prepare_err as i16), + ..Default::default() + }, } } } #[derive(Default)] pub(super) struct EncodedReportAggregationStateValues { - pub(super) prep_state: Option>, - pub(super) prep_msg: Option>, - pub(super) report_share_err: Option, + pub(super) leader_prep_transition: Option>, + pub(super) helper_prep_state: Option>, + pub(super) prepare_err: Option, } // The private ReportAggregationStateCode exists alongside the public ReportAggregationState @@ -791,19 +800,19 @@ pub enum ReportAggregationStateCode { impl> PartialEq for ReportAggregationState where - A::PrepareState: PartialEq, - A::PrepareMessage: PartialEq, A::PrepareShare: PartialEq, A::OutputShare: PartialEq, { fn eq(&self, other: &Self) -> bool { match (self, other) { - ( - Self::Waiting(lhs_prep_state, lhs_prep_msg), - Self::Waiting(rhs_prep_state, rhs_prep_msg), - ) => lhs_prep_state == rhs_prep_state && lhs_prep_msg == rhs_prep_msg, - (Self::Failed(lhs_report_share_err), Self::Failed(rhs_report_share_err)) => { - lhs_report_share_err == rhs_report_share_err + (Self::WaitingLeader(lhs_transition), Self::WaitingLeader(rhs_transition)) => { + lhs_transition == rhs_transition + } + (Self::WaitingHelper(lhs_state), Self::WaitingHelper(rhs_state)) => { + lhs_state == rhs_state + } + (Self::Failed(lhs_prepare_err), Self::Failed(rhs_prepare_err)) => { + lhs_prepare_err == rhs_prepare_err } _ => core::mem::discriminant(self) == core::mem::discriminant(other), } diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 6d2702bcd..f57bc2047 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -34,12 +34,18 @@ use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregateShareAad, AggregationJobId, AggregationJobRound, BatchId, BatchSelector, CollectionJobId, Duration, Extension, ExtensionType, FixedSizeQuery, HpkeCiphertext, - HpkeConfigId, Interval, PrepareStep, PrepareStepResult, Query, ReportId, ReportIdChecksum, - ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + HpkeConfigId, Interval, PrepareError, PrepareResp, PrepareStepResult, Query, ReportId, + ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, - vdaf::prio3::{Prio3, Prio3Count}, + idpf::IdpfInput, + topology::ping_pong::PingPongMessage, + vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, + prio3::Prio3Count, + xof::XofShake128, + }, }; use rand::{distributions::Standard, random, thread_rng, Rng}; use std::{ @@ -1876,20 +1882,49 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { install_test_trace_subscriber(); let report_id = random(); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_shake128(1)); let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); - let leader_prep_state = vdaf_transcript.leader_prep_state(0); - - for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting( - leader_prep_state.clone(), - Some(vdaf_transcript.prepare_messages[0].clone()), + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &verify_key, + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), + ); + + for (ord, (role, state)) in [ + (Role::Leader, ReportAggregationState::Start), + (Role::Helper, ReportAggregationState::Start), + ( + Role::Leader, + ReportAggregationState::WaitingLeader( + vdaf_transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), + ), + ( + Role::Helper, + ReportAggregationState::WaitingHelper( + vdaf_transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), + ), + (Role::Leader, ReportAggregationState::Finished), + (Role::Helper, ReportAggregationState::Finished), + ( + Role::Leader, + ReportAggregationState::Failed(PrepareError::VdafPrepError), + ), + ( + Role::Helper, + ReportAggregationState::Failed(PrepareError::VdafPrepError), ), - ReportAggregationState::Waiting(leader_prep_state.clone(), None), - ReportAggregationState::Finished, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), ] .into_iter() .enumerate() @@ -1899,8 +1934,8 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, + VdafInstance::Poplar1 { bits: 1 }, + role, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) .build(); @@ -1909,17 +1944,18 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let want_report_aggregation = ds .run_tx(|tx| { - let (task, state) = (task.clone(), state.clone()); + let (task, state, aggregation_param) = + (task.clone(), state.clone(), aggregation_param.clone()); Box::pin(async move { tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), @@ -1947,9 +1983,14 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { report_id, OLDEST_ALLOWED_REPORT_TIMESTAMP, ord.try_into().unwrap(), - Some(PrepareStep::new( + Some(PrepareResp::new( report_id, - PrepareStepResult::Continued(format!("prep_msg_{ord}").into()), + PrepareStepResult::Continue { + message: PingPongMessage::Continue { + prep_msg: format!("prep_msg_{ord}").into(), + prep_share: format!("prep_share_{ord}").into(), + }, + }, )), state, ); @@ -1965,13 +2006,15 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Leader, + &role, task.id(), &aggregation_job_id, + &aggregation_param, &report_id, ) .await @@ -1989,9 +2032,14 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { *want_report_aggregation.report_id(), *want_report_aggregation.time(), want_report_aggregation.ord(), - Some(PrepareStep::new( + Some(PrepareResp::new( report_id, - PrepareStepResult::Continued(format!("updated_prep_msg_{ord}").into()), + PrepareStepResult::Continue { + message: PingPongMessage::Continue { + prep_msg: format!("updated_prep_msg_{ord}").into(), + prep_share: format!("updated_prep_share_{ord}").into(), + }, + }, )), want_report_aggregation.state().clone(), ); @@ -2005,13 +2053,15 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Leader, + &role, task.id(), &aggregation_job_id, + &aggregation_param, &report_id, ) .await @@ -2026,13 +2076,15 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Leader, + &role, task.id(), &aggregation_job_id, + &aggregation_param, &report_id, ) .await @@ -2214,6 +2266,7 @@ async fn report_aggregation_not_found(ephemeral_datastore: EphemeralDatastore) { &Role::Leader, &random(), &random(), + &dummy_vdaf::AggregationParam::default(), &ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), ) .await @@ -2233,7 +2286,7 @@ async fn report_aggregation_not_found(ephemeral_datastore: EphemeralDatastore) { Time::from_seconds_since_epoch(12345), 0, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), )) .await }) @@ -2251,14 +2304,24 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let ds = ephemeral_datastore.datastore(clock.clone()).await; let report_id = random(); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_shake128(1)); let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &verify_key, + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), + ); let task = TaskBuilder::new( task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) .build(); @@ -2266,35 +2329,41 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let want_report_aggregations = ds .run_tx(|tx| { - let (task, prep_msg, prep_state) = ( + let (task, vdaf_transcript, aggregation_param) = ( task.clone(), - vdaf_transcript.prepare_messages[0].clone(), - vdaf_transcript.leader_prep_state(0).clone(), + vdaf_transcript.clone(), + aggregation_param.clone(), ); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_task(&task).await.unwrap(); + tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), AggregationJobState::InProgress, AggregationJobRound::from(0), )) - .await?; + .await + .unwrap(); let mut want_report_aggregations = Vec::new(); for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting(prep_state.clone(), Some(prep_msg)), + ReportAggregationState::Start, + ReportAggregationState::WaitingHelper( + vdaf_transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), ReportAggregationState::Finished, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), ] .iter() .enumerate() @@ -2312,7 +2381,8 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme ), ), ) - .await?; + .await + .unwrap(); let report_aggregation = ReportAggregation::new( *task.id(), @@ -2320,10 +2390,12 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme report_id, OLDEST_ALLOWED_REPORT_TIMESTAMP, ord.try_into().unwrap(), - Some(PrepareStep::new(report_id, PrepareStepResult::Finished)), + Some(PrepareResp::new(report_id, PrepareStepResult::Finished)), state.clone(), ); - tx.put_report_aggregation(&report_aggregation).await?; + tx.put_report_aggregation(&report_aggregation) + .await + .unwrap(); want_report_aggregations.push(report_aggregation); } Ok(want_report_aggregations) @@ -2341,7 +2413,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme Box::pin(async move { tx.get_report_aggregations_for_aggregation_job( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, ) @@ -2361,7 +2433,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme Box::pin(async move { tx.get_report_aggregations_for_aggregation_job( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, ) @@ -2836,7 +2908,6 @@ async fn setup_collection_job_acquire_test_case( tx.put_client_report(&dummy_vdaf::Vdaf::new(), report) .await?; } - for aggregation_job in &test_case.aggregation_jobs { tx.put_aggregation_job(aggregation_job).await?; } @@ -4766,14 +4837,27 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { None, ReportAggregationState::Start, // Counted among max_size. ); + + let report_id_0_1 = random(); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task_1.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + &report_id_0_1, + &(), + ); + let report_aggregation_0_1 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task_1.id(), *aggregation_job_0.id(), - random(), + report_id_0_1, clock.now(), 1, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), Some(())), // Counted among max_size. + // Counted among max_size. + ReportAggregationState::WaitingLeader( + transcript.helper_prepare_transitions[0].transition.clone(), + ), ); let report_aggregation_0_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task_1.id(), @@ -4782,7 +4866,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { clock.now(), 2, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. + ReportAggregationState::Failed(PrepareError::VdafPrepError), // Not counted among min_size or max_size. ); let aggregation_job_1 = AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( @@ -4820,7 +4904,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { clock.now(), 2, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. + ReportAggregationState::Failed(PrepareError::VdafPrepError), // Not counted among min_size or max_size. ); let aggregation_job_2 = AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( diff --git a/collector/src/lib.rs b/collector/src/lib.rs index e5ea026ce..3d2df9f3e 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -705,14 +705,14 @@ mod tests { hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &transcript.aggregate_shares[0].get_encoded(), + &transcript.leader_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &transcript.aggregate_shares[1].get_encoded(), + &transcript.helper_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), @@ -733,14 +733,14 @@ mod tests { hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &transcript.aggregate_shares[0].get_encoded(), + &transcript.leader_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &transcript.aggregate_shares[1].get_encoded(), + &transcript.helper_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), diff --git a/core/Cargo.toml b/core/Cargo.toml index d88743a1a..4cd73cf0a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -34,7 +34,7 @@ test-util = [ anyhow.workspace = true assert_matches = { version = "1", optional = true } backoff = { version = "0.4.0", features = ["tokio"] } -base64 = "0.21.4" +base64.workspace = true chrono = { workspace = true, features = ["clock"] } derivative = "2.2.0" futures = "0.3.28" diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 44343761f..1fff7e7dc 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -1,6 +1,12 @@ use assert_matches::assert_matches; use janus_messages::{ReportId, Role}; -use prio::vdaf::{self, PrepareTransition, VdafError}; +use prio::{ + topology::ping_pong::{ + PingPongContinuedValue, PingPongMessage, PingPongState, PingPongTopology, + PingPongTransition, + }, + vdaf, +}; use serde::{de::DeserializeOwned, Serialize}; use std::{fmt::Debug, sync::Once}; use tracing_log::LogTracer; @@ -11,126 +17,219 @@ pub mod kubernetes; pub mod runtime; pub mod testcontainers; -/// A transcript of a VDAF run. All fields are indexed by natural role index (i.e., index 0 = -/// leader, index 1 = helper). #[derive(Clone, Debug)] -pub struct VdafTranscript> { - /// The public share, from the sharding algorithm. - pub public_share: V::PublicShare, - /// The measurement's input shares, from the sharding algorithm. - pub input_shares: Vec, - /// Prepare transitions sent throughout the protocol run. The outer `Vec` is indexed by - /// aggregator, and the inner `Vec`s are indexed by VDAF round. - prepare_transitions: Vec>>, - /// The prepare messages broadcast to all aggregators prior to each continuation round of the - /// VDAF. - pub prepare_messages: Vec, - /// The output shares computed by each aggregator. - output_shares: Vec, - /// The aggregate shares from each aggregator. - pub aggregate_shares: Vec, +pub struct LeaderPrepareTransition< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator, +> { + pub transition: Option>, + pub state: PingPongState, + pub message: PingPongMessage, } -impl> VdafTranscript { - /// Get the leader's preparation state at the requested round. - pub fn leader_prep_state(&self, round: usize) -> &V::PrepareState { - assert_matches!( - &self.prepare_transitions[Role::Leader.index().unwrap()][round], - PrepareTransition::::Continue(prep_state, _) => prep_state - ) - } +#[derive(Clone, Debug)] +pub struct HelperPrepareTransition< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator, +> { + pub transition: PingPongTransition, + pub state: PingPongState, + pub message: PingPongMessage, +} - /// Get the helper's preparation state and prepare share at the requested round. - pub fn helper_prep_state(&self, round: usize) -> (&V::PrepareState, &V::PrepareShare) { - assert_matches!( - &self.prepare_transitions[Role::Helper.index().unwrap()][round], - PrepareTransition::::Continue(prep_state, prep_share) => (prep_state, prep_share) - ) +impl> + HelperPrepareTransition +{ + pub fn prepare_state(&self) -> &V::PrepareState { + assert_matches!(self.state, PingPongState::Continued(ref state) => state) } +} - /// Get the output share for the specified aggregator. - pub fn output_share(&self, role: Role) -> &V::OutputShare { - &self.output_shares[role.index().unwrap()] - } +/// A transcript of a VDAF run using the ping-pong VDAF topology. +#[derive(Clone, Debug)] +pub struct VdafTranscript< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator, +> { + /// The public share, from the sharding algorithm. + pub public_share: V::PublicShare, + /// The leader's input share, from the sharding algorithm. + pub leader_input_share: V::InputShare, + + /// The helper's input share, from the sharding algorithm. + pub helper_input_share: V::InputShare, + + /// The leader's states and messages computed throughout the protocol run. Indexed by the + /// aggregation job round. + #[allow(clippy::type_complexity)] + pub leader_prepare_transitions: Vec>, + + /// The helper's states and messages computed throughout the protocol run. Indexed by the + /// aggregation job round. + #[allow(clippy::type_complexity)] + pub helper_prepare_transitions: Vec>, + + /// The leader's computed output share. + pub leader_output_share: V::OutputShare, + + /// The helper's computed output share. + pub helper_output_share: V::OutputShare, + + /// The leader's aggregate share. + pub leader_aggregate_share: V::AggregateShare, + + /// The helper's aggregate share. + pub helper_aggregate_share: V::AggregateShare, } /// run_vdaf runs a VDAF state machine from sharding through to generating an output share, /// returning a "transcript" of all states & messages. -pub fn run_vdaf + vdaf::Client<16>>( +pub fn run_vdaf< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator + vdaf::Client<16>, +>( vdaf: &V, - verify_key: &[u8; SEED_SIZE], + verify_key: &[u8; VERIFY_KEY_LENGTH], aggregation_param: &V::AggregationParam, report_id: &ReportId, measurement: &V::Measurement, -) -> VdafTranscript { +) -> VdafTranscript { + let mut leader_prepare_transitions = Vec::new(); + let mut helper_prepare_transitions = Vec::new(); + // Shard inputs into input shares, and initialize the initial PrepareTransitions. let (public_share, input_shares) = vdaf.shard(measurement, report_id.as_ref()).unwrap(); - let mut prep_trans: Vec>> = input_shares - .iter() - .enumerate() - .map(|(agg_id, input_share)| { - let (prep_state, prep_share) = vdaf.prepare_init( - verify_key, - agg_id, - aggregation_param, - report_id.as_ref(), - &public_share, - input_share, - )?; - Ok(Vec::from([PrepareTransition::Continue( - prep_state, prep_share, - )])) - }) - .collect::>>, VdafError>>() + + let (leader_state, leader_message) = vdaf + .leader_initialized( + verify_key, + aggregation_param, + report_id.as_ref(), + &public_share, + &input_shares[0], + ) + .unwrap(); + + leader_prepare_transitions.push(LeaderPrepareTransition { + transition: None, + state: leader_state, + message: leader_message.clone(), + }); + + let helper_transition = vdaf + .helper_initialized( + verify_key, + aggregation_param, + report_id.as_ref(), + &public_share, + &input_shares[1], + &leader_message, + ) .unwrap(); - let mut prep_msgs = Vec::new(); + let (helper_state, helper_message) = helper_transition.evaluate(vdaf).unwrap(); + + helper_prepare_transitions.push(HelperPrepareTransition { + transition: helper_transition, + state: helper_state, + message: helper_message, + }); - // Repeatedly step the VDAF until we reach a terminal state. + // Repeatedly step the VDAF until we reach a terminal state + let mut leader_output_share = None; + let mut helper_output_share = None; loop { - // Gather messages from last round & combine them into next round's message; if any - // participants have reached a terminal state (Finish or Fail), we are done. - let mut prep_shares = Vec::new(); - let mut agg_shares = Vec::new(); - let mut output_shares = Vec::new(); - for pts in &prep_trans { - match pts.last().unwrap() { - PrepareTransition::::Continue(_, prep_share) => { - prep_shares.push(prep_share.clone()) - } - PrepareTransition::Finish(output_share) => { - output_shares.push(output_share.clone()); - agg_shares.push( - vdaf.aggregate(aggregation_param, [output_share.clone()].into_iter()) + for role in [Role::Leader, Role::Helper] { + let (curr_state, last_peer_message) = match role { + Role::Leader => ( + leader_prepare_transitions.last().unwrap().state.clone(), + helper_prepare_transitions.last().unwrap().message.clone(), + ), + Role::Helper => ( + helper_prepare_transitions.last().unwrap().state.clone(), + leader_prepare_transitions.last().unwrap().message.clone(), + ), + _ => panic!(), + }; + + match (&curr_state, &last_peer_message) { + (curr_state @ PingPongState::Continued(_), last_peer_message) => { + let state_and_message = match role { + Role::Leader => vdaf + .leader_continued( + curr_state.clone(), + aggregation_param, + last_peer_message, + ) .unwrap(), - ); + Role::Helper => vdaf + .helper_continued( + curr_state.clone(), + aggregation_param, + last_peer_message, + ) + .unwrap(), + _ => panic!(), + }; + + match state_and_message { + PingPongContinuedValue::WithMessage { transition } => { + let (state, message) = transition.clone().evaluate(vdaf).unwrap(); + match role { + Role::Leader => { + leader_prepare_transitions.push(LeaderPrepareTransition { + transition: Some(transition), + state, + message, + }) + } + Role::Helper => { + helper_prepare_transitions.push(HelperPrepareTransition { + transition, + state, + message, + }) + } + _ => panic!(), + } + } + PingPongContinuedValue::FinishedNoMessage { output_share } => match role { + Role::Leader => leader_output_share = Some(output_share.clone()), + Role::Helper => helper_output_share = Some(output_share.clone()), + _ => panic!(), + }, + } } + (PingPongState::Finished(output_share), _) => match role { + Role::Leader => leader_output_share = Some(output_share.clone()), + Role::Helper => helper_output_share = Some(output_share.clone()), + _ => panic!(), + }, } } - if !agg_shares.is_empty() { - return VdafTranscript { - public_share, - input_shares, - prepare_transitions: prep_trans, - prepare_messages: prep_msgs, - output_shares, - aggregate_shares: agg_shares, - }; - } - let prep_msg = vdaf - .prepare_shares_to_prepare_message(aggregation_param, prep_shares) - .unwrap(); - prep_msgs.push(prep_msg.clone()); - - // Compute each participant's next transition. - for pts in &mut prep_trans { - let prep_state = assert_matches!( - pts.last().unwrap(), - PrepareTransition::::Continue(prep_state, _) => prep_state - ) - .clone(); - pts.push(vdaf.prepare_next(prep_state, prep_msg.clone()).unwrap()); + + if leader_output_share.is_some() && helper_output_share.is_some() { + break; } } + + let leader_aggregate_share = vdaf + .aggregate(aggregation_param, [leader_output_share.clone().unwrap()]) + .unwrap(); + let helper_aggregate_share = vdaf + .aggregate(aggregation_param, [helper_output_share.clone().unwrap()]) + .unwrap(); + + VdafTranscript { + public_share, + leader_input_share: input_shares[0].clone(), + helper_input_share: input_shares[1].clone(), + leader_prepare_transitions, + helper_prepare_transitions, + leader_output_share: leader_output_share.unwrap(), + helper_output_share: helper_output_share.unwrap(), + leader_aggregate_share, + helper_aggregate_share, + } } /// Encodes the given value to YAML, then decodes it again, and checks that the diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index cbd5c018d..c2b31549a 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -206,12 +206,10 @@ CREATE TABLE report_aggregations( client_timestamp TIMESTAMP NOT NULL, -- the client timestamp this report aggregation is associated with ord BIGINT NOT NULL, -- a value used to specify the ordering of client reports in the aggregation job state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation - prep_state BYTEA, -- the current preparation state (opaque VDAF message, only if in state WAITING) - prep_msg BYTEA, -- for the leader, the next preparation message to be sent to the helper (opaque VDAF message) - -- for the helper, the next preparation share to be sent to the leader (opaque VDAF message) - -- only non-NULL if in state WAITING + helper_prep_state BYTEA, -- the current VDAF prepare state (opaque VDAF message, only if in state WAITING, only populated for helper) + leader_prep_transition BYTEA, -- the current VDAF prepare transition (opaque VDAF message, only if in state WAITING, only populated for leader) error_code SMALLINT, -- error code corresponding to a DAP ReportShareError value; null if in a state other than FAILED - last_prep_step BYTEA, -- the last PreparationStep message sent to the Leader, to assist in replay (opaque VDAF message, populated for Helper only) + last_prep_resp BYTEA, -- the last PrepareResp message sent to the Leader, to assist in replay (opaque VDAF message, populated for Helper only) CONSTRAINT report_aggregations_unique_ord UNIQUE(aggregation_job_id, ord), CONSTRAINT fk_aggregation_job_id FOREIGN KEY(aggregation_job_id) REFERENCES aggregation_jobs(id) ON DELETE CASCADE diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index 81c131b97..2d11463dc 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -14,10 +14,11 @@ in-cluster = ["dep:k8s-openapi", "dep:kube"] [dependencies] anyhow.workspace = true backoff = { version = "0.4", features = ["tokio"] } -base64 = "0.21.4" +base64.workspace = true futures = "0.3.28" hex = "0.4" http = "0.2" +itertools.workspace = true janus_aggregator = { workspace = true, features = ["test-util"] } janus_aggregator_core = { workspace = true, features = ["test-util"] } janus_client.workspace = true @@ -37,6 +38,5 @@ tokio.workspace = true url = { version = "2.4.1", features = ["serde"] } [dev-dependencies] -itertools.workspace = true janus_collector = { workspace = true, features = ["test-util"] } tempfile = "3" diff --git a/interop_binaries/Cargo.toml b/interop_binaries/Cargo.toml index c570baeb8..25bb66d11 100644 --- a/interop_binaries/Cargo.toml +++ b/interop_binaries/Cargo.toml @@ -24,7 +24,7 @@ testcontainer = [ [dependencies] anyhow.workspace = true backoff = { version = "0.4", features = ["tokio"] } -base64 = "0.21.4" +base64.workspace = true clap = "4.4.2" futures = { version = "0.3.28", optional = true } fixed = { version = "1.23", optional = true } diff --git a/messages/Cargo.toml b/messages/Cargo.toml index 2ed14114b..58cd7edae 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -14,13 +14,13 @@ test-util = [] [dependencies] anyhow.workspace = true -base64 = "0.21.4" +base64.workspace = true derivative = "2.2.0" hex = "0.4" num_enum = "0.7.0" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.15.0", default-features = false, features = ["multithreaded"] } +prio = { version = "0.15.1", default-features = false, features = ["multithreaded", "experimental"] } rand = "0.8" serde.workspace = true thiserror.workspace = true diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 27e5e831d..0e602c83b 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -8,9 +8,12 @@ use anyhow::anyhow; use base64::{display::Base64Display, engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use derivative::Derivative; use num_enum::TryFromPrimitive; -use prio::codec::{ - decode_u16_items, decode_u32_items, encode_u16_items, encode_u32_items, CodecError, Decode, - Encode, +use prio::{ + codec::{ + decode_u16_items, decode_u32_items, encode_u16_items, encode_u32_items, CodecError, Decode, + Encode, + }, + topology::ping_pong::PingPongMessage, }; use rand::{distributions::Standard, prelude::Distribution, Rng}; use serde::{ @@ -2111,14 +2114,65 @@ impl Decode for ReportShare { } } -/// DAP protocol message representing the result of a preparation step in a VDAF evaluation. +/// DAP protocol message representing information required to initialize preparation of a report for +/// aggregation. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct PrepareStep { +pub struct PrepareInit { + report_share: ReportShare, + message: PingPongMessage, +} + +impl PrepareInit { + /// Constructs a new preparation initialization message from its components. + pub fn new(report_share: ReportShare, message: PingPongMessage) -> Self { + Self { + report_share, + message, + } + } + + /// Gets the report share associated with this prep init. + pub fn report_share(&self) -> &ReportShare { + &self.report_share + } + + /// Gets the message associated with this prep init. + pub fn message(&self) -> &PingPongMessage { + &self.message + } +} + +impl Encode for PrepareInit { + fn encode(&self, bytes: &mut Vec) { + self.report_share.encode(bytes); + self.message.encode(bytes); + } + + fn encoded_len(&self) -> Option { + Some(self.report_share.encoded_len()? + self.message.encoded_len()?) + } +} + +impl Decode for PrepareInit { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let report_share = ReportShare::decode(bytes)?; + let message = PingPongMessage::decode(bytes)?; + + Ok(Self { + report_share, + message, + }) + } +} + +/// DAP protocol message representing the response to a preparation step in a VDAF evaluation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PrepareResp { report_id: ReportId, result: PrepareStepResult, } -impl PrepareStep { +impl PrepareResp { /// Constructs a new prepare step from its components. pub fn new(report_id: ReportId, result: PrepareStepResult) -> Self { Self { report_id, result } @@ -2135,7 +2189,7 @@ impl PrepareStep { } } -impl Encode for PrepareStep { +impl Encode for PrepareResp { fn encode(&self, bytes: &mut Vec) { self.report_id.encode(bytes); self.result.encode(bytes); @@ -2146,7 +2200,7 @@ impl Encode for PrepareStep { } } -impl Decode for PrepareStep { +impl Decode for PrepareResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let report_id = ReportId::decode(bytes)?; let result = PrepareStepResult::decode(bytes)?; @@ -2156,13 +2210,16 @@ impl Decode for PrepareStep { } /// DAP protocol message representing result-type-specific data associated with a preparation step -/// in a VDAF evaluation. Included in a PrepareStep message. +/// in a VDAF evaluation. Included in a PrepareResp message. #[derive(Clone, Derivative, PartialEq, Eq)] #[derivative(Debug)] pub enum PrepareStepResult { - Continued(#[derivative(Debug = "ignore")] Vec), // content is a serialized preparation message + Continue { + #[derivative(Debug = "ignore")] + message: PingPongMessage, + }, Finished, - Failed(ReportShareError), + Reject(PrepareError), } impl Encode for PrepareStepResult { @@ -2170,12 +2227,12 @@ impl Encode for PrepareStepResult { // The encoding includes an implicit discriminator byte, called PrepareStepResult in the // DAP spec. match self { - Self::Continued(vdaf_msg) => { + Self::Continue { message: prep_msg } => { 0u8.encode(bytes); - encode_u32_items(bytes, &(), vdaf_msg); + prep_msg.encode(bytes); } Self::Finished => 1u8.encode(bytes), - Self::Failed(error) => { + Self::Reject(error) => { 2u8.encode(bytes); error.encode(bytes); } @@ -2184,9 +2241,9 @@ impl Encode for PrepareStepResult { fn encoded_len(&self) -> Option { match self { - PrepareStepResult::Continued(vdaf_msg) => Some(1 + 4 + vdaf_msg.len()), - PrepareStepResult::Finished => Some(1), - PrepareStepResult::Failed(error) => Some(1 + error.encoded_len()?), + Self::Continue { message: prep_msg } => Some(1 + prep_msg.encoded_len()?), + Self::Finished => Some(1), + Self::Reject(error) => Some(1 + error.encoded_len()?), } } } @@ -2195,9 +2252,12 @@ impl Decode for PrepareStepResult { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let val = u8::decode(bytes)?; Ok(match val { - 0 => Self::Continued(decode_u32_items(&(), bytes)?), + 0 => { + let prep_msg = PingPongMessage::decode(bytes)?; + Self::Continue { message: prep_msg } + } 1 => Self::Finished, - 2 => Self::Failed(ReportShareError::decode(bytes)?), + 2 => Self::Reject(PrepareError::decode(bytes)?), _ => return Err(CodecError::UnexpectedValue), }) } @@ -2206,7 +2266,7 @@ impl Decode for PrepareStepResult { /// DAP protocol message representing an error while preparing a report share for aggregation. #[derive(Clone, Copy, Debug, PartialEq, Eq, TryFromPrimitive)] #[repr(u8)] -pub enum ReportShareError { +pub enum PrepareError { BatchCollected = 0, ReportReplayed = 1, ReportDropped = 2, @@ -2218,7 +2278,7 @@ pub enum ReportShareError { UnrecognizedMessage = 8, } -impl Encode for ReportShareError { +impl Encode for PrepareError { fn encode(&self, bytes: &mut Vec) { (*self as u8).encode(bytes); } @@ -2228,7 +2288,7 @@ impl Encode for ReportShareError { } } -impl Decode for ReportShareError { +impl Decode for PrepareError { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let val = u8::decode(bytes)?; Self::try_from(val).map_err(|_| { @@ -2237,6 +2297,51 @@ impl Decode for ReportShareError { } } +/// DAP protocol message representing a request to continue preparation of a report share for +/// aggregation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PrepareContinue { + report_id: ReportId, + message: PingPongMessage, +} + +impl PrepareContinue { + /// Constructs a new prepare continue from its components. + pub fn new(report_id: ReportId, message: PingPongMessage) -> Self { + Self { report_id, message } + } + + /// Gets the report ID associated with this prepare continue. + pub fn report_id(&self) -> &ReportId { + &self.report_id + } + + /// Gets the message associated with this prepare continue. + pub fn message(&self) -> &PingPongMessage { + &self.message + } +} + +impl Encode for PrepareContinue { + fn encode(&self, bytes: &mut Vec) { + self.report_id.encode(bytes); + self.message.encode(bytes); + } + + fn encoded_len(&self) -> Option { + Some(self.report_id.encoded_len()? + self.message.encoded_len()?) + } +} + +impl Decode for PrepareContinue { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let report_id = ReportId::decode(bytes)?; + let message = PingPongMessage::decode(bytes)?; + + Ok(Self { report_id, message }) + } +} + /// DAP protocol message representing an identifier for an aggregation job. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct AggregationJobId([u8; Self::LEN]); @@ -2309,7 +2414,7 @@ pub struct AggregationJobInitializeReq { #[derivative(Debug = "ignore")] aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + prepare_inits: Vec, } impl AggregationJobInitializeReq { @@ -2320,12 +2425,12 @@ impl AggregationJobInitializeReq { pub fn new( aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + prepare_inits: Vec, ) -> Self { Self { aggregation_parameter, partial_batch_selector, - report_shares, + prepare_inits, } } @@ -2339,9 +2444,10 @@ impl AggregationJobInitializeReq { &self.partial_batch_selector } - /// Gets the report shares associated with this aggregate initialization request. - pub fn report_shares(&self) -> &[ReportShare] { - &self.report_shares + /// Gets the preparation initialization messages associated with this aggregate initialization + /// request. + pub fn prepare_inits(&self) -> &[PrepareInit] { + &self.prepare_inits } } @@ -2349,15 +2455,15 @@ impl Encode for AggregationJobInitializeReq { fn encode(&self, bytes: &mut Vec) { encode_u32_items(bytes, &(), &self.aggregation_parameter); self.partial_batch_selector.encode(bytes); - encode_u32_items(bytes, &(), &self.report_shares); + encode_u32_items(bytes, &(), &self.prepare_inits); } fn encoded_len(&self) -> Option { let mut length = 4 + self.aggregation_parameter.len(); length += self.partial_batch_selector.encoded_len()?; length += 4; - for report_share in self.report_shares.iter() { - length += report_share.encoded_len()?; + for prepare_init in &self.prepare_inits { + length += prepare_init.encoded_len()?; } Some(length) } @@ -2367,12 +2473,12 @@ impl Decode for AggregationJobInitializeReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let aggregation_parameter = decode_u32_items(&(), bytes)?; let partial_batch_selector = PartialBatchSelector::decode(bytes)?; - let report_shares = decode_u32_items(&(), bytes)?; + let prepare_inits = decode_u32_items(&(), bytes)?; Ok(Self { aggregation_parameter, partial_batch_selector, - report_shares, + prepare_inits, }) } } @@ -2438,7 +2544,7 @@ impl TryFrom for AggregationJobRound { #[derive(Clone, Debug, PartialEq, Eq)] pub struct AggregationJobContinueReq { round: AggregationJobRound, - prepare_steps: Vec, + prepare_continues: Vec, } impl AggregationJobContinueReq { @@ -2446,10 +2552,10 @@ impl AggregationJobContinueReq { pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-continue-req"; /// Constructs a new aggregate continuation response from its components. - pub fn new(round: AggregationJobRound, prepare_steps: Vec) -> Self { + pub fn new(round: AggregationJobRound, prepare_continues: Vec) -> Self { Self { round, - prepare_steps, + prepare_continues, } } @@ -2459,22 +2565,22 @@ impl AggregationJobContinueReq { } /// Gets the prepare steps associated with this aggregate continuation response. - pub fn prepare_steps(&self) -> &[PrepareStep] { - &self.prepare_steps + pub fn prepare_steps(&self) -> &[PrepareContinue] { + &self.prepare_continues } } impl Encode for AggregationJobContinueReq { fn encode(&self, bytes: &mut Vec) { self.round.encode(bytes); - encode_u32_items(bytes, &(), &self.prepare_steps); + encode_u32_items(bytes, &(), &self.prepare_continues); } fn encoded_len(&self) -> Option { let mut length = self.round.encoded_len()?; length += 4; - for prepare_step in self.prepare_steps.iter() { - length += prepare_step.encoded_len()?; + for prepare_continue in self.prepare_continues.iter() { + length += prepare_continue.encoded_len()?; } Some(length) } @@ -2483,8 +2589,8 @@ impl Encode for AggregationJobContinueReq { impl Decode for AggregationJobContinueReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let round = AggregationJobRound::decode(bytes)?; - let prepare_steps = decode_u32_items(&(), bytes)?; - Ok(Self::new(round, prepare_steps)) + let prepare_continues = decode_u32_items(&(), bytes)?; + Ok(Self::new(round, prepare_continues)) } } @@ -2492,7 +2598,7 @@ impl Decode for AggregationJobContinueReq { /// continuation request. #[derive(Clone, Debug, PartialEq, Eq)] pub struct AggregationJobResp { - prepare_steps: Vec, + prepare_resps: Vec, } impl AggregationJobResp { @@ -2500,25 +2606,25 @@ impl AggregationJobResp { pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-resp"; /// Constructs a new aggregate continuation response from its components. - pub fn new(prepare_steps: Vec) -> Self { - Self { prepare_steps } + pub fn new(prepare_resps: Vec) -> Self { + Self { prepare_resps } } - /// Gets the prepare steps associated with this aggregate continuation response. - pub fn prepare_steps(&self) -> &[PrepareStep] { - &self.prepare_steps + /// Gets the prepare responses associated with this aggregate continuation response. + pub fn prepare_resps(&self) -> &[PrepareResp] { + &self.prepare_resps } } impl Encode for AggregationJobResp { fn encode(&self, bytes: &mut Vec) { - encode_u32_items(bytes, &(), &self.prepare_steps); + encode_u32_items(bytes, &(), &self.prepare_resps); } fn encoded_len(&self) -> Option { let mut length = 4; - for prepare_step in self.prepare_steps.iter() { - length += prepare_step.encoded_len()?; + for prepare_resp in self.prepare_resps.iter() { + length += prepare_resp.encoded_len()?; } Some(length) } @@ -2526,8 +2632,8 @@ impl Encode for AggregationJobResp { impl Decode for AggregationJobResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { - let prepare_steps = decode_u32_items(&(), bytes)?; - Ok(Self { prepare_steps }) + let prepare_resps = decode_u32_items(&(), bytes)?; + Ok(Self { prepare_resps }) } } @@ -2766,12 +2872,15 @@ mod tests { AggregationJobRound, BatchId, BatchSelector, Collection, CollectionReq, Duration, Extension, ExtensionType, FixedSize, FixedSizeQuery, HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, InputShareAad, Interval, - PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Query, Report, - ReportId, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, - Time, TimeInterval, Url, + PartialBatchSelector, PlaintextInputShare, PrepareContinue, PrepareError, PrepareInit, + PrepareResp, PrepareStepResult, Query, Report, ReportId, ReportIdChecksum, ReportMetadata, + ReportShare, Role, TaskId, Time, TimeInterval, Url, }; use assert_matches::assert_matches; - use prio::codec::{CodecError, Decode, Encode}; + use prio::{ + codec::{CodecError, Decode, Encode}, + topology::ping_pong::PingPongMessage, + }; use serde_test::{assert_de_tokens_error, assert_tokens, Token}; #[test] @@ -3877,72 +3986,97 @@ mod tests { } #[test] - fn roundtrip_prepare_step() { + fn roundtrip_report_share() { roundtrip_encoding(&[ ( - PrepareStep { - report_id: ReportId::from([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), }, concat!( - "0102030405060708090A0B0C0D0E0F10", // report_id - "00", // prepare_step_result concat!( - // vdaf_msg - "00000006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), ), ), ( - PrepareStep { - report_id: ReportId::from([ - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, - ]), - result: PrepareStepResult::Finished, - }, - concat!( - "100F0E0D0C0B0A090807060504030201", // report_id - "01", // prepare_step_result - ), - ), - ( - PrepareStep { - report_id: ReportId::from([255; 16]), - result: PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), }, concat!( - "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", // report_id - "02", // prepare_step_result - "05", // report_share_error + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), ), ), ]) } #[test] - fn roundtrip_report_share_error() { + fn roundtrip_prepare_init() { roundtrip_encoding(&[ - (ReportShareError::BatchCollected, "00"), - (ReportShareError::ReportReplayed, "01"), - (ReportShareError::ReportDropped, "02"), - (ReportShareError::HpkeUnknownConfigId, "03"), - (ReportShareError::HpkeDecryptError, "04"), - (ReportShareError::VdafPrepError, "05"), - ]) - } - - #[test] - fn roundtrip_aggregation_job_initialize_req() { - // TimeInterval. - roundtrip_encoding(&[( - AggregationJobInitializeReq { - aggregation_parameter: Vec::from("012345"), - partial_batch_selector: PartialBatchSelector::new_time_interval(), - report_shares: Vec::from([ - ReportShare { + ( + PrepareInit { + report_share: ReportShare { metadata: ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), @@ -3954,34 +4088,13 @@ mod tests { Vec::from("543210"), ), }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + message: PingPongMessage::Initialize { + prep_share: Vec::from("012345"), }, - ]), - }, - concat!( - concat!( - // aggregation_parameter - "00000006", // length - "303132333435", // opaque data - ), - concat!( - // partial_batch_selector - "01", // query_type - ), + }, concat!( - // report_shares - "0000005E", // length concat!( + // report_share concat!( // metadata "0102030405060708090A0B0C0D0E0F10", // report_id @@ -4008,13 +4121,44 @@ mod tests { ), ), concat!( + // message + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) + ) + ), + ), + ( + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + message: PingPongMessage::Finish { + prep_msg: Vec::new(), + }, + }, + concat!( + concat!( + // report_share concat!( // metadata "100F0E0D0C0B0A090807060504030201", // report_id "0000000000011F46", // time ), concat!( - "00000004", // payload + // public_share + "00000004", // length "30313233", // opaque data ), concat!( @@ -4032,6 +4176,223 @@ mod tests { ), ), ), + concat!( + // message + "02", // Message type + concat!( + "00000000", // length + "" // opaque data + ) + ) + ), + ), + ]) + } + + #[test] + fn roundtrip_prepare_resp() { + roundtrip_encoding(&[ + ( + PrepareResp { + report_id: ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + result: PrepareStepResult::Continue { + message: PingPongMessage::Continue { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("6789"), + }, + }, + }, + concat!( + "0102030405060708090A0B0C0D0E0F10", // report_id + "00", // prepare_step_result + concat!( + // message + "01", // message type + concat!( + "00000006", // length + "303132333435", // opaque data + ), + concat!( + "00000004", // length + "36373839", // opaque data + ) + ), + ), + ), + ( + PrepareResp { + report_id: ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + result: PrepareStepResult::Finished, + }, + concat!( + "100F0E0D0C0B0A090807060504030201", // report_id + "01", // prepare_step_result + ), + ), + ( + PrepareResp { + report_id: ReportId::from([255; 16]), + result: PrepareStepResult::Reject(PrepareError::VdafPrepError), + }, + concat!( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", // report_id + "02", // prepare_step_result + "05", // report_share_error + ), + ), + ]) + } + + #[test] + fn roundtrip_report_share_error() { + roundtrip_encoding(&[ + (PrepareError::BatchCollected, "00"), + (PrepareError::ReportReplayed, "01"), + (PrepareError::ReportDropped, "02"), + (PrepareError::HpkeUnknownConfigId, "03"), + (PrepareError::HpkeDecryptError, "04"), + (PrepareError::VdafPrepError, "05"), + ]) + } + + #[test] + fn roundtrip_aggregation_job_initialize_req() { + // TimeInterval. + roundtrip_encoding(&[( + AggregationJobInitializeReq { + aggregation_parameter: Vec::from("012345"), + partial_batch_selector: PartialBatchSelector::new_time_interval(), + prepare_inits: Vec::from([ + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + message: PingPongMessage::Initialize { + prep_share: Vec::from("012345"), + }, + }, + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + message: PingPongMessage::Finish { + prep_msg: Vec::new(), + }, + }, + ]), + }, + concat!( + concat!( + // aggregation_parameter + "00000006", // length + "303132333435", // opaque data + ), + concat!( + // partial_batch_selector + "01", // query_type + ), + concat!( + // prepare_inits + "0000006E", // length + concat!( + concat!( + // report_share + concat!( + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + ), + concat!( + // message + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ), + ) + ), + concat!( + concat!( + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + concat!( + // message + "02", // Message type + concat!( + "00000000", // length + "" // opaque data + ) + ) + ), ), ), )]); @@ -4043,30 +4404,44 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_fixed_size(BatchId::from( [2u8; 32], )), - report_shares: Vec::from([ - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - public_share: Vec::new(), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + prepare_inits: Vec::from([ + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + message: PingPongMessage::Initialize { + prep_share: Vec::from("012345"), + }, }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + message: PingPongMessage::Finish { + prep_msg: Vec::new(), + }, }, ]), }, @@ -4082,58 +4457,79 @@ mod tests { "0202020202020202020202020202020202020202020202020202020202020202", // batch_id ), concat!( - // report_shares - "0000005E", // length + // prepare_inits + "0000006E", // length concat!( concat!( - // metadata - "0102030405060708090A0B0C0D0E0F10", // report_id - "000000000000D431", // time - ), - concat!( - // public_share - "00000000", // length - "", // opaque data - ), - concat!( - // encrypted_input_share - "2A", // config_id + // report_share concat!( - // encapsulated_context - "0006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time ), concat!( - // payload - "00000006", // length - "353433323130", // opaque data + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), - ), - ), - concat!( - concat!( - // metadata - "100F0E0D0C0B0A090807060504030201", // report_id - "0000000000011F46", // time ), concat!( - // public_share - "00000004", // length - "30313233", // opaque data + // payload + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) ), + ), + concat!( concat!( - // encrypted_input_share - "0D", // config_id concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time ), concat!( - // payload + // public_share "00000004", // length - "61626664", // opaque data + "30313233", // opaque data ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + concat!( + // payload + "02", // Message type + concat!( + "00000000", // length + "", // opaque data + ) ), ), ), @@ -4146,18 +4542,22 @@ mod tests { roundtrip_encoding(&[( AggregationJobContinueReq { round: AggregationJobRound(42405), - prepare_steps: Vec::from([ - PrepareStep { + prepare_continues: Vec::from([ + PrepareContinue { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + message: PingPongMessage::Initialize { + prep_share: Vec::from("012345"), + }, }, - PrepareStep { + PrepareContinue { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + message: PingPongMessage::Initialize { + prep_share: Vec::from("012345"), + }, }, ]), }, @@ -4165,19 +4565,28 @@ mod tests { "A5A5", // round concat!( // prepare_steps - "0000002C", // length + "00000036", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id - "00", // prepare_step_result concat!( // payload - "00000006", // length - "303132333435", // opaque data + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) ), ), concat!( "100F0E0D0C0B0A090807060504030201", // report_id - "01", // prepare_step_result + concat!( + // payload + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) + ), ) ), ), @@ -4188,14 +4597,19 @@ mod tests { fn roundtrip_aggregation_job_resp() { roundtrip_encoding(&[( AggregationJobResp { - prepare_steps: Vec::from([ - PrepareStep { + prepare_resps: Vec::from([ + PrepareResp { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continue { + message: PingPongMessage::Continue { + prep_msg: Vec::from("01234"), + prep_share: Vec::from("56789"), + }, + }, }, - PrepareStep { + PrepareResp { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), @@ -4205,14 +4619,22 @@ mod tests { }, concat!(concat!( // prepare_steps - "0000002C", // length + "00000035", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // payload - "00000006", // length - "303132333435", // opaque data + "01", // Message type + concat!( + // prep_msg + "00000005", // length + "3031323334", // opaque data + ), + concat!( + // prep_share + "00000005", // length + "3536373839", // opaque data + ) ), ), concat!( diff --git a/tools/Cargo.toml b/tools/Cargo.toml index 6a8650bea..d2f957f6b 100644 --- a/tools/Cargo.toml +++ b/tools/Cargo.toml @@ -13,7 +13,7 @@ fpvec_bounded_l2 = ["dep:fixed", "janus_collector/fpvec_bounded_l2", "prio/exper [dependencies] anyhow = "1" -base64 = "0.21.4" +base64.workspace = true clap = { version = "4.4.2", features = ["cargo", "derive", "env"] } derivative = "2.2.0" fixed = { version = "1.23", optional = true }