From 7029084c7e1c9c04f4f6c5e280ae695c8dd37a22 Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 3 Aug 2023 18:14:42 -0700 Subject: [PATCH] Adopt `prio` 0.14.0 The bulk of the changes here deal with the change to the representation of `Prio3Histogram`. Since `prio` 0.14.x implements VDAF-06, taking this change will break compatibility with DAP-04. Part of #1669 --- Cargo.lock | 5 +- Cargo.toml | 2 +- aggregator/src/aggregator.rs | 2 +- .../src/aggregator/aggregation_job_creator.rs | 9 ++-- aggregator_core/src/datastore/tests.rs | 14 +---- collector/src/lib.rs | 6 +-- core/src/task.rs | 52 ++++++++----------- integration_tests/src/client.rs | 8 +-- integration_tests/src/divviup_api_client.rs | 8 +-- integration_tests/tests/common/mod.rs | 16 ++---- integration_tests/tests/divviup_ts.rs | 4 +- integration_tests/tests/in_cluster.rs | 3 +- integration_tests/tests/janus.rs | 4 +- .../src/bin/janus_interop_client.rs | 4 +- .../src/bin/janus_interop_collector.rs | 4 +- interop_binaries/src/lib.rs | 10 ++-- messages/Cargo.toml | 2 +- tools/src/bin/collect.rs | 43 ++------------- 18 files changed, 59 insertions(+), 137 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78973618c..415c3fed3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3033,9 +3033,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prio" -version = "0.12.2" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9028a8aba9ba6b647c6d6931c20473d1119079a68d9898c07a488c5180dccb58" +checksum = "e1139097c0aa90a7e476953f358c0cc25a627ede8ac0dd47e05594a37d665273" dependencies = [ "aes", "base64 0.21.2", @@ -3049,7 +3049,6 @@ dependencies = [ "rayon", "serde", "sha3", - "static_assertions", "subtle", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index 75fb1e72b..2e8cdc0eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ janus_messages = { version = "0.5", path = "messages" } k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] } opentelemetry = { version = "0.19", features = ["metrics"] } -prio = { version = "0.12.2", features = ["multithreaded"] } +prio = { version = "0.14.0", features = ["multithreaded"] } serde = { version = "1.0.183", features = ["derive"] } serde_json = "1.0.103" serde_test = "1.0.175" diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 8dc356677..457a59362 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -788,7 +788,7 @@ impl TaskAggregator { } VdafInstance::Prio3Histogram { buckets } => { - let vdaf = Prio3::new_histogram(2, buckets)?; + let vdaf = Prio3::new_histogram(2, *buckets)?; let verify_key = task.primary_vdaf_verify_key()?; VdafOps::Prio3Histogram(Arc::new(vdaf), verify_key) } diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index 00014f979..7cfd6e0f2 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -291,7 +291,7 @@ impl AggregationJobCreator { } (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { buckets }) => { - let vdaf = Arc::new(Prio3::new_histogram(2, buckets)?); + let vdaf = Arc::new(Prio3::new_histogram(2, *buckets)?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -406,7 +406,7 @@ impl AggregationJobCreator { }, VdafInstance::Prio3Histogram { buckets }, ) => { - let vdaf = Arc::new(Prio3::new_histogram(2, buckets)?); + let vdaf = Arc::new(Prio3::new_histogram(2, *buckets)?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< @@ -660,10 +660,7 @@ mod tests { }; use janus_core::{ task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, - test_util::{ - dummy_vdaf::{self}, - install_test_trace_subscriber, - }, + test_util::{dummy_vdaf, install_test_trace_subscriber}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; use janus_messages::{ diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 4378fdc18..53f0bc792 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -109,18 +109,8 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { (VdafInstance::Prio3CountVec { length: 64 }, Role::Helper), (VdafInstance::Prio3Sum { bits: 64 }, Role::Helper), (VdafInstance::Prio3Sum { bits: 32 }, Role::Helper), - ( - VdafInstance::Prio3Histogram { - buckets: Vec::from([0, 100, 200, 400]), - }, - Role::Leader, - ), - ( - VdafInstance::Prio3Histogram { - buckets: Vec::from([0, 25, 50, 75, 100]), - }, - Role::Leader, - ), + (VdafInstance::Prio3Histogram { buckets: 4 }, Role::Leader), + (VdafInstance::Prio3Histogram { buckets: 5 }, Role::Leader), (VdafInstance::Poplar1 { bits: 8 }, Role::Helper), (VdafInstance::Poplar1 { bits: 64 }, Role::Helper), ] { diff --git a/collector/src/lib.rs b/collector/src/lib.rs index b80282e00..2b6ed7d1c 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -1001,8 +1001,8 @@ mod tests { async fn successful_collect_prio3_histogram() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let vdaf = Prio3::new_histogram(2, &[25, 50, 75, 100]).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &80); + let vdaf = Prio3::new_histogram(2, 4).unwrap(); + let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &3); let collector = setup_collector(&mut server, vdaf); let batch_interval = Interval::new( @@ -1058,7 +1058,7 @@ mod tests { ), chrono::Duration::seconds(3600), ), - Vec::from([0, 0, 0, 1, 0]) + Vec::from([0, 0, 0, 1]) ) ); diff --git a/core/src/task.rs b/core/src/task.rs index 07076163f..0228097ca 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -6,7 +6,7 @@ use rand::{distributions::Standard, prelude::Distribution}; use reqwest::Url; use ring::constant_time; use serde::{de::Error, Deserialize, Deserializer, Serialize}; -use std::{fmt, str}; +use std::str; /// HTTP header where auth tokens are provided in messages between participants. pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token"; @@ -31,10 +31,7 @@ pub enum VdafInstance { /// A vector of `Prio3` sums. Prio3SumVec { bits: usize, length: usize }, /// A `Prio3` histogram. - Prio3Histogram { - #[derivative(Debug(format_with = "bucket_count"))] - buckets: Vec, - }, + Prio3Histogram { buckets: usize }, /// A `Prio3` 16-bit fixed point vector sum with bounded L2 norm. #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum { length: usize }, @@ -61,6 +58,22 @@ pub enum VdafInstance { FakeFailsPrepStep, } +impl VdafInstance { + /// Returns the expected length of a VDAF verification key for a VDAF of this type. + pub fn verify_key_length(&self) -> usize { + match self { + #[cfg(feature = "test-util")] + VdafInstance::Fake + | VdafInstance::FakeFailsPrepInit + | VdafInstance::FakeFailsPrepStep => 0, + + // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's + // not yet done being specified, so choosing 16 bytes is fine for testing.) + _ => PRIO3_VERIFY_KEY_LENGTH, + } + } +} + impl TryFrom<&taskprov::VdafType> for VdafInstance { type Error = &'static str; @@ -85,22 +98,6 @@ fn bucket_count(buckets: &Vec, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "[{} buckets]", buckets.len() + 1) } -impl VdafInstance { - /// Returns the expected length of a VDAF verification key for a VDAF of this type. - pub fn verify_key_length(&self) -> usize { - match self { - #[cfg(feature = "test-util")] - VdafInstance::Fake - | VdafInstance::FakeFailsPrepInit - | VdafInstance::FakeFailsPrepStep => 0, - - // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's - // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, - } - } -} - /// Internal implementation details of [`vdaf_dispatch`](crate::vdaf_dispatch). #[macro_export] macro_rules! vdaf_dispatch_impl_base { @@ -175,7 +172,7 @@ macro_rules! vdaf_dispatch_impl_base { } ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { - let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, buckets)?; + let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, *buckets)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; $body @@ -768,9 +765,7 @@ mod tests { ], ); assert_tokens( - &VdafInstance::Prio3Histogram { - buckets: Vec::from([0, 100, 200, 400]), - }, + &VdafInstance::Prio3Histogram { buckets: 6 }, &[ Token::StructVariant { name: "VdafInstance", @@ -778,12 +773,7 @@ mod tests { len: 1, }, Token::Str("buckets"), - Token::Seq { len: Some(4) }, - Token::U64(0), - Token::U64(100), - Token::U64(200), - Token::U64(400), - Token::SeqEnd, + Token::U64(6), Token::StructVariantEnd, ], ); diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index 6b5fb45bd..3cd5fadfa 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -72,15 +72,9 @@ fn json_encode_vdaf(vdaf: &VdafInstance) -> Value { "length": format!("{length}"), }), VdafInstance::Prio3Histogram { buckets } => { - let buckets = Value::Array( - buckets - .iter() - .map(|value| Value::String(format!("{value}"))) - .collect(), - ); json!({ "type": "Prio3Histogram", - "buckets": buckets, + "buckets": format!("{buckets}"), }) } _ => panic!("VDAF {vdaf:?} is not yet supported"), diff --git a/integration_tests/src/divviup_api_client.rs b/integration_tests/src/divviup_api_client.rs index cdb323d80..b7d293239 100644 --- a/integration_tests/src/divviup_api_client.rs +++ b/integration_tests/src/divviup_api_client.rs @@ -21,7 +21,7 @@ pub enum ApiVdaf { /// Corresponds to Prio3Count Count, Histogram { - buckets: Vec, + buckets: usize, }, Sum { bits: usize, @@ -35,9 +35,9 @@ impl TryFrom<&VdafInstance> for ApiVdaf { match vdaf { VdafInstance::Prio3Count => Ok(ApiVdaf::Count), VdafInstance::Prio3Sum { bits } => Ok(ApiVdaf::Sum { bits: *bits }), - VdafInstance::Prio3Histogram { buckets } => Ok(ApiVdaf::Histogram { - buckets: buckets.clone(), - }), + VdafInstance::Prio3Histogram { buckets } => { + Ok(ApiVdaf::Histogram { buckets: *buckets }) + } _ => Err(anyhow!("unsupported VDAF: {vdaf:?}")), } } diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index 3542bef9e..5ee315d5e 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -323,20 +323,14 @@ pub async fn submit_measurements_and_verify_aggregate( .await; } VdafInstance::Prio3Histogram { buckets } => { - let vdaf = Prio3::new_histogram(2, buckets).unwrap(); + let vdaf = Prio3::new_histogram(2, *buckets).unwrap(); - let mut aggregate_result = vec![0; buckets.len() + 1]; - aggregate_result.resize(buckets.len() + 1, 0); + let mut aggregate_result = vec![0; *buckets]; + aggregate_result.resize(*buckets, 0); let measurements = iter::repeat_with(|| { - let choice = thread_rng().gen_range(0..=buckets.len()); + let choice = thread_rng().gen_range(0..*buckets); aggregate_result[choice] += 1; - let measurement = if choice == buckets.len() { - // This goes into the counter covering the range that extends to positive infinity. - buckets[buckets.len() - 1] + 1 - } else { - buckets[choice] - }; - measurement as u128 + choice }) .take(total_measurements) .collect::>(); diff --git a/integration_tests/tests/divviup_ts.rs b/integration_tests/tests/divviup_ts.rs index 3aecb6648..e7b78bcf3 100644 --- a/integration_tests/tests/divviup_ts.rs +++ b/integration_tests/tests/divviup_ts.rs @@ -55,9 +55,7 @@ async fn janus_divviup_ts_histogram() { run_divviup_ts_integration_test( &container_client(), - VdafInstance::Prio3Histogram { - buckets: Vec::from([1, 10, 100, 1000]), - }, + VdafInstance::Prio3Histogram { buckets: 4 }, ) .await; } diff --git a/integration_tests/tests/in_cluster.rs b/integration_tests/tests/in_cluster.rs index 9c4e35ab1..aa5976b48 100644 --- a/integration_tests/tests/in_cluster.rs +++ b/integration_tests/tests/in_cluster.rs @@ -256,9 +256,8 @@ async fn in_cluster_histogram() { install_test_trace_subscriber(); // Start port forwards and set up task. - let buckets = Vec::from([3, 6, 8]); let janus_pair = InClusterJanusPair::new( - VdafInstance::Prio3Histogram { buckets }, + VdafInstance::Prio3Histogram { buckets: 4 }, QueryType::TimeInterval, ) .await; diff --git a/integration_tests/tests/janus.rs b/integration_tests/tests/janus.rs index a7af9113d..5a41d0157 100644 --- a/integration_tests/tests/janus.rs +++ b/integration_tests/tests/janus.rs @@ -95,13 +95,11 @@ async fn janus_janus_sum_16() { async fn janus_janus_histogram_4_buckets() { install_test_trace_subscriber(); - let buckets = Vec::from([3, 6, 8]); - // Start servers. let container_client = container_client(); let janus_pair = JanusPair::new( &container_client, - VdafInstance::Prio3Histogram { buckets }, + VdafInstance::Prio3Histogram { buckets: 4 }, QueryType::TimeInterval, ) .await; diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index c22701163..21497be9d 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -168,8 +168,8 @@ async fn handle_upload( handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } - VdafInstance::Prio3Histogram { ref buckets } => { - let measurement = parse_primitive_measurement::(request.measurement.clone())?; + VdafInstance::Prio3Histogram { buckets } => { + let measurement = parse_primitive_measurement::(request.measurement.clone())?; let vdaf_client = Prio3::new_histogram(2, buckets) .context("failed to construct Prio3Histogram VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs index d25f674d5..db063f2f3 100644 --- a/interop_binaries/src/bin/janus_interop_collector.rs +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -362,7 +362,7 @@ async fn handle_collection_start( } (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Histogram { buckets }) => { - let vdaf = Prio3::new_histogram(2, &buckets) + let vdaf = Prio3::new_histogram(2, buckets) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, @@ -582,7 +582,7 @@ async fn handle_collection_start( } (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3Histogram { buckets }) => { - let vdaf = Prio3::new_histogram(2, &buckets) + let vdaf = Prio3::new_histogram(2, buckets) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index e7861934b..5b6e5ed45 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -117,7 +117,7 @@ pub enum VdafObject { length: NumberAsString, }, Prio3Histogram { - buckets: Vec>, + buckets: NumberAsString, }, #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum { @@ -152,7 +152,7 @@ impl From for VdafObject { }, VdafInstance::Prio3Histogram { buckets } => VdafObject::Prio3Histogram { - buckets: buckets.iter().copied().map(NumberAsString).collect(), + buckets: NumberAsString(buckets), }, #[cfg(feature = "fpvec_bounded_l2")] @@ -196,9 +196,9 @@ impl From for VdafInstance { length: length.0, }, - VdafObject::Prio3Histogram { buckets } => VdafInstance::Prio3Histogram { - buckets: buckets.iter().map(|value| value.0).collect(), - }, + VdafObject::Prio3Histogram { buckets } => { + VdafInstance::Prio3Histogram { buckets: buckets.0 } + } #[cfg(feature = "fpvec_bounded_l2")] VdafObject::Prio3FixedPoint16BitBoundedL2VecSum { length } => { diff --git a/messages/Cargo.toml b/messages/Cargo.toml index bfa535a07..b5888b5c7 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -20,7 +20,7 @@ hex = "0.4" num_enum = "0.7.0" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.12.2", default-features = false } +prio = { version = "0.14.0", default-features = false } rand = "0.8" serde.workspace = true thiserror.workspace = true diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index cdf6c5eb5..097588fe9 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -203,41 +203,6 @@ impl TypedValueParser for PrivateKeyValueParser { } } -#[derive(Debug, Clone, PartialEq, Eq)] -struct Buckets(Vec); - -#[derive(Clone)] -struct BucketsValueParser { - inner: NonEmptyStringValueParser, -} - -impl BucketsValueParser { - fn new() -> BucketsValueParser { - BucketsValueParser { - inner: NonEmptyStringValueParser::new(), - } - } -} - -impl TypedValueParser for BucketsValueParser { - type Value = Buckets; - - fn parse_ref( - &self, - cmd: &clap::Command, - arg: Option<&clap::Arg>, - value: &std::ffi::OsStr, - ) -> Result { - let input = self.inner.parse_ref(cmd, arg, value)?; - input - .split(',') - .map(|chunk| chunk.trim().parse()) - .collect::, _>>() - .map(Buckets) - .map_err(|err| clap::Error::raw(ErrorKind::ValueValidation, err)) - } -} - #[derive(Derivative, Args, PartialEq, Eq)] #[derivative(Debug)] #[group(required = true)] @@ -367,10 +332,9 @@ struct Options { required = false, num_args = 1, action = ArgAction::Set, - value_parser = BucketsValueParser::new(), help_heading = "VDAF Algorithm and Parameters" )] - buckets: Option, + buckets: Option, #[clap(flatten)] query: QueryOptions, @@ -474,9 +438,8 @@ where .await .map_err(|err| Error::Anyhow(err.into())) } - (VdafType::Histogram, None, None, Some(ref buckets)) => { - let vdaf = - Prio3::new_histogram(2, &buckets.0).map_err(|err| Error::Anyhow(err.into()))?; + (VdafType::Histogram, None, None, Some(buckets)) => { + let vdaf = Prio3::new_histogram(2, buckets).map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into()))