diff --git a/Cargo.lock b/Cargo.lock index da22121e4..2f59ae56a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -554,17 +554,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" -[[package]] -name = "cmac" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8543454e3c3f5126effff9cd44d562af4e31fb8ce1cc0d3dcd8f084515dbc1aa" -dependencies = [ - "cipher", - "dbl", - "digest 0.10.7", -] - [[package]] name = "colorchoice" version = "1.0.0" @@ -830,15 +819,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "dbl" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd2735a791158376708f9347fe8faba9667589d82427ef3aed6794a8981de3d9" -dependencies = [ - "generic-array", -] - [[package]] name = "deadpool" version = "0.9.5" @@ -1083,9 +1063,9 @@ dependencies = [ [[package]] name = "fiat-crypto" -version = "0.1.20" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e825f6987101665dea6ec934c09ec6d721de7bc1bf92248e1d5810c8cd636b77" +checksum = "d0870c84016d4b481be5c9f323c24f65e31e901ae618f0e80f4308fb00de1d2d" [[package]] name = "filetime" @@ -2458,6 +2438,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", + "rand", + "serde", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -2496,6 +2489,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.16" @@ -2932,19 +2937,23 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prio" -version = "0.14.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a65c4a557b2fecb8518c105aafadf33a86d7513a3f599bcfe542c17553cc61" +checksum = "fe7591b152d20a8a992f8b3a5daf6bc9e38e7fb347e3694ed9238eddc7e57332" dependencies = [ "aes", - "base64 0.21.4", "bitvec", "byteorder", - "cmac", "ctr", "fiat-crypto", "fixed", "getrandom", + "num-bigint", + "num-integer", + "num-iter", + "num-rational", + "num-traits", + "rand", "rand_core 0.6.4", "rayon", "serde", diff --git a/Cargo.toml b/Cargo.toml index e6e629db2..032738078 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ janus_messages = { version = "0.6", path = "messages" } k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] } opentelemetry = { version = "0.20", features = ["metrics"] } -prio = { version = "0.14.1", features = ["multithreaded"] } +prio = { version = "0.15.0", features = ["multithreaded"] } serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.106" serde_test = "1.0.175" diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index e639697f3..90a64ad92 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -68,8 +68,8 @@ use prio::{ vdaf::{ self, poplar1::Poplar1, - prg::PrgSha3, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded}, + xof::XofShake128, }, }; use reqwest::Client; @@ -795,7 +795,12 @@ impl TaskAggregator { } VdafInstance::Prio3CountVec { length } => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, *length)?; + let vdaf = Prio3::new_sum_vec_multithreaded( + 2, + 1, + *length, + VdafInstance::chunk_size(*length), + )?; let verify_key = task.primary_vdaf_verify_key()?; VdafOps::Prio3CountVec(Arc::new(vdaf), verify_key) } @@ -807,13 +812,18 @@ impl TaskAggregator { } VdafInstance::Prio3SumVec { bits, length } => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, *bits, *length)?; + let vdaf = Prio3::new_sum_vec_multithreaded( + 2, + *bits, + *length, + VdafInstance::chunk_size(*bits * *length), + )?; let verify_key = task.primary_vdaf_verify_key()?; VdafOps::Prio3SumVec(Arc::new(vdaf), verify_key) } VdafInstance::Prio3Histogram { length } => { - let vdaf = Prio3::new_histogram(2, *length)?; + let vdaf = Prio3::new_histogram(2, *length, VdafInstance::chunk_size(*length))?; let verify_key = task.primary_vdaf_verify_key()?; VdafOps::Prio3Histogram(Arc::new(vdaf), verify_key) } @@ -843,7 +853,7 @@ impl TaskAggregator { } VdafInstance::Poplar1 { bits } => { - let vdaf = Poplar1::new_sha3(*bits); + let vdaf = Poplar1::new_shake128(*bits); let verify_key = task.primary_vdaf_verify_key()?; VdafOps::Poplar1(Arc::new(vdaf), verify_key) } @@ -1038,7 +1048,7 @@ enum VdafOps { Arc>>, VerifyKey, ), - Poplar1(Arc>, VerifyKey), + Poplar1(Arc>, VerifyKey), #[cfg(feature = "test-util")] Fake(Arc), @@ -1125,7 +1135,7 @@ macro_rules! vdaf_ops_dispatch { crate::aggregator::VdafOps::Poplar1(vdaf, verify_key) => { let $vdaf = vdaf; let $verify_key = verify_key; - type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::xof::XofShake128, 16>; const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 0ef675912..31bbb066b 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -136,7 +136,7 @@ impl VdafOps { // Compute the next transition. let prepare_step_res = trace_span!("VDAF preparation") - .in_scope(|| vdaf.prepare_step(prep_state.clone(), prep_msg)); + .in_scope(|| vdaf.prepare_next(prep_state.clone(), prep_msg)); match prepare_step_res { Ok(PrepareTransition::Continue(prep_state, prep_share)) => { *report_aggregation = report_aggregation diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index b370c80f6..fe0fbb2db 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -269,7 +269,12 @@ impl AggregationJobCreator { } (task::QueryType::TimeInterval, VdafInstance::Prio3CountVec { length }) => { - let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, 1, *length)?); + let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( + 2, + 1, + *length, + VdafInstance::chunk_size(*length), + )?); self.create_aggregation_jobs_for_time_interval_task_no_param::< VERIFY_KEY_LENGTH, Prio3SumVecMultithreaded @@ -283,13 +288,22 @@ impl AggregationJobCreator { } (task::QueryType::TimeInterval, VdafInstance::Prio3SumVec { bits, length }) => { - let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, *bits, *length)?); + let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( + 2, + *bits, + *length, + VdafInstance::chunk_size(*bits * *length), + )?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { length }) => { - let vdaf = Arc::new(Prio3::new_histogram(2, *length)?); + let vdaf = Arc::new(Prio3::new_histogram( + 2, + *length, + VdafInstance::chunk_size(*length), + )?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -356,7 +370,12 @@ impl AggregationJobCreator { }, VdafInstance::Prio3CountVec { length }, ) => { - let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, 1, *length)?); + let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( + 2, + 1, + *length, + VdafInstance::chunk_size(*length), + )?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< @@ -388,7 +407,12 @@ impl AggregationJobCreator { }, VdafInstance::Prio3SumVec { bits, length }, ) => { - let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, *bits, *length)?); + let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded( + 2, + *bits, + *length, + VdafInstance::chunk_size(*bits * *length), + )?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< @@ -404,7 +428,11 @@ impl AggregationJobCreator { }, VdafInstance::Prio3Histogram { length }, ) => { - let vdaf = Arc::new(Prio3::new_histogram(2, *length)?); + let vdaf = Arc::new(Prio3::new_histogram( + 2, + *length, + VdafInstance::chunk_size(*length), + )?); let max_batch_size = *max_batch_size; let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index e91896239..93bdbc056 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -469,7 +469,7 @@ impl AggregationJobDriver { // Step our own state. let prepare_step_res = trace_span!("VDAF preparation") - .in_scope(|| vdaf.prepare_step(prep_state.clone(), prep_msg.clone())); + .in_scope(|| vdaf.prepare_next(prep_state.clone(), prep_msg.clone())); let leader_transition = match prepare_step_res { Ok(leader_transition) => leader_transition, Err(error) => { @@ -589,11 +589,14 @@ impl AggregationJobDriver { A::PrepareShare::get_decoded_with_param(&leader_prep_state, payload) .context("couldn't decode helper's prepare message"); let prep_msg = helper_prep_share.and_then(|helper_prep_share| { - vdaf.prepare_preprocess([leader_prep_share.clone(), helper_prep_share]) - .context( - "couldn't preprocess leader & helper prepare shares into \ + vdaf.prepare_shares_to_prepare_message( + aggregation_job.aggregation_parameter(), + [leader_prep_share.clone(), helper_prep_share], + ) + .context( + "couldn't preprocess leader & helper prepare shares into \ prepare message", - ) + ) }); match prep_msg { Ok(prep_msg) => { diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 3db8c8b59..511d5db48 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -51,8 +51,8 @@ use prio::{ field::Field64, flp::types::Count, vdaf::{ - prg::PrgSha3, prio3::{Prio3, Prio3Count}, + xof::XofShake128, AggregateShare, OutputShare, }, }; @@ -66,7 +66,7 @@ use trillium_testing::{ prelude::{post, put}, }; -type TestVdaf = Prio3, PrgSha3, 16>; +type TestVdaf = Prio3, XofShake128, 16>; pub struct TaskprovTestCase { _ephemeral_datastore: EphemeralDatastore, diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 1eaa50606..e5ea026ce 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -955,7 +955,7 @@ mod tests { async fn successful_collect_prio3_histogram() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let vdaf = Prio3::new_histogram(2, 4).unwrap(); + let vdaf = Prio3::new_histogram(2, 4, 2).unwrap(); let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &3); let collector = setup_collector(&mut server, vdaf); diff --git a/core/src/task.rs b/core/src/task.rs index 80e280e63..2adaca84f 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -71,6 +71,19 @@ impl VdafInstance { _ => VERIFY_KEY_LENGTH, } } + + /// Returns a suboptimal estimate of the chunk size to use in ParallelSum gadgets. See [VDAF] + /// for discussion of chunk size. + /// + /// # Bugs + /// + /// Janus should allow chunk size to be configured ([#1900][issue]). + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#name-selection-of-parallelsum-ch + /// [issue]: https://github.com/divviup/janus/issues/1900 + pub fn chunk_size(measurement_length: usize) -> usize { + (measurement_length as f64).sqrt().floor() as usize + } } impl TryFrom<&taskprov::VdafType> for VdafInstance { @@ -154,7 +167,12 @@ macro_rules! vdaf_dispatch_impl_base { ::janus_core::task::VdafInstance::Prio3CountVec { length } => { // Prio3CountVec is implemented as a 1-bit sum vec - let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, 1, *length)?; + let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded( + 2, + 1, + *length, + janus_core::task::VdafInstance::chunk_size(*length), + )?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body @@ -168,23 +186,31 @@ macro_rules! vdaf_dispatch_impl_base { } ::janus_core::task::VdafInstance::Prio3SumVec { bits, length } => { - let $vdaf = - ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, *bits, *length)?; + let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded( + 2, + *bits, + *length, + janus_core::task::VdafInstance::chunk_size(*bits * *length), + )?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Prio3Histogram { length } => { - let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, *length)?; + let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram( + 2, + *length, + janus_core::task::VdafInstance::chunk_size(*length), + )?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } ::janus_core::task::VdafInstance::Poplar1 { bits } => { - let $vdaf = ::prio::vdaf::poplar1::Poplar1::new_sha3(*bits); - type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + let $vdaf = ::prio::vdaf::poplar1::Poplar1::new_shake128(*bits); + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::xof::XofShake128, 16>; const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LENGTH; $body } diff --git a/core/src/test_util/dummy_vdaf.rs b/core/src/test_util/dummy_vdaf.rs index 7d1583112..dd569fb02 100644 --- a/core/src/test_util/dummy_vdaf.rs +++ b/core/src/test_util/dummy_vdaf.rs @@ -111,14 +111,15 @@ impl vdaf::Aggregator<0, 16> for Vdaf { Ok((PrepareState(input_share.0), ())) } - fn prepare_preprocess>( + fn prepare_shares_to_prepare_message>( &self, + _: &Self::AggregationParam, _: M, ) -> Result { Ok(()) } - fn prepare_step( + fn prepare_next( &self, _: Self::PrepareState, _: Self::PrepareMessage, diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 9f5dbd517..44343761f 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -116,7 +116,9 @@ pub fn run_vdaf + vda aggregate_shares: agg_shares, }; } - let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); + let prep_msg = vdaf + .prepare_shares_to_prepare_message(aggregation_param, prep_shares) + .unwrap(); prep_msgs.push(prep_msg.clone()); // Compute each participant's next transition. @@ -126,7 +128,7 @@ pub fn run_vdaf + vda PrepareTransition::::Continue(prep_state, _) => prep_state ) .clone(); - pts.push(vdaf.prepare_step(prep_state, prep_msg.clone()).unwrap()); + pts.push(vdaf.prepare_next(prep_state, prep_msg.clone()).unwrap()); } } } diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index a218fda4b..deced10af 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -280,7 +280,13 @@ pub async fn submit_measurements_and_verify_aggregate( .await; } VdafInstance::Prio3SumVec { bits, length } => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, *bits, *length).unwrap(); + let vdaf = Prio3::new_sum_vec_multithreaded( + 2, + *bits, + *length, + VdafInstance::chunk_size(*bits * *length), + ) + .unwrap(); let measurements = iter::repeat_with(|| { iter::repeat_with(|| (random::()) >> (128 - bits)) @@ -319,7 +325,7 @@ pub async fn submit_measurements_and_verify_aggregate( .await; } VdafInstance::Prio3Histogram { length } => { - let vdaf = Prio3::new_histogram(2, *length).unwrap(); + let vdaf = Prio3::new_histogram(2, *length, VdafInstance::chunk_size(*length)).unwrap(); let mut aggregate_result = vec![0; *length]; let measurements = iter::repeat_with(|| { @@ -350,7 +356,9 @@ pub async fn submit_measurements_and_verify_aggregate( .await; } VdafInstance::Prio3CountVec { length } => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, *length).unwrap(); + let vdaf = + Prio3::new_sum_vec_multithreaded(2, 1, *length, VdafInstance::chunk_size(*length)) + .unwrap(); let measurements = iter::repeat_with(|| { iter::repeat_with(|| random::() as u128) diff --git a/integration_tests/tests/daphne.rs b/integration_tests/tests/daphne.rs index fa53888a2..8d949e411 100644 --- a/integration_tests/tests/daphne.rs +++ b/integration_tests/tests/daphne.rs @@ -49,7 +49,7 @@ async fn daphne_janus() { // This test places Janus in the leader role & Daphne in the helper role. #[tokio::test(flavor = "multi_thread")] -#[ignore = "Daphne does not currently support DAP-05 (issue #1669)"] +#[ignore = "Daphne does not currently support DAP-06 (issue #1669)"] async fn janus_daphne() { install_test_trace_subscriber(); diff --git a/integration_tests/tests/divviup_ts.rs b/integration_tests/tests/divviup_ts.rs index c16cfddf7..674525482 100644 --- a/integration_tests/tests/divviup_ts.rs +++ b/integration_tests/tests/divviup_ts.rs @@ -36,7 +36,7 @@ async fn run_divviup_ts_integration_test(container_client: &Cli, vdaf: VdafInsta } #[tokio::test(flavor = "multi_thread")] -#[ignore = "divviup-ts does not currently support DAP-05 (issue #1669)"] +#[ignore = "divviup-ts does not currently support DAP-06 (issue #1669)"] async fn janus_divviup_ts_count() { install_test_trace_subscriber(); @@ -44,7 +44,7 @@ async fn janus_divviup_ts_count() { } #[tokio::test(flavor = "multi_thread")] -#[ignore = "divviup-ts does not currently support DAP-05 (issue #1669)"] +#[ignore = "divviup-ts does not currently support DAP-06 (issue #1669)"] async fn janus_divviup_ts_sum() { install_test_trace_subscriber(); @@ -52,7 +52,7 @@ async fn janus_divviup_ts_sum() { } #[tokio::test(flavor = "multi_thread")] -#[ignore = "divviup-ts does not currently support DAP-05 (issue #1669)"] +#[ignore = "divviup-ts does not currently support DAP-06 (issue #1669)"] async fn janus_divviup_ts_histogram() { install_test_trace_subscriber(); diff --git a/integration_tests/tests/in_cluster.rs b/integration_tests/tests/in_cluster.rs index 2faf7b8db..5d49b1739 100644 --- a/integration_tests/tests/in_cluster.rs +++ b/integration_tests/tests/in_cluster.rs @@ -252,7 +252,7 @@ async fn in_cluster_sum() { } #[tokio::test(flavor = "multi_thread")] -#[ignore = "divviup-api does not currently support DAP-05 (https://github.com/divviup/divviup-api/issues/410)"] +#[ignore = "divviup-api does not currently support DAP-06 (https://github.com/divviup/divviup-api/issues/410)"] async fn in_cluster_histogram() { install_test_trace_subscriber(); diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index c8b1dbe2d..96ad8c5c3 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -146,8 +146,9 @@ async fn handle_upload( VdafInstance::Prio3CountVec { length } => { let measurement = parse_vector_measurement::(request.measurement.clone())?; - let vdaf_client = Prio3::new_sum_vec_multithreaded(2, 1, length) - .context("failed to construct Prio3CountVec VDAF")?; + let vdaf_client = + Prio3::new_sum_vec_multithreaded(2, 1, length, VdafInstance::chunk_size(length)) + .context("failed to construct Prio3CountVec VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } @@ -160,14 +161,19 @@ async fn handle_upload( VdafInstance::Prio3SumVec { bits, length } => { let measurement = parse_vector_measurement::(request.measurement.clone())?; - let vdaf_client = Prio3::new_sum_vec_multithreaded(2, bits, length) - .context("failed to construct Prio3SumVec VDAF")?; + let vdaf_client = Prio3::new_sum_vec_multithreaded( + 2, + bits, + length, + VdafInstance::chunk_size(bits * length), + ) + .context("failed to construct Prio3SumVec VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } VdafInstance::Prio3Histogram { length } => { let measurement = parse_primitive_measurement::(request.measurement.clone())?; - let vdaf_client = Prio3::new_histogram(2, length) + let vdaf_client = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) .context("failed to construct Prio3Histogram VDAF")?; handle_upload_generic(http_client, vdaf_client, request, measurement).await?; } diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs index 616f8382b..436c7bab1 100644 --- a/interop_binaries/src/bin/janus_interop_collector.rs +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -312,8 +312,9 @@ async fn handle_collection_start( } (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3CountVec { length }) => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, length) - .context("failed to construct Prio3CountVec VDAF")?; + let vdaf = + Prio3::new_sum_vec_multithreaded(2, 1, length, VdafInstance::chunk_size(length)) + .context("failed to construct Prio3CountVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -344,8 +345,13 @@ async fn handle_collection_start( } (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3SumVec { bits, length }) => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, bits, length) - .context("failed to construct Prio3SumVec VDAF")?; + let vdaf = Prio3::new_sum_vec_multithreaded( + 2, + bits, + length, + VdafInstance::chunk_size(bits * length), + ) + .context("failed to construct Prio3SumVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -362,7 +368,7 @@ async fn handle_collection_start( } (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Histogram { length }) => { - let vdaf = Prio3::new_histogram(2, length) + let vdaf = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, @@ -463,8 +469,9 @@ async fn handle_collection_start( } (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3CountVec { length }) => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, length) - .context("failed to construct Prio3CountVec VDAF")?; + let vdaf = + Prio3::new_sum_vec_multithreaded(2, 1, length, VdafInstance::chunk_size(length)) + .context("failed to construct Prio3CountVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -564,8 +571,13 @@ async fn handle_collection_start( } (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3SumVec { bits, length }) => { - let vdaf = Prio3::new_sum_vec_multithreaded(2, bits, length) - .context("failed to construct Prio3SumVec VDAF")?; + let vdaf = Prio3::new_sum_vec_multithreaded( + 2, + bits, + length, + VdafInstance::chunk_size(bits * length), + ) + .context("failed to construct Prio3SumVec VDAF")?; handle_collect_generic( http_client, collector_params, @@ -582,7 +594,7 @@ async fn handle_collection_start( } (ParsedQuery::FixedSize(fixed_size_query), VdafInstance::Prio3Histogram { length }) => { - let vdaf = Prio3::new_histogram(2, length) + let vdaf = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) .context("failed to construct Prio3Histogram VDAF")?; handle_collect_generic( http_client, diff --git a/messages/Cargo.toml b/messages/Cargo.toml index c3b30f9bb..2ed14114b 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -20,7 +20,7 @@ hex = "0.4" num_enum = "0.7.0" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.14.1", features = ["multithreaded"] } +prio = { version = "0.15.0", default-features = false, features = ["multithreaded"] } rand = "0.8" serde.workspace = true thiserror.workspace = true diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index 72f9dbd85..739738d45 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -11,7 +11,10 @@ use fixed::types::extra::{U15, U31, U63}; #[cfg(feature = "fpvec_bounded_l2")] use fixed::{FixedI16, FixedI32, FixedI64}; use janus_collector::{default_http_client, AuthenticationToken, Collector, CollectorParameters}; -use janus_core::hpke::{DivviUpHpkeConfig, HpkeKeypair, HpkePrivateKey}; +use janus_core::{ + hpke::{DivviUpHpkeConfig, HpkeKeypair, HpkePrivateKey}, + task::VdafInstance, +}; use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, BatchId, Duration, FixedSizeQuery, HpkeConfig, Interval, PartialBatchSelector, Query, TaskId, @@ -451,7 +454,8 @@ where .map_err(|err| Error::Anyhow(err.into())) } (VdafType::CountVec, Some(length), None) => { - let vdaf = Prio3::new_sum_vec(2, 1, length).map_err(|err| Error::Anyhow(err.into()))?; + let vdaf = Prio3::new_sum_vec(2, 1, length, VdafInstance::chunk_size(length)) + .map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) @@ -463,14 +467,15 @@ where .map_err(|err| Error::Anyhow(err.into())) } (VdafType::SumVec, Some(length), Some(bits)) => { - let vdaf = - Prio3::new_sum_vec(2, bits, length).map_err(|err| Error::Anyhow(err.into()))?; + let vdaf = Prio3::new_sum_vec(2, bits, length, VdafInstance::chunk_size(bits * length)) + .map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into())) } (VdafType::Histogram, Some(length), None) => { - let vdaf = Prio3::new_histogram(2, length).map_err(|err| Error::Anyhow(err.into()))?; + let vdaf = Prio3::new_histogram(2, length, VdafInstance::chunk_size(length)) + .map_err(|err| Error::Anyhow(err.into()))?; run_collection_generic(parameters, vdaf, http_client, query, &()) .await .map_err(|err| Error::Anyhow(err.into()))