diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 0c464da0f..cd98364f8 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -1776,13 +1776,7 @@ impl VdafOps { &input_share, prepare_init.message(), ) - .and_then(|transition| { - transition - .evaluate(vdaf) - .map(|(ping_pong_state, outgoing_message)| { - (transition, ping_pong_state, outgoing_message) - }) - }) + .and_then(|transition| transition.evaluate(vdaf)) .map_err(|error| { handle_ping_pong_error( task.id(), @@ -1796,18 +1790,18 @@ impl VdafOps { }); let (report_aggregation_state, prepare_step_result) = match init_rslt { - Ok((transition, PingPongState::Continued(_), outgoing_message)) => { + Ok((PingPongState::Continued(prep_state), outgoing_message)) => { // Helper is not finished. Await the next message from the Leader to advance to // the next round. saw_continue = true; ( - ReportAggregationState::Waiting(transition), + ReportAggregationState::WaitingHelper(prep_state), PrepareStepResult::Continue { message: outgoing_message, }, ) } - Ok((_, PingPongState::Finished(output_share), outgoing_message)) => { + Ok((PingPongState::Finished(output_share), outgoing_message)) => { // Helper finished. Unlike the Leader, the Helper does not wait for confirmation // that the Leader finished before accumulating its output share. accumulator.update( diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 4beaf8cc9..9bc6e7c9b 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -1,6 +1,9 @@ use crate::aggregator::{ - http_handlers::aggregator_handler, tests::generate_helper_report_share, Config, + http_handlers::{aggregator_handler, test_util::decode_response_body}, + tests::generate_helper_report_share, + Config, }; +use assert_matches::assert_matches; use janus_aggregator_core::{ datastore::{ test_util::{ephemeral_datastore, EphemeralDatastore}, @@ -15,8 +18,8 @@ use janus_core::{ time::{Clock, MockClock, TimeExt as _}, }; use janus_messages::{ - query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, PartialBatchSelector, - PrepareInit, ReportMetadata, Role, + query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, + PartialBatchSelector, PrepareInit, PrepareStepResult, ReportMetadata, Role, }; use prio::{ codec::Encode, @@ -24,6 +27,7 @@ use prio::{ vdaf::{ self, poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, }, }; use rand::random; @@ -115,6 +119,7 @@ pub(super) struct AggregationJobInitTestCase< pub(super) prepare_inits: Vec, pub(super) aggregation_job_id: AggregationJobId, aggregation_job_init_req: AggregationJobInitializeReq, + aggregation_job_init_resp: Option, pub(super) aggregation_param: V::AggregationParam, pub(super) handler: Box, pub(super) datastore: Arc>, @@ -131,6 +136,20 @@ pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase<0, .await } +async fn setup_poplar1_aggregate_init_test() -> AggregationJobInitTestCase<16, Poplar1> +{ + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + setup_aggregate_init_test_for_vdaf( + Poplar1::new_sha3(1), + VdafInstance::Poplar1 { bits: 1 }, + aggregation_param, + IdpfInput::from_bools(&[true]), + ) + .await +} + async fn setup_aggregate_init_test_for_vdaf< const VERIFY_KEY_SIZE: usize, V: vdaf::Aggregator + vdaf::Client<16>, @@ -140,7 +159,7 @@ async fn setup_aggregate_init_test_for_vdaf< aggregation_param: V::AggregationParam, measurement: V::Measurement, ) -> AggregationJobInitTestCase { - let test_case = setup_aggregate_init_test_without_sending_request( + let mut test_case = setup_aggregate_init_test_without_sending_request( vdaf, vdaf_instance, aggregation_param, @@ -148,7 +167,7 @@ async fn setup_aggregate_init_test_for_vdaf< ) .await; - let response = put_aggregation_job( + let mut response = put_aggregation_job( &test_case.task, &test_case.aggregation_job_id, &test_case.aggregation_job_init_req, @@ -157,6 +176,17 @@ async fn setup_aggregate_init_test_for_vdaf< .await; assert_eq!(response.status(), Some(Status::Ok)); + let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await; + assert_eq!( + aggregation_job_init_resp.prepare_resps().len(), + test_case.aggregation_job_init_req.prepare_inits().len(), + ); + assert_matches!( + aggregation_job_init_resp.prepare_resps()[0].result(), + &PrepareStepResult::Continue { .. } + ); + + test_case.aggregation_job_init_resp = Some(aggregation_job_init_resp); test_case } @@ -209,6 +239,7 @@ async fn setup_aggregate_init_test_without_sending_request< prepare_init_generator, aggregation_job_id, aggregation_job_init_req, + aggregation_job_init_resp: None, aggregation_param, handler: Box::new(handler), datastore, @@ -376,17 +407,7 @@ async fn aggregation_job_mutation_report_shares() { #[tokio::test] async fn aggregation_job_mutation_report_aggregations() { // We must run Poplar1 in this test so that the aggregation job won't finish on the first step - - let aggregation_param = - Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) - .unwrap(); - let test_case = setup_aggregate_init_test_for_vdaf( - Poplar1::new_sha3(1), - VdafInstance::Poplar1 { bits: 1 }, - aggregation_param, - IdpfInput::from_bools(&[true]), - ) - .await; + let test_case = setup_poplar1_aggregate_init_test().await; // Generate some new reports using the existing reports' metadata, but varying the measurement // values such that the prepare state computed during aggregation initializaton won't match the @@ -420,3 +441,24 @@ async fn aggregation_job_mutation_report_aggregations() { .await; assert_eq!(response.status(), Some(Status::Conflict)); } + +#[tokio::test] +async fn aggregation_job_init_two_round_vdaf_idempotence() { + // We must run Poplar1 in this test so that the aggregation job won't finish on the first step + let test_case = setup_poplar1_aggregate_init_test().await; + + // Send the aggregation job init request again. We should get an identical response back. + let mut response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &test_case.aggregation_job_init_req, + &test_case.handler, + ) + .await; + + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + assert_eq!( + aggregation_job_resp, + test_case.aggregation_job_init_resp.unwrap() + ); +} diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 897039e34..c0fd4ebd8 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -72,7 +72,7 @@ impl VdafOps { if report_agg.report_id() != prep_step.report_id() { // This report was omitted by the leader because of a prior failure. Note that // the report was dropped (if it's not already in an error state) and continue. - if matches!(report_agg.state(), ReportAggregationState::Waiting(_)) { + if matches!(report_agg.state(), ReportAggregationState::WaitingHelper(_)) { *report_agg = report_agg .clone() .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)) @@ -103,8 +103,8 @@ impl VdafOps { continue; } - let transition = match report_aggregation.state() { - ReportAggregationState::Waiting(transition) => transition, + let prep_state = match report_aggregation.state() { + ReportAggregationState::WaitingHelper(prep_state) => prep_state, _ => { return Err(datastore::Error::User( Error::UnrecognizedMessage( @@ -119,25 +119,23 @@ impl VdafOps { let (report_aggregation_state, prepare_step_result, output_share) = trace_span!("VDAF preparation") .in_scope(|| { - // Evaluate the stored transition to recover our current state. - transition - .evaluate(vdaf.as_ref()) - .and_then(|(state, _)| { - // Then continue with the incoming message. - vdaf.helper_continued(state, prep_step.message()) - }) - .and_then(|continued_value| match continued_value { - PingPongContinuedValue::WithMessage { - transition: new_transition, - } => { + // Continue with the incoming message. + vdaf.helper_continued( + PingPongState::Continued(prep_state.clone()), + prep_step.message(), + ) + .and_then( + |continued_value| match continued_value { + PingPongContinuedValue::WithMessage { transition } => { let (new_state, message) = - new_transition.evaluate(vdaf.as_ref())?; + transition.evaluate(vdaf.as_ref())?; let (report_aggregation_state, output_share) = match new_state { - // Helper did not finish. Store the new transition and await the next message - // from the Leader to advance preparation. - PingPongState::Continued(_) => { - (ReportAggregationState::Waiting(new_transition), None) - } + // Helper did not finish. Store the new state and await the + // next message from the Leader to advance preparation. + PingPongState::Continued(prep_state) => ( + ReportAggregationState::WaitingHelper(prep_state), + None, + ), // Helper finished. Commit the output share. PingPongState::Finished(output_share) => { (ReportAggregationState::Finished, Some(output_share)) @@ -156,7 +154,8 @@ impl VdafOps { PrepareStepResult::Finished, Some(output_share), )), - }) + }, + ) }) .map_err(|error| { handle_ping_pong_error( @@ -195,7 +194,7 @@ impl VdafOps { for report_agg in report_aggregations_iter { // This report was omitted by the leader because of a prior failure. Note that the // report was dropped (if it's not already in an error state) and continue. - if matches!(report_agg.state(), ReportAggregationState::Waiting(_)) { + if matches!(report_agg.state(), ReportAggregationState::WaitingHelper(_)) { *report_agg = report_agg .clone() .with_state(ReportAggregationState::Failed(PrepareError::ReportDropped)) @@ -494,8 +493,10 @@ mod tests { *prepare_init.report_share().metadata().time(), 0, None, - ReportAggregationState::Waiting( - transcript.helper_prepare_transitions[0].transition.clone(), + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), ), ), ) diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 9abb4d1d3..2fe590997 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -236,9 +236,9 @@ impl AggregationJobDriver { for report_aggregation in &report_aggregations { match report_aggregation.state() { ReportAggregationState::Start => saw_start = true, - ReportAggregationState::Waiting(_) => saw_waiting = true, + ReportAggregationState::WaitingLeader(_) => saw_waiting = true, ReportAggregationState::Finished => saw_finished = true, - ReportAggregationState::Failed(_) => (), // ignore failed aggregations + _ => (), // ignore failed aggregations } } match (saw_start, saw_waiting, saw_finished) { @@ -458,7 +458,7 @@ impl AggregationJobDriver { let mut prepare_continues = Vec::new(); let mut stepped_aggregations = Vec::new(); for report_aggregation in report_aggregations { - if let ReportAggregationState::Waiting(transition) = report_aggregation.state() { + if let ReportAggregationState::WaitingLeader(transition) = report_aggregation.state() { let (prep_state, message) = match transition.evaluate(vdaf.as_ref()) { Ok((state, message)) => (state, message), Err(error) => { @@ -589,7 +589,7 @@ impl AggregationJobDriver { // VDAF level (i.e., state may be PingPongState::Finished) but we cannot // finish at the DAP layer and commit the output share until we get // confirmation from the Helper that they finished, too. - ReportAggregationState::Waiting(transition) + ReportAggregationState::WaitingLeader(transition) } Ok(PingPongContinuedValue::FinishedNoMessage { output_share }) => { // We finished and have no outgoing message, meaning the Helper was @@ -1804,7 +1804,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Waiting( + ReportAggregationState::WaitingLeader( transcript.leader_prepare_transitions[1] .transition .clone() @@ -2313,7 +2313,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Waiting( + ReportAggregationState::WaitingLeader( transcript.leader_prepare_transitions[1] .transition .clone() @@ -2480,7 +2480,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Waiting( + ReportAggregationState::WaitingLeader( transcript.leader_prepare_transitions[1] .transition .clone() @@ -2879,7 +2879,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Waiting( + ReportAggregationState::WaitingLeader( transcript.leader_prepare_transitions[1] .transition .clone() diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index a51c3e085..9c59ca34e 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -2424,7 +2424,7 @@ mod tests { report_metadata_0.id(), &measurement, ); - let ping_pong_transition_0 = &transcript_0.helper_prepare_transitions[0].transition; + let helper_prep_state_0 = transcript_0.helper_prepare_transitions[0].prepare_state(); let leader_prep_message_0 = &transcript_0.leader_prepare_transitions[1].message; let report_share_0 = generate_helper_report_share::>( *task.id(), @@ -2451,7 +2451,7 @@ mod tests { &measurement, ); - let ping_pong_transition_1 = &transcript_1.helper_prepare_transitions[0].transition; + let helper_prep_state_1 = transcript_1.helper_prepare_transitions[0].prepare_state(); let report_share_1 = generate_helper_report_share::>( *task.id(), report_metadata_1.clone(), @@ -2479,7 +2479,7 @@ mod tests { report_metadata_2.id(), &measurement, ); - let ping_pong_transition_2 = &transcript_2.helper_prepare_transitions[0].transition; + let helper_prep_state_2 = transcript_2.helper_prepare_transitions[0].prepare_state(); let leader_prep_message_2 = &transcript_2.leader_prepare_transitions[1].message; let report_share_2 = generate_helper_report_share::>( *task.id(), @@ -2498,10 +2498,10 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (ping_pong_transition_0, ping_pong_transition_1, ping_pong_transition_2) = ( - ping_pong_transition_0.clone(), - ping_pong_transition_1.clone(), - ping_pong_transition_2.clone(), + let (helper_prep_state_0, helper_prep_state_1, helper_prep_state_2) = ( + helper_prep_state_0.clone(), + helper_prep_state_1.clone(), + helper_prep_state_2.clone(), ); let (report_metadata_0, report_metadata_1, report_metadata_2) = ( report_metadata_0.clone(), @@ -2542,7 +2542,7 @@ mod tests { *report_metadata_0.time(), 0, None, - ReportAggregationState::Waiting(ping_pong_transition_0), + ReportAggregationState::WaitingHelper(helper_prep_state_0), ), ) .await?; @@ -2554,7 +2554,7 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Waiting(ping_pong_transition_1), + ReportAggregationState::WaitingHelper(helper_prep_state_1), ), ) .await?; @@ -2566,7 +2566,7 @@ mod tests { *report_metadata_2.time(), 2, None, - ReportAggregationState::Waiting(ping_pong_transition_2), + ReportAggregationState::WaitingHelper(helper_prep_state_2), ), ) .await?; @@ -2739,7 +2739,7 @@ mod tests { report_metadata_0.id(), &measurement, ); - let ping_pong_transition_0 = &transcript_0.helper_prepare_transitions[0].transition; + let helper_prep_state_0 = transcript_0.helper_prepare_transitions[0].prepare_state(); let ping_pong_leader_message_0 = &transcript_0.leader_prepare_transitions[1].message; let report_share_0 = generate_helper_report_share::>( *task.id(), @@ -2766,7 +2766,7 @@ mod tests { report_metadata_1.id(), &measurement, ); - let ping_pong_transition_1 = &transcript_1.helper_prepare_transitions[0].transition; + let helper_prep_state_1 = transcript_1.helper_prepare_transitions[0].prepare_state(); let ping_pong_leader_message_1 = &transcript_1.leader_prepare_transitions[1].message; let report_share_1 = generate_helper_report_share::>( *task.id(), @@ -2793,7 +2793,7 @@ mod tests { report_metadata_2.id(), &measurement, ); - let ping_pong_transition_2 = &transcript_2.helper_prepare_transitions[0].transition; + let helper_prep_state_2 = transcript_2.helper_prepare_transitions[0].prepare_state(); let ping_pong_leader_message_2 = &transcript_2.leader_prepare_transitions[1].message; let report_share_2 = generate_helper_report_share::>( *task.id(), @@ -2837,10 +2837,10 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (ping_pong_transition_0, ping_pong_transition_1, ping_pong_transition_2) = ( - ping_pong_transition_0.clone(), - ping_pong_transition_1.clone(), - ping_pong_transition_2.clone(), + let (helper_prep_state_0, helper_prep_state_1, helper_prep_state_2) = ( + helper_prep_state_0.clone(), + helper_prep_state_1.clone(), + helper_prep_state_2.clone(), ); let (report_metadata_0, report_metadata_1, report_metadata_2) = ( report_metadata_0.clone(), @@ -2884,7 +2884,7 @@ mod tests { *report_metadata_0.time(), 0, None, - ReportAggregationState::Waiting(ping_pong_transition_0), + ReportAggregationState::WaitingHelper(helper_prep_state_0), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -2897,7 +2897,7 @@ mod tests { *report_metadata_1.time(), 1, None, - ReportAggregationState::Waiting(ping_pong_transition_1), + ReportAggregationState::WaitingHelper(helper_prep_state_1), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -2910,7 +2910,7 @@ mod tests { *report_metadata_2.time(), 2, None, - ReportAggregationState::Waiting(ping_pong_transition_2), + ReportAggregationState::WaitingHelper(helper_prep_state_2), )) .await?; @@ -3095,7 +3095,7 @@ mod tests { report_metadata_3.id(), &measurement, ); - let ping_pong_transition_3 = &transcript_3.helper_prepare_transitions[0].transition; + let helper_prep_state_3 = transcript_3.helper_prepare_transitions[0].prepare_state(); let ping_pong_leader_message_3 = &transcript_3.leader_prepare_transitions[1].message; let report_share_3 = generate_helper_report_share::>( *task.id(), @@ -3122,7 +3122,7 @@ mod tests { report_metadata_4.id(), &measurement, ); - let ping_pong_transition_4 = &transcript_4.helper_prepare_transitions[0].transition; + let helper_prep_state_4 = transcript_4.helper_prepare_transitions[0].prepare_state(); let ping_pong_leader_message_4 = &transcript_4.leader_prepare_transitions[1].message; let report_share_4 = generate_helper_report_share::>( *task.id(), @@ -3149,7 +3149,7 @@ mod tests { report_metadata_5.id(), &measurement, ); - let ping_pong_transition_5 = &transcript_5.helper_prepare_transitions[0].transition; + let helper_prep_state_5 = transcript_5.helper_prepare_transitions[0].prepare_state(); let ping_pong_leader_message_5 = &transcript_5.leader_prepare_transitions[1].message; let report_share_5 = generate_helper_report_share::>( *task.id(), @@ -3168,10 +3168,10 @@ mod tests { report_share_4.clone(), report_share_5.clone(), ); - let (ping_pong_transition_3, ping_pong_transition_4, ping_pong_transition_5) = ( - ping_pong_transition_3.clone(), - ping_pong_transition_4.clone(), - ping_pong_transition_5.clone(), + let (helper_prep_state_3, helper_prep_state_4, helper_prep_state_5) = ( + helper_prep_state_3.clone(), + helper_prep_state_4.clone(), + helper_prep_state_5.clone(), ); let (report_metadata_3, report_metadata_4, report_metadata_5) = ( report_metadata_3.clone(), @@ -3211,7 +3211,7 @@ mod tests { *report_metadata_3.time(), 3, None, - ReportAggregationState::Waiting(ping_pong_transition_3), + ReportAggregationState::WaitingHelper(helper_prep_state_3), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -3224,7 +3224,7 @@ mod tests { *report_metadata_4.time(), 4, None, - ReportAggregationState::Waiting(ping_pong_transition_4), + ReportAggregationState::WaitingHelper(helper_prep_state_4), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -3237,7 +3237,7 @@ mod tests { *report_metadata_5.time(), 5, None, - ReportAggregationState::Waiting(ping_pong_transition_5), + ReportAggregationState::WaitingHelper(helper_prep_state_5), )) .await?; @@ -3392,15 +3392,23 @@ mod tests { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let report_id = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( - &dummy_vdaf::Vdaf::new(), + &Poplar1::new_sha3(1), task.primary_vdaf_verify_key().unwrap().as_bytes(), - &dummy_vdaf::AggregationParam(0), + &aggregation_param, &report_id, - &(), + &IdpfInput::from_bools(&[false]), ); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( @@ -3411,8 +3419,12 @@ mod tests { // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata, transcript) = - (task.clone(), report_metadata.clone(), transcript.clone()); + let (task, aggregation_param, report_metadata, transcript) = ( + task.clone(), + aggregation_param.clone(), + report_metadata.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; tx.put_report_share( @@ -3430,10 +3442,10 @@ mod tests { .await?; tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + &AggregationJob::<16, TimeInterval, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), + aggregation_param, (), Interval::new( Time::from_seconds_since_epoch(0), @@ -3445,15 +3457,17 @@ mod tests { ), ) .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation(&ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), 0, None, - ReportAggregationState::Waiting( - transcript.helper_prepare_transitions[0].transition.clone(), + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), ), )) .await @@ -3493,47 +3507,53 @@ mod tests { // Prepare parameters. let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::FakeFailsPrepStep, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); + let vdaf = Poplar1::new_sha3(1); let report_id = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( - &dummy_vdaf::Vdaf::new(), + &vdaf, task.primary_vdaf_verify_key().unwrap().as_bytes(), - &dummy_vdaf::AggregationParam(0), + &aggregation_param, &report_id, - &(), + &IdpfInput::from_bools(&[false]), ); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new(report_id, Time::from_seconds_since_epoch(54321)); + let helper_report_share = generate_helper_report_share::>( + task.id().clone(), + report_metadata.clone(), + task.current_hpke_key().config(), + &transcript.public_share, + Vec::new(), + &transcript.helper_input_share, + ); // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata, transcript) = - (task.clone(), report_metadata.clone(), transcript.clone()); + let (task, aggregation_param, report_metadata, transcript, helper_report_share) = ( + task.clone(), + aggregation_param.clone(), + report_metadata.clone(), + transcript.clone(), + helper_report_share.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; - tx.put_report_share( - task.id(), - &ReportShare::new( - report_metadata.clone(), - Vec::from("public share"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), - ) - .await?; + tx.put_report_share(task.id(), &helper_report_share).await?; tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + &AggregationJob::<16, TimeInterval, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), + aggregation_param, (), Interval::new( Time::from_seconds_since_epoch(0), @@ -3545,15 +3565,17 @@ mod tests { ), ) .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation(&ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), 0, None, - ReportAggregationState::Waiting( - transcript.helper_prepare_transitions[0].transition.clone(), + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), ), )) .await @@ -3587,10 +3609,11 @@ mod tests { // Check datastore state. let (aggregation_job, report_aggregation) = datastore .run_tx(|tx| { - let (task, report_metadata) = (task.clone(), report_metadata.clone()); + let (vdaf, task, report_metadata) = + (vdaf.clone(), task.clone(), report_metadata.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::<16, TimeInterval, Poplar1>( task.id(), &aggregation_job_id, ) @@ -3599,7 +3622,7 @@ mod tests { .unwrap(); let report_aggregation = tx .get_report_aggregation( - &dummy_vdaf::Vdaf::default(), + &vdaf, &Role::Helper, task.id(), &aggregation_job_id, @@ -3620,7 +3643,7 @@ mod tests { AggregationJob::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -3651,15 +3674,23 @@ mod tests { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let report_id = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( - &dummy_vdaf::Vdaf::new(), + &Poplar1::new_sha3(1), task.primary_vdaf_verify_key().unwrap().as_bytes(), - &dummy_vdaf::AggregationParam(0), + &aggregation_param, &report_id, - &(), + &IdpfInput::from_bools(&[false]), ); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new(report_id, Time::from_seconds_since_epoch(54321)); @@ -3667,8 +3698,12 @@ mod tests { // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata, transcript) = - (task.clone(), report_metadata.clone(), transcript.clone()); + let (task, aggregation_param, report_metadata, transcript) = ( + task.clone(), + aggregation_param.clone(), + report_metadata.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await?; @@ -3686,10 +3721,10 @@ mod tests { ) .await?; tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + &AggregationJob::<16, TimeInterval, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), + aggregation_param, (), Interval::new( Time::from_seconds_since_epoch(0), @@ -3701,15 +3736,17 @@ mod tests { ), ) .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation(&ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), 0, None, - ReportAggregationState::Waiting( - transcript.helper_prepare_transitions[0].transition.clone(), + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), ), )) .await @@ -3749,25 +3786,33 @@ mod tests { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; // Prepare parameters. - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let report_id_0 = random(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript_0 = run_vdaf( - &dummy_vdaf::Vdaf::new(), + &Poplar1::new_sha3(1), task.primary_vdaf_verify_key().unwrap().as_bytes(), - &dummy_vdaf::AggregationParam(0), + &aggregation_param, &report_id_0, - &(), + &IdpfInput::from_bools(&[false]), ); let report_metadata_0 = ReportMetadata::new(report_id_0, Time::from_seconds_since_epoch(54321)); let report_id_1 = random(); let transcript_1 = run_vdaf( - &dummy_vdaf::Vdaf::new(), + &Poplar1::new_sha3(1), task.primary_vdaf_verify_key().unwrap().as_bytes(), - &dummy_vdaf::AggregationParam(0), + &aggregation_param, &report_id_1, - &(), + &IdpfInput::from_bools(&[false]), ); let report_metadata_1 = ReportMetadata::new(report_id_1, Time::from_seconds_since_epoch(54321)); @@ -3776,8 +3821,16 @@ mod tests { // Setup datastore. datastore .run_tx(|tx| { - let (task, report_metadata_0, report_metadata_1, transcript_0, transcript_1) = ( + let ( + task, + aggregation_param, + report_metadata_0, + report_metadata_1, + transcript_0, + transcript_1, + ) = ( task.clone(), + aggregation_param.clone(), report_metadata_0.clone(), report_metadata_1.clone(), transcript_0.clone(), @@ -3815,10 +3868,10 @@ mod tests { .await?; tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + &AggregationJob::<16, TimeInterval, Poplar1>::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam(0), + aggregation_param.clone(), (), Interval::new( Time::from_seconds_since_epoch(0), @@ -3831,30 +3884,30 @@ mod tests { ) .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation(&ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, *report_metadata_0.id(), *report_metadata_0.time(), 0, None, - ReportAggregationState::Waiting( + ReportAggregationState::WaitingHelper( transcript_0.helper_prepare_transitions[0] - .transition + .prepare_state() .clone(), ), )) .await?; - tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( + tx.put_report_aggregation(&ReportAggregation::<16, Poplar1>::new( *task.id(), aggregation_job_id, *report_metadata_1.id(), *report_metadata_1.time(), 1, None, - ReportAggregationState::Waiting( + ReportAggregationState::WaitingHelper( transcript_1.helper_prepare_transitions[0] - .transition + .prepare_state() .clone(), ), )) diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 2a9298a88..ff3ac2ac6 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -767,8 +767,10 @@ async fn taskprov_aggregate_continue() { *report_metadata.time(), 0, None, - ReportAggregationState::Waiting( - transcript.helper_prepare_transitions[0].transition.clone(), + ReportAggregationState::WaitingHelper( + transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), ), )) .await?; diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 8a166a811..8162a1450 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -13,7 +13,6 @@ use crate::{ taskprov::{self, PeerAggregator}, SecretBytes, }; -use anyhow::anyhow; use chrono::NaiveDateTime; use futures::future::try_join_all; use janus_core::{ @@ -2108,7 +2107,7 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT report_aggregations.client_timestamp, report_aggregations.ord, - report_aggregations.state, report_aggregations.prep_transition, + report_aggregations.state, report_aggregations.prep_state, report_aggregations.error_code, report_aggregations.last_prep_resp FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id @@ -2163,7 +2162,7 @@ impl Transaction<'_, C> { "SELECT report_aggregations.client_report_id, report_aggregations.client_timestamp, report_aggregations.ord, report_aggregations.state, - report_aggregations.prep_transition, report_aggregations.error_code, + report_aggregations.prep_state, report_aggregations.error_code, report_aggregations.last_prep_resp FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id @@ -2217,7 +2216,7 @@ impl Transaction<'_, C> { "SELECT aggregation_jobs.aggregation_job_id, report_aggregations.client_report_id, report_aggregations.client_timestamp, report_aggregations.ord, - report_aggregations.state, report_aggregations.prep_transition, + report_aggregations.state, report_aggregations.prep_state, report_aggregations.error_code, report_aggregations.last_prep_resp FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id @@ -2262,7 +2261,6 @@ impl Transaction<'_, C> { let time = Time::from_naive_date_time(&row.get("client_timestamp")); let ord: u64 = row.get_bigint_and_convert("ord")?; let state: ReportAggregationStateCode = row.get("state"); - let prep_transition_bytes: Option> = row.get("prep_transition"); let error_code: Option = row.get("error_code"); let last_prep_resp_bytes: Option> = row.get("last_prep_resp"); @@ -2286,20 +2284,32 @@ impl Transaction<'_, C> { ReportAggregationStateCode::Start => ReportAggregationState::Start, ReportAggregationStateCode::Waiting => { - let agg_index = role.index().ok_or_else(|| { - Error::User(anyhow!("unexpected role: {}", role.as_str()).into()) - })?; - let ping_pong_transition = PingPongTransition::get_decoded_with_param( - &(vdaf, agg_index), - &prep_transition_bytes.ok_or_else(|| { + let prep_state_bytes = + row.get::<_, Option>>("prep_state").ok_or_else(|| { Error::DbState( - "report aggregation in state WAITING but prep_transition is NULL" + "report aggregation in state WAITING but prep_state is NULL" .to_string(), ) - })?, - )?; + })?; + match role { + Role::Leader => { + let ping_pong_transition = PingPongTransition::get_decoded_with_param( + &(vdaf, 0 /* leader */), + &prep_state_bytes, + )?; + + ReportAggregationState::WaitingLeader(ping_pong_transition) + } + Role::Helper => { + let prepare_state = A::PrepareState::get_decoded_with_param( + &(vdaf, 1 /* helper */), + &prep_state_bytes, + )?; - ReportAggregationState::Waiting(ping_pong_transition) + ReportAggregationState::WaitingHelper(prepare_state) + } + _ => panic!("unexpected role"), + } } ReportAggregationStateCode::Finished => ReportAggregationState::Finished, @@ -2337,14 +2347,14 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); - let encoded_last_prep_resp = report_aggregation + let encoded_last_prep_resp: Option> = report_aggregation .last_prep_resp() .map(PrepareResp::get_encoded); let stmt = self .prepare_cached( "INSERT INTO report_aggregations - (aggregation_job_id, client_report_id, client_timestamp, ord, state, prep_transition, + (aggregation_job_id, client_report_id, client_timestamp, ord, state, prep_state, error_code, last_prep_resp) SELECT aggregation_jobs.id, $3, $4, $5, $6, $7, $8, $9 FROM aggregation_jobs @@ -2364,7 +2374,7 @@ impl Transaction<'_, C> { /* client_timestamp */ &report_aggregation.time().as_naive_date_time()?, /* ord */ &TryInto::::try_into(report_aggregation.ord())?, /* state */ &report_aggregation.state().state_code(), - /* prep_transition */ &encoded_state_values.transition, + /* prep_state */ &encoded_state_values.prep_state, /* error_code */ &encoded_state_values.prepare_err, /* last_prep_resp */ &encoded_last_prep_resp, /* now */ &self.clock.now().as_naive_date_time()?, @@ -2386,14 +2396,14 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); - let encoded_last_prep_resp = report_aggregation + let encoded_last_prep_resp: Option> = report_aggregation .last_prep_resp() .map(PrepareResp::get_encoded); let stmt = self .prepare_cached( "UPDATE report_aggregations - SET state = $1, prep_transition = $2, error_code = $3, last_prep_resp = $4 + SET state = $1, prep_state = $2, error_code = $3, last_prep_resp = $4 FROM aggregation_jobs, tasks WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id AND aggregation_jobs.task_id = tasks.id @@ -2411,7 +2421,7 @@ impl Transaction<'_, C> { &[ /* state */ &report_aggregation.state().state_code(), - /* prep_transition */ &encoded_state_values.transition, + /* prep_state */ &encoded_state_values.prep_state, /* error_code */ &encoded_state_values.prepare_err, /* last_prep_resp */ &encoded_last_prep_resp, /* aggregation_job_id */ diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index d43fe117c..48c117247 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -704,10 +704,14 @@ where #[derive(Clone, Debug, Derivative)] pub enum ReportAggregationState> { Start, - Waiting( + WaitingLeader( /// Most recent transition for this report aggregation. PingPongTransition, ), + WaitingHelper( + /// Helper's current preparation state + A::PrepareState, + ), Finished, Failed(PrepareError), } @@ -718,7 +722,9 @@ impl> pub fn state_code(&self) -> ReportAggregationStateCode { match self { ReportAggregationState::Start => ReportAggregationStateCode::Start, - ReportAggregationState::Waiting(_) => ReportAggregationStateCode::Waiting, + ReportAggregationState::WaitingLeader(_) | ReportAggregationState::WaitingHelper(_) => { + ReportAggregationStateCode::Waiting + } ReportAggregationState::Finished => ReportAggregationStateCode::Finished, ReportAggregationState::Failed(_) => ReportAggregationStateCode::Failed, } @@ -733,10 +739,18 @@ impl> { match self { ReportAggregationState::Start => EncodedReportAggregationStateValues::default(), - ReportAggregationState::Waiting(transition) => EncodedReportAggregationStateValues { - transition: Some(transition.get_encoded()), - ..Default::default() - }, + ReportAggregationState::WaitingLeader(transition) => { + EncodedReportAggregationStateValues { + prep_state: Some(transition.get_encoded()), + ..Default::default() + } + } + ReportAggregationState::WaitingHelper(prepare_state) => { + EncodedReportAggregationStateValues { + prep_state: Some(prepare_state.get_encoded()), + ..Default::default() + } + } ReportAggregationState::Finished => EncodedReportAggregationStateValues::default(), ReportAggregationState::Failed(prepare_err) => EncodedReportAggregationStateValues { prepare_err: Some(*prepare_err as i16), @@ -748,7 +762,9 @@ impl> #[derive(Default)] pub(super) struct EncodedReportAggregationStateValues { - pub(super) transition: Option>, + // For the leader, prep_state is an encoded PingPongTransition. For the helper, it is an encoded + // PrepareState. + pub(super) prep_state: Option>, pub(super) prepare_err: Option, } @@ -777,9 +793,12 @@ where { fn eq(&self, other: &Self) -> bool { match (self, other) { - (Self::Waiting(lhs_transition), Self::Waiting(rhs_transition)) => { + (Self::WaitingLeader(lhs_transition), Self::WaitingLeader(rhs_transition)) => { lhs_transition == rhs_transition } + (Self::WaitingHelper(lhs_state), Self::WaitingHelper(rhs_state)) => { + lhs_state == rhs_state + } (Self::Failed(lhs_prepare_err), Self::Failed(rhs_prepare_err)) => { lhs_prepare_err == rhs_prepare_err } diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index eee112b1d..a08785a40 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -44,8 +44,13 @@ use janus_messages::{ }; use prio::{ codec::{Decode, Encode}, + idpf::IdpfInput, topology::ping_pong::PingPongMessage, - vdaf::prio3::{Prio3, Prio3Count}, + vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + prio3::Prio3Count, + }, }; use rand::{distributions::Standard, random, thread_rng, Rng}; use std::{ @@ -1879,30 +1884,61 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { install_test_trace_subscriber(); let report_id = random(); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); - - for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting( - vdaf_transcript.helper_prepare_transitions[0] - .transition - .clone(), + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &verify_key, + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), + ); + + for (ord, (role, state)) in [ + (Role::Leader, ReportAggregationState::Start), + (Role::Helper, ReportAggregationState::Start), + ( + Role::Leader, + ReportAggregationState::WaitingLeader( + vdaf_transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + ), + ), + ( + Role::Helper, + ReportAggregationState::WaitingHelper( + vdaf_transcript.helper_prepare_transitions[0] + .prepare_state() + .clone(), + ), + ), + (Role::Leader, ReportAggregationState::Finished), + (Role::Helper, ReportAggregationState::Finished), + ( + Role::Leader, + ReportAggregationState::Failed(PrepareError::VdafPrepError), + ), + ( + Role::Helper, + ReportAggregationState::Failed(PrepareError::VdafPrepError), ), - ReportAggregationState::Finished, - ReportAggregationState::Failed(PrepareError::VdafPrepError), ] .into_iter() .enumerate() { + println!("case {role:?} {state:?}"); let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; let task = TaskBuilder::new( task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Helper, + VdafInstance::Poplar1 { bits: 1 }, + role, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) .build(); @@ -1911,17 +1947,18 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let want_report_aggregation = ds .run_tx(|tx| { - let (task, state) = (task.clone(), state.clone()); + let (task, state, aggregation_param) = + (task.clone(), state.clone(), aggregation_param.clone()); Box::pin(async move { tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), @@ -1972,14 +2009,15 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Helper, + &role, task.id(), &aggregation_job_id, - &(), + &aggregation_param, &report_id, ) .await @@ -2018,14 +2056,15 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Helper, + &role, task.id(), &aggregation_job_id, - &(), + &aggregation_param, &report_id, ) .await @@ -2040,14 +2079,15 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let (vdaf, task, aggregation_param) = + (Arc::clone(&vdaf), task.clone(), aggregation_param.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &Role::Helper, + &role, task.id(), &aggregation_job_id, - &(), + &aggregation_param, &report_id, ) .await @@ -2267,13 +2307,23 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let ds = ephemeral_datastore.datastore(clock.clone()).await; let report_id = random(); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let verify_key: [u8; VERIFY_KEY_LENGTH] = random(); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); + let aggregation_param = + Poplar1AggregationParam::try_from_prefixes(Vec::from([IdpfInput::from_bools(&[false])])) + .unwrap(); + + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &verify_key, + &aggregation_param, + &report_id, + &IdpfInput::from_bools(&[false]), + ); let task = TaskBuilder::new( task::QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -2282,18 +2332,22 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let want_report_aggregations = ds .run_tx(|tx| { - let (task, vdaf_transcript) = (task.clone(), vdaf_transcript.clone()); + let (task, vdaf_transcript, aggregation_param) = ( + task.clone(), + vdaf_transcript.clone(), + aggregation_param.clone(), + ); Box::pin(async move { tx.put_task(&task).await.unwrap(); tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), @@ -2305,10 +2359,10 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let mut want_report_aggregations = Vec::new(); for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting( + ReportAggregationState::Start, + ReportAggregationState::WaitingHelper( vdaf_transcript.helper_prepare_transitions[0] - .transition + .prepare_state() .clone(), ), ReportAggregationState::Finished, @@ -4700,7 +4754,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { 1, None, // Counted among max_size. - ReportAggregationState::Waiting( + ReportAggregationState::WaitingLeader( transcript.helper_prepare_transitions[0].transition.clone(), ), ); diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 88abeb210..de79efdfe 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -1,3 +1,4 @@ +use assert_matches::assert_matches; use janus_messages::{ReportId, Role}; use prio::{ topology::ping_pong::{ @@ -36,6 +37,14 @@ pub struct HelperPrepareTransition< pub message: PingPongMessage, } +impl> + HelperPrepareTransition +{ + pub fn prepare_state(&self) -> &V::PrepareState { + assert_matches!(self.state, PingPongState::Continued(ref state) => state) + } +} + /// A transcript of a VDAF run using the ping-pong VDAF topology. #[derive(Clone, Debug)] pub struct VdafTranscript< diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index 8ebd99840..1d632baf2 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -207,7 +207,8 @@ CREATE TABLE report_aggregations( client_timestamp TIMESTAMP NOT NULL, -- the client timestamp this report aggregation is associated with ord BIGINT NOT NULL, -- a value used to specify the ordering of client reports in the aggregation job state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation - prep_transition BYTEA, -- the current preparation transition (opaque VDAF message, only if in state WAITING) + prep_state BYTEA, -- the current preparation transition (opaque VDAF message, only if in state WAITING) + -- value is an encoded PingPingTransition for the leader or an encoded PrepareState for the helper error_code SMALLINT, -- error code corresponding to a DAP ReportShareError value; null if in a state other than FAILED last_prep_resp BYTEA, -- the last PrepareResp message sent to the Leader, to assist in replay (opaque VDAF message, populated for Helper only)