Skip to content

Commit

Permalink
Adopt prio 0.14.0
Browse files Browse the repository at this point in the history
The bulk of the changes here deal with the change to the representation
of `Prio3Histogram`. Since `prio` 0.14.x implements VDAF-06, taking this
change will break compatibility with DAP-04.

Part of #1669
  • Loading branch information
tgeoghegan committed Aug 14, 2023
1 parent 8b0f571 commit 7029084
Show file tree
Hide file tree
Showing 18 changed files with 59 additions and 137 deletions.
5 changes: 2 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ janus_messages = { version = "0.5", path = "messages" }
k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube`
kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] }
opentelemetry = { version = "0.19", features = ["metrics"] }
prio = { version = "0.12.2", features = ["multithreaded"] }
prio = { version = "0.14.0", features = ["multithreaded"] }
serde = { version = "1.0.183", features = ["derive"] }
serde_json = "1.0.103"
serde_test = "1.0.175"
Expand Down
2 changes: 1 addition & 1 deletion aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ impl<C: Clock> TaskAggregator<C> {
}

VdafInstance::Prio3Histogram { buckets } => {
let vdaf = Prio3::new_histogram(2, buckets)?;
let vdaf = Prio3::new_histogram(2, *buckets)?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3Histogram(Arc::new(vdaf), verify_key)
}
Expand Down
9 changes: 3 additions & 6 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
}

(task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { buckets }) => {
let vdaf = Arc::new(Prio3::new_histogram(2, buckets)?);
let vdaf = Arc::new(Prio3::new_histogram(2, *buckets)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3Histogram>(task, vdaf)
.await
}
Expand Down Expand Up @@ -406,7 +406,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
},
VdafInstance::Prio3Histogram { buckets },
) => {
let vdaf = Arc::new(Prio3::new_histogram(2, buckets)?);
let vdaf = Arc::new(Prio3::new_histogram(2, *buckets)?);
let max_batch_size = *max_batch_size;
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<
Expand Down Expand Up @@ -660,10 +660,7 @@ mod tests {
};
use janus_core::{
task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH},
test_util::{
dummy_vdaf::{self},
install_test_trace_subscriber,
},
test_util::{dummy_vdaf, install_test_trace_subscriber},
time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt},
};
use janus_messages::{
Expand Down
14 changes: 2 additions & 12 deletions aggregator_core/src/datastore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,8 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) {
(VdafInstance::Prio3CountVec { length: 64 }, Role::Helper),
(VdafInstance::Prio3Sum { bits: 64 }, Role::Helper),
(VdafInstance::Prio3Sum { bits: 32 }, Role::Helper),
(
VdafInstance::Prio3Histogram {
buckets: Vec::from([0, 100, 200, 400]),
},
Role::Leader,
),
(
VdafInstance::Prio3Histogram {
buckets: Vec::from([0, 25, 50, 75, 100]),
},
Role::Leader,
),
(VdafInstance::Prio3Histogram { buckets: 4 }, Role::Leader),
(VdafInstance::Prio3Histogram { buckets: 5 }, Role::Leader),
(VdafInstance::Poplar1 { bits: 8 }, Role::Helper),
(VdafInstance::Poplar1 { bits: 64 }, Role::Helper),
] {
Expand Down
6 changes: 3 additions & 3 deletions collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,8 @@ mod tests {
async fn successful_collect_prio3_histogram() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let vdaf = Prio3::new_histogram(2, &[25, 50, 75, 100]).unwrap();
let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &80);
let vdaf = Prio3::new_histogram(2, 4).unwrap();
let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &3);
let collector = setup_collector(&mut server, vdaf);

let batch_interval = Interval::new(
Expand Down Expand Up @@ -1058,7 +1058,7 @@ mod tests {
),
chrono::Duration::seconds(3600),
),
Vec::from([0, 0, 0, 1, 0])
Vec::from([0, 0, 0, 1])
)
);

Expand Down
52 changes: 21 additions & 31 deletions core/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rand::{distributions::Standard, prelude::Distribution};
use reqwest::Url;
use ring::constant_time;
use serde::{de::Error, Deserialize, Deserializer, Serialize};
use std::{fmt, str};
use std::str;

/// HTTP header where auth tokens are provided in messages between participants.
pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token";
Expand All @@ -31,10 +31,7 @@ pub enum VdafInstance {
/// A vector of `Prio3` sums.
Prio3SumVec { bits: usize, length: usize },
/// A `Prio3` histogram.
Prio3Histogram {
#[derivative(Debug(format_with = "bucket_count"))]
buckets: Vec<u64>,
},
Prio3Histogram { buckets: usize },
/// A `Prio3` 16-bit fixed point vector sum with bounded L2 norm.
#[cfg(feature = "fpvec_bounded_l2")]
Prio3FixedPoint16BitBoundedL2VecSum { length: usize },
Expand All @@ -61,6 +58,22 @@ pub enum VdafInstance {
FakeFailsPrepStep,
}

impl VdafInstance {
/// Returns the expected length of a VDAF verification key for a VDAF of this type.
pub fn verify_key_length(&self) -> usize {
match self {
#[cfg(feature = "test-util")]
VdafInstance::Fake
| VdafInstance::FakeFailsPrepInit
| VdafInstance::FakeFailsPrepStep => 0,

// All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's
// not yet done being specified, so choosing 16 bytes is fine for testing.)
_ => PRIO3_VERIFY_KEY_LENGTH,
}
}
}

impl TryFrom<&taskprov::VdafType> for VdafInstance {
type Error = &'static str;

Expand All @@ -85,22 +98,6 @@ fn bucket_count(buckets: &Vec<u64>, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "[{} buckets]", buckets.len() + 1)
}

impl VdafInstance {
/// Returns the expected length of a VDAF verification key for a VDAF of this type.
pub fn verify_key_length(&self) -> usize {
match self {
#[cfg(feature = "test-util")]
VdafInstance::Fake
| VdafInstance::FakeFailsPrepInit
| VdafInstance::FakeFailsPrepStep => 0,

// All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's
// not yet done being specified, so choosing 16 bytes is fine for testing.)
_ => PRIO3_VERIFY_KEY_LENGTH,
}
}
}

/// Internal implementation details of [`vdaf_dispatch`](crate::vdaf_dispatch).
#[macro_export]
macro_rules! vdaf_dispatch_impl_base {
Expand Down Expand Up @@ -175,7 +172,7 @@ macro_rules! vdaf_dispatch_impl_base {
}

::janus_core::task::VdafInstance::Prio3Histogram { buckets } => {
let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, buckets)?;
let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, *buckets)?;
type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram;
const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH;
$body
Expand Down Expand Up @@ -768,22 +765,15 @@ mod tests {
],
);
assert_tokens(
&VdafInstance::Prio3Histogram {
buckets: Vec::from([0, 100, 200, 400]),
},
&VdafInstance::Prio3Histogram { buckets: 6 },
&[
Token::StructVariant {
name: "VdafInstance",
variant: "Prio3Histogram",
len: 1,
},
Token::Str("buckets"),
Token::Seq { len: Some(4) },
Token::U64(0),
Token::U64(100),
Token::U64(200),
Token::U64(400),
Token::SeqEnd,
Token::U64(6),
Token::StructVariantEnd,
],
);
Expand Down
8 changes: 1 addition & 7 deletions integration_tests/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,9 @@ fn json_encode_vdaf(vdaf: &VdafInstance) -> Value {
"length": format!("{length}"),
}),
VdafInstance::Prio3Histogram { buckets } => {
let buckets = Value::Array(
buckets
.iter()
.map(|value| Value::String(format!("{value}")))
.collect(),
);
json!({
"type": "Prio3Histogram",
"buckets": buckets,
"buckets": format!("{buckets}"),
})
}
_ => panic!("VDAF {vdaf:?} is not yet supported"),
Expand Down
8 changes: 4 additions & 4 deletions integration_tests/src/divviup_api_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum ApiVdaf {
/// Corresponds to Prio3Count
Count,
Histogram {
buckets: Vec<u64>,
buckets: usize,
},
Sum {
bits: usize,
Expand All @@ -35,9 +35,9 @@ impl TryFrom<&VdafInstance> for ApiVdaf {
match vdaf {
VdafInstance::Prio3Count => Ok(ApiVdaf::Count),
VdafInstance::Prio3Sum { bits } => Ok(ApiVdaf::Sum { bits: *bits }),
VdafInstance::Prio3Histogram { buckets } => Ok(ApiVdaf::Histogram {
buckets: buckets.clone(),
}),
VdafInstance::Prio3Histogram { buckets } => {
Ok(ApiVdaf::Histogram { buckets: *buckets })
}
_ => Err(anyhow!("unsupported VDAF: {vdaf:?}")),
}
}
Expand Down
16 changes: 5 additions & 11 deletions integration_tests/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,14 @@ pub async fn submit_measurements_and_verify_aggregate(
.await;
}
VdafInstance::Prio3Histogram { buckets } => {
let vdaf = Prio3::new_histogram(2, buckets).unwrap();
let vdaf = Prio3::new_histogram(2, *buckets).unwrap();

let mut aggregate_result = vec![0; buckets.len() + 1];
aggregate_result.resize(buckets.len() + 1, 0);
let mut aggregate_result = vec![0; *buckets];
aggregate_result.resize(*buckets, 0);
let measurements = iter::repeat_with(|| {
let choice = thread_rng().gen_range(0..=buckets.len());
let choice = thread_rng().gen_range(0..*buckets);
aggregate_result[choice] += 1;
let measurement = if choice == buckets.len() {
// This goes into the counter covering the range that extends to positive infinity.
buckets[buckets.len() - 1] + 1
} else {
buckets[choice]
};
measurement as u128
choice
})
.take(total_measurements)
.collect::<Vec<_>>();
Expand Down
4 changes: 1 addition & 3 deletions integration_tests/tests/divviup_ts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ async fn janus_divviup_ts_histogram() {

run_divviup_ts_integration_test(
&container_client(),
VdafInstance::Prio3Histogram {
buckets: Vec::from([1, 10, 100, 1000]),
},
VdafInstance::Prio3Histogram { buckets: 4 },
)
.await;
}
Expand Down
3 changes: 1 addition & 2 deletions integration_tests/tests/in_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,8 @@ async fn in_cluster_histogram() {
install_test_trace_subscriber();

// Start port forwards and set up task.
let buckets = Vec::from([3, 6, 8]);
let janus_pair = InClusterJanusPair::new(
VdafInstance::Prio3Histogram { buckets },
VdafInstance::Prio3Histogram { buckets: 4 },
QueryType::TimeInterval,
)
.await;
Expand Down
4 changes: 1 addition & 3 deletions integration_tests/tests/janus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,11 @@ async fn janus_janus_sum_16() {
async fn janus_janus_histogram_4_buckets() {
install_test_trace_subscriber();

let buckets = Vec::from([3, 6, 8]);

// Start servers.
let container_client = container_client();
let janus_pair = JanusPair::new(
&container_client,
VdafInstance::Prio3Histogram { buckets },
VdafInstance::Prio3Histogram { buckets: 4 },
QueryType::TimeInterval,
)
.await;
Expand Down
4 changes: 2 additions & 2 deletions interop_binaries/src/bin/janus_interop_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ async fn handle_upload(
handle_upload_generic(http_client, vdaf_client, request, measurement).await?;
}

VdafInstance::Prio3Histogram { ref buckets } => {
let measurement = parse_primitive_measurement::<u128>(request.measurement.clone())?;
VdafInstance::Prio3Histogram { buckets } => {
let measurement = parse_primitive_measurement::<usize>(request.measurement.clone())?;
let vdaf_client = Prio3::new_histogram(2, buckets)
.context("failed to construct Prio3Histogram VDAF")?;
handle_upload_generic(http_client, vdaf_client, request, measurement).await?;
Expand Down
4 changes: 2 additions & 2 deletions interop_binaries/src/bin/janus_interop_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ async fn handle_collection_start(
}

(ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Histogram { buckets }) => {
let vdaf = Prio3::new_histogram(2, &buckets)
let vdaf = Prio3::new_histogram(2, buckets)
.context("failed to construct Prio3Histogram VDAF")?;
handle_collect_generic(
http_client,
Expand Down Expand Up @@ -582,7 +582,7 @@ async fn handle_collection_start(
}

(ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3Histogram { buckets }) => {
let vdaf = Prio3::new_histogram(2, &buckets)
let vdaf = Prio3::new_histogram(2, buckets)
.context("failed to construct Prio3Histogram VDAF")?;
handle_collect_generic(
http_client,
Expand Down
10 changes: 5 additions & 5 deletions interop_binaries/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ pub enum VdafObject {
length: NumberAsString<usize>,
},
Prio3Histogram {
buckets: Vec<NumberAsString<u64>>,
buckets: NumberAsString<usize>,
},
#[cfg(feature = "fpvec_bounded_l2")]
Prio3FixedPoint16BitBoundedL2VecSum {
Expand Down Expand Up @@ -152,7 +152,7 @@ impl From<VdafInstance> for VdafObject {
},

VdafInstance::Prio3Histogram { buckets } => VdafObject::Prio3Histogram {
buckets: buckets.iter().copied().map(NumberAsString).collect(),
buckets: NumberAsString(buckets),
},

#[cfg(feature = "fpvec_bounded_l2")]
Expand Down Expand Up @@ -196,9 +196,9 @@ impl From<VdafObject> for VdafInstance {
length: length.0,
},

VdafObject::Prio3Histogram { buckets } => VdafInstance::Prio3Histogram {
buckets: buckets.iter().map(|value| value.0).collect(),
},
VdafObject::Prio3Histogram { buckets } => {
VdafInstance::Prio3Histogram { buckets: buckets.0 }
}

#[cfg(feature = "fpvec_bounded_l2")]
VdafObject::Prio3FixedPoint16BitBoundedL2VecSum { length } => {
Expand Down
2 changes: 1 addition & 1 deletion messages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ hex = "0.4"
num_enum = "0.7.0"
# We can't pull prio in from the workspace because that would enable default features, and we do not
# want prio/crypto-dependencies
prio = { version = "0.12.2", default-features = false }
prio = { version = "0.14.0", default-features = false }
rand = "0.8"
serde.workspace = true
thiserror.workspace = true
Expand Down
Loading

0 comments on commit 7029084

Please sign in to comment.