diff --git a/Cargo.lock b/Cargo.lock index 243f30295..0712f692e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2485,6 +2485,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", + "rand", + "serde", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -2523,6 +2536,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.16" @@ -2981,8 +3006,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prio" version = "0.14.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a65c4a557b2fecb8518c105aafadf33a86d7513a3f599bcfe542c17553cc61" +source = "git+https://github.com/divviup/libprio-rs?branch=timg/ping-pong-topology#bc06133044998457734afc8e96dad0c68a205be3" dependencies = [ "aes", "base64 0.21.2", @@ -2993,6 +3017,12 @@ dependencies = [ "fiat-crypto", "fixed", "getrandom", + "num-bigint", + "num-integer", + "num-iter", + "num-rational", + "num-traits", + "rand", "rand_core 0.6.4", "rayon", "serde", diff --git a/Cargo.toml b/Cargo.toml index 4f629b880..28e785c05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,9 @@ janus_messages = { version = "0.5", path = "messages" } k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] } opentelemetry = { version = "0.20", features = ["metrics"] } -prio = { version = "0.14.1", features = ["multithreaded"] } +# TODO(timg): go back to a released version of prio +#prio = { version = "0.14.1", features = ["multithreaded"] } +prio = { git = "https://github.com/divviup/libprio-rs", branch = "timg/ping-pong-topology", features = ["multithreaded", "experimental"] } serde = { version = "1.0.185", features = ["derive"] } serde_json = "1.0.105" serde_test = "1.0.175" diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 8517f8d1d..e86a0a053 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -52,7 +52,7 @@ use janus_messages::{ AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, - PrepareStep, PrepareStepResult, Report, ReportIdChecksum, ReportShare, ReportShareError, Role, + PrepareError, PrepareResp, PrepareStepResult, Report, ReportIdChecksum, ReportShare, Role, TaskId, }; use opentelemetry::{ @@ -65,6 +65,7 @@ use prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded; use prio::vdaf::{PrepareTransition, VdafError}; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, + topology::ping_pong::{self, PingPongTopology}, vdaf::{ self, poplar1::Poplar1, @@ -87,6 +88,8 @@ use tokio::{sync::Mutex, try_join}; use tracing::{debug, info, trace_span, warn}; use url::Url; +use self::{accumulator::Accumulator, error::handle_ping_pong_error}; + pub mod accumulator; #[cfg(test)] mod aggregate_init_tests; @@ -131,6 +134,9 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "decrypt_failure", "input_share_decode_failure", "public_share_decode_failure", + "prepare_message_decode_failure", + "leader_prep_share_decode_failure", + "helper_prep_share_decode_failure", "continue_mismatch", "accumulate_failure", "finish_mismatch", @@ -138,6 +144,7 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "plaintext_input_share_decode_failure", "duplicate_extension", "missing_client_report", + "missing_prepare_message", ] { aggregate_step_failure_counter.add(0, &[KeyValue::new("type", failure_type)]); } @@ -387,6 +394,7 @@ impl Aggregator { &self.datastore, &self.global_hpke_keypairs, &self.aggregate_step_failure_counter, + self.cfg.batch_aggregation_shard_count, aggregation_job_id, req_bytes, ) @@ -863,7 +871,7 @@ impl TaskAggregator { #[cfg(feature = "test-util")] VdafInstance::FakeFailsPrepStep => { VdafOps::Fake(Arc::new(dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result, VdafError> { + |_| -> Result, VdafError> { Err(VdafError::Uncategorized( "FakeFailsPrepStep failed at prep_step".to_string(), )) @@ -921,6 +929,7 @@ impl TaskAggregator { datastore: &Datastore, global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, + batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, req_bytes: &[u8], ) -> Result { @@ -930,6 +939,7 @@ impl TaskAggregator { global_hpke_keypairs, aggregate_step_failure_counter, Arc::clone(&self.task), + batch_aggregation_shard_count, aggregation_job_id, req_bytes, ) @@ -1213,6 +1223,7 @@ impl VdafOps { global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, task: Arc, + batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, req_bytes: &[u8], ) -> Result { @@ -1225,6 +1236,7 @@ impl VdafOps { vdaf, aggregate_step_failure_counter, task, + batch_aggregation_shard_count, aggregation_job_id, verify_key, req_bytes, @@ -1240,6 +1252,7 @@ impl VdafOps { vdaf, aggregate_step_failure_counter, task, + batch_aggregation_shard_count, aggregation_job_id, verify_key, req_bytes, @@ -1531,6 +1544,7 @@ impl VdafOps { }); if !existing_aggregation_job.eq(incoming_aggregation_job) { + tracing::info!("jobs don't match"); return Ok(false); } @@ -1542,10 +1556,12 @@ impl VdafOps { &Role::Helper, task.id(), incoming_aggregation_job.id(), + incoming_aggregation_job.aggregation_parameter(), ) .await?; if existing_report_aggregations.len() != incoming_report_share_data.len() { + tracing::info!("wrong count of report aggregations"); return Ok(false); } @@ -1580,6 +1596,7 @@ impl VdafOps { vdaf: &A, aggregate_step_failure_counter: &Counter, task: Arc, + batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, verify_key: &VerifyKey, req_bytes: &[u8], @@ -1600,9 +1617,9 @@ impl VdafOps { // If two ReportShare messages have the same report ID, then the helper MUST abort with // error "unrecognizedMessage". (§4.4.4.1) - let mut seen_report_ids = HashSet::with_capacity(req.report_shares().len()); - for share in req.report_shares() { - if !seen_report_ids.insert(share.metadata().id()) { + let mut seen_report_ids = HashSet::with_capacity(req.prepare_inits().len()); + for prepare_init in req.prepare_inits() { + if !seen_report_ids.insert(*prepare_init.report_share().metadata().id()) { return Err(Error::UnrecognizedMessage( Some(*task.id()), "aggregate request contains duplicate report IDs", @@ -1616,53 +1633,71 @@ impl VdafOps { let mut interval_per_batch_identifier: HashMap = HashMap::new(); let agg_param = A::AggregationParam::get_decoded(req.aggregation_parameter())?; - for (ord, report_share) in req.report_shares().iter().enumerate() { + + let mut accumulator = Accumulator::::new( + Arc::clone(&task), + batch_aggregation_shard_count, + agg_param.clone(), + ); + + for (ord, prepare_init) in req.prepare_inits().iter().enumerate() { // Compute intervals for each batch identifier included in this aggregation job. let batch_identifier = Q::to_batch_identifier( &task, req.batch_selector().batch_identifier(), - report_share.metadata().time(), + prepare_init.report_share().metadata().time(), )?; match interval_per_batch_identifier.entry(batch_identifier) { Entry::Occupied(mut entry) => { - *entry.get_mut() = entry.get().merged_with(report_share.metadata().time())?; + *entry.get_mut() = entry + .get() + .merged_with(prepare_init.report_share().metadata().time())?; } Entry::Vacant(entry) => { - entry.insert(Interval::from_time(report_share.metadata().time())?); + entry.insert(Interval::from_time( + prepare_init.report_share().metadata().time(), + )?); } } - let task_hpke_keypair = task - .hpke_keys() - .get(report_share.encrypted_input_share().config_id()); - - let global_hpke_keypair = - global_hpke_keypairs.keypair(report_share.encrypted_input_share().config_id()); - // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) let try_hpke_open = |hpke_keypair: &HpkeKeypair| { hpke::open( hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - report_share.encrypted_input_share(), + prepare_init.report_share().encrypted_input_share(), &InputShareAad::new( *task.id(), - report_share.metadata().clone(), - report_share.public_share().to_vec(), + prepare_init.report_share().metadata().clone(), + prepare_init.report_share().public_share().to_vec(), ) .get_encoded(), ) }; + let global_hpke_keypair = global_hpke_keypairs.keypair( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), + ); + + let task_hpke_keypair = task.hpke_keys().get( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), + ); + let check_keypairs = if task_hpke_keypair.is_none() && global_hpke_keypair.is_none() { info!( - config_id = %report_share.encrypted_input_share().config_id(), + config_id = %prepare_init.report_share().encrypted_input_share().config_id(), "Helper encrypted input share references unknown HPKE config ID" ); aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "unknown_hpke_config_id")]); - Err(ReportShareError::HpkeUnknownConfigId) + Err(PrepareError::HpkeUnknownConfigId) } else { Ok(()) }; @@ -1684,21 +1719,21 @@ impl VdafOps { .map_err(|error| { info!( task_id = %task.id(), - metadata = ?report_share.metadata(), + metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decrypt helper's report share" ); aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "decrypt_failure")]); - ReportShareError::HpkeDecryptError + PrepareError::HpkeDecryptError }) }); let plaintext_input_share = plaintext.and_then(|plaintext| { let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext).map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's plaintext input share"); + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decode helper's plaintext input share"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "plaintext_input_share_decode_failure")]); - ReportShareError::UnrecognizedMessage + PrepareError::UnrecognizedMessage })?; // Check for repeated extensions. let mut extension_types = HashSet::new(); @@ -1706,9 +1741,9 @@ impl VdafOps { .extensions() .iter() .all(|extension| extension_types.insert(extension.extension_type())) { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), "Received report share with duplicate extensions"); + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), "Received report share with duplicate extensions"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "duplicate_extension")]); - return Err(ReportShareError::UnrecognizedMessage) + return Err(PrepareError::UnrecognizedMessage) } Ok(plaintext_input_share) }); @@ -1716,16 +1751,16 @@ impl VdafOps { let input_share = plaintext_input_share.and_then(|plaintext_input_share| { A::InputShare::get_decoded_with_param(&(vdaf, Role::Helper.index().unwrap()), plaintext_input_share.payload()) .map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's input share"); + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decode helper's input share"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "input_share_decode_failure")]); - ReportShareError::UnrecognizedMessage + PrepareError::UnrecognizedMessage }) }); - let public_share = A::PublicShare::get_decoded_with_param(vdaf, report_share.public_share()).map_err(|error|{ - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode public share"); + let public_share = A::PublicShare::get_decoded_with_param(vdaf, prepare_init.report_share().public_share()).map_err(|error|{ + info!(task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), ?error, "Couldn't decode public share"); aggregate_step_failure_counter.add(1, &[KeyValue::new("type", "public_share_decode_failure")]); - ReportShareError::UnrecognizedMessage + PrepareError::UnrecognizedMessage }); let shares = input_share.and_then(|input_share| Ok((public_share?, input_share))); @@ -1734,88 +1769,104 @@ impl VdafOps { // associated with the task and computes the first state transition. [...] If either // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) let init_rslt = shares.and_then(|(public_share, input_share)| { - trace_span!("VDAF preparation") - .in_scope(|| { - vdaf.prepare_init( - verify_key.as_bytes(), - Role::Helper.index().unwrap(), - &agg_param, - report_share.metadata().id().as_ref(), - &public_share, - &input_share, - ) + trace_span!("VDAF preparation").in_scope(|| { + vdaf.helper_initialize( + verify_key.as_bytes(), + &agg_param, + /* report ID is used as VDAF nonce */ + prepare_init.report_share().metadata().id().as_ref(), + &public_share, + &input_share, + prepare_init.message(), + ) + .and_then(|transition| { + transition + .evaluate(vdaf) + .map(|(ping_pong_state, outgoing_message)| { + (transition, ping_pong_state, outgoing_message) + }) }) .map_err(|error| { - info!( - task_id = %task.id(), - report_id = %report_share.metadata().id(), - ?error, - "Couldn't prepare_init report share" - ); - aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_init_failure")]); - ReportShareError::VdafPrepError + handle_ping_pong_error( + task.id(), + Role::Leader, + prepare_init.report_share().metadata().id(), + error, + aggregate_step_failure_counter, + ) }) + }) }); - report_share_data.push(match init_rslt { - Ok((prep_state, prep_share)) => { + let (report_aggregation_state, prepare_step_result) = match init_rslt { + Ok((transition, ping_pong::State::Continued(_), outgoing_message)) => { + // Helper is not finished. Await the next message from the Leader to advance to + // the next round. saw_continue = true; - - let encoded_prep_share = prep_share.get_encoded(); - ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - Some(PrepareStep::new( - *report_share.metadata().id(), - PrepareStepResult::Continued(encoded_prep_share), - )), - ReportAggregationState::::Waiting(prep_state, None), - ), + ( + ReportAggregationState::Waiting(transition), + PrepareStepResult::Continue { + message: outgoing_message, + }, ) } + Ok((_, ping_pong::State::Finished(output_share), outgoing_message)) => { + // Helper finished. Unlike the Leader, the Helper does not wait for confirmation + // that the Leader finished before accumulating its output share. + accumulator.update( + req.batch_selector().batch_identifier(), + prepare_init.report_share().metadata().id(), + prepare_init.report_share().metadata().time(), + &output_share, + )?; + ( + ReportAggregationState::Finished, + PrepareStepResult::Continue { + message: outgoing_message, + }, + ) + } + Err(prepare_error) => ( + ReportAggregationState::Failed(prepare_error), + PrepareStepResult::Reject(prepare_error), + ), + }; - Err(err) => ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - Some(PrepareStep::new( - *report_share.metadata().id(), - PrepareStepResult::Failed(err), - )), - ReportAggregationState::::Failed(err), - ), + report_share_data.push(ReportShareData::new( + prepare_init.report_share().clone(), + ReportAggregation::::new( + *task.id(), + *aggregation_job_id, + *prepare_init.report_share().metadata().id(), + *prepare_init.report_share().metadata().time(), + ord.try_into()?, + Some(PrepareResp::new( + *prepare_init.report_share().metadata().id(), + prepare_step_result, + )), + report_aggregation_state, ), - }); + )); } // Store data to datastore. let req = Arc::new(req); let min_client_timestamp = req - .report_shares() + .prepare_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|prepare_init| *prepare_init.report_share().metadata().time()) .min() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let max_client_timestamp = req - .report_shares() + .prepare_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|prepare_init| *prepare_init.report_share().metadata().time()) .max() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let client_timestamp_interval = Interval::new( - *min_client_timestamp, + min_client_timestamp, max_client_timestamp - .difference(min_client_timestamp)? + .difference(&min_client_timestamp)? .add(&Duration::from_seconds(1))?, )?; let aggregation_job = Arc::new(AggregationJob::::new( @@ -1841,8 +1892,11 @@ impl VdafOps { let aggregation_job = Arc::clone(&aggregation_job); let mut report_share_data = report_share_data.clone(); let interval_per_batch_identifier = Arc::clone(&interval_per_batch_identifier); + let accumulator = accumulator.clone(); Box::pin(async move { + let unwritable_reports = accumulator.flush_to_datastore(tx, &vdaf).await?; + for report_share_data in &mut report_share_data { // Verify that we haven't seen this report ID and aggregation parameter // before in another aggregation job, and that the report isn't for a batch @@ -1867,34 +1921,35 @@ impl VdafOps { )?; if report_aggregation_exists { - report_share_data.report_aggregation = report_share_data - .report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportReplayed, - )) - .with_last_prep_step(Some(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed), - ))); - } else if !conflicting_aggregate_share_jobs.is_empty() { - report_share_data.report_aggregation = report_share_data - .report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::BatchCollected, - )) - .with_last_prep_step(Some(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), - ))); + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + PrepareError::ReportReplayed)) + .with_last_prep_step(Some(PrepareResp::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Reject(PrepareError::ReportReplayed)) + )); + } else if !conflicting_aggregate_share_jobs.is_empty() || + unwritable_reports.contains(report_share_data.report_aggregation.report_id()) { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + PrepareError::BatchCollected)) + .with_last_prep_step(Some(PrepareResp::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Reject(PrepareError::BatchCollected)) + )); } } // Write aggregation job. + tracing::info!("putting aggregation job"); let replayed_request = match tx.put_aggregation_job(&aggregation_job).await { Ok(_) => false, Err(datastore::Error::MutationTargetAlreadyExists) => { + tracing::info!("detected mutation!"); // Slow path: this request is writing an aggregation job that already // exists in the datastore. PUT to an aggregation job is idempotent, so // that's OK, provided the current request is equivalent to what's in @@ -1938,12 +1993,12 @@ impl VdafOps { .report_aggregation .clone() .with_state(ReportAggregationState::Failed( - ReportShareError::ReportReplayed, + PrepareError::ReportReplayed, )) - .with_last_prep_step(Some(PrepareStep::new( + .with_last_prep_step(Some(PrepareResp::new( *rsd.report_share.metadata().id(), - PrepareStepResult::Failed( - ReportShareError::ReportReplayed, + PrepareStepResult::Reject( + PrepareError::ReportReplayed, ), ))); } @@ -1958,7 +2013,11 @@ impl VdafOps { let task = Arc::clone(&task); let aggregation_job = Arc::clone(&aggregation_job); async move { - match tx.get_batch::(task.id(), batch_identifier, aggregation_job.aggregation_parameter()).await? { + match tx.get_batch::( + task.id(), + batch_identifier, + aggregation_job.aggregation_parameter(), + ).await? { Some(batch) => { let interval = batch.client_timestamp_interval().merge(interval)?; tx.update_batch(&batch.with_client_timestamp_interval(interval)).await?; @@ -2038,22 +2097,25 @@ impl VdafOps { Box::pin(async move { // Read existing state. - let (helper_aggregation_job, report_aggregations) = try_join!( - tx.get_aggregation_job::(task.id(), &aggregation_job_id), - tx.get_report_aggregations_for_aggregation_job( + let helper_aggregation_job = tx + .get_aggregation_job::(task.id(), &aggregation_job_id) + .await? + .ok_or_else(|| { + datastore::Error::User( + Error::UnrecognizedAggregationJob(*task.id(), aggregation_job_id) + .into(), + ) + })?; + + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( vdaf.as_ref(), &Role::Helper, task.id(), &aggregation_job_id, - ), - )?; - - let helper_aggregation_job = helper_aggregation_job.ok_or_else(|| { - datastore::Error::User( - Error::UnrecognizedAggregationJob(*task.id(), aggregation_job_id) - .into(), + helper_aggregation_job.aggregation_parameter(), ) - })?; + .await?; // If the leader's request is on the same round as our stored aggregation job, // then we probably have already received this message and computed this round, @@ -3080,7 +3142,7 @@ mod tests { }; use prio::{ codec::Encode, - vdaf::{self, prio3::Prio3Count, Client as _}, + vdaf::{self, prio3::Prio3Count, Client as VdafClient}, }; use rand::random; use std::{collections::HashSet, iter, sync::Arc, time::Duration as StdDuration}; diff --git a/aggregator/src/aggregator/accumulator.rs b/aggregator/src/aggregator/accumulator.rs index 0bdf2b9ba..959a6f75e 100644 --- a/aggregator/src/aggregator/accumulator.rs +++ b/aggregator/src/aggregator/accumulator.rs @@ -26,7 +26,7 @@ use std::{ /// Accumulates output shares in memory and eventually flushes accumulations to a datastore. We /// accumulate output shares into a [`HashMap`] mapping the batch identifier at which the batch /// interval begins to the accumulated aggregate share, report count and checksum. -#[derive(Derivative)] +#[derive(Clone, Derivative)] #[derivative(Debug)] pub struct Accumulator< const SEED_SIZE: usize, @@ -40,7 +40,7 @@ pub struct Accumulator< aggregations: HashMap>, } -#[derive(Debug)] +#[derive(Clone, Debug)] struct BatchData< const SEED_SIZE: usize, Q: AccumulableQueryType, diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 59ffefdbe..69aedf036 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -16,89 +16,137 @@ use janus_core::{ }; use janus_messages::{ query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, PartialBatchSelector, - ReportMetadata, ReportShare, Role, + PrepareInit, ReportMetadata, Role, +}; +use prio::{ + codec::Encode, + idpf::IdpfInput, + vdaf::{ + self, + poplar1::{Poplar1, Poplar1AggregationParam}, + }, }; -use prio::codec::Encode; use rand::random; use std::sync::Arc; use trillium::{Handler, KnownHeaderName, Status}; use trillium_testing::{prelude::put, TestConn}; -pub(super) struct ReportShareGenerator { +pub(super) struct PrepareInitGenerator +where + V: vdaf::Vdaf, +{ clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, - vdaf: dummy_vdaf::Vdaf, + vdaf: V, + aggregation_param: V::AggregationParam, } -impl ReportShareGenerator { +impl PrepareInitGenerator +where + V: vdaf::Vdaf + vdaf::Aggregator + vdaf::Client<16>, +{ pub(super) fn new( clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, + vdaf: V, + aggregation_param: V::AggregationParam, ) -> Self { Self { clock, task, + vdaf, aggregation_param, - vdaf: dummy_vdaf::Vdaf::new(), } } - fn with_vdaf(mut self, vdaf: dummy_vdaf::Vdaf) -> Self { - self.vdaf = vdaf; - self - } - - pub(super) fn next(&self) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { - self.next_with_metadata(ReportMetadata::new( - random(), - self.clock - .now() - .to_batch_interval_start(self.task.time_precision()) - .unwrap(), - )) + pub(super) fn next( + &self, + measurement: &V::Measurement, + ) -> (PrepareInit, VdafTranscript) { + self.next_with_metadata( + ReportMetadata::new( + random(), + self.clock + .now() + .to_batch_interval_start(self.task.time_precision()) + .unwrap(), + ), + measurement, + ) } pub(super) fn next_with_metadata( &self, report_metadata: ReportMetadata, - ) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { + measurement: &V::Measurement, + ) -> (PrepareInit, VdafTranscript) { let transcript = run_vdaf( &self.vdaf, self.task.primary_vdaf_verify_key().unwrap().as_bytes(), &self.aggregation_param, report_metadata.id(), - &(), + measurement, ); - let report_share = generate_helper_report_share::( + let report_share = generate_helper_report_share::( *self.task.id(), report_metadata, self.task.current_hpke_key().config(), &transcript.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript.helper_input_share, ); - - (report_share, transcript) + ( + PrepareInit::new( + report_share, + transcript.leader_prepare_transitions[0].message.clone(), + ), + transcript, + ) } } -pub(super) struct AggregationJobInitTestCase { +pub(super) struct AggregationJobInitTestCase< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator, +> { pub(super) clock: MockClock, pub(super) task: Task, - pub(super) report_share_generator: ReportShareGenerator, - pub(super) report_shares: Vec, + pub(super) prepare_init_generator: PrepareInitGenerator, + pub(super) prepare_inits: Vec, pub(super) aggregation_job_id: AggregationJobId, aggregation_job_init_req: AggregationJobInitializeReq, - pub(super) aggregation_param: dummy_vdaf::AggregationParam, + pub(super) aggregation_param: V::AggregationParam, pub(super) handler: Box, pub(super) datastore: Arc>, _ephemeral_datastore: EphemeralDatastore, } -pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { - let test_case = setup_aggregate_init_test_without_sending_request().await; +pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy_vdaf::Vdaf> { + setup_aggregate_init_test_for_vdaf( + dummy_vdaf::Vdaf::new(), + VdafInstance::Fake, + dummy_vdaf::AggregationParam(0), + 0, + ) + .await +} + +async fn setup_aggregate_init_test_for_vdaf< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator + vdaf::Client<16>, +>( + vdaf: V, + vdaf_instance: VdafInstance, + aggregation_param: V::AggregationParam, + measurement: V::Measurement, +) -> AggregationJobInitTestCase { + let test_case = setup_aggregate_init_test_without_sending_request( + vdaf, + vdaf_instance, + aggregation_param, + measurement, + ) + .await; let response = put_aggregation_job( &test_case.task, @@ -112,10 +160,18 @@ pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { test_case } -async fn setup_aggregate_init_test_without_sending_request() -> AggregationJobInitTestCase { +async fn setup_aggregate_init_test_without_sending_request< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator + vdaf::Client<16>, +>( + vdaf: V, + vdaf_instance: VdafInstance, + aggregation_param: V::AggregationParam, + measurement: V::Measurement, +) -> AggregationJobInitTestCase { install_test_trace_subscriber(); - let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper).build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); @@ -131,28 +187,26 @@ async fn setup_aggregate_init_test_without_sending_request() -> AggregationJobIn .await .unwrap(); - let aggregation_param = dummy_vdaf::AggregationParam(0); - - let report_share_generator = - ReportShareGenerator::new(clock.clone(), task.clone(), aggregation_param); + let prepare_init_generator = + PrepareInitGenerator::new(clock.clone(), task.clone(), vdaf, aggregation_param.clone()); - let report_shares = Vec::from([ - report_share_generator.next().0, - report_share_generator.next().0, + let prepare_inits = Vec::from([ + prepare_init_generator.next(&measurement).0, + prepare_init_generator.next(&measurement).0, ]); let aggregation_job_id = random(); let aggregation_job_init_req = AggregationJobInitializeReq::new( aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - report_shares.clone(), + prepare_inits.clone(), ); AggregationJobInitTestCase { clock, task, - report_shares, - report_share_generator, + prepare_inits, + prepare_init_generator, aggregation_job_id, aggregation_job_init_req, aggregation_param, @@ -184,7 +238,13 @@ pub(crate) async fn put_aggregation_job( #[tokio::test] async fn aggregation_job_init_authorization_dap_auth_token() { - let test_case = setup_aggregate_init_test_without_sending_request().await; + let test_case = setup_aggregate_init_test_without_sending_request( + dummy_vdaf::Vdaf::new(), + VdafInstance::Fake, + dummy_vdaf::AggregationParam(0), + 0, + ) + .await; // Find a DapAuthToken among the task's aggregator auth tokens let (auth_header, auth_value) = test_case .task @@ -216,7 +276,13 @@ async fn aggregation_job_init_authorization_dap_auth_token() { #[case::not_base64("Bearer: ")] #[tokio::test] async fn aggregation_job_init_malformed_authorization_header(#[case] header_value: &'static str) { - let test_case = setup_aggregate_init_test_without_sending_request().await; + let test_case = setup_aggregate_init_test_without_sending_request( + dummy_vdaf::Vdaf::new(), + VdafInstance::Fake, + dummy_vdaf::AggregationParam(0), + 0, + ) + .await; let response = put(test_case .task @@ -254,7 +320,7 @@ async fn aggregation_job_mutation_aggregation_job() { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(1).get_encoded(), PartialBatchSelector::new_time_interval(), - test_case.report_shares, + test_case.prepare_inits, ); let response = put_aggregation_job( @@ -273,28 +339,28 @@ async fn aggregation_job_mutation_report_shares() { // Put the aggregation job again, mutating the associated report shares' metadata such that // uniqueness constraints on client_reports are violated - for mutated_report_shares in [ + for mutated_prepare_inits in [ // Omit a report share that was included previously - Vec::from(&test_case.report_shares[0..test_case.report_shares.len() - 1]), + Vec::from(&test_case.prepare_inits[0..test_case.prepare_inits.len() - 1]), // Include a different report share than was included previously [ - &test_case.report_shares[0..test_case.report_shares.len() - 1], - &[test_case.report_share_generator.next().0], + &test_case.prepare_inits[0..test_case.prepare_inits.len() - 1], + &[test_case.prepare_init_generator.next(&0).0], ] .concat(), // Include an extra report share than was included previously [ - test_case.report_shares.as_slice(), - &[test_case.report_share_generator.next().0], + test_case.prepare_inits.as_slice(), + &[test_case.prepare_init_generator.next(&0).0], ] .concat(), // Reverse the order of the reports - test_case.report_shares.into_iter().rev().collect(), + test_case.prepare_inits.into_iter().rev().collect(), ] { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( test_case.aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - mutated_report_shares, + mutated_prepare_inits, ); let response = put_aggregation_job( &test_case.task, @@ -309,20 +375,32 @@ async fn aggregation_job_mutation_report_shares() { #[tokio::test] async fn aggregation_job_mutation_report_aggregations() { - let test_case = setup_aggregate_init_test().await; + // We must run Poplar1 in this test so that the aggregation job won't finish on the first step + + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + let test_case = setup_aggregate_init_test_for_vdaf( + Poplar1::new_sha3(1), + VdafInstance::Poplar1 { bits: 1 }, + aggregation_param, + IdpfInput::from_bools(&[true]), + ) + .await; - // Generate some new reports using the existing reports' metadata, but varying the input shares - // such that the prepare state computed during aggregation initializaton won't match the first - // aggregation job. - let mutated_report_shares_generator = test_case - .report_share_generator - .with_vdaf(dummy_vdaf::Vdaf::new().with_input_share(dummy_vdaf::InputShare(1))); - let mutated_report_shares = test_case - .report_shares + // Generate some new reports using the existing reports' metadata, but varying the measurement + // values such that the prepare state computed during aggregation initializaton won't match the + // first aggregation job. + let mutated_prepare_inits = test_case + .prepare_inits .iter() .map(|s| { - mutated_report_shares_generator - .next_with_metadata(s.metadata().clone()) + test_case + .prepare_init_generator + .next_with_metadata( + s.report_share().metadata().clone(), + &IdpfInput::from_bools(&[false]), + ) .0 }) .collect(); @@ -330,8 +408,10 @@ async fn aggregation_job_mutation_report_aggregations() { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( test_case.aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - mutated_report_shares, + mutated_prepare_inits, ); + + tracing::info!("putting mutated agg job"); let response = put_aggregation_job( &test_case.task, &test_case.aggregation_job_id, diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 72c79ccd4..f4d4b5800 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -1,5 +1,6 @@ //! Implements portions of aggregation job continuation for the helper. +use super::error::handle_ping_pong_error; use crate::aggregator::{accumulator::Accumulator, Error, VdafOps}; use futures::future::try_join_all; use janus_aggregator_core::{ @@ -13,16 +14,18 @@ use janus_aggregator_core::{ }; use janus_core::time::Clock; use janus_messages::{ - AggregationJobContinueReq, AggregationJobResp, PrepareStep, PrepareStepResult, ReportShareError, + AggregationJobContinueReq, AggregationJobResp, PrepareError, PrepareResp, PrepareStepResult, + Role, }; -use opentelemetry::{metrics::Counter, KeyValue}; +use opentelemetry::metrics::Counter; use prio::{ codec::{Encode, ParameterizedDecode}, - vdaf::{self, PrepareTransition}, + topology::ping_pong::{self, PingPongTopology}, + vdaf, }; -use std::{io::Cursor, sync::Arc}; +use std::sync::Arc; use tokio::try_join; -use tracing::{info, trace_span}; +use tracing::trace_span; impl VdafOps { /// Step the helper's aggregation job to the next round of VDAF preparation using the round `n` @@ -69,12 +72,10 @@ impl VdafOps { if report_agg.report_id() != prep_step.report_id() { // This report was omitted by the leader because of a prior failure. Note that // the report was dropped (if it's not already in an error state) and continue. - if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { + if matches!(report_agg.state(), ReportAggregationState::Waiting(_)) { *report_agg = report_agg .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportDropped, - )) + .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)) .with_last_prep_step(None); } continue; @@ -94,18 +95,16 @@ impl VdafOps { if !conflicting_aggregate_share_jobs.is_empty() { *report_aggregation = report_aggregation .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::BatchCollected, - )) - .with_last_prep_step(Some(PrepareStep::new( + .with_state(ReportAggregationState::Failed(PrepareError::BatchCollected)) + .with_last_prep_step(Some(PrepareResp::new( *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), + PrepareStepResult::Reject(PrepareError::BatchCollected), ))); continue; } - let prep_state = match report_aggregation.state() { - ReportAggregationState::Waiting(prep_state, _) => prep_state, + let transition = match report_aggregation.state() { + ReportAggregationState::Waiting(transition) => transition, _ => { return Err(datastore::Error::User( Error::UnrecognizedMessage( @@ -117,84 +116,91 @@ impl VdafOps { } }; - // Parse preparation message out of prepare step received from leader. - let prep_msg = match prep_step.result() { - PrepareStepResult::Continued(payload) => A::PrepareMessage::decode_with_param( - prep_state, - &mut Cursor::new(payload.as_ref()), - )?, - _ => { - return Err(datastore::Error::User( - Error::UnrecognizedMessage( - Some(*task.id()), - "leader sent non-Continued prepare step", + let (report_aggregation_state, prepare_step_result, output_share) = + trace_span!("VDAF preparation") + .in_scope(|| { + // Evaluate the stored transition to recover our current state. + transition + .evaluate(vdaf.as_ref()) + .and_then(|(state, _)| { + // Then continue with the incoming message. + vdaf.continued(ping_pong::Role::Helper, state, prep_step.message()) + }) + .and_then(|continued_value| match continued_value { + ping_pong::ContinuedValue::WithMessage { + transition: new_transition, + } => { + let (new_state, message) = + new_transition.evaluate(vdaf.as_ref())?; + let (report_aggregation_state, output_share) = match new_state { + // Helper did not finish. Store the new transition and await the next message + // from the Leader to advance preparation. + ping_pong::State::Continued(_) => { + (ReportAggregationState::Waiting(new_transition), None) + } + // Helper finished. Commit the output share. + ping_pong::State::Finished(output_share) => { + (ReportAggregationState::Finished, Some(output_share)) + } + }; + + Ok(( + report_aggregation_state, + // Helper has an outgoing message for Leader + PrepareStepResult::Continue { message }, + output_share, + )) + } + ping_pong::ContinuedValue::FinishedNoMessage { output_share } => { + Ok(( + ReportAggregationState::Finished, + PrepareStepResult::Finished, + Some(output_share), + )) + } + }) + }) + .map_err(|error| { + handle_ping_pong_error( + task.id(), + Role::Leader, + prep_step.report_id(), + error, + &aggregate_step_failure_counter, ) - .into(), - )); - } - }; - - // Compute the next transition. - let prepare_step_res = trace_span!("VDAF preparation") - .in_scope(|| vdaf.prepare_step(prep_state.clone(), prep_msg)); - match prepare_step_res { - Ok(PrepareTransition::Continue(prep_state, prep_share)) => { - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Waiting(prep_state, None)) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Continued(prep_share.get_encoded()), - ))); - } - - Ok(PrepareTransition::Finish(output_share)) => { - accumulator.update( - helper_aggregation_job.partial_batch_identifier(), - prep_step.report_id(), - report_aggregation.time(), - &output_share, - )?; - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Finished) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Finished, - ))); - } - - Err(error) => { - info!( - task_id = %task.id(), - job_id = %helper_aggregation_job.id(), - report_id = %prep_step.report_id(), - ?error, "Prepare step failed", - ); - aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_step_failure")]); - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::VdafPrepError, - )) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError), - ))) - } - }; + }) + .unwrap_or_else(|prepare_error| { + ( + ReportAggregationState::Failed(prepare_error), + PrepareStepResult::Reject(prepare_error), + None, + ) + }); + + *report_aggregation = report_aggregation + .clone() + .with_state(report_aggregation_state) + .with_last_prep_step(Some(PrepareResp::new( + *prep_step.report_id(), + prepare_step_result, + ))); + if let Some(output_share) = output_share { + accumulator.update( + helper_aggregation_job.partial_batch_identifier(), + prep_step.report_id(), + report_aggregation.time(), + &output_share, + )?; + } } for report_agg in report_aggregations_iter { // This report was omitted by the leader because of a prior failure. Note that the // report was dropped (if it's not already in an error state) and continue. - if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { + if matches!(report_agg.state(), ReportAggregationState::Waiting(_)) { *report_agg = report_agg .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportDropped, - )) + .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)) .with_last_prep_step(None); } } @@ -206,26 +212,24 @@ impl VdafOps { if unwritable_reports.contains(report_aggregation.report_id()) { *report_aggregation = report_aggregation .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::BatchCollected, - )) - .with_last_prep_step(Some(PrepareStep::new( + .with_state(ReportAggregationState::Failed(PrepareError::BatchCollected)) + .with_last_prep_step(Some(PrepareResp::new( *report_aggregation.report_id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), + PrepareStepResult::Reject(PrepareError::BatchCollected), ))); } } let saw_continue = report_aggregations.iter().any(|report_agg| { matches!( - report_agg.last_prep_step().map(PrepareStep::result), - Some(PrepareStepResult::Continued(_)) + report_agg.last_prep_step().map(PrepareResp::result), + Some(PrepareStepResult::Continue { .. }) ) }); let saw_finish = report_aggregations.iter().any(|report_agg| { matches!( - report_agg.last_prep_step().map(PrepareStep::result), - Some(PrepareStepResult::Finished) + report_agg.last_prep_step().map(PrepareResp::result), + Some(PrepareStepResult::Finished { .. }) ) }); let helper_aggregation_job = helper_aggregation_job @@ -252,7 +256,7 @@ impl VdafOps { try_join_all( report_aggregations .iter() - .map(|ra| tx.update_report_aggregation(ra)) + .map(|report_agg| tx.update_report_aggregation(report_agg)), ), )?; @@ -260,16 +264,16 @@ impl VdafOps { } /// Constructs an AggregationJobResp from a given set of Helper report aggregations. - pub(super) fn aggregation_job_resp_for< - const SEED_SIZE: usize, - A: vdaf::Aggregator, - >( + pub(super) fn aggregation_job_resp_for( report_aggregations: impl IntoIterator>, - ) -> AggregationJobResp { + ) -> AggregationJobResp + where + A: vdaf::Aggregator, + { AggregationJobResp::new( report_aggregations .into_iter() - .filter_map(|ra| ra.last_prep_step().cloned()) + .filter_map(|v| v.last_prep_step().cloned()) .collect(), ) } @@ -366,7 +370,7 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - aggregate_init_tests::ReportShareGenerator, + aggregate_init_tests::PrepareInitGenerator, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, post_aggregation_job_expecting_status, @@ -386,87 +390,121 @@ mod tests { test_util::noop_meter, }; use janus_core::{ - task::VdafInstance, - test_util::{dummy_vdaf, install_test_trace_subscriber}, + task::{VdafInstance, VERIFY_KEY_LENGTH}, + test_util::install_test_trace_subscriber, time::{IntervalExt, MockClock}, }; use janus_messages::{ query_type::TimeInterval, AggregationJobContinueReq, AggregationJobId, AggregationJobResp, - AggregationJobRound, Interval, PrepareStep, PrepareStepResult, Role, + AggregationJobRound, Interval, PrepareContinue, PrepareResp, PrepareStepResult, Role, + }; + use prio::{ + idpf::IdpfInput, + vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + Aggregator, + }, }; - use prio::codec::Encode; use rand::random; use std::sync::Arc; use trillium::{Handler, Status}; - struct AggregationJobContinueTestCase { + struct AggregationJobContinueTestCase< + const VERIFY_KEY_LENGTH: usize, + V: Aggregator, + > { task: Task, datastore: Arc>, - report_generator: ReportShareGenerator, + prepare_init_generator: PrepareInitGenerator, aggregation_job_id: AggregationJobId, first_continue_request: AggregationJobContinueReq, first_continue_response: Option, + aggregation_param: V::AggregationParam, handler: Box, _ephemeral_datastore: EphemeralDatastore, } /// Set up a helper with an aggregation job in round 0 #[allow(clippy::unit_arg)] - async fn setup_aggregation_job_continue_test() -> AggregationJobContinueTestCase { + async fn setup_aggregation_job_continue_test( + ) -> AggregationJobContinueTestCase> { // Prepare datastore & request. install_test_trace_subscriber(); let aggregation_job_id = random(); - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let meter = noop_meter(); let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let report_generator = ReportShareGenerator::new( + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let prepare_init_generator = PrepareInitGenerator::new( clock.clone(), task.clone(), - dummy_vdaf::AggregationParam::default(), + Poplar1::new_sha3(1), + aggregation_param.clone(), ); - let report = report_generator.next(); + let (prepare_init, transcript) = + prepare_init_generator.next(&IdpfInput::from_bools(&[true])); datastore .run_tx(|tx| { - let (task, report) = (task.clone(), report.clone()); + let (task, aggregation_param, prepare_init, transcript) = ( + task.clone(), + aggregation_param.clone(), + prepare_init.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await.unwrap(); - tx.put_report_share(task.id(), &report.0).await.unwrap(); + tx.put_report_share(task.id(), prepare_init.report_share()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::from_time(prepare_init.report_share().metadata().time()).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await + .unwrap(); - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation::>( + &ReportAggregation::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam::default(), - (), - Interval::from_time(report.0.metadata().time()).unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *prepare_init.report_share().metadata().id(), + *prepare_init.report_share().metadata().time(), + 0, + None, + ReportAggregationState::Waiting( + transcript.helper_prepare_transitions[0].transition.clone(), + ), ), ) .await .unwrap(); - let (prep_state, _) = report.1.helper_prep_state(0); - tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( - *task.id(), - aggregation_job_id, - *report.0.metadata().id(), - *report.0.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(*prep_state, None), - )) - .await - .unwrap(); - Ok(()) }) }) @@ -475,9 +513,9 @@ mod tests { let first_continue_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( - *report.0.metadata().id(), - PrepareStepResult::Continued(report.1.prepare_messages[0].get_encoded()), + Vec::from([PrepareContinue::new( + *prepare_init.report_share().metadata().id(), + transcript.leader_prepare_transitions[1].message.clone(), )]), ); @@ -494,8 +532,9 @@ mod tests { AggregationJobContinueTestCase { task, datastore, - report_generator, + prepare_init_generator, aggregation_job_id, + aggregation_param, first_continue_request, first_continue_response: None, handler: Box::new(handler), @@ -503,10 +542,10 @@ mod tests { } } - /// Set up a helper with an aggregation job in round 1 + /// Set up a helper with an aggregation job in round 1. #[allow(clippy::unit_arg)] - async fn setup_aggregation_job_continue_round_recovery_test() -> AggregationJobContinueTestCase - { + async fn setup_aggregation_job_continue_round_recovery_test( + ) -> AggregationJobContinueTestCase> { let mut test_case = setup_aggregation_job_continue_test().await; let first_continue_response = post_aggregation_job_and_decode( @@ -525,7 +564,7 @@ mod tests { .first_continue_request .prepare_steps() .iter() - .map(|step| PrepareStep::new(*step.report_id(), PrepareStepResult::Finished)) + .map(|step| PrepareResp::new(*step.report_id(), PrepareStepResult::Finished)) .collect() ) ); @@ -581,23 +620,26 @@ mod tests { async fn aggregation_job_continue_round_recovery_mutate_continue_request() { let test_case = setup_aggregation_job_continue_round_recovery_test().await; - let unrelated_report = test_case.report_generator.next(); + let (unrelated_prepare_init, unrelated_transcript) = test_case + .prepare_init_generator + .next(&IdpfInput::from_bools(&[false])); let (before_aggregation_job, before_report_aggregations) = test_case .datastore .run_tx(|tx| { - let (task_id, unrelated_report, aggregation_job_id) = ( + let (task_id, unrelated_prepare_init, aggregation_job_id, aggregation_param) = ( *test_case.task.id(), - unrelated_report.clone(), + unrelated_prepare_init.clone(), test_case.aggregation_job_id, + test_case.aggregation_param.clone(), ); Box::pin(async move { - tx.put_report_share(&task_id, &unrelated_report.0) + tx.put_report_share(&task_id, unrelated_prepare_init.report_share()) .await .unwrap(); let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -605,11 +647,12 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_sha3(1), &Role::Helper, &task_id, &aggregation_job_id, + &aggregation_param, ) .await .unwrap(); @@ -624,9 +667,11 @@ mod tests { // ID. let modified_request = AggregationJobContinueReq::new( test_case.first_continue_request.round(), - Vec::from([PrepareStep::new( - *unrelated_report.0.metadata().id(), - PrepareStepResult::Continued(unrelated_report.1.prepare_messages[0].get_encoded()), + Vec::from([PrepareContinue::new( + *unrelated_prepare_init.report_share().metadata().id(), + unrelated_transcript.leader_prepare_transitions[1] + .message + .clone(), )]), ); @@ -643,11 +688,11 @@ mod tests { let (after_aggregation_job, after_report_aggregations) = test_case .datastore .run_tx(|tx| { - let (task_id, aggregation_job_id) = - (*test_case.task.id(), test_case.aggregation_job_id); + let (task_id, aggregation_job_id, aggregation_param) = + (*test_case.task.id(), test_case.aggregation_job_id, test_case.aggregation_param.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -655,11 +700,12 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_sha3(1), &Role::Helper, &task_id, &aggregation_job_id, + &aggregation_param, ) .await .unwrap(); @@ -689,7 +735,7 @@ mod tests { // round mismatch error instead of tripping the check for a request to continue // to round 0. let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index b370c80f6..ff20fb41e 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -645,7 +645,7 @@ impl AggregationJobCreator { #[cfg(test)] mod tests { use super::AggregationJobCreator; - use futures::{future::try_join_all, TryFutureExt}; + use futures::future::try_join_all; use janus_aggregator_core::{ datastore::{ models::{AggregationJob, AggregationJobState, Batch, BatchState, LeaderStoredReport}, @@ -662,6 +662,7 @@ mod tests { time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; use janus_messages::{ + codec::ParameterizedDecode, query_type::{FixedSize, TimeInterval}, AggregationJobRound, Interval, ReportId, Role, TaskId, Time, }; @@ -747,13 +748,14 @@ mod tests { .run_tx(|tx| { let (leader_task, helper_task) = (leader_task.clone(), helper_task.clone()); Box::pin(async move { + let vdaf = Prio3Count::new_count(2).unwrap(); let (leader_aggregations, leader_batches) = read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, leader_task.id()) + >(tx, leader_task.id(), &vdaf) .await; let (helper_aggregations, helper_batches) = read_aggregate_info_for_task::< @@ -761,7 +763,7 @@ mod tests { TimeInterval, Prio3Count, _, - >(tx, helper_task.id()) + >(tx, helper_task.id(), &vdaf) .await; Ok(( leader_aggregations, @@ -860,23 +862,24 @@ mod tests { .unwrap(); // Verify. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = task.clone(); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, task.id(), &Prio3Count::new_count(2).unwrap()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); let mut seen_report_ids = HashSet::new(); for (agg_job, report_ids) in &agg_jobs { // Jobs are created in round 0 @@ -958,23 +961,24 @@ mod tests { .unwrap(); // Verify -- we haven't received enough reports yet, so we don't create anything. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, task.id(), &Prio3Count::new_count(2).unwrap()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert!(agg_jobs.is_empty()); assert!(batches.is_empty()); @@ -998,23 +1002,24 @@ mod tests { .unwrap(); // Verify -- the additional report we wrote allows an aggregation job to be created. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = Arc::clone(&task); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = Arc::clone(&task); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, task.id(), &Prio3Count::new_count(2).unwrap()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert_eq!(agg_jobs.len(), 1); let report_ids: HashSet<_> = agg_jobs.into_iter().next().unwrap().1.into_iter().collect(); assert_eq!( @@ -1107,23 +1112,24 @@ mod tests { .unwrap(); // Verify. - let (agg_jobs, batches) = - job_creator - .datastore - .run_tx(|tx| { - let task = task.clone(); - Box::pin(async move { - Ok(read_aggregate_info_for_task::< + let (agg_jobs, batches) = job_creator + .datastore + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { + Ok( + read_aggregate_info_for_task::< VERIFY_KEY_LENGTH, TimeInterval, Prio3Count, _, - >(tx, task.id()) - .await) - }) + >(tx, task.id(), &Prio3Count::new_count(2).unwrap()) + .await, + ) }) - .await - .unwrap(); + }) + .await + .unwrap(); let mut seen_report_ids = HashSet::new(); for (agg_job, report_ids) in &agg_jobs { // Job immediately finished since all reports are in a closed batch. @@ -1243,7 +1249,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -1404,7 +1412,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -1520,7 +1530,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -1577,7 +1589,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -1701,7 +1715,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -1765,7 +1781,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -1904,7 +1922,9 @@ mod tests { FixedSize, Prio3Count, _, - >(tx, task.id()) + >( + tx, task.id(), &Prio3Count::new_count(2).unwrap() + ) .await, )) }) @@ -2032,18 +2052,20 @@ mod tests { /// Test helper function that reads all aggregation jobs & batches for a given task ID, /// returning the aggregation jobs, the report IDs included in the aggregation job, and the /// batches. Report IDs are returned in the order they are included in the aggregation job. - async fn read_aggregate_info_for_task< - const SEED_SIZE: usize, - Q: AccumulableQueryType, - A: vdaf::Aggregator, - C: Clock, - >( + async fn read_aggregate_info_for_task( tx: &Transaction<'_, C>, task_id: &TaskId, + vdaf: &A, ) -> ( Vec<(AggregationJob, Vec)>, Vec>, - ) { + ) + where + Q: AccumulableQueryType, + A: vdaf::Aggregator, + C: Clock, + for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, + { try_join!( try_join_all( tx.get_aggregation_jobs_for_task(task_id) @@ -2053,18 +2075,19 @@ mod tests { .map(|agg_job| async { let agg_job_id = *agg_job.id(); tx.get_report_aggregations_for_aggregation_job( - &dummy_vdaf::Vdaf::new(), + vdaf, &Role::Leader, task_id, &agg_job_id, + agg_job.aggregation_parameter(), ) - .map_ok(move |report_aggs| { + .await + .map(|report_aggs| { ( agg_job, report_aggs.into_iter().map(|ra| *ra.report_id()).collect(), ) }) - .await }), ), tx.get_batches_for_task(task_id), diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 8db01020a..aa1e5eaaa 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -21,8 +21,8 @@ use janus_core::{time::Clock, vdaf_dispatch}; use janus_messages::{ query_type::{FixedSize, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, - PartialBatchSelector, PrepareStep, PrepareStepResult, ReportId, ReportShare, ReportShareError, - Role, + PartialBatchSelector, PrepareContinue, PrepareError, PrepareInit, PrepareResp, + PrepareStepResult, ReportId, ReportShare, Role, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter, Unit}, @@ -30,7 +30,8 @@ use opentelemetry::{ }; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, - vdaf::{self, PrepareTransition}, + topology::ping_pong::{self, PingPongTopology}, + vdaf, }; use reqwest::Method; use std::{ @@ -41,6 +42,8 @@ use std::{ use tokio::try_join; use tracing::{info, trace_span, warn}; +use super::error::handle_ping_pong_error; + #[derive(Derivative)] #[derivative(Debug)] pub struct AggregationJobDriver { @@ -157,30 +160,31 @@ impl AggregationJobDriver { ) })?; - let aggregation_job_future = tx.get_aggregation_job::( - lease.leased().task_id(), - lease.leased().aggregation_job_id(), - ); - let report_aggregations_future = tx + let aggregation_job = tx + .get_aggregation_job::( + lease.leased().task_id(), + lease.leased().aggregation_job_id(), + ) + .await? + .ok_or_else(|| { + datastore::Error::User( + anyhow!( + "couldn't find aggregation job {} for task {}", + *lease.leased().aggregation_job_id(), + *lease.leased().task_id(), + ) + .into(), + ) + })?; + let report_aggregations = tx .get_report_aggregations_for_aggregation_job( vdaf.as_ref(), &Role::Leader, lease.leased().task_id(), lease.leased().aggregation_job_id(), - ); - - let (aggregation_job, report_aggregations) = - try_join!(aggregation_job_future, report_aggregations_future)?; - let aggregation_job = aggregation_job.ok_or_else(|| { - datastore::Error::User( - anyhow!( - "couldn't find aggregation job {} for task {}", - *lease.leased().aggregation_job_id(), - *lease.leased().task_id(), - ) - .into(), + aggregation_job.aggregation_parameter(), ) - })?; + .await?; // Read client reports, but only for report aggregations in state START. // TODO(#224): create "get_client_reports_for_aggregation_job" datastore @@ -233,7 +237,7 @@ impl AggregationJobDriver { for report_aggregation in &report_aggregations { match report_aggregation.state() { ReportAggregationState::Start => saw_start = true, - ReportAggregationState::Waiting(_, _) => saw_waiting = true, + ReportAggregationState::Waiting(_) => saw_waiting = true, ReportAggregationState::Finished => saw_finished = true, ReportAggregationState::Failed(_) => (), // ignore failed aggregations } @@ -315,7 +319,7 @@ impl AggregationJobDriver { // Compute report shares to send to helper, and decrypt our input shares & initialize // preparation state. let mut report_aggregations_to_write = Vec::new(); - let mut report_shares = Vec::new(); + let mut prepare_inits = Vec::new(); let mut stepped_aggregations = Vec::new(); for report_aggregation in report_aggregations { // Look up report. @@ -325,9 +329,10 @@ impl AggregationJobDriver { info!(report_id = %report_aggregation.report_id(), "Attempted to aggregate missing report (most likely garbage collected)"); self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "missing_client_report")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::ReportDropped), - )); + report_aggregations_to_write.push( + report_aggregation + .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)), + ); continue; }; @@ -342,44 +347,52 @@ impl AggregationJobDriver { self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "duplicate_extension")]); report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), + ReportAggregationState::Failed(PrepareError::UnrecognizedMessage), )); continue; } // Initialize the leader's preparation state from the input share. - let prepare_init_res = trace_span!("VDAF preparation").in_scope(|| { - vdaf.prepare_init( + match trace_span!("VDAF preparation").in_scope(|| { + vdaf.leader_initialize( verify_key.as_bytes(), - Role::Leader.index().unwrap(), aggregation_job.aggregation_parameter(), + // DAP report ID is used as VDAF nonce report.metadata().id().as_ref(), report.public_share(), report.leader_input_share(), ) - }); - let (prep_state, prep_share) = match prepare_init_res { - Ok(prep_state_and_share) => prep_state_and_share, - Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't initialize leader's preparation state"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_init_failure")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + .map_err(|ping_pong_error| { + handle_ping_pong_error( + task.id(), + Role::Helper, + report.metadata().id(), + ping_pong_error, + &self.aggregate_step_failure_counter, + ) + }) + }) { + Ok((ping_pong_state, ping_pong_message)) => { + prepare_inits.push(PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + ping_pong_message.clone(), )); + stepped_aggregations.push(SteppedAggregation { + report_aggregation, + leader_state: ping_pong_state, + }); + } + Err(prep_error) => { + report_aggregations_to_write.push( + report_aggregation.with_state(ReportAggregationState::Failed(prep_error)), + ); continue; } - }; - - report_shares.push(ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), - )); - stepped_aggregations.push(SteppedAggregation { - report_aggregation, - leader_transition: PrepareTransition::Continue(prep_state, prep_share), - }); + } } // Construct request, send it to the helper, and process the response. @@ -388,7 +401,7 @@ impl AggregationJobDriver { let req = AggregationJobInitializeReq::::new( aggregation_job.aggregation_parameter().get_encoded(), PartialBatchSelector::new(aggregation_job.partial_batch_identifier().clone()), - report_shares, + prepare_inits, ); let resp_bytes = send_request_to_helper( @@ -410,9 +423,9 @@ impl AggregationJobDriver { lease, task, aggregation_job, - stepped_aggregations, + &stepped_aggregations, report_aggregations_to_write, - resp.prepare_steps(), + resp.prepare_resps(), ) .await } @@ -443,60 +456,47 @@ impl AggregationJobDriver { // Visit the report aggregations, ignoring any that have already failed; compute our own // next step & transitions to send to the helper. let mut report_aggregations_to_write = Vec::new(); - let mut prepare_steps = Vec::new(); + let mut prepare_continues = Vec::new(); let mut stepped_aggregations = Vec::new(); for report_aggregation in report_aggregations { - if let ReportAggregationState::Waiting(prep_state, prep_msg) = - report_aggregation.state() - { - let prep_msg = match prep_msg.as_ref() { - Some(prep_msg) => prep_msg, - None => { - // This error indicates programmer/system error (i.e. it cannot possibly be - // the fault of our co-aggregator). We still record this failure against a - // single report, rather than failing the entire request, to minimize impact - // if we ever encounter this bug. - info!(report_id = %report_aggregation.report_id(), "Report aggregation is missing prepare message"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "missing_prepare_message")]); - report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::VdafPrepError), - )); - continue; - } - }; - - // Step our own state. - let prepare_step_res = trace_span!("VDAF preparation") - .in_scope(|| vdaf.prepare_step(prep_state.clone(), prep_msg.clone())); - let leader_transition = match prepare_step_res { - Ok(leader_transition) => leader_transition, + if let ReportAggregationState::Waiting(transition) = report_aggregation.state() { + let (prep_state, message) = match transition.evaluate(vdaf.as_ref()) { + Ok((state, message)) => (state, message), Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Prepare step failed"); + // This shouldn't ever happen, because we'd never store a transition that + // can't be evaluated. Most likely this indicates a programmer error (e.g., + // using the wrong VDAF) but we still record this failure against a single + // report, rather than failing the entire request, to minimize impact if we + // ever encounter this bug. + info!( + report_id = %report_aggregation.report_id(), + error = ?error, + "Stored transition cannot be evaluated", + ); self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_step_failure")]); + .add(1, &[KeyValue::new("type", "invalid_ping_pong_transition")]); report_aggregations_to_write.push(report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), )); continue; } }; - prepare_steps.push(PrepareStep::new( + prepare_continues.push(PrepareContinue::new( *report_aggregation.report_id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + message, )); stepped_aggregations.push(SteppedAggregation { - report_aggregation, - leader_transition, - }) + report_aggregation: report_aggregation.clone(), + leader_state: prep_state.clone(), + }); } } // Construct request, send it to the helper, and process the response. // TODO(#235): abandon work immediately on "terminal" failures from helper, or other // unexpected cases such as unknown/unexpected content type. - let req = AggregationJobContinueReq::new(aggregation_job.round(), prepare_steps); + let req = AggregationJobContinueReq::new(aggregation_job.round(), prepare_continues); let resp_bytes = send_request_to_helper( &self.http_client, @@ -517,9 +517,9 @@ impl AggregationJobDriver { lease, task, aggregation_job, - stepped_aggregations, + &stepped_aggregations, report_aggregations_to_write, - resp.prepare_steps(), + resp.prepare_resps(), ) .await } @@ -537,9 +537,9 @@ impl AggregationJobDriver { lease: Arc>, task: Arc, aggregation_job: AggregationJob, - stepped_aggregations: Vec>, + stepped_aggregations: &[SteppedAggregation], mut report_aggregations_to_write: Vec>, - helper_prep_steps: &[PrepareStep], + helper_prep_resps: &[PrepareResp], ) -> Result<()> where A: 'static, @@ -551,7 +551,7 @@ impl AggregationJobDriver { A::PrepareState: Send + Sync + Encode, { // Handle response, computing the new report aggregations to be stored. - if stepped_aggregations.len() != helper_prep_steps.len() { + if stepped_aggregations.len() != helper_prep_resps.len() { return Err(anyhow!( "missing, duplicate, out-of-order, or unexpected prepare steps in response" )); @@ -561,95 +561,123 @@ impl AggregationJobDriver { self.batch_aggregation_shard_count, aggregation_job.aggregation_parameter().clone(), ); - for (stepped_aggregation, helper_prep_step) in - stepped_aggregations.into_iter().zip(helper_prep_steps) + for (stepped_aggregation, helper_prep_resp) in + stepped_aggregations.iter().zip(helper_prep_resps) { - let (report_aggregation, leader_transition) = ( - stepped_aggregation.report_aggregation, - stepped_aggregation.leader_transition, - ); - if helper_prep_step.report_id() != report_aggregation.report_id() { + if helper_prep_resp.report_id() != stepped_aggregation.report_aggregation.report_id() { return Err(anyhow!( "missing, duplicate, out-of-order, or unexpected prepare steps in response" )); } - let new_state = match helper_prep_step.result() { - PrepareStepResult::Continued(payload) => { - // If the leader continued too, combine the leader's prepare share with the - // helper's to compute next round's prepare message. Prepare to store the - // leader's new state & the prepare message. If the leader didn't continue, - // transition to INVALID. - if let PrepareTransition::Continue(leader_prep_state, leader_prep_share) = - leader_transition - { - let leader_prep_state = leader_prep_state.clone(); - let helper_prep_share = - A::PrepareShare::get_decoded_with_param(&leader_prep_state, payload) - .context("couldn't decode helper's prepare message"); - let prep_msg = helper_prep_share.and_then(|helper_prep_share| { - vdaf.prepare_preprocess([leader_prep_share.clone(), helper_prep_share]) - .context( - "couldn't preprocess leader & helper prepare shares into \ - prepare message", - ) + let new_state = (|| match helper_prep_resp.result() { + PrepareStepResult::Continue { + message: helper_prep_msg, + } => { + let state_and_message = vdaf + .continued( + ping_pong::Role::Leader, + stepped_aggregation.leader_state.clone(), + helper_prep_msg, + ) + .map_err(|ping_pong_error| { + handle_ping_pong_error( + task.id(), + Role::Helper, + stepped_aggregation.report_aggregation.report_id(), + ping_pong_error, + &self.aggregate_step_failure_counter, + ) }); - match prep_msg { - Ok(prep_msg) => { - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) - } - Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't compute prepare message"); + + match state_and_message { + Ok(ping_pong::ContinuedValue::WithMessage { transition }) => { + // Leader did not finish. Store our state and outgoing message for the + // next round. + // n.b. it's possible we finished and recovered an output share at the + // VDAF level (i.e., state may be ping_pong::State::Finished) but we + // cannot finish at the DAP layer and commit the output share until we + // get confirmation from the Helper that they finished, too. + ReportAggregationState::Waiting(transition) + } + Ok(ping_pong::ContinuedValue::FinishedNoMessage { output_share }) => { + // We finished and have no outgoing message, meaning the Helper was + // already finished. Commit the output share. + if let Err(err) = accumulator.update( + aggregation_job.partial_batch_identifier(), + stepped_aggregation.report_aggregation.report_id(), + stepped_aggregation.report_aggregation.time(), + &output_share, + ) { + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + ?err, + "Could not update batch aggregation", + ); self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "prepare_message_failure")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + .add(1, &[KeyValue::new("type", "accumulate_failure")]); + return ReportAggregationState::Failed(PrepareError::VdafPrepError); } + + ReportAggregationState::Finished } - } else { - warn!(report_id = %report_aggregation.report_id(), "Helper continued but leader did not"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "continue_mismatch")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + Err(prepare_error) => ReportAggregationState::Failed(prepare_error), } } PrepareStepResult::Finished => { - // If the leader finished too, we are done; prepare to store the output share. - // If the leader didn't finish too, we transition to INVALID. - if let PrepareTransition::Finish(out_share) = leader_transition { - match accumulator.update( + if let ping_pong::State::Finished(output_share) = + &stepped_aggregation.leader_state + { + // Helper finished and we had already finished. Commit the output share. + if let Err(err) = accumulator.update( aggregation_job.partial_batch_identifier(), - report_aggregation.report_id(), - report_aggregation.time(), - &out_share, + stepped_aggregation.report_aggregation.report_id(), + stepped_aggregation.report_aggregation.time(), + output_share, ) { - Ok(_) => ReportAggregationState::Finished, - Err(error) => { - warn!(report_id = %report_aggregation.report_id(), ?error, "Could not update batch aggregation"); - self.aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "accumulate_failure")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) - } + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + ?err, + "Could not update batch aggregation", + ); + self.aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "accumulate_failure")]); + return ReportAggregationState::Failed(PrepareError::VdafPrepError); } + + ReportAggregationState::Finished } else { - warn!(report_id = %report_aggregation.report_id(), "Helper finished but leader did not"); + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + "Helper finished but Leader did not", + ); self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "finish_mismatch")]); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + ReportAggregationState::Failed(PrepareError::VdafPrepError) } } - PrepareStepResult::Failed(err) => { + PrepareStepResult::Reject(err) => { // If the helper failed, we move to FAILED immediately. // TODO(#236): is it correct to just record the transition error that the helper reports? - info!(report_id = %report_aggregation.report_id(), helper_error = ?err, "Helper couldn't step report aggregation"); + info!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + helper_error = ?err, + "Helper couldn't step report aggregation", + ); self.aggregate_step_failure_counter .add(1, &[KeyValue::new("type", "helper_step_failure")]); ReportAggregationState::Failed(*err) } - }; + })(); - report_aggregations_to_write.push(report_aggregation.with_state(new_state)); + report_aggregations_to_write.push( + stepped_aggregation + .report_aggregation + .clone() + .with_state(new_state), + ); } // Write everything back to storage. @@ -660,6 +688,7 @@ impl AggregationJobDriver { report_aggregations_to_write, )?; let aggregation_job_writer = Arc::new(aggregation_job_writer); + let accumulator = Arc::new(accumulator); datastore .run_tx_with_name("step_aggregation_job_2", |tx| { @@ -738,6 +767,7 @@ impl AggregationJobDriver { A::AggregateShare: Send + Sync, A::AggregationParam: Send + Sync + PartialEq + Eq, A::PrepareMessage: Send + Sync, + A::OutputShare: Send + Sync, for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, { let vdaf = Arc::new(vdaf); @@ -784,6 +814,7 @@ impl AggregationJobDriver { &Role::Leader, lease.leased().task_id(), lease.leased().aggregation_job_id(), + aggregation_job.aggregation_parameter(), ) .await?; @@ -860,7 +891,7 @@ impl AggregationJobDriver { /// transition representing the next step for the leader. struct SteppedAggregation> { report_aggregation: ReportAggregation, - leader_transition: PrepareTransition, + leader_state: ping_pong::State, } #[cfg(test)] @@ -899,19 +930,22 @@ mod tests { query_type::{FixedSize, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, Duration, Extension, ExtensionType, HpkeConfig, InputShareAad, - Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + Interval, PartialBatchSelector, PlaintextInputShare, PrepareContinue, PrepareError, + PrepareInit, PrepareResp, PrepareStepResult, ReportIdChecksum, ReportMetadata, ReportShare, + Role, TaskId, Time, }; use prio::{ codec::Encode, + idpf::IdpfInput, vdaf::{ self, + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, prio3::{Prio3, Prio3Count}, Aggregator, }, }; use rand::random; - use reqwest::Url; use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; @@ -930,13 +964,13 @@ mod tests { let mut runtime_manager = TestRuntimeManager::new(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -946,31 +980,41 @@ mod tests { let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &measurement, ); let agg_auth_token = task.primary_aggregator_auth_token().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let aggregation_job_id = random(); let collection_job = ds .run_tx(|tx| { - let (vdaf, task, report) = (vdaf.clone(), task.clone(), report.clone()); + let (vdaf, task, report, aggregation_param) = ( + vdaf.clone(), + task.clone(), + report.clone(), + aggregation_param.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; @@ -980,11 +1024,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -992,37 +1036,45 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::from_time(&time).unwrap(), )) .await?; - let collection_job = - CollectionJob::::new( - *task.id(), - random(), - batch_identifier, - (), - CollectionJobState::Start, - ); + let collection_job = CollectionJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( + *task.id(), + random(), + batch_identifier, + aggregation_param.clone(), + CollectionJobState::Start, + ); tx.put_collection_job(&collection_job).await?; Ok(collection_job) @@ -1032,15 +1084,16 @@ mod tests { .unwrap(); // Setup: prepare mocked HTTP responses. - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); let helper_responses = Vec::from([ ( "PUT", AggregationJobInitializeReq::::MEDIA_TYPE, AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, )])) .get_encoded(), ), @@ -1048,14 +1101,14 @@ mod tests { "POST", AggregationJobContinueReq::MEDIA_TYPE, AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), PrepareStepResult::Finished, )])) .get_encoded(), ), ]); - let mocked_aggregates = join_all(helper_responses.into_iter().map( + let mocked_aggregates = join_all(helper_responses.iter().map( |(req_method, req_content_type, resp_content_type, resp_body)| { server .mock( @@ -1068,7 +1121,7 @@ mod tests { "DAP-Auth-Token", str::from_utf8(agg_auth_token.as_ref()).unwrap(), ) - .match_header(CONTENT_TYPE.as_str(), req_content_type) + .match_header(CONTENT_TYPE.as_str(), *req_content_type) .with_status(200) .with_header(CONTENT_TYPE.as_str(), resp_content_type) .with_body(resp_body) @@ -1121,29 +1174,30 @@ mod tests { } let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, AggregationJobRound::from(2), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Finished, - ); - let want_batch = Batch::::new( + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let want_batch = Batch::>::new( *task.id(), batch_identifier, - (), + aggregation_param.clone(), BatchState::Closed, 0, Interval::from_time(&time).unwrap(), @@ -1159,7 +1213,7 @@ mod tests { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -1171,12 +1225,17 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? .unwrap(); let batch = tx - .get_batch(task.id(), &batch_identifier, &()) + .get_batch( + task.id(), + &batch_identifier, + aggregation_job.aggregation_parameter(), + ) .await? .unwrap(); let collection_job = tx @@ -1196,7 +1255,7 @@ mod tests { } #[tokio::test] - async fn step_time_interval_aggregation_job_init() { + async fn step_time_interval_aggregation_job_init_single_round() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; @@ -1210,7 +1269,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -1237,7 +1296,8 @@ mod tests { helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let repeated_extension_report = generate_report::( *task.id(), @@ -1248,7 +1308,8 @@ mod tests { Extension::new(ExtensionType::Tbd, Vec::new()), Extension::new(ExtensionType::Tbd, Vec::new()), ]), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let missing_report_id = random(); let aggregation_job_id = random(); @@ -1347,16 +1408,20 @@ mod tests { let leader_request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1426,11 +1491,9 @@ mod tests { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Finished, AggregationJobRound::from(1), ); - let leader_prep_state = transcript.leader_prep_state(0).clone(); - let prep_msg = transcript.prepare_messages[0].clone(); let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, @@ -1438,7 +1501,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + ReportAggregationState::Finished, ); let want_repeated_extension_report_aggregation = ReportAggregation::::new( @@ -1448,7 +1511,7 @@ mod tests { *repeated_extension_report.metadata().time(), 1, None, - ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), + ReportAggregationState::Failed(PrepareError::UnrecognizedMessage), ); let want_missing_report_report_aggregation = ReportAggregation::::new( @@ -1458,14 +1521,14 @@ mod tests { time, 2, None, - ReportAggregationState::Failed(ReportShareError::ReportDropped), + ReportAggregationState::Failed(PrepareError::ReportDropped), ); let want_batch = Batch::::new( *task.id(), batch_identifier, (), BatchState::Closing, - 1, + 0, Interval::from_time(&time).unwrap(), ); @@ -1497,6 +1560,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? @@ -1507,6 +1571,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &repeated_extension_report_id, ) .await? @@ -1517,6 +1582,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &missing_report_id, ) .await? @@ -1551,98 +1617,109 @@ mod tests { } #[tokio::test] - async fn step_fixed_size_aggregation_job_init() { + async fn step_time_interval_aggregation_job_init_two_rounds() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( - QueryType::FixedSize { - max_batch_size: 10, - batch_time_window_size: None, - }, - VdafInstance::Prio3Count, + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); - let report_metadata = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &measurement, ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); - let batch_id = random(); let aggregation_job_id = random(); let lease = ds .run_tx(|tx| { - let (vdaf, task, report) = (vdaf.clone(), task.clone(), report.clone()); + let (vdaf, task, report, aggregation_param) = ( + vdaf.clone(), + task.clone(), + report.clone(), + aggregation_param.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, + TimeInterval, + Poplar1, >::new( *task.id(), aggregation_job_id, + aggregation_param.clone(), (), - batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::InProgress, AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - ), - ) + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), - batch_id, - (), - BatchState::Open, + batch_identifier, + aggregation_param, + BatchState::Closing, 1, - Interval::from_time(report.metadata().time()).unwrap(), + Interval::from_time(&time).unwrap(), )) .await?; @@ -1662,31 +1739,23 @@ mod tests { // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregationJobInitializeReq::new( - ().get_encoded(), - PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + aggregation_param.get_encoded(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, )])); - let mocked_aggregate_failure = server - .mock( - "PUT", - task.aggregation_job_uri(&aggregation_job_id) - .unwrap() - .path(), - ) - .with_status(500) - .with_header("Content-Type", "application/problem+json") - .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") - .create_async() - .await; let mocked_aggregate_success = server .mock( "PUT", @@ -1700,7 +1769,7 @@ mod tests { ) .match_header( CONTENT_TYPE.as_str(), - AggregationJobInitializeReq::::MEDIA_TYPE, + AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded()) .with_status(200) @@ -1715,51 +1784,553 @@ mod tests { &noop_meter(), 32, ); - let error = aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease.clone())) - .await - .unwrap_err(); - assert_matches!( - error.downcast().unwrap(), - Error::Http { problem_details, dap_problem_type } => { - assert_eq!(problem_details.status.unwrap(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(dap_problem_type, Some(DapProblemType::UnauthorizedRequest)); - } - ); aggregation_job_driver .step_aggregation_job(ds.clone(), Arc::new(lease)) .await .unwrap(); // Verify. - mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = + AggregationJob::>::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(1), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Waiting( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), + ); + let want_batch = Batch::>::new( + *task.id(), + batch_identifier, + aggregation_param, + BatchState::Closing, + 1, + Interval::from_time(&time).unwrap(), + ); + + let (got_aggregation_job, got_report_aggregation, got_batch) = ds + .run_tx(|tx| { + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::>( + task.id(), + &aggregation_job_id, + ) + .await? + .unwrap(); + let report_aggregation = tx + .get_report_aggregation( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + aggregation_job.aggregation_parameter(), + &report_id, + ) + .await? + .unwrap(); + let batch = tx + .get_batch( + task.id(), + &batch_identifier, + aggregation_job.aggregation_parameter(), + ) + .await? + .unwrap(); + Ok((aggregation_job, report_aggregation, batch)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch, got_batch); + } + + #[tokio::test] + async fn step_fixed_size_aggregation_job_init_single_round() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + + let task = TaskBuilder::new( + QueryType::FixedSize { + max_batch_size: 10, + batch_time_window_size: None, + }, + VdafInstance::Prio3Count, + Role::Leader, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + + let transcript = run_vdaf( + vdaf.as_ref(), + verify_key.as_bytes(), + &(), + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.primary_aggregator_auth_token(); + let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); + let report = generate_report::( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + transcript.public_share.clone(), + Vec::new(), + &transcript.leader_input_share, + &transcript.helper_input_share, + ); + let batch_id = random(); + let aggregation_job_id = random(); + + let lease = ds + .run_tx(|tx| { + let (vdaf, task, report) = (vdaf.clone(), task.clone(), report.clone()); + Box::pin(async move { + tx.put_task(&task).await?; + tx.put_client_report(vdaf.borrow(), &report).await?; + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + FixedSize, + Prio3Count, + >::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) + .await?; + + tx.put_batch(&Batch::::new( + *task.id(), + batch_id, + (), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + )) + .await?; + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await? + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. (first an error response, then a success) + // (This is fragile in that it expects the leader request to be deterministically encoded. + // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & + // verification -- but mockito does not expose this functionality at time of writing.) + let leader_request = AggregationJobInitializeReq::new( + ().get_encoded(), + PartialBatchSelector::new_fixed_size(batch_id), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), + )]), + ); + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )])); + let mocked_aggregate_failure = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .with_status(500) + .with_header("Content-Type", "application/problem+json") + .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") + .create_async() + .await; + let mocked_aggregate_success = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header( + "DAP-Auth-Token", + str::from_utf8(agg_auth_token.as_ref()).unwrap(), + ) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + &noop_meter(), + 32, + ); + let error = aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease.clone())) + .await + .unwrap_err(); + assert_matches!( + error.downcast().unwrap(), + Error::Http { problem_details, dap_problem_type } => { + assert_eq!(problem_details.status.unwrap(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(dap_problem_type, Some(DapProblemType::UnauthorizedRequest)); + } + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_failure.assert_async().await; + mocked_aggregate_success.assert_async().await; + + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting( - transcript.leader_prep_state(0).clone(), - Some(transcript.prepare_messages[0].clone()), - ), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let want_batch = Batch::::new( + *task.id(), + batch_id, + (), + BatchState::Open, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + ); + + let (got_aggregation_job, got_report_aggregation, got_batch) = ds + .run_tx(|tx| { + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::( + task.id(), + &aggregation_job_id, + ) + .await? + .unwrap(); + let report_aggregation = tx + .get_report_aggregation( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + aggregation_job.aggregation_parameter(), + &report_id, + ) + .await? + .unwrap(); + let batch = tx.get_batch(task.id(), &batch_id, &()).await?.unwrap(); + Ok((aggregation_job, report_aggregation, batch)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch, got_batch); + } + + #[tokio::test] + async fn step_fixed_size_aggregation_job_init_two_rounds() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(Poplar1::new_sha3(1)); + + let task = TaskBuilder::new( + QueryType::FixedSize { + max_batch_size: 10, + batch_time_window_size: None, + }, + VdafInstance::Poplar1 { bits: 1 }, + Role::Leader, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); + + let transcript = run_vdaf( + vdaf.as_ref(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + + let agg_auth_token = task.primary_aggregator_auth_token(); + let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); + let report = generate_report::>( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + transcript.public_share.clone(), + Vec::new(), + &transcript.leader_input_share, + &transcript.helper_input_share, + ); + let batch_id = random(); + let aggregation_job_id = random(); + + let lease = ds + .run_tx(|tx| { + let (vdaf, task, report, aggregation_param) = ( + vdaf.clone(), + task.clone(), + report.clone(), + aggregation_param.clone(), + ); + Box::pin(async move { + tx.put_task(&task).await?; + tx.put_client_report(vdaf.borrow(), &report).await?; + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await?; + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) + .await?; + + tx.put_batch( + &Batch::>::new( + *task.id(), + batch_id, + aggregation_param.clone(), + BatchState::Open, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + ), + ) + .await?; + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await? + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. (first an error response, then a success) + // (This is fragile in that it expects the leader request to be deterministically encoded. + // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & + // verification -- but mockito does not expose this functionality at time of writing.) + let leader_request = AggregationJobInitializeReq::new( + aggregation_param.get_encoded(), + PartialBatchSelector::new_fixed_size(batch_id), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), + )]), + ); + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )])); + let mocked_aggregate_success = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header( + "DAP-Auth-Token", + str::from_utf8(agg_auth_token.as_ref()).unwrap(), + ) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + &noop_meter(), + 32, ); - let want_batch = Batch::::new( + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_success.assert_async().await; + + let want_aggregation_job = + AggregationJob::>::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(1), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Waiting( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), + ); + let want_batch = Batch::>::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Open, 1, Interval::from_time(report.metadata().time()).unwrap(), @@ -1771,7 +2342,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -1783,11 +2354,19 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? .unwrap(); - let batch = tx.get_batch(task.id(), &batch_id, &()).await?.unwrap(); + let batch = tx + .get_batch( + task.id(), + &batch_id, + aggregation_job.aggregation_parameter(), + ) + .await? + .unwrap(); Ok((aggregation_job, report_aggregation, batch)) }) }) @@ -1808,14 +2387,14 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock .now() @@ -1838,40 +2417,43 @@ mod tests { let report_metadata = ReportMetadata::new(random(), time); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate(&aggregation_param, [transcript.leader_output_share.clone()]) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; let (lease, want_collection_job) = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, aggregation_param, report, transcript) = ( vdaf.clone(), task.clone(), + aggregation_param.clone(), report.clone(), - leader_prep_state.clone(), - prep_msg.clone(), + transcript.clone(), ); Box::pin(async move { tx.put_task(&task).await?; @@ -1882,11 +2464,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -1894,46 +2476,64 @@ mod tests { AggregationJobRound::from(1), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Waiting( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), ), - ) + )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), active_batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::from_time(report.metadata().time()).unwrap(), )) .await?; - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), other_batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::EMPTY, )) .await?; - let collection_job = - CollectionJob::::new( - *task.id(), - random(), - collection_identifier, - (), - CollectionJobState::Start, - ); + let collection_job = CollectionJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( + *task.id(), + random(), + collection_identifier, + aggregation_param, + CollectionJobState::Start, + ); tx.put_collection_job(&collection_job).await?; let lease = tx @@ -1955,12 +2555,12 @@ mod tests { // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + transcript.leader_prepare_transitions[1].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), PrepareStepResult::Finished, )])); @@ -2022,25 +2622,27 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, AggregationJobRound::from(2), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Finished, - ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let batch_interval_start = report .metadata() .time() @@ -2049,11 +2651,11 @@ mod tests { let want_batch_aggregations = Vec::from([BatchAggregation::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), Interval::new(batch_interval_start, *task.time_precision()).unwrap(), - (), + aggregation_param.clone(), 0, BatchAggregationState::Aggregating, Some(leader_aggregate_share), @@ -2061,18 +2663,18 @@ mod tests { Interval::from_time(report.metadata().time()).unwrap(), ReportIdChecksum::for_report_id(report.metadata().id()), )]); - let want_active_batch = Batch::::new( + let want_active_batch = Batch::>::new( *task.id(), active_batch_identifier, - (), + aggregation_param.clone(), BatchState::Closed, 0, Interval::from_time(report.metadata().time()).unwrap(), ); - let want_other_batch = Batch::::new( + let want_other_batch = Batch::>::new( *task.id(), other_batch_identifier, - (), + aggregation_param.clone(), BatchState::Closing, 1, Interval::EMPTY, @@ -2087,14 +2689,16 @@ mod tests { got_collection_job, ) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let task = task.clone(); - let report_metadata = report.metadata().clone(); - let collection_job_id = *want_collection_job.id(); - + let (vdaf, task, report_metadata, aggregation_param, collection_job_id) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + *want_collection_job.id(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2106,6 +2710,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), report_metadata.id(), ) .await? @@ -2113,7 +2718,7 @@ mod tests { let batch_aggregations = TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -2127,16 +2732,16 @@ mod tests { *task.time_precision(), ) .unwrap(), - &(), + &aggregation_param, ) .await .unwrap(); let got_active_batch = tx - .get_batch(task.id(), &active_batch_identifier, &()) + .get_batch(task.id(), &active_batch_identifier, &aggregation_param) .await? .unwrap(); let got_other_batch = tx - .get_batch(task.id(), &other_batch_identifier, &()) + .get_batch(task.id(), &other_batch_identifier, &aggregation_param) .await? .unwrap(); let got_collection_job = tx @@ -2164,7 +2769,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, *agg.state(), agg.aggregate_share().cloned(), @@ -2192,17 +2797,17 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let report_metadata = ReportMetadata::new( random(), @@ -2213,40 +2818,43 @@ mod tests { ); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), transcript.public_share.clone(), Vec::new(), - transcript.input_shares.clone(), + &transcript.leader_input_share, + &transcript.helper_input_share, ); let batch_id = random(); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate(&aggregation_param, [transcript.leader_output_share.clone()]) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; let (lease, collection_job) = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, report, aggregation_param, transcript) = ( vdaf.clone(), task.clone(), report.clone(), - leader_prep_state.clone(), - prep_msg.clone(), + aggregation_param.clone(), + transcript.clone(), ); Box::pin(async move { tx.put_task(&task).await?; @@ -2255,11 +2863,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, FixedSize, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2267,35 +2875,44 @@ mod tests { AggregationJobRound::from(1), )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), - ), - ) - .await?; - tx.put_batch(&Batch::::new( + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( *task.id(), - batch_id, - (), - BatchState::Closing, - 1, - Interval::from_time(report.metadata().time()).unwrap(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Waiting( + transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), )) .await?; + tx.put_batch( + &Batch::>::new( + *task.id(), + batch_id, + aggregation_param.clone(), + BatchState::Closing, + 1, + Interval::from_time(report.metadata().time()).unwrap(), + ), + ) + .await?; + let collection_job = - CollectionJob::::new( + CollectionJob::>::new( *task.id(), random(), batch_id, - (), + aggregation_param, CollectionJobState::Start, ); tx.put_collection_job(&collection_job).await?; @@ -2319,12 +2936,12 @@ mod tests { // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + transcript.leader_prepare_transitions[1].message.clone(), )]), ); - let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( *report.metadata().id(), PrepareStepResult::Finished, )])); @@ -2385,42 +3002,46 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = AggregationJob::::new( + let want_aggregation_job = + AggregationJob::>::new( + *task.id(), + aggregation_job_id, + aggregation_param.clone(), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(2), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + let want_batch_aggregations = Vec::from([BatchAggregation::< + VERIFY_KEY_LENGTH, + FixedSize, + Poplar1, + >::new( *task.id(), - aggregation_job_id, - (), batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(2), - ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), + aggregation_param.clone(), 0, - None, - ReportAggregationState::Finished, - ); - let want_batch_aggregations = - Vec::from([ - BatchAggregation::::new( - *task.id(), - batch_id, - (), - 0, - BatchAggregationState::Aggregating, - Some(leader_aggregate_share), - 1, - Interval::from_time(report.metadata().time()).unwrap(), - ReportIdChecksum::for_report_id(report.metadata().id()), - ), - ]); - let want_batch = Batch::::new( + BatchAggregationState::Aggregating, + Some(leader_aggregate_share), + 1, + Interval::from_time(report.metadata().time()).unwrap(), + ReportIdChecksum::for_report_id(report.metadata().id()), + )]); + let want_batch = Batch::>::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Closed, 0, Interval::from_time(report.metadata().time()).unwrap(), @@ -2435,14 +3056,16 @@ mod tests { got_collection_job, ) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let task = task.clone(); - let report_metadata = report.metadata().clone(); - let collection_job_id = *want_collection_job.id(); - + let (vdaf, task, report_metadata, aggregation_param, collection_job_id) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + *want_collection_job.id(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2454,6 +3077,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), report_metadata.id(), ) .await? @@ -2461,11 +3085,14 @@ mod tests { let batch_aggregations = FixedSize::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, - >(tx, &task, &vdaf, &batch_id, &()) + >(tx, &task, &vdaf, &batch_id, &aggregation_param) .await?; - let batch = tx.get_batch(task.id(), &batch_id, &()).await?.unwrap(); + let batch = tx + .get_batch(task.id(), &batch_id, &aggregation_param) + .await? + .unwrap(); let collection_job = tx .get_collection_job(vdaf.as_ref(), &collection_job_id) .await? @@ -2489,7 +3116,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, *agg.state(), agg.aggregate_share().cloned(), @@ -2545,7 +3172,8 @@ mod tests { helper_hpke_keypair.config(), transcript.public_share, Vec::new(), - transcript.input_shares, + &transcript.leader_input_share, + &transcript.helper_input_share, ); let aggregation_job_id = random(); @@ -2647,6 +3275,7 @@ mod tests { &Role::Leader, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), &report_id, ) .await? @@ -2677,26 +3306,18 @@ mod tests { helper_hpke_config: &HpkeConfig, public_share: A::PublicShare, extensions: Vec, - input_shares: Vec, + leader_input_share: &A::InputShare, + helper_input_share: &A::InputShare, ) -> LeaderStoredReport where A: vdaf::Aggregator, A::InputShare: PartialEq, A::PublicShare: PartialEq, { - assert_eq!(input_shares.len(), 2); - let encrypted_helper_input_share = hpke::seal( helper_hpke_config, &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - &PlaintextInputShare::new( - Vec::new(), - input_shares - .get(Role::Helper.index().unwrap()) - .unwrap() - .get_encoded(), - ) - .get_encoded(), + &PlaintextInputShare::new(Vec::new(), helper_input_share.get_encoded()).get_encoded(), &InputShareAad::new(task_id, report_metadata.clone(), public_share.get_encoded()) .get_encoded(), ) @@ -2707,10 +3328,7 @@ mod tests { report_metadata, public_share, extensions, - input_shares - .get(Role::Leader.index().unwrap()) - .unwrap() - .clone(), + leader_input_share.clone(), encrypted_helper_input_share, ) } @@ -2730,7 +3348,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_helper_aggregator_endpoint(Url::parse(&server.url()).unwrap()) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); @@ -2752,7 +3370,8 @@ mod tests { helper_hpke_keypair.config(), transcript.public_share, Vec::new(), - transcript.input_shares, + &transcript.leader_input_share, + &transcript.helper_input_share, ); // Set up fixtures in the database. diff --git a/aggregator/src/aggregator/aggregation_job_writer.rs b/aggregator/src/aggregator/aggregation_job_writer.rs index 65ce1d387..ae2d1e41b 100644 --- a/aggregator/src/aggregator/aggregation_job_writer.rs +++ b/aggregator/src/aggregator/aggregation_job_writer.rs @@ -14,7 +14,7 @@ use janus_aggregator_core::{ task::Task, }; use janus_core::time::{Clock, IntervalExt}; -use janus_messages::{AggregationJobId, Interval, ReportId, ReportShareError}; +use janus_messages::{AggregationJobId, Interval, PrepareError, ReportId}; use prio::{codec::Encode, vdaf}; use std::{ borrow::Cow, @@ -260,7 +260,7 @@ impl, +) -> PrepareError { + let (error_desc, value, prepare_error) = match ping_pong_error { + PingPongError::VdafPrepareInit(_) => ( + "Couldn't helper_initialize report share".to_string(), + "prepare_init_failure".to_string(), + PrepareError::VdafPrepError, + ), + PingPongError::VdafPreparePreprocess(_) => ( + "Couldn't compute prepare message".to_string(), + "prepare_message_failure".to_string(), + PrepareError::VdafPrepError, + ), + PingPongError::VdafPrepareStep(_) => ( + "Prepare step failed".to_string(), + "prepare_step_failure".to_string(), + PrepareError::VdafPrepError, + ), + PingPongError::CodecPrepShare(_) => ( + format!("Couldn't decode {peer_role} prepare share"), + format!("{peer_role}_prep_share_decode_failure"), + PrepareError::UnrecognizedMessage, + ), + PingPongError::CodecPrepMessage(_) => ( + format!("Couldn't decode {peer_role} prepare message"), + format!("{peer_role}_prep_message_decode_failure"), + PrepareError::UnrecognizedMessage, + ), + ref error @ PingPongError::StateMismatch(_, _) => ( + format!("{error}"), + format!("{peer_role}_ping_pong_message_state_mismatch"), + // TODO(timg): is this the right error if state mismatch? + PrepareError::VdafPrepError, + ), + PingPongError::InternalError(desc) => ( + desc.to_string(), + "vdaf_ping_pong_internal_error".to_string(), + PrepareError::VdafPrepError, + ), + }; + + info!( + task_id = %task_id, + report_id = %report_id, + ?ping_pong_error, + error_desc, + ); + + aggregate_step_failure_counter.add(1, &[KeyValue::new("type", value)]); + + prepare_error +} diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index e89ccb2c0..c28390be7 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -640,7 +640,9 @@ pub mod test_util { mod tests { use crate::{ aggregator::{ - aggregate_init_tests::{put_aggregation_job, setup_aggregate_init_test}, + aggregate_init_tests::{ + put_aggregation_job, setup_aggregate_init_test, PrepareInitGenerator, + }, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, }, @@ -696,15 +698,18 @@ mod tests { AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, HpkeConfigList, InputShareAad, Interval, - PartialBatchSelector, PrepareStep, PrepareStepResult, Query, Report, ReportId, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + PartialBatchSelector, PrepareContinue, PrepareError, PrepareInit, PrepareResp, + PrepareStepResult, Query, Report, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, + Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, - field::Field64, + idpf::IdpfInput, + topology::ping_pong, vdaf::{ - prio3::{Prio3, Prio3Count}, - AggregateShare, Aggregator, OutputShare, + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + Aggregator, }, }; use rand::random; @@ -1506,55 +1511,21 @@ mod tests { let vdaf = dummy_vdaf::Vdaf::new(); let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); - - // report_share_0 is a "happy path" report. - let report_metadata_0 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_0.id(), - &(), - ); - let report_share_0 = generate_helper_report_share::( - *task.id(), - report_metadata_0, - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], + let measurement = 1; + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + vdaf.clone(), + dummy_vdaf::AggregationParam(0), ); + // prepare_init_0 is a "happy path" report. + let (prepare_init_0, transcript_0) = prep_init_generator.next(&measurement); + // report_share_1 fails decryption. - let report_metadata_1 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_1.id(), - &(), - ); - let report_share_1 = generate_helper_report_share::( - *task.id(), - report_metadata_1.clone(), - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); - let encrypted_input_share = report_share_1.encrypted_input_share(); + let (prepare_init_1, transcript_1) = prep_init_generator.next(&measurement); + + let encrypted_input_share = prepare_init_1.report_share().encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); corrupted_payload[0] ^= 0xFF; let corrupted_input_share = HpkeCiphertext::new( @@ -1562,53 +1533,39 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_1 = ReportShare::new( - report_metadata_1, - encoded_public_share.clone(), - corrupted_input_share, - ); - // report_share_2 fails decoding due to an issue with the input share. - let report_metadata_2 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_2.id(), - &(), + let prepare_init_1 = PrepareInit::new( + ReportShare::new( + prepare_init_1.report_share().metadata().clone(), + transcript_1.public_share.get_encoded(), + corrupted_input_share, + ), + prepare_init_1.message().clone(), ); - let mut input_share_bytes = transcript.input_shares[1].get_encoded(); + + // prepare_init_2 fails decoding due to an issue with the input share. + let (prepare_init_2, transcript_2) = prep_init_generator.next(&measurement); + + let mut input_share_bytes = transcript_2.helper_input_share.get_encoded(); input_share_bytes.push(0); // can no longer be decoded. let report_share_2 = generate_helper_report_share_for_plaintext( - report_metadata_2.clone(), + prepare_init_2.report_share().metadata().clone(), hpke_key.config(), - encoded_public_share.clone(), + transcript_2.public_share.get_encoded(), &input_share_bytes, - &InputShareAad::new(*task.id(), report_metadata_2, encoded_public_share).get_encoded(), + &InputShareAad::new( + *task.id(), + prepare_init_2.report_share().metadata().clone(), + transcript_2.public_share.get_encoded(), + ) + .get_encoded(), ); - // report_share_3 has an unknown HPKE config ID. - let report_metadata_3 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_3.id(), - &(), - ); + let prepare_init_2 = PrepareInit::new(report_share_2, prepare_init_2.message().clone()); + + // prepare_init_3 has an unknown HPKE config ID. + let (prepare_init_3, transcript_3) = prep_init_generator.next(&measurement); + let wrong_hpke_config = loop { let hpke_config = generate_test_hpke_config_and_private_key().config().clone(); if task.hpke_keys().contains_key(hpke_config.id()) { @@ -1616,41 +1573,23 @@ mod tests { } break hpke_config; }; + let report_share_3 = generate_helper_report_share::( *task.id(), - report_metadata_3, + prepare_init_3.report_share().metadata().clone(), &wrong_hpke_config, - &transcript.public_share, + &transcript_3.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_3.helper_input_share, ); - // report_share_4 has already been aggregated in another aggregation job, with the same + let prepare_init_3 = PrepareInit::new(report_share_3, prepare_init_3.message().clone()); + + // prepare_init_4 has already been aggregated in another aggregation job, with the same // aggregation parameter. - let report_metadata_4 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_4.id(), - &(), - ); - let report_share_4 = generate_helper_report_share::( - *task.id(), - report_metadata_4, - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); + let (prepare_init_4, _) = prep_init_generator.next(&measurement); - // report_share_5 falls into a batch that has already been collected. + // prepare_init_5 falls into a batch that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( task.time_precision().as_seconds() / 2, )); @@ -1661,23 +1600,28 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_5 = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_5.id(), - &(), + &measurement, ); let report_share_5 = generate_helper_report_share::( *task.id(), report_metadata_5, hpke_key.config(), - &transcript.public_share, + &transcript_5.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_5.helper_input_share, + ); + + let prepare_init_5 = PrepareInit::new( + report_share_5, + transcript_5.leader_prepare_transitions[0].message.clone(), ); - // report_share_6 fails decoding due to an issue with the public share. + // prepare_init_6 fails decoding due to an issue with the public share. let public_share_6 = Vec::from([0]); let report_metadata_6 = ReportMetadata::new( random(), @@ -1686,22 +1630,27 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_6 = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_6.id(), - &(), + &measurement, ); let report_share_6 = generate_helper_report_share_for_plaintext( report_metadata_6.clone(), hpke_key.config(), public_share_6.clone(), - &transcript.input_shares[1].get_encoded(), + &transcript_6.helper_input_share.get_encoded(), &InputShareAad::new(*task.id(), report_metadata_6, public_share_6).get_encoded(), ); - // report_share_7 fails due to having repeated extensions. + let prepare_init_6 = PrepareInit::new( + report_share_6, + transcript_6.leader_prepare_transitions[0].message.clone(), + ); + + // prepare_init_7 fails due to having repeated extensions. let report_metadata_7 = ReportMetadata::new( random(), clock @@ -1709,56 +1658,40 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_7 = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_7.id(), - &(), + &measurement, ); let report_share_7 = generate_helper_report_share::( *task.id(), report_metadata_7, hpke_key.config(), - &transcript.public_share, + &transcript_7.public_share, Vec::from([ Extension::new(ExtensionType::Tbd, Vec::new()), Extension::new(ExtensionType::Tbd, Vec::new()), ]), - &transcript.input_shares[0], + &transcript_7.helper_input_share, ); - // report_share_8 has already been aggregated in another aggregation job, with a different - // aggregation parameter. - let report_metadata_8 = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(1), - report_metadata_8.id(), - &(), - ); - let report_share_8 = generate_helper_report_share::( - *task.id(), - report_metadata_8, - hpke_key.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], + let prepare_init_7 = PrepareInit::new( + report_share_7, + transcript_7.leader_prepare_transitions[0].message.clone(), ); + // prepare_init_8 has already been aggregated in another aggregation job, with a different + // aggregation parameter. + let (prepare_init_8, transcript_8) = prep_init_generator.next(&measurement); + let (conflicting_aggregation_job, non_conflicting_aggregation_job) = datastore .run_tx(|tx| { let task = task.clone(); - let report_share_4 = report_share_4.clone(); - let report_share_5 = report_share_5.clone(); - let report_share_8 = report_share_8.clone(); + let report_share_4 = prepare_init_4.report_share().clone(); + let report_share_5 = prepare_init_5.report_share().clone(); + let report_share_8 = prepare_init_8.report_share().clone(); Box::pin(async move { tx.put_task(&task).await?; @@ -1856,15 +1789,15 @@ mod tests { dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), Vec::from([ - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), - report_share_6.clone(), - report_share_7.clone(), - report_share_8.clone(), + prepare_init_0.clone(), + prepare_init_1.clone(), + prepare_init_2.clone(), + prepare_init_3.clone(), + prepare_init_4.clone(), + prepare_init_5.clone(), + prepare_init_6.clone(), + prepare_init_7.clone(), + prepare_init_8.clone(), ]), ); @@ -1881,64 +1814,95 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 9); + assert_eq!(aggregate_resp.prepare_resps().len(), 9); - let prepare_step_0 = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step_0.report_id(), report_share_0.metadata().id()); - assert_matches!(prepare_step_0.result(), &PrepareStepResult::Continued(..)); + let prepare_step_0 = aggregate_resp.prepare_resps().get(0).unwrap(); + assert_eq!( + prepare_step_0.report_id(), + prepare_init_0.report_share().metadata().id() + ); + assert_matches!(prepare_step_0.result(), PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_0.helper_prepare_transitions[0].message); + }); - let prepare_step_1 = aggregate_resp.prepare_steps().get(1).unwrap(); - assert_eq!(prepare_step_1.report_id(), report_share_1.metadata().id()); + let prepare_step_1 = aggregate_resp.prepare_resps().get(1).unwrap(); + assert_eq!( + prepare_step_1.report_id(), + prepare_init_1.report_share().metadata().id() + ); assert_matches!( prepare_step_1.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + &PrepareStepResult::Reject(PrepareError::HpkeDecryptError) ); - let prepare_step_2 = aggregate_resp.prepare_steps().get(2).unwrap(); - assert_eq!(prepare_step_2.report_id(), report_share_2.metadata().id()); + let prepare_step_2 = aggregate_resp.prepare_resps().get(2).unwrap(); + assert_eq!( + prepare_step_2.report_id(), + prepare_init_2.report_share().metadata().id() + ); assert_matches!( prepare_step_2.result(), - &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage) + &PrepareStepResult::Reject(PrepareError::UnrecognizedMessage) ); - let prepare_step_3 = aggregate_resp.prepare_steps().get(3).unwrap(); - assert_eq!(prepare_step_3.report_id(), report_share_3.metadata().id()); + let prepare_step_3 = aggregate_resp.prepare_resps().get(3).unwrap(); + assert_eq!( + prepare_step_3.report_id(), + prepare_init_3.report_share().metadata().id() + ); assert_matches!( prepare_step_3.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeUnknownConfigId) + &PrepareStepResult::Reject(PrepareError::HpkeUnknownConfigId) ); - let prepare_step_4 = aggregate_resp.prepare_steps().get(4).unwrap(); - assert_eq!(prepare_step_4.report_id(), report_share_4.metadata().id()); + let prepare_step_4 = aggregate_resp.prepare_resps().get(4).unwrap(); + assert_eq!( + prepare_step_4.report_id(), + prepare_init_4.report_share().metadata().id() + ); assert_eq!( prepare_step_4.result(), - &PrepareStepResult::Failed(ReportShareError::ReportReplayed) + &PrepareStepResult::Reject(PrepareError::ReportReplayed) ); - let prepare_step_5 = aggregate_resp.prepare_steps().get(5).unwrap(); - assert_eq!(prepare_step_5.report_id(), report_share_5.metadata().id()); + let prepare_step_5 = aggregate_resp.prepare_resps().get(5).unwrap(); + assert_eq!( + prepare_step_5.report_id(), + prepare_init_5.report_share().metadata().id() + ); assert_eq!( prepare_step_5.result(), - &PrepareStepResult::Failed(ReportShareError::BatchCollected) + &PrepareStepResult::Reject(PrepareError::BatchCollected) ); - let prepare_step_6 = aggregate_resp.prepare_steps().get(6).unwrap(); - assert_eq!(prepare_step_6.report_id(), report_share_6.metadata().id()); + let prepare_step_6 = aggregate_resp.prepare_resps().get(6).unwrap(); + assert_eq!( + prepare_step_6.report_id(), + prepare_init_6.report_share().metadata().id() + ); assert_eq!( prepare_step_6.result(), - &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), + &PrepareStepResult::Reject(PrepareError::UnrecognizedMessage), ); - let prepare_step_7 = aggregate_resp.prepare_steps().get(7).unwrap(); - assert_eq!(prepare_step_7.report_id(), report_share_7.metadata().id()); + let prepare_step_7 = aggregate_resp.prepare_resps().get(7).unwrap(); + assert_eq!( + prepare_step_7.report_id(), + prepare_init_7.report_share().metadata().id() + ); assert_eq!( prepare_step_7.result(), - &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), + &PrepareStepResult::Reject(PrepareError::UnrecognizedMessage), ); - let prepare_step_8 = aggregate_resp.prepare_steps().get(8).unwrap(); - assert_eq!(prepare_step_8.report_id(), report_share_8.metadata().id()); - assert_matches!(prepare_step_8.result(), &PrepareStepResult::Continued(..)); + let prepare_step_8 = aggregate_resp.prepare_resps().get(8).unwrap(); + assert_eq!( + prepare_step_8.report_id(), + prepare_init_8.report_share().metadata().id() + ); + assert_matches!(prepare_step_8.result(), PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_8.helper_prepare_transitions[0].message); + }); // Check aggregation job in datastore. let aggregation_jobs = datastore @@ -1968,7 +1932,7 @@ mod tests { } else if aggregation_job.task_id().eq(task.id()) && aggregation_job.id().eq(&aggregation_job_id) && aggregation_job.partial_batch_identifier().eq(&()) - && aggregation_job.state().eq(&AggregationJobState::InProgress) + && aggregation_job.state().eq(&AggregationJobState::Finished) { saw_new_aggregation_job = true; } @@ -1988,6 +1952,10 @@ mod tests { let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); datastore.put_task(&task).await.unwrap(); + let vdaf = dummy_vdaf::Vdaf::new(); + let aggregation_param = dummy_vdaf::AggregationParam(0); + let prep_init_generator = + PrepareInitGenerator::new(clock.clone(), task.clone(), vdaf.clone(), aggregation_param); // Insert some global HPKE keys. // Same ID as the task to test having both keys to choose from. @@ -2028,59 +1996,20 @@ mod tests { .await .unwrap(); - let vdaf = dummy_vdaf::Vdaf::new(); let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); // This report was encrypted with a global HPKE config that has the same config // ID as the task's HPKE config. - let report_metadata_same_id = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_same_id.id(), - &(), - ); - let report_share_same_id = generate_helper_report_share::( - *task.id(), - report_metadata_same_id, - global_hpke_keypair_same_id.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); + let (prepare_init_same_id, transcript_same_id) = prep_init_generator.next(&1); // This report was encrypted with a global HPKE config that has the same config // ID as the task's HPKE config, but will fail to decrypt. - let report_metadata_same_id_corrupted = ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &dummy_vdaf::AggregationParam(0), - report_metadata_same_id_corrupted.id(), - &(), - ); - let report_share_same_id_corrupted = generate_helper_report_share::( - *task.id(), - report_metadata_same_id_corrupted.clone(), - global_hpke_keypair_same_id.config(), - &transcript.public_share, - Vec::new(), - &transcript.input_shares[1], - ); - let encrypted_input_share = report_share_same_id_corrupted.encrypted_input_share(); + let (prepare_init_same_id_corrupted, transcript_same_id_corrupted) = + prep_init_generator.next(&1); + + let encrypted_input_share = prepare_init_same_id_corrupted + .report_share() + .encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); corrupted_payload[0] ^= 0xFF; let corrupted_input_share = HpkeCiphertext::new( @@ -2088,11 +2017,17 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_same_id_corrupted = ReportShare::new( - report_metadata_same_id_corrupted, - encoded_public_share.clone(), - corrupted_input_share, + + let prepare_init_same_id_corrupted = PrepareInit::new( + ReportShare::new( + prepare_init_same_id_corrupted + .report_share() + .metadata() + .clone(), + transcript_same_id_corrupted.public_share.get_encoded(), + corrupted_input_share, + ), + prepare_init_same_id_corrupted.message().clone(), ); // This report was encrypted with a global HPKE config that doesn't collide @@ -2104,20 +2039,27 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_different_id = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_different_id.id(), - &(), + &1, ); let report_share_different_id = generate_helper_report_share::( *task.id(), report_metadata_different_id, global_hpke_keypair_different_id.config(), - &transcript.public_share, + &transcript_different_id.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_different_id.helper_input_share, + ); + + let prepare_init_different_id = PrepareInit::new( + report_share_different_id, + transcript_different_id.leader_prepare_transitions[0] + .message + .clone(), ); // This report was encrypted with a global HPKE config that doesn't collide @@ -2129,20 +2071,20 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let transcript = run_vdaf( + let transcript_different_id_corrupted = run_vdaf( &vdaf, verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_different_id_corrupted.id(), - &(), + &1, ); let report_share_different_id_corrupted = generate_helper_report_share::( *task.id(), report_metadata_different_id_corrupted.clone(), global_hpke_keypair_different_id.config(), - &transcript.public_share, + &transcript_different_id_corrupted.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript_different_id_corrupted.helper_input_share, ); let encrypted_input_share = report_share_different_id_corrupted.encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); @@ -2152,11 +2094,17 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_different_id_corrupted = ReportShare::new( - report_metadata_different_id_corrupted, - encoded_public_share.clone(), - corrupted_input_share, + let encoded_public_share = transcript_different_id_corrupted.public_share.get_encoded(); + + let prepare_init_different_id_corrupted = PrepareInit::new( + ReportShare::new( + report_metadata_different_id_corrupted, + encoded_public_share.clone(), + corrupted_input_share, + ), + transcript_different_id_corrupted.leader_prepare_transitions[0] + .message + .clone(), ); let aggregation_job_id: AggregationJobId = random(); @@ -2164,10 +2112,10 @@ mod tests { dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), Vec::from([ - report_share_same_id.clone(), - report_share_different_id.clone(), - report_share_same_id_corrupted.clone(), - report_share_different_id_corrupted.clone(), + prepare_init_same_id.clone(), + prepare_init_different_id.clone(), + prepare_init_same_id_corrupted.clone(), + prepare_init_different_id_corrupted.clone(), ]), ); @@ -2177,46 +2125,53 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 4); + assert_eq!(aggregate_resp.prepare_resps().len(), 4); - let prepare_step_same_id = aggregate_resp.prepare_steps().get(0).unwrap(); + let prepare_step_same_id = aggregate_resp.prepare_resps().get(0).unwrap(); assert_eq!( prepare_step_same_id.report_id(), - report_share_same_id.metadata().id() - ); - assert_matches!( - prepare_step_same_id.result(), - &PrepareStepResult::Continued(..) + prepare_init_same_id.report_share().metadata().id() ); + assert_matches!(prepare_step_same_id.result(), PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_same_id.helper_prepare_transitions[0].message); + }); - let prepare_step_different_id = aggregate_resp.prepare_steps().get(1).unwrap(); + let prepare_step_different_id = aggregate_resp.prepare_resps().get(1).unwrap(); assert_eq!( prepare_step_different_id.report_id(), - report_share_different_id.metadata().id() + prepare_init_different_id.report_share().metadata().id() ); assert_matches!( prepare_step_different_id.result(), - &PrepareStepResult::Continued(..) + PrepareStepResult::Continue { message } => { + assert_eq!(message, &transcript_different_id.helper_prepare_transitions[0].message); + } ); - let prepare_step_same_id_corrupted = aggregate_resp.prepare_steps().get(2).unwrap(); + let prepare_step_same_id_corrupted = aggregate_resp.prepare_resps().get(2).unwrap(); assert_eq!( prepare_step_same_id_corrupted.report_id(), - report_share_same_id_corrupted.metadata().id() + prepare_init_same_id_corrupted + .report_share() + .metadata() + .id(), ); assert_matches!( prepare_step_same_id_corrupted.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + &PrepareStepResult::Reject(PrepareError::HpkeDecryptError) ); - let prepare_step_different_id_corrupted = aggregate_resp.prepare_steps().get(3).unwrap(); + let prepare_step_different_id_corrupted = aggregate_resp.prepare_resps().get(3).unwrap(); assert_eq!( prepare_step_different_id_corrupted.report_id(), - report_share_different_id_corrupted.metadata().id() + prepare_init_different_id_corrupted + .report_share() + .metadata() + .id() ); assert_matches!( prepare_step_different_id_corrupted.result(), - &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + &PrepareStepResult::Reject(PrepareError::HpkeDecryptError) ); } @@ -2230,24 +2185,23 @@ mod tests { // This report has the same ID as the previous one, but a different timestamp. let mutated_timestamp_report_metadata = ReportMetadata::new( - *test_case.report_shares[0].metadata().id(), + *test_case.prepare_inits[0].report_share().metadata().id(), test_case .clock .now() .add(test_case.task.time_precision()) .unwrap(), ); - let mutated_timestamp_report_share = test_case - .report_share_generator - .next_with_metadata(mutated_timestamp_report_metadata) - .0; + let (mutated_timestamp_prepare_init, _) = test_case + .prepare_init_generator + .next_with_metadata(mutated_timestamp_report_metadata, &0); // Send another aggregate job re-using the same report ID but with a different timestamp. It // should be flagged as a replay. let request = AggregationJobInitializeReq::new( other_aggregation_parameter.get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([mutated_timestamp_report_share.clone()]), + Vec::from([mutated_timestamp_prepare_init.clone()]), ); let mut test_conn = @@ -2255,16 +2209,19 @@ mod tests { assert_eq!(test_conn.status(), Some(Status::Ok)); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregate_resp.prepare_steps().len(), 1); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); assert_eq!( prepare_step.report_id(), - mutated_timestamp_report_share.metadata().id() + mutated_timestamp_prepare_init + .report_share() + .metadata() + .id(), ); assert_matches!( prepare_step.result(), - &PrepareStepResult::Failed(ReportShareError::ReportReplayed) + &PrepareStepResult::Reject(PrepareError::ReportReplayed) ); // The attempt to mutate the report share timestamp should not cause any change in the @@ -2282,8 +2239,14 @@ mod tests { .await .unwrap(); assert_eq!(client_reports.len(), 2); - assert_eq!(&client_reports[0], test_case.report_shares[0].metadata()); - assert_eq!(&client_reports[1], test_case.report_shares[1].metadata()); + assert_eq!( + &client_reports[0], + test_case.prepare_inits[0].report_share().metadata() + ); + assert_eq!( + &client_reports[1], + test_case.prepare_inits[1].report_share().metadata() + ); } #[tokio::test] @@ -2296,27 +2259,20 @@ mod tests { Role::Helper, ) .build(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + dummy_vdaf::AggregationParam(0), + ); + datastore.put_task(&task).await.unwrap(); - let hpke_key = task.current_hpke_key(); - let report_share = generate_helper_report_share::( - *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), - Vec::new(), - &dummy_vdaf::InputShare::default(), - ); + let (prepare_init, _) = prep_init_generator.next(&0); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([prepare_init.clone()]), ); // Send request, and parse response. @@ -2331,13 +2287,16 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 1); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); + assert_eq!( + prepare_step.report_id(), + prepare_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), - &PrepareStepResult::Failed(ReportShareError::VdafPrepError) + &PrepareStepResult::Reject(PrepareError::VdafPrepError) ); } @@ -2347,31 +2306,24 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::FakeFailsPrepInit, + VdafInstance::FakeFailsPrepStep, Role::Helper, ) .build(); - let hpke_key = task.current_hpke_key(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + dummy_vdaf::AggregationParam(0), + ); + datastore.put_task(&task).await.unwrap(); - let report_share = generate_helper_report_share::( - *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), - Vec::new(), - &dummy_vdaf::InputShare::default(), - ); + let (prepare_init, _) = prep_init_generator.next(&0); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([prepare_init.clone()]), ); let aggregation_job_id: AggregationJobId = random(); @@ -2385,46 +2337,40 @@ mod tests { let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 1); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); + assert_eq!( + prepare_step.report_id(), + prepare_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), - &PrepareStepResult::Failed(ReportShareError::VdafPrepError) + &PrepareStepResult::Reject(PrepareError::VdafPrepError) ); } #[tokio::test] async fn aggregate_init_duplicated_report_id() { - let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; + let (clock, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; + + let task = + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + dummy_vdaf::AggregationParam(0), + ); - let task = TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::FakeFailsPrepInit, - Role::Helper, - ) - .build(); datastore.put_task(&task).await.unwrap(); - let report_share = ReportShare::new( - ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - Vec::from("PUBLIC"), - HpkeCiphertext::new( - // bogus, but we never get far enough to notice - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ); + let (prepare_init, _) = prep_init_generator.next(&0); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone(), report_share]), + Vec::from([prepare_init.clone(), prepare_init]), ); let aggregation_job_id: AggregationJobId = random(); @@ -2451,14 +2397,17 @@ mod tests { let aggregation_job_id = random(); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::::new(1)); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(vec![measurement.clone()]).unwrap(); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( @@ -2471,19 +2420,19 @@ mod tests { let transcript_0 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, _) = transcript_0.helper_prep_state(0); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let ping_pong_transition_0 = &transcript_0.helper_prepare_transitions[0].transition; + let leader_prep_message_0 = &transcript_0.leader_prepare_transitions[1].message; + let report_share_0 = generate_helper_report_share::>( *task.id(), report_metadata_0.clone(), hpke_key.config(), &transcript_0.public_share, Vec::new(), - &transcript_0.input_shares[1], + &transcript_0.helper_input_share, ); // report_share_1 is omitted by the leader's request. @@ -2497,19 +2446,19 @@ mod tests { let transcript_1 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - let (prep_state_1, _) = transcript_1.helper_prep_state(0); - let report_share_1 = generate_helper_report_share::( + let ping_pong_transition_1 = &transcript_1.helper_prepare_transitions[0].transition; + let report_share_1 = generate_helper_report_share::>( *task.id(), report_metadata_1.clone(), hpke_key.config(), &transcript_1.public_share, Vec::new(), - &transcript_1.input_shares[1], + &transcript_1.helper_input_share, ); // report_share_2 falls into a batch that has already been collected. @@ -2526,19 +2475,19 @@ mod tests { let transcript_2 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, _) = transcript_2.helper_prep_state(0); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let ping_pong_transition_2 = &transcript_2.helper_prepare_transitions[0].transition; + let leader_prep_message_2 = &transcript_2.leader_prepare_transitions[1].message; + let report_share_2 = generate_helper_report_share::>( *task.id(), report_metadata_2.clone(), hpke_key.config(), &transcript_2.public_share, Vec::new(), - &transcript_2.input_shares[1], + &transcript_2.helper_input_share, ); datastore @@ -2549,16 +2498,18 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_state_0, prep_state_1, prep_state_2) = ( - prep_state_0.clone(), - prep_state_1.clone(), - prep_state_2.clone(), + let (ping_pong_transition_0, ping_pong_transition_1, ping_pong_transition_2) = ( + ping_pong_transition_0.clone(), + ping_pong_transition_1.clone(), + ping_pong_transition_2.clone(), ); let (report_metadata_0, report_metadata_1, report_metadata_2) = ( report_metadata_0.clone(), report_metadata_1.clone(), report_metadata_2.clone(), ); + let aggregation_param = aggregation_param.clone(); + let helper_aggregate_share = transcript_0.helper_aggregate_share.clone(); Box::pin(async move { tx.put_task(&task).await?; @@ -2570,11 +2521,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2583,7 +2534,7 @@ mod tests { )) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2591,11 +2542,11 @@ mod tests { *report_metadata_0.time(), 0, None, - ReportAggregationState::Waiting(prep_state_0, None), + ReportAggregationState::Waiting(ping_pong_transition_0), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2603,11 +2554,11 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Waiting(prep_state_1, None), + ReportAggregationState::Waiting(ping_pong_transition_1), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, @@ -2615,12 +2566,12 @@ mod tests { *report_metadata_2.time(), 2, None, - ReportAggregationState::Waiting(prep_state_2, None), + ReportAggregationState::Waiting(ping_pong_transition_2), ), ) .await?; - tx.put_aggregate_share_job::( + tx.put_aggregate_share_job::>( &AggregateShareJob::new( *task.id(), Interval::new( @@ -2628,8 +2579,8 @@ mod tests { *task.time_precision(), ) .unwrap(), - (), - AggregateShare::from(OutputShare::from(Vec::from([Field64::from(7)]))), + aggregation_param.clone(), + helper_aggregate_share, 0, ReportIdChecksum::default(), ), @@ -2643,14 +2594,8 @@ mod tests { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([ - PrepareStep::new( - *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), - ), - PrepareStep::new( - *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), - ), + PrepareContinue::new(*report_metadata_0.id(), leader_prep_message_0.clone()), + PrepareContinue::new(*report_metadata_2.id(), leader_prep_message_2.clone()), ]), ); @@ -2662,10 +2607,10 @@ mod tests { assert_eq!( aggregate_resp, AggregationJobResp::new(Vec::from([ - PrepareStep::new(*report_metadata_0.id(), PrepareStepResult::Finished), - PrepareStep::new( + PrepareResp::new(*report_metadata_0.id(), PrepareStepResult::Finished), + PrepareResp::new( *report_metadata_2.id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), + PrepareStepResult::Reject(PrepareError::BatchCollected), ) ])) ); @@ -2673,10 +2618,11 @@ mod tests { // Validate datastore. let (aggregation_job, report_aggregations) = datastore .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2689,6 +2635,7 @@ mod tests { &Role::Helper, task.id(), &aggregation_job_id, + &aggregation_param, ) .await .unwrap(); @@ -2703,7 +2650,7 @@ mod tests { AggregationJob::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2721,7 +2668,7 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, - Some(PrepareStep::new( + Some(PrepareResp::new( *report_metadata_0.id(), PrepareStepResult::Finished )), @@ -2734,7 +2681,7 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Failed(ReportShareError::ReportDropped), + ReportAggregationState::Failed(PrepareError::ReportDropped), ), ReportAggregation::new( *task.id(), @@ -2742,11 +2689,11 @@ mod tests { *report_metadata_2.id(), *report_metadata_2.time(), 2, - Some(PrepareStep::new( + Some(PrepareResp::new( *report_metadata_2.id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected) + PrepareStepResult::Reject(PrepareError::BatchCollected) )), - ReportAggregationState::Failed(ReportShareError::BatchCollected), + ReportAggregationState::Failed(PrepareError::BatchCollected), ) ]) ); @@ -2758,7 +2705,7 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); @@ -2772,9 +2719,12 @@ mod tests { .unwrap(), ); - let vdaf = Prio3::new_count(2).unwrap(); + let vdaf = Poplar1::new(1); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); + let measurement = IdpfInput::from_bools(&[true]); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(vec![measurement.clone()]).unwrap(); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( @@ -2787,20 +2737,19 @@ mod tests { let transcript_0 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, _) = transcript_0.helper_prep_state(0); - let out_share_0 = transcript_0.output_share(Role::Helper); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let ping_pong_transition_0 = &transcript_0.helper_prepare_transitions[0].transition; + let ping_pong_leader_message_0 = &transcript_0.leader_prepare_transitions[1].message; + let report_share_0 = generate_helper_report_share::>( *task.id(), report_metadata_0.clone(), hpke_key.config(), &transcript_0.public_share, Vec::new(), - &transcript_0.input_shares[1], + &transcript_0.helper_input_share, ); // report_share_1 is another "happy path" report to exercise in-memory accumulation of @@ -2815,20 +2764,19 @@ mod tests { let transcript_1 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - let (prep_state_1, _) = transcript_1.helper_prep_state(0); - let out_share_1 = transcript_1.output_share(Role::Helper); - let prep_msg_1 = transcript_1.prepare_messages[0].clone(); - let report_share_1 = generate_helper_report_share::( + let ping_pong_transition_1 = &transcript_1.helper_prepare_transitions[0].transition; + let ping_pong_leader_message_1 = &transcript_1.leader_prepare_transitions[1].message; + let report_share_1 = generate_helper_report_share::>( *task.id(), report_metadata_1.clone(), hpke_key.config(), &transcript_1.public_share, Vec::new(), - &transcript_1.input_shares[1], + &transcript_1.helper_input_share, ); // report_share_2 aggregates successfully, but into a distinct batch aggregation which has @@ -2843,19 +2791,19 @@ mod tests { let transcript_2 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, _) = transcript_2.helper_prep_state(0); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let ping_pong_transition_2 = &transcript_2.helper_prepare_transitions[0].transition; + let ping_pong_leader_message_2 = &transcript_2.leader_prepare_transitions[1].message; + let report_share_2 = generate_helper_report_share::>( *task.id(), report_metadata_2.clone(), hpke_key.config(), &transcript_2.public_share, Vec::new(), - &transcript_2.input_shares[1], + &transcript_2.helper_input_share, ); let first_batch_identifier = Interval::new( @@ -2875,11 +2823,11 @@ mod tests { ) .unwrap(); let second_batch_want_batch_aggregations = - empty_batch_aggregations::( + empty_batch_aggregations::>( &task, BATCH_AGGREGATION_SHARD_COUNT, &second_batch_identifier, - &(), + &aggregation_param, &[], ); @@ -2891,16 +2839,17 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_state_0, prep_state_1, prep_state_2) = ( - prep_state_0.clone(), - prep_state_1.clone(), - prep_state_2.clone(), + let (ping_pong_transition_0, ping_pong_transition_1, ping_pong_transition_2) = ( + ping_pong_transition_0.clone(), + ping_pong_transition_1.clone(), + ping_pong_transition_2.clone(), ); let (report_metadata_0, report_metadata_1, report_metadata_2) = ( report_metadata_0.clone(), report_metadata_1.clone(), report_metadata_2.clone(), ); + let aggregation_param = aggregation_param.clone(); let second_batch_want_batch_aggregations = second_batch_want_batch_aggregations.clone(); @@ -2914,11 +2863,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2927,48 +2876,55 @@ mod tests { )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_0.id(), - *report_metadata_0.time(), - 0, - None, - ReportAggregationState::Waiting(prep_state_0, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_0.id(), + *report_metadata_0.time(), + 0, + None, + ReportAggregationState::Waiting(ping_pong_transition_0), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_1.id(), - *report_metadata_1.time(), - 1, - None, - ReportAggregationState::Waiting(prep_state_1, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_1.id(), + *report_metadata_1.time(), + 1, + None, + ReportAggregationState::Waiting(ping_pong_transition_1), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_0, - *report_metadata_2.id(), - *report_metadata_2.time(), - 2, - None, - ReportAggregationState::Waiting(prep_state_2, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_0, + *report_metadata_2.id(), + *report_metadata_2.time(), + 2, + None, + ReportAggregationState::Waiting(ping_pong_transition_2), + )) .await?; for batch_identifier in [first_batch_identifier, second_batch_identifier] { - tx.put_batch(&Batch::::new( + tx.put_batch(&Batch::< + VERIFY_KEY_LENGTH, + TimeInterval, + Poplar1, + >::new( *task.id(), batch_identifier, - (), + aggregation_param.clone(), BatchState::Closed, 0, batch_identifier, @@ -2994,18 +2950,9 @@ mod tests { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([ - PrepareStep::new( - *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), - ), - PrepareStep::new( - *report_metadata_1.id(), - PrepareStepResult::Continued(prep_msg_1.get_encoded()), - ), - PrepareStep::new( - *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), - ), + PrepareContinue::new(*report_metadata_0.id(), ping_pong_leader_message_0.clone()), + PrepareContinue::new(*report_metadata_1.id(), ping_pong_leader_message_1.clone()), + PrepareContinue::new(*report_metadata_2.id(), ping_pong_leader_message_2.clone()), ]), ); @@ -3016,12 +2963,16 @@ mod tests { // Map the batch aggregation ordinal value to 0, as it may vary due to sharding. let first_batch_got_batch_aggregations: Vec<_> = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3035,7 +2986,7 @@ mod tests { *task.time_precision(), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3044,10 +2995,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, BatchAggregationState::Aggregating, agg.aggregate_share().cloned(), @@ -3059,7 +3010,13 @@ mod tests { .collect(); let aggregate_share = vdaf - .aggregate(&(), [out_share_0.clone(), out_share_1.clone()]) + .aggregate( + &aggregation_param, + [ + transcript_0.helper_output_share.clone(), + transcript_1.helper_output_share.clone(), + ], + ) .unwrap(); let checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) .updated_with(report_metadata_1.id()); @@ -3076,7 +3033,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, BatchAggregationState::Aggregating, Some(aggregate_share), @@ -3088,12 +3045,16 @@ mod tests { let second_batch_got_batch_aggregations = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_2) = - (task.clone(), vdaf.clone(), report_metadata_2.clone()); + let (task, vdaf, report_metadata_2, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_2.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3107,7 +3068,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3132,20 +3093,19 @@ mod tests { let transcript_3 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_3.id(), - &0, + &measurement, ); - let (prep_state_3, _) = transcript_3.helper_prep_state(0); - let out_share_3 = transcript_3.output_share(Role::Helper); - let prep_msg_3 = transcript_3.prepare_messages[0].clone(); - let report_share_3 = generate_helper_report_share::( + let ping_pong_transition_3 = &transcript_3.helper_prepare_transitions[0].transition; + let ping_pong_leader_message_3 = &transcript_3.leader_prepare_transitions[1].message; + let report_share_3 = generate_helper_report_share::>( *task.id(), report_metadata_3.clone(), hpke_key.config(), &transcript_3.public_share, Vec::new(), - &transcript_3.input_shares[1], + &transcript_3.helper_input_share, ); // report_share_4 gets aggregated into the second batch interval (which has already been @@ -3160,19 +3120,19 @@ mod tests { let transcript_4 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_4.id(), - &0, + &measurement, ); - let (prep_state_4, _) = transcript_4.helper_prep_state(0); - let prep_msg_4 = transcript_4.prepare_messages[0].clone(); - let report_share_4 = generate_helper_report_share::( + let ping_pong_transition_4 = &transcript_4.helper_prepare_transitions[0].transition; + let ping_pong_leader_message_4 = &transcript_4.leader_prepare_transitions[1].message; + let report_share_4 = generate_helper_report_share::>( *task.id(), report_metadata_4.clone(), hpke_key.config(), &transcript_4.public_share, Vec::new(), - &transcript_4.input_shares[1], + &transcript_4.helper_input_share, ); // report_share_5 also gets aggregated into the second batch interval (which has already @@ -3187,19 +3147,19 @@ mod tests { let transcript_5 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_5.id(), - &0, + &measurement, ); - let (prep_state_5, _) = transcript_5.helper_prep_state(0); - let prep_msg_5 = transcript_5.prepare_messages[0].clone(); - let report_share_5 = generate_helper_report_share::( + let ping_pong_transition_5 = &transcript_5.helper_prepare_transitions[0].transition; + let ping_pong_leader_message_5 = &transcript_5.leader_prepare_transitions[1].message; + let report_share_5 = generate_helper_report_share::>( *task.id(), report_metadata_5.clone(), hpke_key.config(), &transcript_5.public_share, Vec::new(), - &transcript_5.input_shares[1], + &transcript_5.helper_input_share, ); datastore @@ -3210,16 +3170,17 @@ mod tests { report_share_4.clone(), report_share_5.clone(), ); - let (prep_state_3, prep_state_4, prep_state_5) = ( - prep_state_3.clone(), - prep_state_4.clone(), - prep_state_5.clone(), + let (ping_pong_transition_3, ping_pong_transition_4, ping_pong_transition_5) = ( + ping_pong_transition_3.clone(), + ping_pong_transition_4.clone(), + ping_pong_transition_5.clone(), ); let (report_metadata_3, report_metadata_4, report_metadata_5) = ( report_metadata_3.clone(), report_metadata_4.clone(), report_metadata_5.clone(), ); + let aggregation_param = aggregation_param.clone(); Box::pin(async move { tx.put_report_share(task.id(), &report_share_3).await?; @@ -3229,11 +3190,11 @@ mod tests { tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -3242,41 +3203,44 @@ mod tests { )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_3.id(), - *report_metadata_3.time(), - 3, - None, - ReportAggregationState::Waiting(prep_state_3, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_3.id(), + *report_metadata_3.time(), + 3, + None, + ReportAggregationState::Waiting(ping_pong_transition_3), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_4.id(), - *report_metadata_4.time(), - 4, - None, - ReportAggregationState::Waiting(prep_state_4, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_4.id(), + *report_metadata_4.time(), + 4, + None, + ReportAggregationState::Waiting(ping_pong_transition_4), + )) .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( - *task.id(), - aggregation_job_id_1, - *report_metadata_5.id(), - *report_metadata_5.time(), - 5, - None, - ReportAggregationState::Waiting(prep_state_5, None), - ), - ) + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + Poplar1, + >::new( + *task.id(), + aggregation_job_id_1, + *report_metadata_5.id(), + *report_metadata_5.time(), + 5, + None, + ReportAggregationState::Waiting(ping_pong_transition_5), + )) .await?; Ok(()) @@ -3288,18 +3252,9 @@ mod tests { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([ - PrepareStep::new( - *report_metadata_3.id(), - PrepareStepResult::Continued(prep_msg_3.get_encoded()), - ), - PrepareStep::new( - *report_metadata_4.id(), - PrepareStepResult::Continued(prep_msg_4.get_encoded()), - ), - PrepareStep::new( - *report_metadata_5.id(), - PrepareStepResult::Continued(prep_msg_5.get_encoded()), - ), + PrepareContinue::new(*report_metadata_3.id(), ping_pong_leader_message_3.clone()), + PrepareContinue::new(*report_metadata_4.id(), ping_pong_leader_message_4.clone()), + PrepareContinue::new(*report_metadata_5.id(), ping_pong_leader_message_5.clone()), ]), ); @@ -3311,12 +3266,16 @@ mod tests { // be the same) let merged_first_batch_aggregation = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3330,7 +3289,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3339,10 +3298,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, BatchAggregationState::Aggregating, agg.aggregate_share().cloned(), @@ -3356,8 +3315,14 @@ mod tests { let first_aggregate_share = vdaf .aggregate( - &(), - [out_share_0, out_share_1, out_share_3].into_iter().cloned(), + &aggregation_param, + [ + &transcript_0.helper_output_share, + &transcript_1.helper_output_share, + &transcript_3.helper_output_share, + ] + .into_iter() + .cloned(), ) .unwrap(); let first_checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) @@ -3376,7 +3341,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, BatchAggregationState::Aggregating, Some(first_aggregate_share), @@ -3388,12 +3353,16 @@ mod tests { let second_batch_got_batch_aggregations = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_2) = - (task.clone(), vdaf.clone(), report_metadata_2.clone()); + let (task, vdaf, report_metadata_2, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_2.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collection_identifier::< VERIFY_KEY_LENGTH, - Prio3Count, + Poplar1, _, >( tx, @@ -3407,7 +3376,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -3421,7 +3390,7 @@ mod tests { } #[tokio::test] - async fn aggregate_continue_leader_sends_non_continue_transition() { + async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. @@ -3476,7 +3445,7 @@ mod tests { *report_metadata.time(), 0, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), + ReportAggregationState::Waiting(ping_pong::Transition::default()), )) .await }) @@ -3487,21 +3456,25 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report_metadata.id(), - PrepareStepResult::Finished, + // An AggregationJobContinueReq should only ever contain Continue or Finished + ping_pong::Message::Initialize { + prep_share: Vec::new(), + }, )]), ); - post_aggregation_job_expecting_error( - &task, - &aggregation_job_id, - &request, - &handler, - Status::BadRequest, - "urn:ietf:params:ppm:dap:error:unrecognizedMessage", - "The message type for a response was incorrect or the payload was malformed.", - ) - .await; + + let resp = + post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; + assert_eq!(resp.prepare_resps().len(), 1); + assert_eq!( + resp.prepare_resps()[0], + PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Reject(PrepareError::VdafPrepError), + ) + ); } #[tokio::test] @@ -3564,7 +3537,7 @@ mod tests { *report_metadata.time(), 0, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), + ReportAggregationState::Waiting(ping_pong::Transition::default()), )) .await }) @@ -3575,9 +3548,12 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *report_metadata.id(), - PrepareStepResult::Continued(Vec::new()), + ping_pong::Message::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -3585,9 +3561,9 @@ mod tests { post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *report_metadata.id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError), + PrepareStepResult::Reject(PrepareError::VdafPrepError), )]),) ); @@ -3610,6 +3586,7 @@ mod tests { &Role::Helper, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), report_metadata.id(), ) .await @@ -3643,11 +3620,11 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - Some(PrepareStep::new( + Some(PrepareResp::new( *report_metadata.id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError) + PrepareStepResult::Reject(PrepareError::VdafPrepError) )), - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), ) ); } @@ -3708,7 +3685,7 @@ mod tests { *report_metadata.time(), 0, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), + ReportAggregationState::Waiting(ping_pong::Transition::default()), )) .await }) @@ -3719,11 +3696,14 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( ReportId::from( [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], // not the same as above ), - PrepareStepResult::Continued(Vec::new()), + ping_pong::Message::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -3819,7 +3799,7 @@ mod tests { *report_metadata_0.time(), 0, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), + ReportAggregationState::Waiting(ping_pong::Transition::default()), )) .await?; tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -3829,7 +3809,7 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), + ReportAggregationState::Waiting(ping_pong::Transition::default()), )) .await }) @@ -3842,13 +3822,19 @@ mod tests { AggregationJobRound::from(1), Vec::from([ // Report IDs are in opposite order to what was stored in the datastore. - PrepareStep::new( + PrepareContinue::new( *report_metadata_1.id(), - PrepareStepResult::Continued(Vec::new()), + ping_pong::Message::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), - PrepareStep::new( + PrepareContinue::new( *report_metadata_0.id(), - PrepareStepResult::Continued(Vec::new()), + ping_pong::Message::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), ]), ); @@ -3919,7 +3905,7 @@ mod tests { *report_metadata.time(), 0, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), )) .await }) @@ -3930,9 +3916,12 @@ mod tests { // Make request. let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - PrepareStepResult::Continued(Vec::new()), + ping_pong::Message::Continue { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); post_aggregation_job_expecting_error( diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 3db8c8b59..2a9298a88 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -44,16 +44,15 @@ use janus_messages::{ }, AggregateShare as AggregateShareMessage, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, - AggregationJobRound, BatchSelector, Duration, Interval, PartialBatchSelector, PrepareStep, - PrepareStepResult, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, + AggregationJobRound, BatchSelector, Duration, Interval, PartialBatchSelector, PrepareContinue, + PrepareInit, PrepareResp, PrepareStepResult, ReportIdChecksum, ReportMetadata, ReportShare, + Role, TaskId, Time, }; use prio::{ - field::Field64, - flp::types::Count, + idpf::IdpfInput, vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, prg::PrgSha3, - prio3::{Prio3, Prio3Count}, - AggregateShare, OutputShare, }, }; use rand::random; @@ -66,7 +65,7 @@ use trillium_testing::{ prelude::{post, put}, }; -type TestVdaf = Prio3, PrgSha3, 16>; +type TestVdaf = Poplar1; pub struct TaskprovTestCase { _ephemeral_datastore: EphemeralDatastore, @@ -81,6 +80,7 @@ pub struct TaskprovTestCase { task: Task, task_config: TaskConfig, task_id: TaskId, + aggregation_param: Poplar1AggregationParam, } async fn setup_taskprov_test() -> TaskprovTestCase { @@ -143,19 +143,29 @@ async fn setup_taskprov_test() -> TaskprovTestCase { TaskprovQuery::FixedSize { max_batch_size }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let mut task_config_encoded = vec![]; task_config.encode(&mut task_config_encoded); - // We use a real VDAF since taskprov doesn't have any allowance for a test VDAF. - let vdaf = Prio3Count::new_count(2).unwrap(); + // We use a real VDAF since taskprov doesn't have any allowance for a test VDAF, and we use + // Poplar1 so that the VDAF wil take more than one round, so we can exercise aggregation + // continuation. + let vdaf = Poplar1::new(1); let task_id = TaskId::try_from(digest(&SHA256, &task_config_encoded).as_ref()).unwrap(); let vdaf_instance = task_config.vdaf_config().vdaf_type().try_into().unwrap(); let vdaf_verify_key = peer_aggregator.derive_vdaf_verify_key(&task_id, &vdaf_instance); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[true])])) + .unwrap(); + let measurement = IdpfInput::from_bools(&[true]); let task = janus_aggregator_core::taskprov::Task::new( task_id, @@ -187,9 +197,9 @@ async fn setup_taskprov_test() -> TaskprovTestCase { let transcript = run_vdaf( &vdaf, vdaf_verify_key.as_ref().try_into().unwrap(), - &(), + &aggregation_param, report_metadata.id(), - &1, + &measurement, ); let report_share = generate_helper_report_share::( task_id, @@ -197,7 +207,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { global_hpke_key.config(), &transcript.public_share, Vec::new(), - &transcript.input_shares[1], + &transcript.helper_input_share, ); TaskprovTestCase { @@ -213,6 +223,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { report_metadata, transcript, report_share, + aggregation_param, } } @@ -222,9 +233,14 @@ async fn taskprov_aggregate_init() { let batch_id = random(); let request = AggregationJobInitializeReq::new( - ().get_encoded(), + test.aggregation_param.get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -287,10 +303,10 @@ async fn taskprov_aggregate_init() { ); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregate_resp.prepare_steps().len(), 1); - let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); + assert_eq!(aggregate_resp.prepare_resps().len(), 1); + let prepare_step = aggregate_resp.prepare_resps().get(0).unwrap(); assert_eq!(prepare_step.report_id(), test.report_share.metadata().id()); - assert_matches!(prepare_step.result(), &PrepareStepResult::Continued(..)); + assert_matches!(prepare_step.result(), &PrepareStepResult::Continue { .. }); let (aggregation_jobs, got_task) = test .datastore @@ -328,7 +344,12 @@ async fn taskprov_opt_out_task_expired() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -378,7 +399,12 @@ async fn taskprov_opt_out_mismatched_task_id() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -404,7 +430,11 @@ async fn taskprov_opt_out_mismatched_task_id() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); @@ -452,7 +482,12 @@ async fn taskprov_opt_out_missing_aggregator() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -475,7 +510,11 @@ async fn taskprov_opt_out_missing_aggregator() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let another_task_config_encoded = another_task_config.get_encoded(); @@ -490,8 +529,7 @@ async fn taskprov_opt_out_missing_aggregator() { .request_authentication(); let mut test_conn = put(format!( - "/tasks/{another_task_id -}/aggregation_jobs/{aggregation_job_id}" + "/tasks/{another_task_id}/aggregation_jobs/{aggregation_job_id}" )) .with_request_header(auth.0, auth.1) .with_request_header( @@ -525,7 +563,12 @@ async fn taskprov_opt_out_peer_aggregator_wrong_role() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -551,7 +594,11 @@ async fn taskprov_opt_out_peer_aggregator_wrong_role() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let another_task_config_encoded = another_task_config.get_encoded(); @@ -566,8 +613,7 @@ async fn taskprov_opt_out_peer_aggregator_wrong_role() { .request_authentication(); let mut test_conn = put(format!( - "/tasks/{another_task_id -}/aggregation_jobs/{aggregation_job_id}" + "/tasks/{another_task_id}/aggregation_jobs/{aggregation_job_id}" )) .with_request_header(auth.0, auth.1) .with_request_header( @@ -602,7 +648,12 @@ async fn taskprov_opt_out_peer_aggregator_does_not_exist() { let request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let aggregation_job_id: AggregationJobId = random(); @@ -628,7 +679,11 @@ async fn taskprov_opt_out_peer_aggregator_does_not_exist() { }, ), task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Count).unwrap(), + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Poplar1 { bits: 1 }, + ) + .unwrap(), ) .unwrap(); let another_task_config_encoded = another_task_config.get_encoded(); @@ -643,8 +698,7 @@ async fn taskprov_opt_out_peer_aggregator_does_not_exist() { .request_authentication(); let mut test_conn = put(format!( - "/tasks/{another_task_id -}/aggregation_jobs/{aggregation_job_id}" + "/tasks/{another_task_id}/aggregation_jobs/{aggregation_job_id}" )) .with_request_header(auth.0, auth.1) .with_request_header( @@ -678,15 +732,13 @@ async fn taskprov_aggregate_continue() { let aggregation_job_id = random(); let batch_id = random(); - let (prep_state, _) = test.transcript.helper_prep_state(0); - let prep_msg = test.transcript.prepare_messages[0].clone(); - test.datastore .run_tx(|tx| { let task = test.task.clone(); let report_share = test.report_share.clone(); - let prep_state = prep_state.clone(); let report_metadata = test.report_metadata.clone(); + let transcript = test.transcript.clone(); + let aggregation_param = test.aggregation_param.clone(); Box::pin(async move { // Aggregate continue is only possible if the task has already been inserted. @@ -698,7 +750,7 @@ async fn taskprov_aggregate_continue() { &AggregationJob::::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -715,7 +767,9 @@ async fn taskprov_aggregate_continue() { *report_metadata.time(), 0, None, - ReportAggregationState::Waiting(prep_state, None), + ReportAggregationState::Waiting( + transcript.helper_prepare_transitions[0].transition.clone(), + ), )) .await?; @@ -723,8 +777,8 @@ async fn taskprov_aggregate_continue() { &AggregateShareJob::new( *task.id(), batch_id, - (), - AggregateShare::from(OutputShare::from(Vec::from([Field64::from(7)]))), + aggregation_param, + transcript.helper_aggregate_share, 0, ReportIdChecksum::default(), ), @@ -737,9 +791,11 @@ async fn taskprov_aggregate_continue() { let request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *test.report_metadata.id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + test.transcript.leader_prepare_transitions[1] + .message + .clone(), )]), ); @@ -801,11 +857,11 @@ async fn taskprov_aggregate_continue() { assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - // We'll only validate the response. Taskprov doesn't touch functionality beyond the authorization - // of the request. + // We'll only validate the response. Taskprov doesn't touch functionality beyond the + // authorization of the request. assert_eq!( aggregate_resp, - AggregationJobResp::new(Vec::from([PrepareStep::new( + AggregationJobResp::new(Vec::from([PrepareResp::new( *test.report_metadata.id(), PrepareStepResult::Finished )])) @@ -823,6 +879,8 @@ async fn taskprov_aggregate_share() { let interval = Interval::new(Time::from_seconds_since_epoch(6000), *task.time_precision()) .unwrap(); + let aggregation_param = test.aggregation_param.clone(); + let transcript = test.transcript.clone(); Box::pin(async move { tx.put_task(&task).await?; @@ -830,7 +888,7 @@ async fn taskprov_aggregate_share() { tx.put_batch(&Batch::<16, FixedSize, TestVdaf>::new( *task.id(), batch_id, - (), + aggregation_param.clone(), BatchState::Closed, 0, interval, @@ -841,12 +899,10 @@ async fn taskprov_aggregate_share() { tx.put_batch_aggregation(&BatchAggregation::<16, FixedSize, TestVdaf>::new( *task.id(), batch_id, - (), + aggregation_param, 0, BatchAggregationState::Aggregating, - Some(AggregateShare::from(OutputShare::from(Vec::from([ - Field64::from(7), - ])))), + Some(transcript.helper_aggregate_share), 1, interval, ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), @@ -861,7 +917,7 @@ async fn taskprov_aggregate_share() { let request = AggregateShareReq::new( BatchSelector::new_fixed_size(batch_id), - ().get_encoded(), + test.aggregation_param.get_encoded(), 1, ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), ); @@ -910,8 +966,6 @@ async fn taskprov_aggregate_share() { .run_async(&test.handler) .await; - println!("{:?}", test_conn); - assert_eq!(test_conn.status(), Some(Status::Ok)); assert_headers!( &test_conn, @@ -943,9 +997,14 @@ async fn end_to_end() { let aggregation_job_id = random(); let aggregation_job_init_request = AggregationJobInitializeReq::new( - ().get_encoded(), + test.aggregation_param.get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), + Vec::from([PrepareInit::new( + test.report_share.clone(), + test.transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), ); let mut test_conn = put(test @@ -970,23 +1029,25 @@ async fn end_to_end() { assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregation_job_resp.prepare_steps().len(), 1); - let prepare_step = &aggregation_job_resp.prepare_steps()[0]; - assert_eq!(prepare_step.report_id(), test.report_metadata.id()); - let encoded_prep_share = assert_matches!( - prepare_step.result(), - PrepareStepResult::Continued(payload) => payload.clone() + assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); + let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resp.report_id(), test.report_metadata.id()); + let message = assert_matches!( + prepare_resp.result(), + PrepareStepResult::Continue { message } => message.clone() ); assert_eq!( - encoded_prep_share, - test.transcript.helper_prep_state(0).1.get_encoded() + message, + test.transcript.helper_prepare_transitions[0].message, ); let aggregation_job_continue_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), - Vec::from([PrepareStep::new( + Vec::from([PrepareContinue::new( *test.report_metadata.id(), - PrepareStepResult::Continued(test.transcript.prepare_messages[0].get_encoded()), + test.transcript.leader_prepare_transitions[1] + .message + .clone(), )]), ); @@ -1013,15 +1074,15 @@ async fn end_to_end() { assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; - assert_eq!(aggregation_job_resp.prepare_steps().len(), 1); - let prepare_step = &aggregation_job_resp.prepare_steps()[0]; - assert_eq!(prepare_step.report_id(), test.report_metadata.id()); - assert_matches!(prepare_step.result(), PrepareStepResult::Finished); + assert_eq!(aggregation_job_resp.prepare_resps().len(), 1); + let prepare_resp = &aggregation_job_resp.prepare_resps()[0]; + assert_eq!(prepare_resp.report_id(), test.report_metadata.id()); + assert_matches!(prepare_resp.result(), PrepareStepResult::Finished); let checksum = ReportIdChecksum::for_report_id(test.report_metadata.id()); let aggregate_share_request = AggregateShareReq::new( BatchSelector::new_fixed_size(batch_id), - ().get_encoded(), + test.aggregation_param.get_encoded(), 1, checksum, ); @@ -1056,5 +1117,8 @@ async fn end_to_end() { .get_encoded(), ) .unwrap(); - assert_eq!(plaintext, test.transcript.aggregate_shares[1].get_encoded()); + assert_eq!( + plaintext, + test.transcript.helper_aggregate_share.get_encoded() + ); } diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 387001a68..14c3f5c7f 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -24,7 +24,7 @@ use janus_core::{ use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregationJobId, BatchId, CollectionJobId, Duration, Extension, HpkeCiphertext, HpkeConfig, - HpkeConfigId, Interval, PrepareStep, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, + HpkeConfigId, Interval, PrepareResp, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, }; use opentelemetry::{ @@ -34,6 +34,7 @@ use opentelemetry::{ use postgres_types::{FromSql, Json, ToSql}; use prio::{ codec::{decode_u16_items, encode_u16_items, CodecError, Decode, Encode, ParameterizedDecode}, + topology::ping_pong, vdaf, }; use rand::random; @@ -2097,6 +2098,7 @@ impl Transaction<'_, C> { role: &Role, task_id: &TaskId, aggregation_job_id: &AggregationJobId, + aggregation_param: &A::AggregationParam, report_id: &ReportId, ) -> Result>, Error> where @@ -2107,8 +2109,7 @@ impl Transaction<'_, C> { "SELECT report_aggregations.client_timestamp, report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.error_code, - report_aggregations.last_prep_step + report_aggregations.error_code, report_aggregations.last_prep_step FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2153,6 +2154,7 @@ impl Transaction<'_, C> { role: &Role, task_id: &TaskId, aggregation_job_id: &AggregationJobId, + aggregation_param: &A::AggregationParam, ) -> Result>, Error> where for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, @@ -2162,8 +2164,8 @@ impl Transaction<'_, C> { "SELECT report_aggregations.client_report_id, report_aggregations.client_timestamp, report_aggregations.ord, report_aggregations.state, - report_aggregations.prep_state, report_aggregations.prep_msg, - report_aggregations.error_code, report_aggregations.last_prep_step + report_aggregations.prep_state, report_aggregations.error_code, + report_aggregations.last_prep_step FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2217,8 +2219,7 @@ impl Transaction<'_, C> { aggregation_jobs.aggregation_job_id, report_aggregations.client_report_id, report_aggregations.client_timestamp, report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.error_code, - report_aggregations.last_prep_step + report_aggregations.error_code, report_aggregations.last_prep_step FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id JOIN tasks ON tasks.id = aggregation_jobs.task_id @@ -2263,7 +2264,6 @@ impl Transaction<'_, C> { let ord: u64 = row.get_bigint_and_convert("ord")?; let state: ReportAggregationStateCode = row.get("state"); let prep_state_bytes: Option> = row.get("prep_state"); - let prep_msg_bytes: Option> = row.get("prep_msg"); let error_code: Option = row.get("error_code"); let last_prep_step_bytes: Option> = row.get("last_prep_step"); @@ -2280,7 +2280,7 @@ impl Transaction<'_, C> { }; let last_prep_step = last_prep_step_bytes - .map(|bytes| PrepareStep::get_decoded(&bytes)) + .map(|bytes| PrepareResp::get_decoded(&bytes)) .transpose()?; let agg_state = match state { @@ -2290,7 +2290,7 @@ impl Transaction<'_, C> { let agg_index = role.index().ok_or_else(|| { Error::User(anyhow!("unexpected role: {}", role.as_str()).into()) })?; - let prep_state = A::PrepareState::get_decoded_with_param( + let ping_pong_transition = ping_pong::Transition::get_decoded_with_param( &(vdaf, agg_index), &prep_state_bytes.ok_or_else(|| { Error::DbState( @@ -2299,11 +2299,8 @@ impl Transaction<'_, C> { ) })?, )?; - let prep_msg = prep_msg_bytes - .map(|bytes| A::PrepareMessage::get_decoded_with_param(&prep_state, &bytes)) - .transpose()?; - ReportAggregationState::Waiting(prep_state, prep_msg) + ReportAggregationState::Waiting(ping_pong_transition) } ReportAggregationStateCode::Finished => ReportAggregationState::Finished, @@ -2343,19 +2340,19 @@ impl Transaction<'_, C> { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); let encoded_last_prep_step = report_aggregation .last_prep_step() - .map(PrepareStep::get_encoded); + .map(PrepareResp::get_encoded); let stmt = self .prepare_cached( "INSERT INTO report_aggregations (aggregation_job_id, client_report_id, client_timestamp, ord, state, prep_state, - prep_msg, error_code, last_prep_step) - SELECT aggregation_jobs.id, $3, $4, $5, $6, $7, $8, $9, $10 + error_code, last_prep_step) + SELECT aggregation_jobs.id, $3, $4, $5, $6, $7, $8, $9 FROM aggregation_jobs JOIN tasks ON tasks.id = aggregation_jobs.task_id WHERE tasks.task_id = $1 AND aggregation_job_id = $2 - AND UPPER(aggregation_jobs.client_timestamp_interval) >= COALESCE($11::TIMESTAMP - tasks.report_expiry_age * '1 second'::INTERVAL, '-infinity'::TIMESTAMP) + AND UPPER(aggregation_jobs.client_timestamp_interval) >= COALESCE($10::TIMESTAMP - tasks.report_expiry_age * '1 second'::INTERVAL, '-infinity'::TIMESTAMP) ON CONFLICT DO NOTHING", ) .await?; @@ -2369,7 +2366,6 @@ impl Transaction<'_, C> { /* ord */ &TryInto::::try_into(report_aggregation.ord())?, /* state */ &report_aggregation.state().state_code(), /* prep_state */ &encoded_state_values.prep_state, - /* prep_msg */ &encoded_state_values.prep_msg, /* error_code */ &encoded_state_values.report_share_err, /* last_prep_step */ &encoded_last_prep_step, /* now */ &self.clock.now().as_naive_date_time()?, @@ -2393,21 +2389,21 @@ impl Transaction<'_, C> { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); let encoded_last_prep_step = report_aggregation .last_prep_step() - .map(PrepareStep::get_encoded); + .map(PrepareResp::get_encoded); let stmt = self .prepare_cached( "UPDATE report_aggregations - SET state = $1, prep_state = $2, prep_msg = $3, error_code = $4, last_prep_step = $5 + SET state = $1, prep_state = $2, error_code = $3, last_prep_step = $4 FROM aggregation_jobs, tasks WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id AND aggregation_jobs.task_id = tasks.id - AND aggregation_jobs.aggregation_job_id = $6 - AND tasks.task_id = $7 - AND report_aggregations.client_report_id = $8 - AND report_aggregations.client_timestamp = $9 - AND report_aggregations.ord = $10 - AND UPPER(aggregation_jobs.client_timestamp_interval) >= COALESCE($11::TIMESTAMP - tasks.report_expiry_age * '1 second'::INTERVAL, '-infinity'::TIMESTAMP)", + AND aggregation_jobs.aggregation_job_id = $5 + AND tasks.task_id = $6 + AND report_aggregations.client_report_id = $7 + AND report_aggregations.client_timestamp = $8 + AND report_aggregations.ord = $9 + AND UPPER(aggregation_jobs.client_timestamp_interval) >= COALESCE($10::TIMESTAMP - tasks.report_expiry_age * '1 second'::INTERVAL, '-infinity'::TIMESTAMP)", ) .await?; check_single_row_mutation( @@ -2417,7 +2413,6 @@ impl Transaction<'_, C> { /* state */ &report_aggregation.state().state_code(), /* prep_state */ &encoded_state_values.prep_state, - /* prep_msg */ &encoded_state_values.prep_msg, /* error_code */ &encoded_state_values.report_share_err, /* last_prep_step */ &encoded_last_prep_step, /* aggregation_job_id */ @@ -4411,7 +4406,7 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p + "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p WHERE p.id = a.peer_aggregator_id) AS peer_id, ord, type, token FROM taskprov_aggregator_auth_tokens AS a ORDER BY ord ASC", @@ -4421,7 +4416,7 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p + "SELECT (SELECT p.id FROM taskprov_peer_aggregators AS p WHERE p.id = a.peer_aggregator_id) AS peer_id, ord, type, token FROM taskprov_collector_auth_tokens AS a ORDER BY ord ASC", diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index 17e40838d..f58e7ec73 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -13,8 +13,8 @@ use janus_core::{ use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregationJobId, AggregationJobRound, BatchId, CollectionJobId, Duration, Extension, - HpkeCiphertext, Interval, PrepareStep, ReportId, ReportIdChecksum, ReportMetadata, - ReportShareError, Role, TaskId, Time, + HpkeCiphertext, Interval, PrepareError, PrepareResp, ReportId, ReportIdChecksum, + ReportMetadata, Role, TaskId, Time, }; use postgres_protocol::types::{ range_from_sql, range_to_sql, timestamp_from_sql, timestamp_to_sql, Range, RangeBound, @@ -22,6 +22,7 @@ use postgres_protocol::types::{ use postgres_types::{accepts, to_sql_checked, FromSql, ToSql}; use prio::{ codec::Encode, + topology::ping_pong, vdaf::{self, Aggregatable}, }; use rand::{distributions::Standard, prelude::Distribution}; @@ -587,7 +588,7 @@ pub struct ReportAggregation, + last_prep_step: Option, state: ReportAggregationState, } @@ -599,7 +600,7 @@ impl> ReportAggregati report_id: ReportId, time: Time, ord: u64, - last_prep_step: Option, + last_prep_step: Option, state: ReportAggregationState, ) -> Self { Self { @@ -644,13 +645,13 @@ impl> ReportAggregati } /// Returns the last preparation step returned by the Helper, if any. - pub fn last_prep_step(&self) -> Option<&PrepareStep> { + pub fn last_prep_step(&self) -> Option<&PrepareResp> { self.last_prep_step.as_ref() } /// Returns a new [`ReportAggregation`] corresponding to this report aggregation updated to /// have the given last preparation step. - pub fn with_last_prep_step(self, last_prep_step: Option) -> Self { + pub fn with_last_prep_step(self, last_prep_step: Option) -> Self { Self { last_prep_step, ..self @@ -700,16 +701,15 @@ where /// ReportAggregationState represents the state of a single report aggregation. It corresponds /// to the REPORT_AGGREGATION_STATE enum in the schema, along with the state-specific data. -#[derive(Clone, Derivative)] -#[derivative(Debug)] +#[derive(Clone, Debug, Derivative)] pub enum ReportAggregationState> { Start, Waiting( - #[derivative(Debug = "ignore")] A::PrepareState, - #[derivative(Debug = "ignore")] Option, + /// Most recent transition for this report aggregation. + ping_pong::Transition, ), Finished, - Failed(ReportShareError), + Failed(PrepareError), } impl> @@ -718,7 +718,7 @@ impl> pub fn state_code(&self) -> ReportAggregationStateCode { match self { ReportAggregationState::Start => ReportAggregationStateCode::Start, - ReportAggregationState::Waiting(_, _) => ReportAggregationStateCode::Waiting, + ReportAggregationState::Waiting(_) => ReportAggregationStateCode::Waiting, ReportAggregationState::Finished => ReportAggregationStateCode::Finished, ReportAggregationState::Failed(_) => ReportAggregationStateCode::Failed, } @@ -733,13 +733,10 @@ impl> { match self { ReportAggregationState::Start => EncodedReportAggregationStateValues::default(), - ReportAggregationState::Waiting(prep_state, prep_msg) => { - EncodedReportAggregationStateValues { - prep_state: Some(prep_state.get_encoded()), - prep_msg: prep_msg.as_ref().map(Encode::get_encoded), - ..Default::default() - } - } + ReportAggregationState::Waiting(prep_state) => EncodedReportAggregationStateValues { + prep_state: Some(prep_state.get_encoded()), + ..Default::default() + }, ReportAggregationState::Finished => EncodedReportAggregationStateValues::default(), ReportAggregationState::Failed(report_share_err) => { EncodedReportAggregationStateValues { @@ -754,7 +751,6 @@ impl> #[derive(Default)] pub(super) struct EncodedReportAggregationStateValues { pub(super) prep_state: Option>, - pub(super) prep_msg: Option>, pub(super) report_share_err: Option, } @@ -778,17 +774,14 @@ pub enum ReportAggregationStateCode { impl> PartialEq for ReportAggregationState where - A::PrepareState: PartialEq, - A::PrepareMessage: PartialEq, A::PrepareShare: PartialEq, A::OutputShare: PartialEq, { fn eq(&self, other: &Self) -> bool { match (self, other) { - ( - Self::Waiting(lhs_prep_state, lhs_prep_msg), - Self::Waiting(rhs_prep_state, rhs_prep_msg), - ) => lhs_prep_state == rhs_prep_state && lhs_prep_msg == rhs_prep_msg, + (Self::Waiting(lhs_prep_state), Self::Waiting(rhs_prep_state)) => { + lhs_prep_state == rhs_prep_state + } (Self::Failed(lhs_report_share_err), Self::Failed(rhs_report_share_err)) => { lhs_report_share_err == rhs_report_share_err } diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 7307b2b6a..4d9b4aa8e 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -39,11 +39,12 @@ use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregateShareAad, AggregationJobId, AggregationJobRound, BatchId, BatchSelector, CollectionJobId, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, Interval, - PrepareStep, PrepareStepResult, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, - ReportShareError, Role, TaskId, Time, + PrepareError, PrepareResp, PrepareStepResult, ReportId, ReportIdChecksum, ReportMetadata, + ReportShare, Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, + topology::ping_pong, vdaf::prio3::{Prio3, Prio3Count}, }; use rand::{distributions::Standard, random, thread_rng, Rng}; @@ -1881,17 +1882,16 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let vdaf = Arc::new(Prio3::new_count(2).unwrap()); let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); - let leader_prep_state = vdaf_transcript.leader_prep_state(0); for (ord, state) in [ ReportAggregationState::::Start, ReportAggregationState::Waiting( - leader_prep_state.clone(), - Some(vdaf_transcript.prepare_messages[0].clone()), + vdaf_transcript.helper_prepare_transitions[0] + .transition + .clone(), ), - ReportAggregationState::Waiting(leader_prep_state.clone(), None), ReportAggregationState::Finished, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), ] .into_iter() .enumerate() @@ -1902,7 +1902,7 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::QueryType::TimeInterval, VdafInstance::Prio3Count, - Role::Leader, + Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) .build(); @@ -1949,9 +1949,14 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { report_id, OLDEST_ALLOWED_REPORT_TIMESTAMP, ord.try_into().unwrap(), - Some(PrepareStep::new( + Some(PrepareResp::new( report_id, - PrepareStepResult::Continued(format!("prep_msg_{ord}").into()), + PrepareStepResult::Continue { + message: ping_pong::Message::Continue { + prep_msg: format!("prep_msg_{ord}").into(), + prep_share: format!("prep_share_{ord}").into(), + }, + }, )), state, ); @@ -1971,9 +1976,10 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, + &(), &report_id, ) .await @@ -1991,9 +1997,14 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { *want_report_aggregation.report_id(), *want_report_aggregation.time(), want_report_aggregation.ord(), - Some(PrepareStep::new( + Some(PrepareResp::new( report_id, - PrepareStepResult::Continued(format!("updated_prep_msg_{ord}").into()), + PrepareStepResult::Continue { + message: ping_pong::Message::Continue { + prep_msg: format!("updated_prep_msg_{ord}").into(), + prep_share: format!("updated_prep_share_{ord}").into(), + }, + }, )), want_report_aggregation.state().clone(), ); @@ -2011,9 +2022,10 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, + &(), &report_id, ) .await @@ -2032,9 +2044,10 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, + &(), &report_id, ) .await @@ -2216,6 +2229,7 @@ async fn report_aggregation_not_found(ephemeral_datastore: EphemeralDatastore) { &Role::Leader, &random(), &random(), + &dummy_vdaf::AggregationParam::default(), &ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), ) .await @@ -2235,7 +2249,7 @@ async fn report_aggregation_not_found(ephemeral_datastore: EphemeralDatastore) { Time::from_seconds_since_epoch(12345), 0, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), )) .await }) @@ -2260,7 +2274,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let task = TaskBuilder::new( task::QueryType::TimeInterval, VdafInstance::Prio3Count, - Role::Leader, + Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) .build(); @@ -2268,13 +2282,10 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let want_report_aggregations = ds .run_tx(|tx| { - let (task, prep_msg, prep_state) = ( - task.clone(), - vdaf_transcript.prepare_messages[0].clone(), - vdaf_transcript.leader_prep_state(0).clone(), - ); + let (task, vdaf_transcript) = (task.clone(), vdaf_transcript.clone()); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_task(&task).await.unwrap(); + tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, @@ -2289,14 +2300,19 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme AggregationJobState::InProgress, AggregationJobRound::from(0), )) - .await?; + .await + .unwrap(); let mut want_report_aggregations = Vec::new(); for (ord, state) in [ ReportAggregationState::::Start, - ReportAggregationState::Waiting(prep_state.clone(), Some(prep_msg)), + ReportAggregationState::Waiting( + vdaf_transcript.helper_prepare_transitions[0] + .transition + .clone(), + ), ReportAggregationState::Finished, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Failed(PrepareError::VdafPrepError), ] .iter() .enumerate() @@ -2314,7 +2330,8 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme ), ), ) - .await?; + .await + .unwrap(); let report_aggregation = ReportAggregation::new( *task.id(), @@ -2322,10 +2339,12 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme report_id, OLDEST_ALLOWED_REPORT_TIMESTAMP, ord.try_into().unwrap(), - Some(PrepareStep::new(report_id, PrepareStepResult::Finished)), + Some(PrepareResp::new(report_id, PrepareStepResult::Finished)), state.clone(), ); - tx.put_report_aggregation(&report_aggregation).await?; + tx.put_report_aggregation(&report_aggregation) + .await + .unwrap(); want_report_aggregations.push(report_aggregation); } Ok(want_report_aggregations) @@ -2343,9 +2362,10 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme Box::pin(async move { tx.get_report_aggregations_for_aggregation_job( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, + &(), ) .await }) @@ -2363,9 +2383,10 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme Box::pin(async move { tx.get_report_aggregations_for_aggregation_job( vdaf.as_ref(), - &Role::Leader, + &Role::Helper, task.id(), &aggregation_job_id, + &(), ) .await }) @@ -4670,7 +4691,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { clock.now(), 1, None, - ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), Some(())), // Counted among max_size. + ReportAggregationState::Waiting(ping_pong::Transition::default()), // Counted among max_size. ); let report_aggregation_0_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task_1.id(), @@ -4679,7 +4700,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { clock.now(), 2, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. + ReportAggregationState::Failed(PrepareError::VdafPrepError), // Not counted among min_size or max_size. ); let aggregation_job_1 = AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( @@ -4717,7 +4738,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { clock.now(), 2, None, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. + ReportAggregationState::Failed(PrepareError::VdafPrepError), // Not counted among min_size or max_size. ); let aggregation_job_2 = AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 63d2fedc1..63cd015ed 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -731,14 +731,14 @@ mod tests { hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &transcript.aggregate_shares[0].get_encoded(), + &transcript.leader_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &transcript.aggregate_shares[1].get_encoded(), + &transcript.helper_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), @@ -759,14 +759,14 @@ mod tests { hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &transcript.aggregate_shares[0].get_encoded(), + &transcript.leader_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), hpke::seal( ¶meters.hpke_config, &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &transcript.aggregate_shares[1].get_encoded(), + &transcript.helper_aggregate_share.get_encoded(), &associated_data.get_encoded(), ) .unwrap(), diff --git a/core/src/task.rs b/core/src/task.rs index 80e280e63..43dd30689 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -327,19 +327,19 @@ macro_rules! vdaf_dispatch_impl_test_util { ::janus_core::task::VdafInstance::FakeFailsPrepStep => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result< - ::prio::vdaf::PrepareTransition< - ::janus_core::test_util::dummy_vdaf::Vdaf, - 0, - 16, - >, - ::prio::vdaf::VdafError, - > { - ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( - "FakeFailsPrepStep failed at prep_step".to_string(), - )) - }, - ); + |_| -> Result< + ::prio::vdaf::PrepareTransition< + ::janus_core::test_util::dummy_vdaf::Vdaf, + 0, + 16, + >, + ::prio::vdaf::VdafError, + > { + ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( + "FakeFailsPrepStep failed at prep_step".to_string(), + )) + }, + ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; const $VERIFY_KEY_LEN: usize = 0; $body diff --git a/core/src/test_util/dummy_vdaf.rs b/core/src/test_util/dummy_vdaf.rs index 7d1583112..30734fce3 100644 --- a/core/src/test_util/dummy_vdaf.rs +++ b/core/src/test_util/dummy_vdaf.rs @@ -4,20 +4,24 @@ use prio::{ codec::{CodecError, Decode, Encode}, vdaf::{self, Aggregatable, PrepareTransition, VdafError}, }; +use rand::random; use std::fmt::Debug; use std::io::Cursor; use std::sync::Arc; type ArcPrepInitFn = Arc Result<(), VdafError> + 'static + Send + Sync>; -type ArcPrepStepFn = - Arc Result, VdafError> + 'static + Send + Sync>; +type ArcPrepStepFn = Arc< + dyn Fn(&PrepareState) -> Result, VdafError> + + 'static + + Send + + Sync, +>; #[derive(Clone)] pub struct Vdaf { prep_init_fn: ArcPrepInitFn, prep_step_fn: ArcPrepStepFn, - input_share: InputShare, } impl Debug for Vdaf { @@ -36,10 +40,11 @@ impl Vdaf { pub fn new() -> Self { Self { prep_init_fn: Arc::new(|_| -> Result<(), VdafError> { Ok(()) }), - prep_step_fn: Arc::new(|| -> Result, VdafError> { - Ok(PrepareTransition::Finish(OutputShare())) - }), - input_share: InputShare::default(), + prep_step_fn: Arc::new( + |state| -> Result, VdafError> { + Ok(PrepareTransition::Finish(OutputShare(state.0))) + }, + ), } } @@ -54,7 +59,9 @@ impl Vdaf { self } - pub fn with_prep_step_fn Result, VdafError>>( + pub fn with_prep_step_fn< + F: Fn(&PrepareState) -> Result, VdafError>, + >( mut self, f: F, ) -> Self @@ -64,11 +71,6 @@ impl Vdaf { self.prep_step_fn = Arc::new(f); self } - - pub fn with_input_share(mut self, input_share: InputShare) -> Self { - self.input_share = input_share; - self - } } impl Default for Vdaf { @@ -80,8 +82,8 @@ impl Default for Vdaf { impl vdaf::Vdaf for Vdaf { const ID: u32 = 0xFFFF0000; - type Measurement = (); - type AggregateResult = (); + type Measurement = u8; + type AggregateResult = u8; type AggregationParam = AggregationParam; type PublicShare = (); type InputShare = InputShare; @@ -120,10 +122,10 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_step( &self, - _: Self::PrepareState, + state: Self::PrepareState, _: Self::PrepareMessage, ) -> Result, VdafError> { - (self.prep_step_fn)() + (self.prep_step_fn)(&state) } fn aggregate>( @@ -142,10 +144,18 @@ impl vdaf::Aggregator<0, 16> for Vdaf { impl vdaf::Client<16> for Vdaf { fn shard( &self, - _measurement: &Self::Measurement, + measurement: &Self::Measurement, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { - Ok(((), Vec::from([self.input_share, self.input_share]))) + let first_input_share = random(); + let (second_input_share, _) = measurement.overflowing_sub(first_input_share); + Ok(( + (), + Vec::from([ + InputShare(first_input_share), + InputShare(second_input_share), + ]), + )) } } @@ -173,7 +183,7 @@ pub struct AggregationParam(pub u8); impl Encode for AggregationParam { fn encode(&self, bytes: &mut Vec) { - self.0.encode(bytes); + self.0.encode(bytes) } fn encoded_len(&self) -> Option { @@ -188,19 +198,21 @@ impl Decode for AggregationParam { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct OutputShare(); +pub struct OutputShare(pub u8); impl Decode for OutputShare { - fn decode(_: &mut Cursor<&[u8]>) -> Result { - Ok(Self()) + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + Ok(Self(u8::decode(bytes)?)) } } impl Encode for OutputShare { - fn encode(&self, _: &mut Vec) {} + fn encode(&self, bytes: &mut Vec) { + self.0.encode(bytes); + } fn encoded_len(&self) -> Option { - Some(0) + self.0.encoded_len() } } @@ -234,15 +246,15 @@ impl Aggregatable for AggregateShare { Ok(()) } - fn accumulate(&mut self, _: &Self::OutputShare) -> Result<(), VdafError> { - self.0 += 1; + fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> { + self.0 += u64::from(out_share.0); Ok(()) } } impl From for AggregateShare { - fn from(_: OutputShare) -> Self { - Self(1) + fn from(out_share: OutputShare) -> Self { + Self(u64::from(out_share.0)) } } diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 9f5dbd517..5c9332f3e 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -1,6 +1,8 @@ -use assert_matches::assert_matches; -use janus_messages::{ReportId, Role}; -use prio::vdaf::{self, PrepareTransition, VdafError}; +use janus_messages::ReportId; +use prio::{ + topology::ping_pong::{self, PingPongTopology}, + vdaf, +}; use serde::{de::DeserializeOwned, Serialize}; use std::{fmt::Debug, sync::Once}; use tracing_log::LogTracer; @@ -11,124 +13,199 @@ pub mod kubernetes; pub mod runtime; pub mod testcontainers; -/// A transcript of a VDAF run. All fields are indexed by natural role index (i.e., index 0 = -/// leader, index 1 = helper). #[derive(Clone, Debug)] -pub struct VdafTranscript> { +pub struct LeaderPrepareTransition< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator, +> { + pub transition: Option>, + pub state: ping_pong::State, + pub message: ping_pong::Message, +} + +#[derive(Clone, Debug)] +pub struct HelperPrepareTransition< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator, +> { + pub transition: ping_pong::Transition, + pub state: ping_pong::State, + pub message: ping_pong::Message, +} + +/// A transcript of a VDAF run using the ping-pong VDAF topology. +#[derive(Clone, Debug)] +pub struct VdafTranscript< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator, +> { /// The public share, from the sharding algorithm. pub public_share: V::PublicShare, - /// The measurement's input shares, from the sharding algorithm. - pub input_shares: Vec, - /// Prepare transitions sent throughout the protocol run. The outer `Vec` is indexed by - /// aggregator, and the inner `Vec`s are indexed by VDAF round. - prepare_transitions: Vec>>, - /// The prepare messages broadcast to all aggregators prior to each continuation round of the - /// VDAF. - pub prepare_messages: Vec, - /// The output shares computed by each aggregator. - output_shares: Vec, - /// The aggregate shares from each aggregator. - pub aggregate_shares: Vec, -} + /// The leader's input share, from the sharding algorithm. + pub leader_input_share: V::InputShare, -impl> VdafTranscript { - /// Get the leader's preparation state at the requested round. - pub fn leader_prep_state(&self, round: usize) -> &V::PrepareState { - assert_matches!( - &self.prepare_transitions[Role::Leader.index().unwrap()][round], - PrepareTransition::::Continue(prep_state, _) => prep_state - ) - } + /// The helper's input share, from the sharding algorithm. + pub helper_input_share: V::InputShare, - /// Get the helper's preparation state and prepare share at the requested round. - pub fn helper_prep_state(&self, round: usize) -> (&V::PrepareState, &V::PrepareShare) { - assert_matches!( - &self.prepare_transitions[Role::Helper.index().unwrap()][round], - PrepareTransition::::Continue(prep_state, prep_share) => (prep_state, prep_share) - ) - } + /// The leader's states and messages computed throughout the protocol run. Indexed by the + /// aggregation job round. + #[allow(clippy::type_complexity)] + pub leader_prepare_transitions: Vec>, - /// Get the output share for the specified aggregator. - pub fn output_share(&self, role: Role) -> &V::OutputShare { - &self.output_shares[role.index().unwrap()] - } + /// The helper's states and messages computed throughout the protocol run. Indexed by the + /// aggregation job round. + #[allow(clippy::type_complexity)] + pub helper_prepare_transitions: Vec>, + + /// The leader's computed output share. + pub leader_output_share: V::OutputShare, + + /// The helper's computed output share. + pub helper_output_share: V::OutputShare, + + /// The leader's aggregate share. + pub leader_aggregate_share: V::AggregateShare, + + /// The helper's aggregate share. + pub helper_aggregate_share: V::AggregateShare, } /// run_vdaf runs a VDAF state machine from sharding through to generating an output share, /// returning a "transcript" of all states & messages. -pub fn run_vdaf + vdaf::Client<16>>( +pub fn run_vdaf< + const VERIFY_KEY_LENGTH: usize, + V: vdaf::Aggregator + vdaf::Client<16>, +>( vdaf: &V, - verify_key: &[u8; SEED_SIZE], + verify_key: &[u8; VERIFY_KEY_LENGTH], aggregation_param: &V::AggregationParam, report_id: &ReportId, measurement: &V::Measurement, -) -> VdafTranscript { +) -> VdafTranscript { + let mut leader_prepare_transitions = Vec::new(); + let mut helper_prepare_transitions = Vec::new(); + // Shard inputs into input shares, and initialize the initial PrepareTransitions. let (public_share, input_shares) = vdaf.shard(measurement, report_id.as_ref()).unwrap(); - let mut prep_trans: Vec>> = input_shares - .iter() - .enumerate() - .map(|(agg_id, input_share)| { - let (prep_state, prep_share) = vdaf.prepare_init( - verify_key, - agg_id, - aggregation_param, - report_id.as_ref(), - &public_share, - input_share, - )?; - Ok(Vec::from([PrepareTransition::Continue( - prep_state, prep_share, - )])) - }) - .collect::>>, VdafError>>() + + let (leader_state, leader_message) = vdaf + .leader_initialize( + verify_key, + aggregation_param, + report_id.as_ref(), + &public_share, + &input_shares[0], + ) + .unwrap(); + + leader_prepare_transitions.push(LeaderPrepareTransition { + transition: None, + state: leader_state, + message: leader_message.clone(), + }); + + let helper_transition = vdaf + .helper_initialize( + verify_key, + aggregation_param, + report_id.as_ref(), + &public_share, + &input_shares[1], + &leader_message, + ) .unwrap(); - let mut prep_msgs = Vec::new(); + let (helper_state, helper_message) = helper_transition.evaluate(vdaf).unwrap(); - // Repeatedly step the VDAF until we reach a terminal state. + helper_prepare_transitions.push(HelperPrepareTransition { + transition: helper_transition, + state: helper_state, + message: helper_message, + }); + + // Repeatedly step the VDAF until we reach a terminal state + let mut leader_output_share = None; + let mut helper_output_share = None; loop { - // Gather messages from last round & combine them into next round's message; if any - // participants have reached a terminal state (Finish or Fail), we are done. - let mut prep_shares = Vec::new(); - let mut agg_shares = Vec::new(); - let mut output_shares = Vec::new(); - for pts in &prep_trans { - match pts.last().unwrap() { - PrepareTransition::::Continue(_, prep_share) => { - prep_shares.push(prep_share.clone()) - } - PrepareTransition::Finish(output_share) => { - output_shares.push(output_share.clone()); - agg_shares.push( - vdaf.aggregate(aggregation_param, [output_share.clone()].into_iter()) - .unwrap(), - ); + for ping_pong_role in [ping_pong::Role::Leader, ping_pong::Role::Helper] { + let (curr_state, last_peer_message) = match ping_pong_role { + ping_pong::Role::Leader => ( + leader_prepare_transitions.last().unwrap().state.clone(), + helper_prepare_transitions.last().unwrap().message.clone(), + ), + ping_pong::Role::Helper => ( + helper_prepare_transitions.last().unwrap().state.clone(), + leader_prepare_transitions.last().unwrap().message.clone(), + ), + }; + + match (&curr_state, &last_peer_message) { + (curr_state @ ping_pong::State::Continued(_), last_peer_message) => { + let state_and_message = vdaf + .continued(ping_pong_role, curr_state.clone(), last_peer_message) + .unwrap(); + + match state_and_message { + ping_pong::ContinuedValue::WithMessage { transition } => { + let (state, message) = transition.clone().evaluate(vdaf).unwrap(); + match ping_pong_role { + ping_pong::Role::Leader => { + leader_prepare_transitions.push(LeaderPrepareTransition { + transition: Some(transition), + state, + message, + }) + } + ping_pong::Role::Helper => { + helper_prepare_transitions.push(HelperPrepareTransition { + transition, + state, + message, + }) + } + } + } + ping_pong::ContinuedValue::FinishedNoMessage { output_share } => { + match ping_pong_role { + ping_pong::Role::Leader => { + leader_output_share = Some(output_share.clone()) + } + ping_pong::Role::Helper => { + helper_output_share = Some(output_share.clone()) + } + } + } + } } + (ping_pong::State::Finished(output_share), _) => match ping_pong_role { + ping_pong::Role::Leader => leader_output_share = Some(output_share.clone()), + ping_pong::Role::Helper => helper_output_share = Some(output_share.clone()), + }, } } - if !agg_shares.is_empty() { - return VdafTranscript { - public_share, - input_shares, - prepare_transitions: prep_trans, - prepare_messages: prep_msgs, - output_shares, - aggregate_shares: agg_shares, - }; - } - let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); - prep_msgs.push(prep_msg.clone()); - - // Compute each participant's next transition. - for pts in &mut prep_trans { - let prep_state = assert_matches!( - pts.last().unwrap(), - PrepareTransition::::Continue(prep_state, _) => prep_state - ) - .clone(); - pts.push(vdaf.prepare_step(prep_state, prep_msg.clone()).unwrap()); + + if leader_output_share.is_some() && helper_output_share.is_some() { + break; } } + + let leader_aggregate_share = vdaf + .aggregate(aggregation_param, [leader_output_share.clone().unwrap()]) + .unwrap(); + let helper_aggregate_share = vdaf + .aggregate(aggregation_param, [helper_output_share.clone().unwrap()]) + .unwrap(); + + VdafTranscript { + public_share, + leader_input_share: input_shares[0].clone(), + helper_input_share: input_shares[1].clone(), + leader_prepare_transitions, + helper_prepare_transitions, + leader_output_share: leader_output_share.unwrap(), + helper_output_share: helper_output_share.unwrap(), + leader_aggregate_share, + helper_aggregate_share, + } } /// Encodes the given value to YAML, then decodes it again, and checks that the diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index a358a5cd1..601de44d5 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -208,9 +208,6 @@ CREATE TABLE report_aggregations( ord BIGINT NOT NULL, -- a value used to specify the ordering of client reports in the aggregation job state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation prep_state BYTEA, -- the current preparation state (opaque VDAF message, only if in state WAITING) - prep_msg BYTEA, -- for the leader, the next preparation message to be sent to the helper (opaque VDAF message) - -- for the helper, the next preparation share to be sent to the leader (opaque VDAF message) - -- only non-NULL if in state WAITING error_code SMALLINT, -- error code corresponding to a DAP ReportShareError value; null if in a state other than FAILED last_prep_step BYTEA, -- the last PreparationStep message sent to the Leader, to assist in replay (opaque VDAF message, populated for Helper only) diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index a06e0a0b6..3c6669c9e 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -18,6 +18,7 @@ base64 = "0.21.2" futures = "0.3.28" hex = "0.4" http = "0.2" +itertools.workspace = true janus_aggregator = { workspace = true, features = ["test-util"] } janus_aggregator_core = { workspace = true, features = ["test-util"] } janus_client.workspace = true @@ -37,6 +38,5 @@ tokio.workspace = true url = { version = "2.4.0", features = ["serde"] } [dev-dependencies] -itertools.workspace = true janus_collector = { workspace = true, features = ["test-util"] } tempfile = "3" diff --git a/messages/Cargo.toml b/messages/Cargo.toml index d3c3e2397..4b2e5e1e3 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -20,7 +20,9 @@ hex = "0.4" num_enum = "0.7.0" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.14.1", features = ["multithreaded"] } +# TODO(timg): go back to a released version of prio +#prio = { version = "0.14.1", features = ["multithreaded"] } +prio = { git = "https://github.com/divviup/libprio-rs", branch = "timg/ping-pong-topology", features = ["multithreaded", "experimental"] } rand = "0.8" serde.workspace = true thiserror.workspace = true diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 816565b1f..34adf35c3 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -8,9 +8,12 @@ use anyhow::anyhow; use base64::{display::Base64Display, engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use derivative::Derivative; use num_enum::TryFromPrimitive; -use prio::codec::{ - decode_u16_items, decode_u32_items, encode_u16_items, encode_u32_items, CodecError, Decode, - Encode, +use prio::{ + codec::{ + decode_u16_items, decode_u32_items, encode_u16_items, encode_u32_items, CodecError, Decode, + Encode, + }, + topology::ping_pong, }; use rand::{distributions::Standard, prelude::Distribution, Rng}; use serde::{ @@ -2111,14 +2114,65 @@ impl Decode for ReportShare { } } -/// DAP protocol message representing the result of a preparation step in a VDAF evaluation. +/// DAP protocol message representing information required to initialize preparation of a report for +/// aggregation. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct PrepareStep { +pub struct PrepareInit { + report_share: ReportShare, + message: ping_pong::Message, +} + +impl PrepareInit { + /// Constructs a new preparation initialization message from its components. + pub fn new(report_share: ReportShare, message: ping_pong::Message) -> Self { + Self { + report_share, + message, + } + } + + /// Gets the report share associated with this prep init. + pub fn report_share(&self) -> &ReportShare { + &self.report_share + } + + /// Gets the message associated with this prep init. + pub fn message(&self) -> &ping_pong::Message { + &self.message + } +} + +impl Encode for PrepareInit { + fn encode(&self, bytes: &mut Vec) { + self.report_share.encode(bytes); + self.message.encode(bytes); + } + + fn encoded_len(&self) -> Option { + Some(self.report_share.encoded_len()? + self.message.encoded_len()?) + } +} + +impl Decode for PrepareInit { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let report_share = ReportShare::decode(bytes)?; + let message = ping_pong::Message::decode(bytes)?; + + Ok(Self { + report_share, + message, + }) + } +} + +/// DAP protocol message representing the response to a preparation step in a VDAF evaluation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PrepareResp { report_id: ReportId, result: PrepareStepResult, } -impl PrepareStep { +impl PrepareResp { /// Constructs a new prepare step from its components. pub fn new(report_id: ReportId, result: PrepareStepResult) -> Self { Self { report_id, result } @@ -2135,7 +2189,7 @@ impl PrepareStep { } } -impl Encode for PrepareStep { +impl Encode for PrepareResp { fn encode(&self, bytes: &mut Vec) { self.report_id.encode(bytes); self.result.encode(bytes); @@ -2146,7 +2200,7 @@ impl Encode for PrepareStep { } } -impl Decode for PrepareStep { +impl Decode for PrepareResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let report_id = ReportId::decode(bytes)?; let result = PrepareStepResult::decode(bytes)?; @@ -2156,13 +2210,16 @@ impl Decode for PrepareStep { } /// DAP protocol message representing result-type-specific data associated with a preparation step -/// in a VDAF evaluation. Included in a PrepareStep message. +/// in a VDAF evaluation. Included in a PrepareResp message. #[derive(Clone, Derivative, PartialEq, Eq)] #[derivative(Debug)] pub enum PrepareStepResult { - Continued(#[derivative(Debug = "ignore")] Vec), // content is a serialized preparation message + Continue { + #[derivative(Debug = "ignore")] + message: ping_pong::Message, + }, Finished, - Failed(ReportShareError), + Reject(PrepareError), } impl Encode for PrepareStepResult { @@ -2170,12 +2227,12 @@ impl Encode for PrepareStepResult { // The encoding includes an implicit discriminator byte, called PrepareStepResult in the // DAP spec. match self { - Self::Continued(vdaf_msg) => { + Self::Continue { message: prep_msg } => { 0u8.encode(bytes); - encode_u32_items(bytes, &(), vdaf_msg); + prep_msg.encode(bytes); } Self::Finished => 1u8.encode(bytes), - Self::Failed(error) => { + Self::Reject(error) => { 2u8.encode(bytes); error.encode(bytes); } @@ -2184,9 +2241,9 @@ impl Encode for PrepareStepResult { fn encoded_len(&self) -> Option { match self { - PrepareStepResult::Continued(vdaf_msg) => Some(1 + 4 + vdaf_msg.len()), - PrepareStepResult::Finished => Some(1), - PrepareStepResult::Failed(error) => Some(1 + error.encoded_len()?), + Self::Continue { message: prep_msg } => Some(1 + prep_msg.encoded_len()?), + Self::Finished => Some(1), + Self::Reject(error) => Some(1 + error.encoded_len()?), } } } @@ -2195,9 +2252,12 @@ impl Decode for PrepareStepResult { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let val = u8::decode(bytes)?; Ok(match val { - 0 => Self::Continued(decode_u32_items(&(), bytes)?), + 0 => { + let prep_msg = ping_pong::Message::decode(bytes)?; + Self::Continue { message: prep_msg } + } 1 => Self::Finished, - 2 => Self::Failed(ReportShareError::decode(bytes)?), + 2 => Self::Reject(PrepareError::decode(bytes)?), _ => return Err(CodecError::UnexpectedValue), }) } @@ -2206,7 +2266,7 @@ impl Decode for PrepareStepResult { /// DAP protocol message representing an error while preparing a report share for aggregation. #[derive(Clone, Copy, Debug, PartialEq, Eq, TryFromPrimitive)] #[repr(u8)] -pub enum ReportShareError { +pub enum PrepareError { BatchCollected = 0, ReportReplayed = 1, ReportDropped = 2, @@ -2218,7 +2278,7 @@ pub enum ReportShareError { UnrecognizedMessage = 8, } -impl Encode for ReportShareError { +impl Encode for PrepareError { fn encode(&self, bytes: &mut Vec) { (*self as u8).encode(bytes); } @@ -2228,7 +2288,7 @@ impl Encode for ReportShareError { } } -impl Decode for ReportShareError { +impl Decode for PrepareError { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let val = u8::decode(bytes)?; Self::try_from(val).map_err(|_| { @@ -2237,6 +2297,51 @@ impl Decode for ReportShareError { } } +/// DAP protocol message representing a request to continue preparation of a report share for +/// aggregation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PrepareContinue { + report_id: ReportId, + message: ping_pong::Message, +} + +impl PrepareContinue { + /// Constructs a new prepare continue from its components. + pub fn new(report_id: ReportId, message: ping_pong::Message) -> Self { + Self { report_id, message } + } + + /// Gets the report ID associated with this prepare continue. + pub fn report_id(&self) -> &ReportId { + &self.report_id + } + + /// Gets the message associated with this prepare continue. + pub fn message(&self) -> &ping_pong::Message { + &self.message + } +} + +impl Encode for PrepareContinue { + fn encode(&self, bytes: &mut Vec) { + self.report_id.encode(bytes); + self.message.encode(bytes); + } + + fn encoded_len(&self) -> Option { + Some(self.report_id.encoded_len()? + self.message.encoded_len()?) + } +} + +impl Decode for PrepareContinue { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let report_id = ReportId::decode(bytes)?; + let message = ping_pong::Message::decode(bytes)?; + + Ok(Self { report_id, message }) + } +} + /// DAP protocol message representing an identifier for an aggregation job. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct AggregationJobId([u8; Self::LEN]); @@ -2309,7 +2414,7 @@ pub struct AggregationJobInitializeReq { #[derivative(Debug = "ignore")] aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + prepare_inits: Vec, } impl AggregationJobInitializeReq { @@ -2320,12 +2425,12 @@ impl AggregationJobInitializeReq { pub fn new( aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + prepare_inits: Vec, ) -> Self { Self { aggregation_parameter, partial_batch_selector, - report_shares, + prepare_inits, } } @@ -2339,9 +2444,10 @@ impl AggregationJobInitializeReq { &self.partial_batch_selector } - /// Gets the report shares associated with this aggregate initialization request. - pub fn report_shares(&self) -> &[ReportShare] { - &self.report_shares + /// Gets the preparation initialization messages associated with this aggregate initialization + /// request. + pub fn prepare_inits(&self) -> &[PrepareInit] { + &self.prepare_inits } } @@ -2349,15 +2455,15 @@ impl Encode for AggregationJobInitializeReq { fn encode(&self, bytes: &mut Vec) { encode_u32_items(bytes, &(), &self.aggregation_parameter); self.partial_batch_selector.encode(bytes); - encode_u32_items(bytes, &(), &self.report_shares); + encode_u32_items(bytes, &(), &self.prepare_inits); } fn encoded_len(&self) -> Option { let mut length = 4 + self.aggregation_parameter.len(); length += self.partial_batch_selector.encoded_len()?; length += 4; - for report_share in self.report_shares.iter() { - length += report_share.encoded_len()?; + for prepare_init in &self.prepare_inits { + length += prepare_init.encoded_len()?; } Some(length) } @@ -2367,12 +2473,12 @@ impl Decode for AggregationJobInitializeReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let aggregation_parameter = decode_u32_items(&(), bytes)?; let partial_batch_selector = PartialBatchSelector::decode(bytes)?; - let report_shares = decode_u32_items(&(), bytes)?; + let prepare_inits = decode_u32_items(&(), bytes)?; Ok(Self { aggregation_parameter, partial_batch_selector, - report_shares, + prepare_inits, }) } } @@ -2438,7 +2544,7 @@ impl TryFrom for AggregationJobRound { #[derive(Clone, Debug, PartialEq, Eq)] pub struct AggregationJobContinueReq { round: AggregationJobRound, - prepare_steps: Vec, + prepare_continues: Vec, } impl AggregationJobContinueReq { @@ -2446,10 +2552,10 @@ impl AggregationJobContinueReq { pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-continue-req"; /// Constructs a new aggregate continuation response from its components. - pub fn new(round: AggregationJobRound, prepare_steps: Vec) -> Self { + pub fn new(round: AggregationJobRound, prepare_continues: Vec) -> Self { Self { round, - prepare_steps, + prepare_continues, } } @@ -2459,22 +2565,22 @@ impl AggregationJobContinueReq { } /// Gets the prepare steps associated with this aggregate continuation response. - pub fn prepare_steps(&self) -> &[PrepareStep] { - &self.prepare_steps + pub fn prepare_steps(&self) -> &[PrepareContinue] { + &self.prepare_continues } } impl Encode for AggregationJobContinueReq { fn encode(&self, bytes: &mut Vec) { self.round.encode(bytes); - encode_u32_items(bytes, &(), &self.prepare_steps); + encode_u32_items(bytes, &(), &self.prepare_continues); } fn encoded_len(&self) -> Option { let mut length = self.round.encoded_len()?; length += 4; - for prepare_step in self.prepare_steps.iter() { - length += prepare_step.encoded_len()?; + for prepare_continue in self.prepare_continues.iter() { + length += prepare_continue.encoded_len()?; } Some(length) } @@ -2483,8 +2589,8 @@ impl Encode for AggregationJobContinueReq { impl Decode for AggregationJobContinueReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let round = AggregationJobRound::decode(bytes)?; - let prepare_steps = decode_u32_items(&(), bytes)?; - Ok(Self::new(round, prepare_steps)) + let prepare_continues = decode_u32_items(&(), bytes)?; + Ok(Self::new(round, prepare_continues)) } } @@ -2492,7 +2598,7 @@ impl Decode for AggregationJobContinueReq { /// continuation request. #[derive(Clone, Debug, PartialEq, Eq)] pub struct AggregationJobResp { - prepare_steps: Vec, + prepare_resps: Vec, } impl AggregationJobResp { @@ -2500,25 +2606,25 @@ impl AggregationJobResp { pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-resp"; /// Constructs a new aggregate continuation response from its components. - pub fn new(prepare_steps: Vec) -> Self { - Self { prepare_steps } + pub fn new(prepare_resps: Vec) -> Self { + Self { prepare_resps } } - /// Gets the prepare steps associated with this aggregate continuation response. - pub fn prepare_steps(&self) -> &[PrepareStep] { - &self.prepare_steps + /// Gets the prepare responses associated with this aggregate continuation response. + pub fn prepare_resps(&self) -> &[PrepareResp] { + &self.prepare_resps } } impl Encode for AggregationJobResp { fn encode(&self, bytes: &mut Vec) { - encode_u32_items(bytes, &(), &self.prepare_steps); + encode_u32_items(bytes, &(), &self.prepare_resps); } fn encoded_len(&self) -> Option { let mut length = 4; - for prepare_step in self.prepare_steps.iter() { - length += prepare_step.encoded_len()?; + for prepare_resp in self.prepare_resps.iter() { + length += prepare_resp.encoded_len()?; } Some(length) } @@ -2526,8 +2632,8 @@ impl Encode for AggregationJobResp { impl Decode for AggregationJobResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { - let prepare_steps = decode_u32_items(&(), bytes)?; - Ok(Self { prepare_steps }) + let prepare_resps = decode_u32_items(&(), bytes)?; + Ok(Self { prepare_resps }) } } @@ -2766,12 +2872,15 @@ mod tests { AggregationJobRound, BatchId, BatchSelector, Collection, CollectionReq, Duration, Extension, ExtensionType, FixedSize, FixedSizeQuery, HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, InputShareAad, Interval, - PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Query, Report, - ReportId, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, - Time, TimeInterval, Url, + PartialBatchSelector, PlaintextInputShare, PrepareContinue, PrepareError, PrepareInit, + PrepareResp, PrepareStepResult, Query, Report, ReportId, ReportIdChecksum, ReportMetadata, + ReportShare, Role, TaskId, Time, TimeInterval, Url, }; use assert_matches::assert_matches; - use prio::codec::{CodecError, Decode, Encode}; + use prio::{ + codec::{CodecError, Decode, Encode}, + topology::ping_pong, + }; use serde_test::{assert_de_tokens_error, assert_tokens, Token}; #[test] @@ -3877,72 +3986,97 @@ mod tests { } #[test] - fn roundtrip_prepare_step() { + fn roundtrip_report_share() { roundtrip_encoding(&[ ( - PrepareStep { - report_id: ReportId::from([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), }, concat!( - "0102030405060708090A0B0C0D0E0F10", // report_id - "00", // prepare_step_result concat!( - // vdaf_msg - "00000006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), ), ), ( - PrepareStep { - report_id: ReportId::from([ - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, - ]), - result: PrepareStepResult::Finished, - }, - concat!( - "100F0E0D0C0B0A090807060504030201", // report_id - "01", // prepare_step_result - ), - ), - ( - PrepareStep { - report_id: ReportId::from([255; 16]), - result: PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), }, concat!( - "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", // report_id - "02", // prepare_step_result - "05", // report_share_error + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), ), ), ]) } #[test] - fn roundtrip_report_share_error() { + fn roundtrip_prepare_init() { roundtrip_encoding(&[ - (ReportShareError::BatchCollected, "00"), - (ReportShareError::ReportReplayed, "01"), - (ReportShareError::ReportDropped, "02"), - (ReportShareError::HpkeUnknownConfigId, "03"), - (ReportShareError::HpkeDecryptError, "04"), - (ReportShareError::VdafPrepError, "05"), - ]) - } - - #[test] - fn roundtrip_aggregation_job_initialize_req() { - // TimeInterval. - roundtrip_encoding(&[( - AggregationJobInitializeReq { - aggregation_parameter: Vec::from("012345"), - partial_batch_selector: PartialBatchSelector::new_time_interval(), - report_shares: Vec::from([ - ReportShare { + ( + PrepareInit { + report_share: ReportShare { metadata: ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), @@ -3954,34 +4088,13 @@ mod tests { Vec::from("543210"), ), }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + message: ping_pong::Message::Initialize { + prep_share: Vec::from("012345"), }, - ]), - }, - concat!( - concat!( - // aggregation_parameter - "00000006", // length - "303132333435", // opaque data - ), - concat!( - // partial_batch_selector - "01", // query_type - ), + }, concat!( - // report_shares - "0000005E", // length concat!( + // report_share concat!( // metadata "0102030405060708090A0B0C0D0E0F10", // report_id @@ -4008,13 +4121,44 @@ mod tests { ), ), concat!( + // message + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) + ) + ), + ), + ( + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + message: ping_pong::Message::Finish { + prep_msg: Vec::new(), + }, + }, + concat!( + concat!( + // report_share concat!( // metadata "100F0E0D0C0B0A090807060504030201", // report_id "0000000000011F46", // time ), concat!( - "00000004", // payload + // public_share + "00000004", // length "30313233", // opaque data ), concat!( @@ -4032,6 +4176,223 @@ mod tests { ), ), ), + concat!( + // message + "02", // Message type + concat!( + "00000000", // length + "" // opaque data + ) + ) + ), + ), + ]) + } + + #[test] + fn roundtrip_prepare_resp() { + roundtrip_encoding(&[ + ( + PrepareResp { + report_id: ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + result: PrepareStepResult::Continue { + message: ping_pong::Message::Continue { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("6789"), + }, + }, + }, + concat!( + "0102030405060708090A0B0C0D0E0F10", // report_id + "00", // prepare_step_result + concat!( + // message + "01", // message type + concat!( + "00000006", // length + "303132333435", // opaque data + ), + concat!( + "00000004", // length + "36373839", // opaque data + ) + ), + ), + ), + ( + PrepareResp { + report_id: ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + result: PrepareStepResult::Finished, + }, + concat!( + "100F0E0D0C0B0A090807060504030201", // report_id + "01", // prepare_step_result + ), + ), + ( + PrepareResp { + report_id: ReportId::from([255; 16]), + result: PrepareStepResult::Reject(PrepareError::VdafPrepError), + }, + concat!( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", // report_id + "02", // prepare_step_result + "05", // report_share_error + ), + ), + ]) + } + + #[test] + fn roundtrip_report_share_error() { + roundtrip_encoding(&[ + (PrepareError::BatchCollected, "00"), + (PrepareError::ReportReplayed, "01"), + (PrepareError::ReportDropped, "02"), + (PrepareError::HpkeUnknownConfigId, "03"), + (PrepareError::HpkeDecryptError, "04"), + (PrepareError::VdafPrepError, "05"), + ]) + } + + #[test] + fn roundtrip_aggregation_job_initialize_req() { + // TimeInterval. + roundtrip_encoding(&[( + AggregationJobInitializeReq { + aggregation_parameter: Vec::from("012345"), + partial_batch_selector: PartialBatchSelector::new_time_interval(), + prepare_inits: Vec::from([ + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + message: ping_pong::Message::Initialize { + prep_share: Vec::from("012345"), + }, + }, + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + message: ping_pong::Message::Finish { + prep_msg: Vec::new(), + }, + }, + ]), + }, + concat!( + concat!( + // aggregation_parameter + "00000006", // length + "303132333435", // opaque data + ), + concat!( + // partial_batch_selector + "01", // query_type + ), + concat!( + // prepare_inits + "0000006E", // length + concat!( + concat!( + // report_share + concat!( + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + ), + concat!( + // message + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ), + ) + ), + concat!( + concat!( + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + concat!( + // message + "02", // Message type + concat!( + "00000000", // length + "" // opaque data + ) + ) + ), ), ), )]); @@ -4043,30 +4404,44 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_fixed_size(BatchId::from( [2u8; 32], )), - report_shares: Vec::from([ - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - public_share: Vec::new(), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + prepare_inits: Vec::from([ + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + message: ping_pong::Message::Initialize { + prep_share: Vec::from("012345"), + }, }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + PrepareInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + message: ping_pong::Message::Finish { + prep_msg: Vec::new(), + }, }, ]), }, @@ -4082,58 +4457,79 @@ mod tests { "0202020202020202020202020202020202020202020202020202020202020202", // batch_id ), concat!( - // report_shares - "0000005E", // length + // prepare_inits + "0000006E", // length concat!( concat!( - // metadata - "0102030405060708090A0B0C0D0E0F10", // report_id - "000000000000D431", // time - ), - concat!( - // public_share - "00000000", // length - "", // opaque data - ), - concat!( - // encrypted_input_share - "2A", // config_id + // report_share concat!( - // encapsulated_context - "0006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time ), concat!( - // payload - "00000006", // length - "353433323130", // opaque data + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), - ), - ), - concat!( - concat!( - // metadata - "100F0E0D0C0B0A090807060504030201", // report_id - "0000000000011F46", // time ), concat!( - // public_share - "00000004", // length - "30313233", // opaque data + // payload + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) ), + ), + concat!( concat!( - // encrypted_input_share - "0D", // config_id concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time ), concat!( - // payload + // public_share "00000004", // length - "61626664", // opaque data + "30313233", // opaque data ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + concat!( + // payload + "02", // Message type + concat!( + "00000000", // length + "", // opaque data + ) ), ), ), @@ -4146,18 +4542,22 @@ mod tests { roundtrip_encoding(&[( AggregationJobContinueReq { round: AggregationJobRound(42405), - prepare_steps: Vec::from([ - PrepareStep { + prepare_continues: Vec::from([ + PrepareContinue { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + message: ping_pong::Message::Initialize { + prep_share: Vec::from("012345"), + }, }, - PrepareStep { + PrepareContinue { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + message: ping_pong::Message::Initialize { + prep_share: Vec::from("012345"), + }, }, ]), }, @@ -4165,19 +4565,28 @@ mod tests { "A5A5", // round concat!( // prepare_steps - "0000002C", // length + "00000036", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id - "00", // prepare_step_result concat!( // payload - "00000006", // length - "303132333435", // opaque data + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) ), ), concat!( "100F0E0D0C0B0A090807060504030201", // report_id - "01", // prepare_step_result + concat!( + // payload + "00", // Message type + concat!( + "00000006", // length + "303132333435", // opaque data + ) + ), ) ), ), @@ -4188,14 +4597,19 @@ mod tests { fn roundtrip_aggregation_job_resp() { roundtrip_encoding(&[( AggregationJobResp { - prepare_steps: Vec::from([ - PrepareStep { + prepare_resps: Vec::from([ + PrepareResp { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continue { + message: ping_pong::Message::Continue { + prep_msg: Vec::from("01234"), + prep_share: Vec::from("56789"), + }, + }, }, - PrepareStep { + PrepareResp { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), @@ -4205,14 +4619,22 @@ mod tests { }, concat!(concat!( // prepare_steps - "0000002C", // length + "00000035", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // payload - "00000006", // length - "303132333435", // opaque data + "01", // Message type + concat!( + // prep_msg + "00000005", // length + "3031323334", // opaque data + ), + concat!( + // prep_share + "00000005", // length + "3536373839", // opaque data + ) ), ), concat!(