From 608de17b4b3063372e4eb6d34d061ce55052474e Mon Sep 17 00:00:00 2001 From: David Cook Date: Tue, 19 Sep 2023 16:55:40 -0500 Subject: [PATCH] Handle chunk_length VDAF parameter --- Cargo.lock | 2 +- aggregator/src/aggregator.rs | 32 ++--- .../src/aggregator/aggregation_job_creator.rs | 61 +++++--- aggregator/src/bin/janus_cli.rs | 5 +- aggregator_api/src/tests.rs | 27 +++- aggregator_core/src/datastore/tests.rs | 32 ++++- aggregator_core/src/task.rs | 9 +- core/src/task.rs | 78 ++++++----- integration_tests/Cargo.toml | 2 +- integration_tests/src/client.rs | 17 ++- integration_tests/tests/common/mod.rs | 30 ++-- integration_tests/tests/divviup_ts.rs | 5 +- integration_tests/tests/in_cluster.rs | 24 +++- integration_tests/tests/janus.rs | 29 +--- .../src/bin/janus_interop_client.rs | 30 ++-- .../src/bin/janus_interop_collector.rs | 80 +++++------ interop_binaries/src/lib.rs | 43 +++--- interop_binaries/tests/end_to_end.rs | 30 +--- messages/src/taskprov.rs | 130 ++++++++++-------- tools/src/bin/collect.rs | 23 ++-- 20 files changed, 393 insertions(+), 296 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0d6c9a277..11e223146 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -951,7 +951,7 @@ dependencies = [ [[package]] name = "divviup-client" version = "0.0.1" -source = "git+https://github.com/divviup/divviup-api?tag=0.0.25#8eda807aff7a6ed20d7ea8873b4d36b2421cdb95" +source = "git+https://github.com/divviup/divviup-api?rev=06667dbeb93d4870bb257b7ddf9a19b37a7e5da5#06667dbeb93d4870bb257b7ddf9a19b37a7e5da5" dependencies = [ "base64 0.21.4", "email_address", diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 8fdd29858..e4f1dd222 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -801,13 +801,11 @@ impl TaskAggregator { VdafOps::Prio3Count(Arc::new(vdaf), verify_key) } - VdafInstance::Prio3CountVec { length } => { - let vdaf = Prio3::new_sum_vec_multithreaded( - 2, - 1, - *length, - VdafInstance::chunk_size(*length), - )?; + VdafInstance::Prio3CountVec { + length, + chunk_length, + } => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, *length, *chunk_length)?; let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3CountVec(Arc::new(vdaf), verify_key) } @@ -818,19 +816,21 @@ impl TaskAggregator { VdafOps::Prio3Sum(Arc::new(vdaf), verify_key) } - VdafInstance::Prio3SumVec { bits, length } => { - let vdaf = Prio3::new_sum_vec_multithreaded( - 2, - *bits, - *length, - VdafInstance::chunk_size(*bits * *length), - )?; + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, *bits, *length, *chunk_length)?; let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3SumVec(Arc::new(vdaf), verify_key) } - VdafInstance::Prio3Histogram { length } => { - let vdaf = Prio3::new_histogram(2, *length, VdafInstance::chunk_size(*length))?; + VdafInstance::Prio3Histogram { + length, + chunk_length, + } => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length)?; let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3Histogram(Arc::new(vdaf), verify_key) } diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index c763cb6fe..cf13736e4 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -268,12 +268,18 @@ impl AggregationJobCreator { .await } - (task::QueryType::TimeInterval, VdafInstance::Prio3CountVec { length }) => { + ( + task::QueryType::TimeInterval, + VdafInstance::Prio3CountVec { + length, + chunk_length, + }, + ) => { let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( 2, 1, *length, - VdafInstance::chunk_size(*length), + *chunk_length, )?); self.create_aggregation_jobs_for_time_interval_task_no_param::< VERIFY_KEY_LENGTH, @@ -287,23 +293,32 @@ impl AggregationJobCreator { .await } - (task::QueryType::TimeInterval, VdafInstance::Prio3SumVec { bits, length }) => { + ( + task::QueryType::TimeInterval, + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + }, + ) => { let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( 2, *bits, *length, - VdafInstance::chunk_size(*bits * *length), + *chunk_length, )?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } - (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { length }) => { - let vdaf = Arc::new(Prio3::new_histogram( - 2, - *length, - VdafInstance::chunk_size(*length), - )?); + ( + task::QueryType::TimeInterval, + VdafInstance::Prio3Histogram { + length, + chunk_length, + }, + ) => { + let vdaf = Arc::new(Prio3::new_histogram(2, *length, *chunk_length)?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -368,13 +383,16 @@ impl AggregationJobCreator { max_batch_size, batch_time_window_size, }, - VdafInstance::Prio3CountVec { length }, + VdafInstance::Prio3CountVec { + length, + chunk_length, + }, ) => { let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( 2, 1, *length, - VdafInstance::chunk_size(*length), + *chunk_length, )?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; @@ -405,13 +423,17 @@ impl AggregationJobCreator { max_batch_size, batch_time_window_size, }, - VdafInstance::Prio3SumVec { bits, length }, + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + }, ) => { let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( 2, *bits, *length, - VdafInstance::chunk_size(*bits * *length), + *chunk_length, )?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; @@ -426,13 +448,12 @@ impl AggregationJobCreator { max_batch_size, batch_time_window_size, }, - VdafInstance::Prio3Histogram { length }, + VdafInstance::Prio3Histogram { + length, + chunk_length, + }, ) => { - let vdaf = Arc::new(Prio3::new_histogram( - 2, - *length, - VdafInstance::chunk_size(*length), - )?); + let vdaf = Arc::new(Prio3::new_histogram(2, *length, *chunk_length)?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index a5ac05b41..f69597d8b 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -671,7 +671,10 @@ mod tests { max_batch_size: 100, batch_time_window_size: None, }, - VdafInstance::Prio3CountVec { length: 4 }, + VdafInstance::Prio3CountVec { + length: 4, + chunk_length: 2, + }, Role::Leader, ) .with_id(*tasks[0].id()) diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index 3b7e7d1b2..baf5ee0a6 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -1518,7 +1518,10 @@ fn post_task_req_serialization() { max_batch_size: 999, batch_time_window_size: None, }, - vdaf: VdafInstance::Prio3CountVec { length: 5 }, + vdaf: VdafInstance::Prio3CountVec { + length: 5, + chunk_length: 2, + }, role: Role::Helper, vdaf_verify_key: "encoded".to_owned(), max_batch_query_count: 1, @@ -1556,10 +1559,12 @@ fn post_task_req_serialization() { Token::StructVariant { name: "VdafInstance", variant: "Prio3CountVec", - len: 1, + len: 2, }, Token::Str("length"), Token::U64(5), + Token::Str("chunk_length"), + Token::U64(2), Token::StructVariantEnd, Token::Str("role"), Token::UnitVariant { @@ -1619,7 +1624,10 @@ fn post_task_req_serialization() { max_batch_size: 999, batch_time_window_size: None, }, - vdaf: VdafInstance::Prio3CountVec { length: 5 }, + vdaf: VdafInstance::Prio3CountVec { + length: 5, + chunk_length: 2, + }, role: Role::Leader, vdaf_verify_key: "encoded".to_owned(), max_batch_query_count: 1, @@ -1659,10 +1667,12 @@ fn post_task_req_serialization() { Token::StructVariant { name: "VdafInstance", variant: "Prio3CountVec", - len: 1, + len: 2, }, Token::Str("length"), Token::U64(5), + Token::Str("chunk_length"), + Token::U64(2), Token::StructVariantEnd, Token::Str("role"), Token::UnitVariant { @@ -1739,7 +1749,10 @@ fn task_resp_serialization() { max_batch_size: 999, batch_time_window_size: None, }, - VdafInstance::Prio3CountVec { length: 5 }, + VdafInstance::Prio3CountVec { + length: 5, + chunk_length: 2, + }, Role::Leader, SecretBytes::new(b"vdaf verify key!".to_vec()), 1, @@ -1801,10 +1814,12 @@ fn task_resp_serialization() { Token::StructVariant { name: "VdafInstance", variant: "Prio3CountVec", - len: 1, + len: 2, }, Token::Str("length"), Token::U64(5), + Token::Str("chunk_length"), + Token::U64(2), Token::StructVariantEnd, Token::Str("role"), Token::UnitVariant { diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 3779ba223..bcd88abef 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -110,12 +110,36 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { let mut want_tasks = HashMap::new(); for (vdaf, role) in [ (VdafInstance::Prio3Count, Role::Leader), - (VdafInstance::Prio3CountVec { length: 8 }, Role::Leader), - (VdafInstance::Prio3CountVec { length: 64 }, Role::Helper), + ( + VdafInstance::Prio3CountVec { + length: 8, + chunk_length: 3, + }, + Role::Leader, + ), + ( + VdafInstance::Prio3CountVec { + length: 64, + chunk_length: 10, + }, + Role::Helper, + ), (VdafInstance::Prio3Sum { bits: 64 }, Role::Helper), (VdafInstance::Prio3Sum { bits: 32 }, Role::Helper), - (VdafInstance::Prio3Histogram { length: 4 }, Role::Leader), - (VdafInstance::Prio3Histogram { length: 5 }, Role::Leader), + ( + VdafInstance::Prio3Histogram { + length: 4, + chunk_length: 2, + }, + Role::Leader, + ), + ( + VdafInstance::Prio3Histogram { + length: 5, + chunk_length: 2, + }, + Role::Leader, + ), (VdafInstance::Poplar1 { bits: 8 }, Role::Helper), (VdafInstance::Poplar1 { bits: 64 }, Role::Helper), ] { diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 77755ad9c..ff2230c33 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -1266,7 +1266,10 @@ mod tests { max_batch_size: 10, batch_time_window_size: None, }, - VdafInstance::Prio3CountVec { length: 8 }, + VdafInstance::Prio3CountVec { + length: 8, + chunk_length: 3, + }, Role::Helper, SecretBytes::new(b"1234567812345678".to_vec()), 1, @@ -1326,10 +1329,12 @@ mod tests { Token::StructVariant { name: "VdafInstance", variant: "Prio3CountVec", - len: 1, + len: 2, }, Token::Str("length"), Token::U64(8), + Token::Str("chunk_length"), + Token::U64(3), Token::StructVariantEnd, Token::Str("role"), Token::UnitVariant { diff --git a/core/src/task.rs b/core/src/task.rs index 2adaca84f..24cf2eb92 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -25,13 +25,17 @@ pub enum VdafInstance { /// A `Prio3` counter. Prio3Count, /// A vector of `Prio3` counters. - Prio3CountVec { length: usize }, + Prio3CountVec { length: usize, chunk_length: usize }, /// A `Prio3` sum. Prio3Sum { bits: usize }, /// A vector of `Prio3` sums. - Prio3SumVec { bits: usize, length: usize }, + Prio3SumVec { + bits: usize, + length: usize, + chunk_length: usize, + }, /// A `Prio3` histogram with `length` buckets in it. - Prio3Histogram { length: usize }, + Prio3Histogram { length: usize, chunk_length: usize }, /// A `Prio3` 16-bit fixed point vector sum with bounded L2 norm. #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum { length: usize }, @@ -71,19 +75,6 @@ impl VdafInstance { _ => VERIFY_KEY_LENGTH, } } - - /// Returns a suboptimal estimate of the chunk size to use in ParallelSum gadgets. See [VDAF] - /// for discussion of chunk size. - /// - /// # Bugs - /// - /// Janus should allow chunk size to be configured ([#1900][issue]). - /// - /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#name-selection-of-parallelsum-ch - /// [issue]: https://github.com/divviup/janus/issues/1900 - pub fn chunk_size(measurement_length: usize) -> usize { - (measurement_length as f64).sqrt().floor() as usize - } } impl TryFrom<&taskprov::VdafType> for VdafInstance { @@ -95,11 +86,12 @@ impl TryFrom<&taskprov::VdafType> for VdafInstance { taskprov::VdafType::Prio3Sum { bits } => Ok(Self::Prio3Sum { bits: *bits as usize, }), - taskprov::VdafType::Prio3Histogram { buckets } => Ok(Self::Prio3Histogram { - // taskprov does not yet deal with the VDAF-06 representation of histograms. In the - // meantime, we translate the bucket boundaries to a length that Janus understands. - // https://github.com/wangshan/draft-wang-ppm-dap-taskprov/issues/33 - length: buckets.len() + 1, // +1 to account for the top bucket extending to infinity + taskprov::VdafType::Prio3Histogram { + length, + chunk_length, + } => Ok(Self::Prio3Histogram { + length: *length as usize, + chunk_length: *chunk_length as usize, }), taskprov::VdafType::Poplar1 { bits } => Ok(Self::Poplar1 { bits: *bits as usize, @@ -165,13 +157,16 @@ macro_rules! vdaf_dispatch_impl_base { $body } - ::janus_core::task::VdafInstance::Prio3CountVec { length } => { + ::janus_core::task::VdafInstance::Prio3CountVec { + length, + chunk_length, + } => { // Prio3CountVec is implemented as a 1-bit sum vec let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded( 2, 1, *length, - janus_core::task::VdafInstance::chunk_size(*length), + *chunk_length, )?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; @@ -185,24 +180,27 @@ macro_rules! vdaf_dispatch_impl_base { $body } - ::janus_core::task::VdafInstance::Prio3SumVec { bits, length } => { + ::janus_core::task::VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded( 2, *bits, *length, - janus_core::task::VdafInstance::chunk_size(*bits * *length), + *chunk_length, )?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } - ::janus_core::task::VdafInstance::Prio3Histogram { length } => { - let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram( - 2, - *length, - janus_core::task::VdafInstance::chunk_size(*length), - )?; + ::janus_core::task::VdafInstance::Prio3Histogram { + length, + chunk_length, + } => { + let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, *length, *chunk_length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body @@ -789,15 +787,20 @@ mod tests { }], ); assert_tokens( - &VdafInstance::Prio3CountVec { length: 8 }, + &VdafInstance::Prio3CountVec { + length: 8, + chunk_length: 3, + }, &[ Token::StructVariant { name: "VdafInstance", variant: "Prio3CountVec", - len: 1, + len: 2, }, Token::Str("length"), Token::U64(8), + Token::Str("chunk_length"), + Token::U64(3), Token::StructVariantEnd, ], ); @@ -815,15 +818,20 @@ mod tests { ], ); assert_tokens( - &VdafInstance::Prio3Histogram { length: 6 }, + &VdafInstance::Prio3Histogram { + length: 6, + chunk_length: 2, + }, &[ Token::StructVariant { name: "VdafInstance", variant: "Prio3Histogram", - len: 1, + len: 2, }, Token::Str("length"), Token::U64(6), + Token::Str("chunk_length"), + Token::U64(2), Token::StructVariantEnd, ], ); diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index 19c92d9b6..99b15a2c5 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -15,7 +15,7 @@ in-cluster = ["dep:k8s-openapi", "dep:kube"] anyhow.workspace = true backoff = { version = "0.4", features = ["tokio"] } base64.workspace = true -divviup-client = { git = "https://github.com/divviup/divviup-api", features = ["admin"], tag = "0.0.25" } +divviup-client = { git = "https://github.com/divviup/divviup-api", features = ["admin"], rev = "06667dbeb93d4870bb257b7ddf9a19b37a7e5da5" } futures = "0.3.28" hex = "0.4" http = "0.2" diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index 86832978b..07154f266 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -58,23 +58,28 @@ fn json_encode_vdaf(vdaf: &VdafInstance) -> Value { VdafInstance::Prio3Count => json!({ "type": "Prio3Count" }), - VdafInstance::Prio3CountVec { length } => json!({ - "type": "Prio3CountVec", - "length": format!("{length}"), - }), VdafInstance::Prio3Sum { bits } => json!({ "type": "Prio3Sum", "bits": format!("{bits}"), }), - VdafInstance::Prio3SumVec { bits, length } => json!({ + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => json!({ "type": "Prio3SumVec", "bits": format!("{bits}"), "length": format!("{length}"), + "chunk_length": format!("{chunk_length}"), }), - VdafInstance::Prio3Histogram { length } => { + VdafInstance::Prio3Histogram { + length, + chunk_length, + } => { json!({ "type": "Prio3Histogram", "length": format!("{length}"), + "chunk_length": format!("{chunk_length}"), }) } _ => panic!("VDAF {vdaf:?} is not yet supported"), diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index 644db5ae0..a4a8b63c2 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -282,14 +282,12 @@ pub async fn submit_measurements_and_verify_aggregate( ) .await; } - VdafInstance::Prio3SumVec { bits, length } => { - let vdaf = Prio3::new_sum_vec_multithreaded( - 2, - *bits, - *length, - VdafInstance::chunk_size(*bits * *length), - ) - .unwrap(); + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, *bits, *length, *chunk_length).unwrap(); let measurements = iter::repeat_with(|| { iter::repeat_with(|| (random::()) >> (128 - bits)) @@ -327,8 +325,11 @@ pub async fn submit_measurements_and_verify_aggregate( ) .await; } - VdafInstance::Prio3Histogram { length } => { - let vdaf = Prio3::new_histogram(2, *length, VdafInstance::chunk_size(*length)).unwrap(); + VdafInstance::Prio3Histogram { + length, + chunk_length, + } => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length).unwrap(); let mut aggregate_result = vec![0; *length]; let measurements = iter::repeat_with(|| { @@ -358,10 +359,11 @@ pub async fn submit_measurements_and_verify_aggregate( ) .await; } - VdafInstance::Prio3CountVec { length } => { - let vdaf = - Prio3::new_sum_vec_multithreaded(2, 1, *length, VdafInstance::chunk_size(*length)) - .unwrap(); + VdafInstance::Prio3CountVec { + length, + chunk_length, + } => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, *length, *chunk_length).unwrap(); let measurements = iter::repeat_with(|| { iter::repeat_with(|| random::() as u128) diff --git a/integration_tests/tests/divviup_ts.rs b/integration_tests/tests/divviup_ts.rs index 50ad73478..9066db970 100644 --- a/integration_tests/tests/divviup_ts.rs +++ b/integration_tests/tests/divviup_ts.rs @@ -58,7 +58,10 @@ async fn janus_divviup_ts_histogram() { run_divviup_ts_integration_test( &container_client(), - VdafInstance::Prio3Histogram { length: 4 }, + VdafInstance::Prio3Histogram { + length: 4, + chunk_length: 2, + }, ) .await; } diff --git a/integration_tests/tests/in_cluster.rs b/integration_tests/tests/in_cluster.rs index 8ece2e2d1..23a8dbd39 100644 --- a/integration_tests/tests/in_cluster.rs +++ b/integration_tests/tests/in_cluster.rs @@ -169,15 +169,28 @@ impl InClusterJanusPair { VdafInstance::Prio3Sum { bits } => Vdaf::Sum { bits: bits.try_into().unwrap(), }, - VdafInstance::Prio3SumVec { bits, length } => Vdaf::SumVec { + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => Vdaf::SumVec { bits: bits.try_into().unwrap(), length: length.try_into().unwrap(), + chunk_length: Some(chunk_length.try_into().unwrap()), }, - VdafInstance::Prio3Histogram { length } => Vdaf::Histogram(Histogram::Length { + VdafInstance::Prio3Histogram { + length, + chunk_length, + } => Vdaf::Histogram(Histogram::Length { length: length.try_into().unwrap(), + chunk_length: Some(chunk_length.try_into().unwrap()), }), - VdafInstance::Prio3CountVec { length } => Vdaf::CountVec { + VdafInstance::Prio3CountVec { + length, + chunk_length, + } => Vdaf::CountVec { length: length.try_into().unwrap(), + chunk_length: Some(chunk_length.try_into().unwrap()), }, other => panic!("unsupported vdaf {other:?}"), }, @@ -285,7 +298,10 @@ async fn in_cluster_histogram() { // Start port forwards and set up task. let janus_pair = InClusterJanusPair::new( - VdafInstance::Prio3Histogram { length: 4 }, + VdafInstance::Prio3Histogram { + length: 4, + chunk_length: 2, + }, QueryType::TimeInterval, ) .await; diff --git a/integration_tests/tests/janus.rs b/integration_tests/tests/janus.rs index 59137f2f6..0ff5171b9 100644 --- a/integration_tests/tests/janus.rs +++ b/integration_tests/tests/janus.rs @@ -99,30 +99,10 @@ async fn janus_janus_histogram_4_buckets() { let container_client = container_client(); let janus_pair = JanusPair::new( &container_client, - VdafInstance::Prio3Histogram { length: 4 }, - QueryType::TimeInterval, - ) - .await; - - // Run the behavioral test. - submit_measurements_and_verify_aggregate( - &janus_pair.task_parameters, - (janus_pair.leader.port(), janus_pair.helper.port()), - &ClientBackend::InProcess, - ) - .await; -} - -/// This test exercises Prio3CountVec with Janus as both the leader and the helper. -#[tokio::test(flavor = "multi_thread")] -async fn janus_janus_count_vec_15() { - install_test_trace_subscriber(); - - // Start servers. - let container_client = container_client(); - let janus_pair = JanusPair::new( - &container_client, - VdafInstance::Prio3CountVec { length: 15 }, + VdafInstance::Prio3Histogram { + length: 4, + chunk_length: 2, + }, QueryType::TimeInterval, ) .await; @@ -173,6 +153,7 @@ async fn janus_janus_sum_vec() { VdafInstance::Prio3SumVec { bits: 16, length: 15, + chunk_length: 16, }, QueryType::TimeInterval, ) diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index 96ad8c5c3..d418f6969 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -144,14 +144,6 @@ async fn handle_upload( handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } - VdafInstance::Prio3CountVec { length } => { - let measurement = parse_vector_measurement::(request.measurement.clone())?; - let vdaf_client = - Prio3::new_sum_vec_multithreaded(2, 1, length, VdafInstance::chunk_size(length)) - .context("failed to construct Prio3CountVec VDAF")?; - handle_upload_generic(http_client, vdaf_client, request, measurement).await?; - } - VdafInstance::Prio3Sum { bits } => { let measurement = parse_primitive_measurement::(request.measurement.clone())?; let vdaf_client = @@ -159,21 +151,23 @@ async fn handle_upload( handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } - VdafInstance::Prio3SumVec { bits, length } => { + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => { let measurement = parse_vector_measurement::(request.measurement.clone())?; - let vdaf_client = Prio3::new_sum_vec_multithreaded( - 2, - bits, - length, - VdafInstance::chunk_size(bits * length), - ) - .context("failed to construct Prio3SumVec VDAF")?; + let vdaf_client = Prio3::new_sum_vec_multithreaded(2, bits, length, chunk_length) + .context("failed to construct Prio3SumVec VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } - VdafInstance::Prio3Histogram { length } => { + VdafInstance::Prio3Histogram { + length, + chunk_length, + } => { let measurement = parse_primitive_measurement::(request.measurement.clone())?; - let vdaf_client = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) + let vdaf_client = Prio3::new_histogram(2, length, chunk_length) .context("failed to construct Prio3Histogram VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs index 436c7bab1..e7b971f97 100644 --- a/interop_binaries/src/bin/janus_interop_collector.rs +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -311,25 +311,6 @@ async fn handle_collection_start( .await? } - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3CountVec { length }) => { - let vdaf = - Prio3::new_sum_vec_multithreaded(2, 1, length, VdafInstance::chunk_size(length)) - .context("failed to construct Prio3CountVec VDAF")?; - handle_collect_generic( - http_client, - collector_params, - Query::new_time_interval(batch_interval), - vdaf, - &agg_param, - |_| None, - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::NumberVec(converted) - }, - ) - .await? - } - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Sum { bits }) => { let vdaf = Prio3::new_sum(2, bits).context("failed to construct Prio3Sum VDAF")?; handle_collect_generic( @@ -344,14 +325,16 @@ async fn handle_collection_start( .await? } - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3SumVec { bits, length }) => { - let vdaf = Prio3::new_sum_vec_multithreaded( - 2, + ( + ParsedQuery::TimeInterval(batch_interval), + VdafInstance::Prio3SumVec { bits, length, - VdafInstance::chunk_size(bits * length), - ) - .context("failed to construct Prio3SumVec VDAF")?; + chunk_length, + }, + ) => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, bits, length, chunk_length) + .context("failed to construct Prio3SumVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -367,8 +350,14 @@ async fn handle_collection_start( .await? } - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Histogram { length }) => { - let vdaf = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) + ( + ParsedQuery::TimeInterval(batch_interval), + VdafInstance::Prio3Histogram { + length, + chunk_length, + }, + ) => { + let vdaf = Prio3::new_histogram(2, length, chunk_length) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, @@ -468,10 +457,15 @@ async fn handle_collection_start( .await? } - (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3CountVec { length }) => { - let vdaf = - Prio3::new_sum_vec_multithreaded(2, 1, length, VdafInstance::chunk_size(length)) - .context("failed to construct Prio3CountVec VDAF")?; + ( + ParsedQuery::FixedSize(fixed_size_query), + VdafInstance::Prio3CountVec { + length, + chunk_length, + }, + ) => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, length, chunk_length) + .context("failed to construct Prio3CountVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -570,14 +564,16 @@ async fn handle_collection_start( .await? } - (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3SumVec { bits, length }) => { - let vdaf = Prio3::new_sum_vec_multithreaded( - 2, + ( + ParsedQuery::FixedSize(fixed_size_query), + VdafInstance::Prio3SumVec { bits, length, - VdafInstance::chunk_size(bits * length), - ) - .context("failed to construct Prio3SumVec VDAF")?; + chunk_length, + }, + ) => { + let vdaf = Prio3::new_sum_vec_multithreaded(2, bits, length, chunk_length) + .context("failed to construct Prio3SumVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -593,8 +589,14 @@ async fn handle_collection_start( .await? } - (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3Histogram { length }) => { - let vdaf = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) + ( + ParsedQuery::FixedSize(fixed_size_query), + VdafInstance::Prio3Histogram { + length, + chunk_length, + }, + ) => { + let vdaf = Prio3::new_histogram(2, length, chunk_length) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index 957a9440b..38c40fe85 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -106,18 +106,17 @@ where #[serde(tag = "type")] pub enum VdafObject { Prio3Count, - Prio3CountVec { - length: NumberAsString, - }, Prio3Sum { bits: NumberAsString, }, Prio3SumVec { bits: NumberAsString, length: NumberAsString, + chunk_length: NumberAsString, }, Prio3Histogram { length: NumberAsString, + chunk_length: NumberAsString, }, #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum { @@ -138,21 +137,26 @@ impl From for VdafObject { match vdaf { VdafInstance::Prio3Count => VdafObject::Prio3Count, - VdafInstance::Prio3CountVec { length } => VdafObject::Prio3CountVec { - length: NumberAsString(length), - }, - VdafInstance::Prio3Sum { bits } => VdafObject::Prio3Sum { bits: NumberAsString(bits), }, - VdafInstance::Prio3SumVec { bits, length } => VdafObject::Prio3SumVec { + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + } => VdafObject::Prio3SumVec { bits: NumberAsString(bits), length: NumberAsString(length), + chunk_length: NumberAsString(chunk_length), }, - VdafInstance::Prio3Histogram { length } => VdafObject::Prio3Histogram { + VdafInstance::Prio3Histogram { + length, + chunk_length, + } => VdafObject::Prio3Histogram { length: NumberAsString(length), + chunk_length: NumberAsString(chunk_length), }, #[cfg(feature = "fpvec_bounded_l2")] @@ -185,20 +189,25 @@ impl From for VdafInstance { match vdaf { VdafObject::Prio3Count => VdafInstance::Prio3Count, - VdafObject::Prio3CountVec { length } => { - VdafInstance::Prio3CountVec { length: length.0 } - } - VdafObject::Prio3Sum { bits } => VdafInstance::Prio3Sum { bits: bits.0 }, - VdafObject::Prio3SumVec { bits, length } => VdafInstance::Prio3SumVec { + VdafObject::Prio3SumVec { + bits, + length, + chunk_length, + } => VdafInstance::Prio3SumVec { bits: bits.0, length: length.0, + chunk_length: chunk_length.0, }, - VdafObject::Prio3Histogram { length } => { - VdafInstance::Prio3Histogram { length: length.0 } - } + VdafObject::Prio3Histogram { + length, + chunk_length, + } => VdafInstance::Prio3Histogram { + length: length.0, + chunk_length: chunk_length.0, + }, #[cfg(feature = "fpvec_bounded_l2")] VdafObject::Prio3FixedPoint16BitBoundedL2VecSum { length } => { diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index 1c72310be..3d57c0518 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -621,7 +621,12 @@ async fn e2e_prio3_sum() { async fn e2e_prio3_sum_vec() { let result = run( QueryKind::TimeInterval, - json!({"type": "Prio3SumVec", "bits": "64", "length": "4"}), + json!({ + "type": "Prio3SumVec", + "bits": "64", + "length": "4", + "chunk_length": "18", + }), &[ json!(["0", "0", "0", "10"]), json!(["0", "0", "10", "0"]), @@ -643,6 +648,7 @@ async fn e2e_prio3_histogram() { json!({ "type": "Prio3Histogram", "length": "6", + "chunk_length": "2", }), &[ json!("0"), @@ -663,28 +669,6 @@ async fn e2e_prio3_histogram() { } } -#[tokio::test] -async fn e2e_prio3_count_vec() { - let result = run( - QueryKind::TimeInterval, - json!({"type": "Prio3CountVec", "length": "4"}), - &[ - json!(["0", "0", "0", "1"]), - json!(["0", "0", "1", "0"]), - json!(["0", "1", "0", "0"]), - json!(["1", "0", "0", "0"]), - ], - b"", - ) - .await; - for element in result - .as_array() - .expect("CountVec result should be an array") - { - assert!(element.is_string()); - } -} - #[tokio::test] async fn e2e_prio3_fixed16vec() { let fp16_4_inv = fixed!(0.25: I1F15); diff --git a/messages/src/taskprov.rs b/messages/src/taskprov.rs index 72036498c..c8bbea0fa 100644 --- a/messages/src/taskprov.rs +++ b/messages/src/taskprov.rs @@ -6,8 +6,8 @@ use crate::{Duration, Error, Role, Time, Url}; use anyhow::anyhow; use derivative::Derivative; use prio::codec::{ - decode_u16_items, decode_u24_items, decode_u8_items, encode_u16_items, encode_u24_items, - encode_u8_items, CodecError, Decode, Encode, + decode_u16_items, decode_u8_items, encode_u16_items, encode_u8_items, CodecError, Decode, + Encode, }; use std::{fmt::Debug, io::Cursor}; @@ -261,13 +261,6 @@ pub struct VdafConfig { impl VdafConfig { pub fn new(dp_config: DpConfig, vdaf_type: VdafType) -> Result { - if let VdafType::Prio3Histogram { buckets } = &vdaf_type { - if buckets.is_empty() { - return Err(Error::InvalidParameter( - "buckets must not be empty for Prio3Histogram", - )); - } - } Ok(Self { dp_config, vdaf_type, @@ -300,13 +293,6 @@ impl Decode for VdafConfig { dp_config: DpConfig::decode(bytes)?, vdaf_type: VdafType::decode(bytes)?, }; - if let VdafType::Prio3Histogram { buckets } = &ret.vdaf_type { - if buckets.is_empty() { - return Err(CodecError::Other( - anyhow!("buckets must not be empty for Prio3Histogram").into(), - )); - } - } Ok(ret) } } @@ -321,11 +307,19 @@ pub enum VdafType { /// Bit length of the summand. bits: u8, }, + Prio3SumVec { + /// Bit length of each summand. + bits: u8, + /// Number of summands. + length: u32, + /// Size of each proof chunk. + chunk_length: u32, + }, Prio3Histogram { - /// Number of buckets in the histogram - // This may change as the taskprov draft adapts to VDAF-06 - // https://github.com/wangshan/draft-wang-ppm-dap-taskprov/issues/33 - buckets: Vec, + /// Number of buckets. + length: u32, + /// Size of each proof chunk. + chunk_length: u32, }, Poplar1 { /// Bit length of the input string. @@ -336,7 +330,8 @@ pub enum VdafType { impl VdafType { const PRIO3COUNT: u32 = 0x00000000; const PRIO3SUM: u32 = 0x00000001; - const PRIO3HISTOGRAM: u32 = 0x00000002; + const PRIO3SUMVEC: u32 = 0x00000002; + const PRIO3HISTOGRAM: u32 = 0x00000003; const POPLAR1: u32 = 0x00001000; } @@ -348,9 +343,23 @@ impl Encode for VdafType { Self::PRIO3SUM.encode(bytes); bits.encode(bytes); } - Self::Prio3Histogram { buckets } => { + Self::Prio3SumVec { + bits, + length, + chunk_length, + } => { + Self::PRIO3SUMVEC.encode(bytes); + bits.encode(bytes); + length.encode(bytes); + chunk_length.encode(bytes); + } + Self::Prio3Histogram { + length, + chunk_length, + } => { Self::PRIO3HISTOGRAM.encode(bytes); - encode_u24_items(bytes, &(), buckets); + length.encode(bytes); + chunk_length.encode(bytes); } Self::Poplar1 { bits } => { Self::POPLAR1.encode(bytes); @@ -363,9 +372,10 @@ impl Encode for VdafType { Some( 4 + match self { Self::Prio3Count => 0, - Self::Prio3Sum { bits } => bits.encoded_len()?, - Self::Prio3Histogram { buckets } => 3 + buckets.len() * 0u64.encoded_len()?, - Self::Poplar1 { bits } => bits.encoded_len()?, + Self::Prio3Sum { .. } => 1, + Self::Prio3SumVec { .. } => 9, + Self::Prio3Histogram { .. } => 8, + Self::Poplar1 { .. } => 2, }, ) } @@ -378,8 +388,14 @@ impl Decode for VdafType { Self::PRIO3SUM => Ok(Self::Prio3Sum { bits: u8::decode(bytes)?, }), + Self::PRIO3SUMVEC => Ok(Self::Prio3SumVec { + bits: u8::decode(bytes)?, + length: u32::decode(bytes)?, + chunk_length: u32::decode(bytes)?, + }), Self::PRIO3HISTOGRAM => Ok(Self::Prio3Histogram { - buckets: decode_u24_items(&(), bytes)?, + length: u32::decode(bytes)?, + chunk_length: u32::decode(bytes)?, }), Self::POPLAR1 => Ok(Self::Poplar1 { bits: u16::decode(bytes)?, @@ -496,26 +512,27 @@ mod tests { concat!("00000001", "FF"), ), ( - VdafType::Prio3Histogram { - buckets: vec![0x00ABCDEF, 0x40404040, 0xDEADBEEF], + VdafType::Prio3SumVec { + bits: 8, + length: 12, + chunk_length: 14, }, concat!( - "00000002", - "000018", // length - "0000000000ABCDEF", - "0000000040404040", - "00000000DEADBEEF", + "00000002", // algorithm ID + "08", // bits + "0000000C", // length + "0000000E" // chunk_length ), ), ( VdafType::Prio3Histogram { - buckets: vec![u64::MIN, u64::MAX], + length: 256, + chunk_length: 18, }, concat!( - "00000002", - "000010", // length - "0000000000000000", - "FFFFFFFFFFFFFFFF", + "00000003", // algorithm ID + "00000100", // length + "00000012", // chunk_length ), ), ( @@ -548,25 +565,30 @@ mod tests { .unwrap(), concat!("01", concat!("00000001", "42")), ), + ( + VdafConfig::new( + DpConfig::new(DpMechanism::None), + VdafType::Prio3SumVec { + bits: 8, + length: 12, + chunk_length: 14, + }, + ) + .unwrap(), + concat!("01", concat!("00000002", "08", "0000000C", "0000000E")), + ), ( VdafConfig::new( DpConfig::new(DpMechanism::None), VdafType::Prio3Histogram { - buckets: vec![0xAAAAAAAA], + length: 10, + chunk_length: 4, }, ) .unwrap(), - concat!("01", concat!("00000002", "000008", "00000000AAAAAAAA")), + concat!("01", concat!("00000003", "0000000A", "00000004")), ), ]); - - // Empty Prio3Histogram buckets. - assert_matches!( - VdafConfig::get_decoded( - &hex::decode(concat!("01", concat!("00000002", "000000"))).unwrap() - ), - Err(CodecError::Other(_)) - ); } #[test] @@ -711,7 +733,8 @@ mod tests { VdafConfig::new( DpConfig::new(DpMechanism::None), VdafType::Prio3Histogram { - buckets: vec![0xFFFF], + length: 10, + chunk_length: 4, }, ) .unwrap(), @@ -742,11 +765,10 @@ mod tests { concat!( // vdaf_config "01", // dp_config - "00000002", // vdaf_type + "00000003", // vdaf_type concat!( - // buckets - "000008", // length - "000000000000FFFF" // bucket + "0000000A", // length + "00000004" // chunk_length ) ), ), diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index 739738d45..2c397276f 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -11,10 +11,7 @@ use fixed::types::extra::{U15, U31, U63}; #[cfg(feature = "fpvec_bounded_l2")] use fixed::{FixedI16, FixedI32, FixedI64}; use janus_collector::{default_http_client, AuthenticationToken, Collector, CollectorParameters}; -use janus_core::{ - hpke::{DivviUpHpkeConfig, HpkeKeypair, HpkePrivateKey}, - task::VdafInstance, -}; +use janus_core::hpke::{DivviUpHpkeConfig, HpkeKeypair, HpkePrivateKey}; use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, BatchId, Duration, FixedSizeQuery, HpkeConfig, Interval, PartialBatchSelector, Query, TaskId, @@ -454,8 +451,10 @@ where .map_err(|err| Error::Anyhow(err.into())) } (VdafType::CountVec, Some(length), None) => { - let vdaf = Prio3::new_sum_vec(2, 1, length, VdafInstance::chunk_size(length)) - .map_err(|err| Error::Anyhow(err.into()))?; + // We can take advantage of the fact that Prio3SumVec unsharding does not use the + // chunk_length parameter and avoid asking the user for it. + let vdaf = + Prio3::new_sum_vec(2, 1, length, 1).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) @@ -467,15 +466,19 @@ where .map_err(|err| Error::Anyhow(err.into())) } (VdafType::SumVec, Some(length), Some(bits)) => { - let vdaf = Prio3::new_sum_vec(2, bits, length, VdafInstance::chunk_size(bits * length)) - .map_err(|err| Error::Anyhow(err.into()))?; + // We can take advantage of the fact that Prio3SumVec unsharding does not use the + // chunk_length parameter and avoid asking the user for it. + let vdaf = + Prio3::new_sum_vec(2, bits, length, 1).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } (VdafType::Histogram, Some(length), None) => { - let vdaf = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) - .map_err(|err| Error::Anyhow(err.into()))?; + // We can take advantage of the fact that Prio3Histogram unsharding does not use the + // chunk_length parameter and avoid asking the user for it. + let vdaf = + Prio3::new_histogram(2, length, 1).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into()))