From 74aa9575a47df392de7f3e47b027a0e389d64737 Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Wed, 16 Aug 2023 13:55:42 -0700 Subject: [PATCH] Adopt `prio` 0.14.0 (#1673) The bulk of the changes here deal with the change to the representation of `Prio3Histogram`. Since `prio` 0.14.x implements VDAF-06, taking this change will break compatibility with DAP-04. This also breaks compatibility with existing `divviup-api` versions, because it has to deal with the new histogram representation. Integration tests with Daphne and divviup-ts are disabled for the same reason. --- Cargo.lock | 5 +- Cargo.toml | 2 +- aggregator/src/aggregator.rs | 4 +- .../src/aggregator/aggregation_job_creator.rs | 13 +-- aggregator_core/src/datastore/tests.rs | 14 +-- collector/src/lib.rs | 6 +- core/src/task.rs | 67 +++++------ integration_tests/src/client.rs | 10 +- integration_tests/src/divviup_api_client.rs | 16 ++- integration_tests/tests/common/mod.rs | 17 +-- integration_tests/tests/daphne.rs | 1 + integration_tests/tests/divviup_ts.rs | 7 +- integration_tests/tests/in_cluster.rs | 4 +- integration_tests/tests/janus.rs | 4 +- .../src/bin/janus_interop_client.rs | 6 +- .../src/bin/janus_interop_collector.rs | 8 +- interop_binaries/src/lib.rs | 12 +- interop_binaries/tests/end_to_end.rs | 14 +-- messages/Cargo.toml | 2 +- messages/src/taskprov.rs | 14 +-- tools/src/bin/collect.rs | 109 +++--------------- tools/tests/cmd/collect.trycmd | 5 +- .../tests/cmd/collect_fpvec_bounded_l2.trycmd | 7 +- 23 files changed, 110 insertions(+), 237 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 23f251cce..652a64725 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3033,9 +3033,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prio" -version = "0.12.2" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9028a8aba9ba6b647c6d6931c20473d1119079a68d9898c07a488c5180dccb58" +checksum = "e1139097c0aa90a7e476953f358c0cc25a627ede8ac0dd47e05594a37d665273" dependencies = [ "aes", "base64 0.21.2", @@ -3049,7 +3049,6 @@ dependencies = [ "rayon", "serde", "sha3", - "static_assertions", "subtle", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index bdc3654dd..aaaf2be3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ janus_messages = { version = "0.5", path = "messages" } k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] } opentelemetry = { version = "0.19", features = ["metrics"] } -prio = { version = "0.12.2", features = ["multithreaded"] } +prio = { version = "0.14.0", features = ["multithreaded"] } serde = { version = "1.0.183", features = ["derive"] } serde_json = "1.0.103" serde_test = "1.0.175" diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 8dc356677..2899685f3 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -787,8 +787,8 @@ impl TaskAggregator { VdafOps::Prio3SumVec(Arc::new(vdaf), verify_key) } - VdafInstance::Prio3Histogram { buckets } => { - let vdaf = Prio3::new_histogram(2, buckets)?; + VdafInstance::Prio3Histogram { length } => { + let vdaf = Prio3::new_histogram(2, *length)?; let verify_key = task.primary_vdaf_verify_key()?; VdafOps::Prio3Histogram(Arc::new(vdaf), verify_key) } diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index 00014f979..ba9ca1de8 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -290,8 +290,8 @@ impl AggregationJobCreator { .await } - (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { buckets }) => { - let vdaf = Arc::new(Prio3::new_histogram(2, buckets)?); + (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { length }) => { + let vdaf = Arc::new(Prio3::new_histogram(2, *length)?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -404,9 +404,9 @@ impl AggregationJobCreator { max_batch_size, batch_time_window_size, }, - VdafInstance::Prio3Histogram { buckets }, + VdafInstance::Prio3Histogram { length }, ) => { - let vdaf = Arc::new(Prio3::new_histogram(2, buckets)?); + let vdaf = Arc::new(Prio3::new_histogram(2, *length)?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< @@ -660,10 +660,7 @@ mod tests { }; use janus_core::{ task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, - test_util::{ - dummy_vdaf::{self}, - install_test_trace_subscriber, - }, + test_util::{dummy_vdaf, install_test_trace_subscriber}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; use janus_messages::{ diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index de42ba883..4f5b3a348 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -109,18 +109,8 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { (VdafInstance::Prio3CountVec { length: 64 }, Role::Helper), (VdafInstance::Prio3Sum { bits: 64 }, Role::Helper), (VdafInstance::Prio3Sum { bits: 32 }, Role::Helper), - ( - VdafInstance::Prio3Histogram { - buckets: Vec::from([0, 100, 200, 400]), - }, - Role::Leader, - ), - ( - VdafInstance::Prio3Histogram { - buckets: Vec::from([0, 25, 50, 75, 100]), - }, - Role::Leader, - ), + (VdafInstance::Prio3Histogram { length: 4 }, Role::Leader), + (VdafInstance::Prio3Histogram { length: 5 }, Role::Leader), (VdafInstance::Poplar1 { bits: 8 }, Role::Helper), (VdafInstance::Poplar1 { bits: 64 }, Role::Helper), ] { diff --git a/collector/src/lib.rs b/collector/src/lib.rs index b80282e00..2b6ed7d1c 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -1001,8 +1001,8 @@ mod tests { async fn successful_collect_prio3_histogram() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let vdaf = Prio3::new_histogram(2, &[25, 50, 75, 100]).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &80); + let vdaf = Prio3::new_histogram(2, 4).unwrap(); + let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &3); let collector = setup_collector(&mut server, vdaf); let batch_interval = Interval::new( @@ -1058,7 +1058,7 @@ mod tests { ), chrono::Duration::seconds(3600), ), - Vec::from([0, 0, 0, 1, 0]) + Vec::from([0, 0, 0, 1]) ) ); diff --git a/core/src/task.rs b/core/src/task.rs index 07076163f..f7de447eb 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -6,7 +6,7 @@ use rand::{distributions::Standard, prelude::Distribution}; use reqwest::Url; use ring::constant_time; use serde::{de::Error, Deserialize, Deserializer, Serialize}; -use std::{fmt, str}; +use std::str; /// HTTP header where auth tokens are provided in messages between participants. pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token"; @@ -30,11 +30,8 @@ pub enum VdafInstance { Prio3Sum { bits: usize }, /// A vector of `Prio3` sums. Prio3SumVec { bits: usize, length: usize }, - /// A `Prio3` histogram. - Prio3Histogram { - #[derivative(Debug(format_with = "bucket_count"))] - buckets: Vec, - }, + /// A `Prio3` histogram with `length` buckets in it. + Prio3Histogram { length: usize }, /// A `Prio3` 16-bit fixed point vector sum with bounded L2 norm. #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum { length: usize }, @@ -61,6 +58,22 @@ pub enum VdafInstance { FakeFailsPrepStep, } +impl VdafInstance { + /// Returns the expected length of a VDAF verification key for a VDAF of this type. + pub fn verify_key_length(&self) -> usize { + match self { + #[cfg(feature = "test-util")] + VdafInstance::Fake + | VdafInstance::FakeFailsPrepInit + | VdafInstance::FakeFailsPrepStep => 0, + + // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's + // not yet done being specified, so choosing 16 bytes is fine for testing.) + _ => PRIO3_VERIFY_KEY_LENGTH, + } + } +} + impl TryFrom<&taskprov::VdafType> for VdafInstance { type Error = &'static str; @@ -71,7 +84,10 @@ impl TryFrom<&taskprov::VdafType> for VdafInstance { bits: *bits as usize, }), taskprov::VdafType::Prio3Histogram { buckets } => Ok(Self::Prio3Histogram { - buckets: buckets.clone(), + // taskprov does not yet deal with the VDAF-06 representation of histograms. In the + // meantime, we translate the bucket boundaries to a length that Janus understands. + // https://github.com/wangshan/draft-wang-ppm-dap-taskprov/issues/33 + length: buckets.len() + 1, // +1 to account for the top bucket extending to infinity }), taskprov::VdafType::Poplar1 { bits } => Ok(Self::Poplar1 { bits: *bits as usize, @@ -81,26 +97,6 @@ impl TryFrom<&taskprov::VdafType> for VdafInstance { } } -fn bucket_count(buckets: &Vec, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "[{} buckets]", buckets.len() + 1) -} - -impl VdafInstance { - /// Returns the expected length of a VDAF verification key for a VDAF of this type. - pub fn verify_key_length(&self) -> usize { - match self { - #[cfg(feature = "test-util")] - VdafInstance::Fake - | VdafInstance::FakeFailsPrepInit - | VdafInstance::FakeFailsPrepStep => 0, - - // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's - // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, - } - } -} - /// Internal implementation details of [`vdaf_dispatch`](crate::vdaf_dispatch). #[macro_export] macro_rules! vdaf_dispatch_impl_base { @@ -174,8 +170,8 @@ macro_rules! vdaf_dispatch_impl_base { $body } - ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { - let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, buckets)?; + ::janus_core::task::VdafInstance::Prio3Histogram { length } => { + let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; $body @@ -768,22 +764,15 @@ mod tests { ], ); assert_tokens( - &VdafInstance::Prio3Histogram { - buckets: Vec::from([0, 100, 200, 400]), - }, + &VdafInstance::Prio3Histogram { length: 6 }, &[ Token::StructVariant { name: "VdafInstance", variant: "Prio3Histogram", len: 1, }, - Token::Str("buckets"), - Token::Seq { len: Some(4) }, - Token::U64(0), - Token::U64(100), - Token::U64(200), - Token::U64(400), - Token::SeqEnd, + Token::Str("length"), + Token::U64(6), Token::StructVariantEnd, ], ); diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index 6b5fb45bd..28f00218c 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -71,16 +71,10 @@ fn json_encode_vdaf(vdaf: &VdafInstance) -> Value { "bits": format!("{bits}"), "length": format!("{length}"), }), - VdafInstance::Prio3Histogram { buckets } => { - let buckets = Value::Array( - buckets - .iter() - .map(|value| Value::String(format!("{value}"))) - .collect(), - ); + VdafInstance::Prio3Histogram { length } => { json!({ "type": "Prio3Histogram", - "buckets": buckets, + "length": format!("{length}"), }) } _ => panic!("VDAF {vdaf:?} is not yet supported"), diff --git a/integration_tests/src/divviup_api_client.rs b/integration_tests/src/divviup_api_client.rs index cdb323d80..71622a866 100644 --- a/integration_tests/src/divviup_api_client.rs +++ b/integration_tests/src/divviup_api_client.rs @@ -1,4 +1,4 @@ -use anyhow::anyhow; +use anyhow::{anyhow, Context}; use http::{ header::{ACCEPT, CONTENT_TYPE}, Method, @@ -35,9 +35,17 @@ impl TryFrom<&VdafInstance> for ApiVdaf { match vdaf { VdafInstance::Prio3Count => Ok(ApiVdaf::Count), VdafInstance::Prio3Sum { bits } => Ok(ApiVdaf::Sum { bits: *bits }), - VdafInstance::Prio3Histogram { buckets } => Ok(ApiVdaf::Histogram { - buckets: buckets.clone(), - }), + VdafInstance::Prio3Histogram { length } => { + // divviup-api does not yet support the new Prio3Histogram representation. Until it + // does, we synthesize fake bucket boundaries that will yield the number of buckets + // we want. + // https://github.com/divviup/divviup-api/issues/410 + Ok(ApiVdaf::Histogram { + buckets: (0..*length - 1) + .map(|length| u64::try_from(length).context("cannot convert length to u64")) + .collect::, _>>()?, + }) + } _ => Err(anyhow!("unsupported VDAF: {vdaf:?}")), } } diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index 3542bef9e..645cb18c9 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -322,21 +322,14 @@ pub async fn submit_measurements_and_verify_aggregate( ) .await; } - VdafInstance::Prio3Histogram { buckets } => { - let vdaf = Prio3::new_histogram(2, buckets).unwrap(); + VdafInstance::Prio3Histogram { length } => { + let vdaf = Prio3::new_histogram(2, *length).unwrap(); - let mut aggregate_result = vec![0; buckets.len() + 1]; - aggregate_result.resize(buckets.len() + 1, 0); + let mut aggregate_result = vec![0; *length]; let measurements = iter::repeat_with(|| { - let choice = thread_rng().gen_range(0..=buckets.len()); + let choice = thread_rng().gen_range(0..*length); aggregate_result[choice] += 1; - let measurement = if choice == buckets.len() { - // This goes into the counter covering the range that extends to positive infinity. - buckets[buckets.len() - 1] + 1 - } else { - buckets[choice] - }; - measurement as u128 + choice }) .take(total_measurements) .collect::>(); diff --git a/integration_tests/tests/daphne.rs b/integration_tests/tests/daphne.rs index e351e6070..2c6ee5f28 100644 --- a/integration_tests/tests/daphne.rs +++ b/integration_tests/tests/daphne.rs @@ -49,6 +49,7 @@ async fn daphne_janus() { // This test places Janus in the leader role & Daphne in the helper role. #[tokio::test(flavor = "multi_thread")] +#[ignore = "Daphne does not currently support DAP-05 (issue #1669)"] async fn janus_daphne() { install_test_trace_subscriber(); diff --git a/integration_tests/tests/divviup_ts.rs b/integration_tests/tests/divviup_ts.rs index 3aecb6648..c16cfddf7 100644 --- a/integration_tests/tests/divviup_ts.rs +++ b/integration_tests/tests/divviup_ts.rs @@ -36,6 +36,7 @@ async fn run_divviup_ts_integration_test(container_client: &Cli, vdaf: VdafInsta } #[tokio::test(flavor = "multi_thread")] +#[ignore = "divviup-ts does not currently support DAP-05 (issue #1669)"] async fn janus_divviup_ts_count() { install_test_trace_subscriber(); @@ -43,6 +44,7 @@ async fn janus_divviup_ts_count() { } #[tokio::test(flavor = "multi_thread")] +#[ignore = "divviup-ts does not currently support DAP-05 (issue #1669)"] async fn janus_divviup_ts_sum() { install_test_trace_subscriber(); @@ -50,14 +52,13 @@ async fn janus_divviup_ts_sum() { } #[tokio::test(flavor = "multi_thread")] +#[ignore = "divviup-ts does not currently support DAP-05 (issue #1669)"] async fn janus_divviup_ts_histogram() { install_test_trace_subscriber(); run_divviup_ts_integration_test( &container_client(), - VdafInstance::Prio3Histogram { - buckets: Vec::from([1, 10, 100, 1000]), - }, + VdafInstance::Prio3Histogram { length: 4 }, ) .await; } diff --git a/integration_tests/tests/in_cluster.rs b/integration_tests/tests/in_cluster.rs index 9c4e35ab1..2faf7b8db 100644 --- a/integration_tests/tests/in_cluster.rs +++ b/integration_tests/tests/in_cluster.rs @@ -252,13 +252,13 @@ async fn in_cluster_sum() { } #[tokio::test(flavor = "multi_thread")] +#[ignore = "divviup-api does not currently support DAP-05 (https://github.com/divviup/divviup-api/issues/410)"] async fn in_cluster_histogram() { install_test_trace_subscriber(); // Start port forwards and set up task. - let buckets = Vec::from([3, 6, 8]); let janus_pair = InClusterJanusPair::new( - VdafInstance::Prio3Histogram { buckets }, + VdafInstance::Prio3Histogram { length: 4 }, QueryType::TimeInterval, ) .await; diff --git a/integration_tests/tests/janus.rs b/integration_tests/tests/janus.rs index a7af9113d..59137f2f6 100644 --- a/integration_tests/tests/janus.rs +++ b/integration_tests/tests/janus.rs @@ -95,13 +95,11 @@ async fn janus_janus_sum_16() { async fn janus_janus_histogram_4_buckets() { install_test_trace_subscriber(); - let buckets = Vec::from([3, 6, 8]); - // Start servers. let container_client = container_client(); let janus_pair = JanusPair::new( &container_client, - VdafInstance::Prio3Histogram { buckets }, + VdafInstance::Prio3Histogram { length: 4 }, QueryType::TimeInterval, ) .await; diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index c22701163..cca595006 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -168,9 +168,9 @@ async fn handle_upload( handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } - VdafInstance::Prio3Histogram { ref buckets } => { - let measurement = parse_primitive_measurement::(request.measurement.clone())?; - let vdaf_client = Prio3::new_histogram(2, buckets) + VdafInstance::Prio3Histogram { length } => { + let measurement = parse_primitive_measurement::(request.measurement.clone())?; + let vdaf_client = Prio3::new_histogram(2, length) .context("failed to construct Prio3Histogram VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs index d25f674d5..616f8382b 100644 --- a/interop_binaries/src/bin/janus_interop_collector.rs +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -361,8 +361,8 @@ async fn handle_collection_start( .await? } - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Histogram { buckets }) => { - let vdaf = Prio3::new_histogram(2, &buckets) + (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Histogram { length }) => { + let vdaf = Prio3::new_histogram(2, length) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, @@ -581,8 +581,8 @@ async fn handle_collection_start( .await? } - (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3Histogram { buckets }) => { - let vdaf = Prio3::new_histogram(2, &buckets) + (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3Histogram { length }) => { + let vdaf = Prio3::new_histogram(2, length) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index e7861934b..926b8535a 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -117,7 +117,7 @@ pub enum VdafObject { length: NumberAsString, }, Prio3Histogram { - buckets: Vec>, + length: NumberAsString, }, #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum { @@ -151,8 +151,8 @@ impl From for VdafObject { length: NumberAsString(length), }, - VdafInstance::Prio3Histogram { buckets } => VdafObject::Prio3Histogram { - buckets: buckets.iter().copied().map(NumberAsString).collect(), + VdafInstance::Prio3Histogram { length } => VdafObject::Prio3Histogram { + length: NumberAsString(length), }, #[cfg(feature = "fpvec_bounded_l2")] @@ -196,9 +196,9 @@ impl From for VdafInstance { length: length.0, }, - VdafObject::Prio3Histogram { buckets } => VdafInstance::Prio3Histogram { - buckets: buckets.iter().map(|value| value.0).collect(), - }, + VdafObject::Prio3Histogram { length } => { + VdafInstance::Prio3Histogram { length: length.0 } + } #[cfg(feature = "fpvec_bounded_l2")] VdafObject::Prio3FixedPoint16BitBoundedL2VecSum { length } => { diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index 50cbf21fe..325db2025 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -642,19 +642,15 @@ async fn e2e_prio3_histogram() { QueryKind::TimeInterval, json!({ "type": "Prio3Histogram", - "buckets": ["0", "1", "10", "100", "1000", "10000", "100000"], + "length": "6", }), &[ + json!("0"), json!("1"), + json!("2"), + json!("3"), json!("4"), - json!("16"), - json!("64"), - json!("256"), - json!("1024"), - json!("4096"), - json!("16384"), - json!("65536"), - json!("262144"), + json!("5"), ], b"", ) diff --git a/messages/Cargo.toml b/messages/Cargo.toml index bfa535a07..b5888b5c7 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -20,7 +20,7 @@ hex = "0.4" num_enum = "0.7.0" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.12.2", default-features = false } +prio = { version = "0.14.0", default-features = false } rand = "0.8" serde.workspace = true thiserror.workspace = true diff --git a/messages/src/taskprov.rs b/messages/src/taskprov.rs index 3d9a4f276..72036498c 100644 --- a/messages/src/taskprov.rs +++ b/messages/src/taskprov.rs @@ -9,10 +9,7 @@ use prio::codec::{ decode_u16_items, decode_u24_items, decode_u8_items, encode_u16_items, encode_u24_items, encode_u8_items, CodecError, Decode, Encode, }; -use std::{ - fmt::{self, Debug, Formatter}, - io::Cursor, -}; +use std::{fmt::Debug, io::Cursor}; /// Defines all parameters necessary to configure an aggregator with a new task. /// Provided by taskprov participants in all requests incident to task execution. @@ -325,8 +322,9 @@ pub enum VdafType { bits: u8, }, Prio3Histogram { - /// List of buckets. - #[derivative(Debug(format_with = "fmt_histogram"))] + /// Number of buckets in the histogram + // This may change as the taskprov draft adapts to VDAF-06 + // https://github.com/wangshan/draft-wang-ppm-dap-taskprov/issues/33 buckets: Vec, }, Poplar1 { @@ -342,10 +340,6 @@ impl VdafType { const POPLAR1: u32 = 0x00001000; } -fn fmt_histogram(buckets: &Vec, f: &mut Formatter) -> Result<(), fmt::Error> { - write!(f, "num_buckets: {}", buckets.len()) -} - impl Encode for VdafType { fn encode(&self, bytes: &mut Vec) { match self { diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index cdf6c5eb5..eaf8c57d8 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -203,41 +203,6 @@ impl TypedValueParser for PrivateKeyValueParser { } } -#[derive(Debug, Clone, PartialEq, Eq)] -struct Buckets(Vec); - -#[derive(Clone)] -struct BucketsValueParser { - inner: NonEmptyStringValueParser, -} - -impl BucketsValueParser { - fn new() -> BucketsValueParser { - BucketsValueParser { - inner: NonEmptyStringValueParser::new(), - } - } -} - -impl TypedValueParser for BucketsValueParser { - type Value = Buckets; - - fn parse_ref( - &self, - cmd: &clap::Command, - arg: Option<&clap::Arg>, - value: &std::ffi::OsStr, - ) -> Result { - let input = self.inner.parse_ref(cmd, arg, value)?; - input - .split(',') - .map(|chunk| chunk.trim().parse()) - .collect::, _>>() - .map(Buckets) - .map_err(|err| clap::Error::raw(ErrorKind::ValueValidation, err)) - } -} - #[derive(Derivative, Args, PartialEq, Eq)] #[derivative(Debug)] #[group(required = true)] @@ -355,22 +320,13 @@ struct Options { display_order = 0 )] vdaf: VdafType, - /// Number of vector elements, for use with --vdaf=countvec and --vdaf=sumvec + /// Number of vector elements, when used with --vdaf=countvec and --vdaf=sumvec or number of + /// histogram buckets, when used with --vdaf=histogram #[clap(long, help_heading = "VDAF Algorithm and Parameters")] length: Option, /// Bit length of measurements, for use with --vdaf=sum and --vdaf=sumvec #[clap(long, help_heading = "VDAF Algorithm and Parameters")] bits: Option, - /// Comma-separated list of bucket boundaries, for use with --vdaf=histogram - #[clap( - long, - required = false, - num_args = 1, - action = ArgAction::Set, - value_parser = BucketsValueParser::new(), - help_heading = "VDAF Algorithm and Parameters" - )] - buckets: Option, #[clap(flatten)] query: QueryOptions, @@ -448,41 +404,40 @@ where options.hpke_private_key.clone(), ); let http_client = default_http_client().map_err(|err| Error::Anyhow(err.into()))?; - match (options.vdaf, options.length, options.bits, options.buckets) { - (VdafType::Count, None, None, None) => { + match (options.vdaf, options.length, options.bits) { + (VdafType::Count, None, None) => { let vdaf = Prio3::new_count(2).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } - (VdafType::CountVec, Some(length), None, None) => { + (VdafType::CountVec, Some(length), None) => { let vdaf = Prio3::new_sum_vec(2, 1, length).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } - (VdafType::Sum, None, Some(bits), None) => { + (VdafType::Sum, None, Some(bits)) => { let vdaf = Prio3::new_sum(2, bits).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } - (VdafType::SumVec, Some(length), Some(bits), None) => { + (VdafType::SumVec, Some(length), Some(bits)) => { let vdaf = Prio3::new_sum_vec(2, bits, length).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } - (VdafType::Histogram, None, None, Some(ref buckets)) => { - let vdaf = - Prio3::new_histogram(2, &buckets.0).map_err(|err| Error::Anyhow(err.into()))?; + (VdafType::Histogram, Some(length), None) => { + let vdaf = Prio3::new_histogram(2, length).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } #[cfg(feature = "fpvec_bounded_l2")] - (VdafType::FixedPoint16BitBoundedL2VecSum, Some(length), None, None) => { + (VdafType::FixedPoint16BitBoundedL2VecSum, Some(length), None) => { let vdaf: Prio3FixedPointBoundedL2VecSumMultithreaded> = Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, length) .map_err(|err| Error::Anyhow(err.into()))?; @@ -491,7 +446,7 @@ where .map_err(|err| Error::Anyhow(err.into())) } #[cfg(feature = "fpvec_bounded_l2")] - (VdafType::FixedPoint32BitBoundedL2VecSum, Some(length), None, None) => { + (VdafType::FixedPoint32BitBoundedL2VecSum, Some(length), None) => { let vdaf: Prio3FixedPointBoundedL2VecSumMultithreaded> = Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, length) .map_err(|err| Error::Anyhow(err.into()))?; @@ -500,7 +455,7 @@ where .map_err(|err| Error::Anyhow(err.into())) } #[cfg(feature = "fpvec_bounded_l2")] - (VdafType::FixedPoint64BitBoundedL2VecSum, Some(length), None, None) => { + (VdafType::FixedPoint64BitBoundedL2VecSum, Some(length), None) => { let vdaf: Prio3FixedPointBoundedL2VecSumMultithreaded> = Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, length) .map_err(|err| Error::Anyhow(err.into()))?; @@ -634,7 +589,6 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, - buckets: None, query: QueryOptions { batch_interval_start: Some(1_000_000), batch_interval_duration: Some(1_000), @@ -719,33 +673,6 @@ mod tests { "1000".to_string(), ]); - let mut bad_arguments = base_arguments.clone(); - bad_arguments.extend(["--vdaf=count".to_string(), "--buckets=1,2,3,4".to_string()]); - let bad_options = Options::try_parse_from(bad_arguments).unwrap(); - assert_matches!( - run(bad_options).await.unwrap_err(), - Error::Clap(err) => assert_eq!(err.kind(), ErrorKind::ArgumentConflict) - ); - - let mut bad_arguments = base_arguments.clone(); - bad_arguments.extend(["--vdaf=sum".to_string(), "--buckets=1,2,3,4".to_string()]); - let bad_options = Options::try_parse_from(bad_arguments).unwrap(); - assert_matches!( - run(bad_options).await.unwrap_err(), - Error::Clap(err) => assert_eq!(err.kind(), ErrorKind::ArgumentConflict) - ); - - let mut bad_arguments = base_arguments.clone(); - bad_arguments.extend([ - "--vdaf=countvec".to_string(), - "--buckets=1,2,3,4".to_string(), - ]); - let bad_options = Options::try_parse_from(bad_arguments).unwrap(); - assert_matches!( - run(bad_options).await.unwrap_err(), - Error::Clap(err) => assert_eq!(err.kind(), ErrorKind::ArgumentConflict) - ); - let mut bad_arguments = base_arguments.clone(); bad_arguments.extend(["--vdaf=countvec".to_string(), "--bits=3".to_string()]); let bad_options = Options::try_parse_from(bad_arguments).unwrap(); @@ -791,10 +718,7 @@ mod tests { } let mut bad_arguments = base_arguments.clone(); - bad_arguments.extend([ - "--vdaf=histogram".to_string(), - "--buckets=1,2,3,4,apple".to_string(), - ]); + bad_arguments.extend(["--vdaf=histogram".to_string(), "--length=apple".to_string()]); assert_eq!( Options::try_parse_from(bad_arguments).unwrap_err().kind(), ErrorKind::ValueValidation @@ -833,10 +757,7 @@ mod tests { Options::try_parse_from(good_arguments).unwrap(); let mut good_arguments = base_arguments.clone(); - good_arguments.extend([ - "--vdaf=histogram".to_string(), - "--buckets=1,2,3,4".to_string(), - ]); + good_arguments.extend(["--vdaf=histogram".to_string(), "--length=4".to_string()]); Options::try_parse_from(good_arguments).unwrap(); #[cfg(feature = "fpvec_bounded_l2")] @@ -889,7 +810,6 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, - buckets: None, query: QueryOptions { batch_interval_start: None, batch_interval_duration: None, @@ -929,7 +849,6 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, - buckets: None, query: QueryOptions { batch_interval_start: None, batch_interval_duration: None, diff --git a/tools/tests/cmd/collect.trycmd b/tools/tests/cmd/collect.trycmd index 18e911e4f..041e85c4c 100644 --- a/tools/tests/cmd/collect.trycmd +++ b/tools/tests/cmd/collect.trycmd @@ -49,14 +49,11 @@ VDAF Algorithm and Parameters: - histogram: Prio3Histogram --length - Number of vector elements, for use with --vdaf=countvec and --vdaf=sumvec + Number of vector elements, when used with --vdaf=countvec and --vdaf=sumvec or number of histogram buckets, when used with --vdaf=histogram --bits Bit length of measurements, for use with --vdaf=sum and --vdaf=sumvec - --buckets - Comma-separated list of bucket boundaries, for use with --vdaf=histogram - Collect Request Parameters (Time Interval): --batch-interval-start Start of the collection batch interval, as the number of seconds since the Unix epoch diff --git a/tools/tests/cmd/collect_fpvec_bounded_l2.trycmd b/tools/tests/cmd/collect_fpvec_bounded_l2.trycmd index 2f388204d..243c53644 100644 --- a/tools/tests/cmd/collect_fpvec_bounded_l2.trycmd +++ b/tools/tests/cmd/collect_fpvec_bounded_l2.trycmd @@ -33,7 +33,7 @@ Authorization: [env: DAP_AUTH_TOKEN=] --authorization-bearer-token - Authentication token for the "Authorization: Bearer ..." HTTP header, in base64 + Authentication token for the "Authorization: Bearer ..." HTTP header [env: AUTHORIZATION_BEARER_TOKEN=] @@ -52,14 +52,11 @@ VDAF Algorithm and Parameters: - fixedpoint64bitboundedl2vecsum: Prio3FixedPoint64BitBoundedL2VecSum --length - Number of vector elements, for use with --vdaf=countvec and --vdaf=sumvec + Number of vector elements, when used with --vdaf=countvec and --vdaf=sumvec or number of histogram buckets, when used with --vdaf=histogram --bits Bit length of measurements, for use with --vdaf=sum and --vdaf=sumvec - --buckets - Comma-separated list of bucket boundaries, for use with --vdaf=histogram - Collect Request Parameters (Time Interval): --batch-interval-start Start of the collection batch interval, as the number of seconds since the Unix epoch