Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt prio 0.15.0 #1916

Merged
merged 4 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ janus_messages = { version = "0.6", path = "messages" }
k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube`
kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] }
opentelemetry = { version = "0.20", features = ["metrics"] }
prio = { version = "0.14.1", features = ["multithreaded"] }
prio = { version = "0.15.0", features = ["multithreaded"] }
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.106"
serde_test = "1.0.175"
Expand Down
24 changes: 17 additions & 7 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ use prio::{
vdaf::{
self,
poplar1::Poplar1,
prg::PrgSha3,
prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded},
xof::XofShake128,
},
};
use reqwest::Client;
Expand Down Expand Up @@ -795,7 +795,12 @@ impl<C: Clock> TaskAggregator<C> {
}

VdafInstance::Prio3CountVec { length } => {
let vdaf = Prio3::new_sum_vec_multithreaded(2, 1, *length)?;
let vdaf = Prio3::new_sum_vec_multithreaded(
2,
1,
*length,
VdafInstance::chunk_size(*length),
)?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3CountVec(Arc::new(vdaf), verify_key)
}
Expand All @@ -807,13 +812,18 @@ impl<C: Clock> TaskAggregator<C> {
}

VdafInstance::Prio3SumVec { bits, length } => {
let vdaf = Prio3::new_sum_vec_multithreaded(2, *bits, *length)?;
let vdaf = Prio3::new_sum_vec_multithreaded(
2,
*bits,
*length,
VdafInstance::chunk_size(*bits * *length),
)?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3SumVec(Arc::new(vdaf), verify_key)
}

VdafInstance::Prio3Histogram { length } => {
let vdaf = Prio3::new_histogram(2, *length)?;
let vdaf = Prio3::new_histogram(2, *length, VdafInstance::chunk_size(*length))?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3Histogram(Arc::new(vdaf), verify_key)
}
Expand Down Expand Up @@ -843,7 +853,7 @@ impl<C: Clock> TaskAggregator<C> {
}

VdafInstance::Poplar1 { bits } => {
let vdaf = Poplar1::new_sha3(*bits);
let vdaf = Poplar1::new_shake128(*bits);
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Poplar1(Arc::new(vdaf), verify_key)
}
Expand Down Expand Up @@ -1038,7 +1048,7 @@ enum VdafOps {
Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI64<U63>>>,
VerifyKey<VERIFY_KEY_LENGTH>,
),
Poplar1(Arc<Poplar1<PrgSha3, 16>>, VerifyKey<VERIFY_KEY_LENGTH>),
Poplar1(Arc<Poplar1<XofShake128, 16>>, VerifyKey<VERIFY_KEY_LENGTH>),

#[cfg(feature = "test-util")]
Fake(Arc<dummy_vdaf::Vdaf>),
Expand Down Expand Up @@ -1125,7 +1135,7 @@ macro_rules! vdaf_ops_dispatch {
crate::aggregator::VdafOps::Poplar1(vdaf, verify_key) => {
let $vdaf = vdaf;
let $verify_key = verify_key;
type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>;
type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::xof::XofShake128, 16>;
const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::VERIFY_KEY_LENGTH;
$body
}
Expand Down
2 changes: 1 addition & 1 deletion aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl VdafOps {

// Compute the next transition.
let prepare_step_res = trace_span!("VDAF preparation")
.in_scope(|| vdaf.prepare_step(prep_state.clone(), prep_msg));
.in_scope(|| vdaf.prepare_next(prep_state.clone(), prep_msg));
match prepare_step_res {
Ok(PrepareTransition::Continue(prep_state, prep_share)) => {
*report_aggregation = report_aggregation
Expand Down
40 changes: 34 additions & 6 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,12 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
}

(task::QueryType::TimeInterval, VdafInstance::Prio3CountVec { length }) => {
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, 1, *length)?);
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(
2,
1,
*length,
VdafInstance::chunk_size(*length),
)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<
VERIFY_KEY_LENGTH,
Prio3SumVecMultithreaded
Expand All @@ -283,13 +288,22 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
}

(task::QueryType::TimeInterval, VdafInstance::Prio3SumVec { bits, length }) => {
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, *bits, *length)?);
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(
2,
*bits,
*length,
VdafInstance::chunk_size(*bits * *length),
)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<VERIFY_KEY_LENGTH, Prio3SumVecMultithreaded>(task, vdaf)
.await
}

(task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { length }) => {
let vdaf = Arc::new(Prio3::new_histogram(2, *length)?);
let vdaf = Arc::new(Prio3::new_histogram(
2,
*length,
VdafInstance::chunk_size(*length),
)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<VERIFY_KEY_LENGTH, Prio3Histogram>(task, vdaf)
.await
}
Expand Down Expand Up @@ -356,7 +370,12 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
},
VdafInstance::Prio3CountVec { length },
) => {
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, 1, *length)?);
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(
2,
1,
*length,
VdafInstance::chunk_size(*length),
)?);
let max_batch_size = *max_batch_size;
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<
Expand Down Expand Up @@ -388,7 +407,12 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
},
VdafInstance::Prio3SumVec { bits, length },
) => {
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(2, *bits, *length)?);
let vdaf = Arc::new(Prio3::new_sum_vec_multithreaded(
2,
*bits,
*length,
VdafInstance::chunk_size(*bits * *length),
)?);
let max_batch_size = *max_batch_size;
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<
Expand All @@ -404,7 +428,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
},
VdafInstance::Prio3Histogram { length },
) => {
let vdaf = Arc::new(Prio3::new_histogram(2, *length)?);
let vdaf = Arc::new(Prio3::new_histogram(
2,
*length,
VdafInstance::chunk_size(*length),
)?);
let max_batch_size = *max_batch_size;
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<
Expand Down
13 changes: 8 additions & 5 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ impl AggregationJobDriver {

// Step our own state.
let prepare_step_res = trace_span!("VDAF preparation")
.in_scope(|| vdaf.prepare_step(prep_state.clone(), prep_msg.clone()));
.in_scope(|| vdaf.prepare_next(prep_state.clone(), prep_msg.clone()));
let leader_transition = match prepare_step_res {
Ok(leader_transition) => leader_transition,
Err(error) => {
Expand Down Expand Up @@ -589,11 +589,14 @@ impl AggregationJobDriver {
A::PrepareShare::get_decoded_with_param(&leader_prep_state, payload)
.context("couldn't decode helper's prepare message");
let prep_msg = helper_prep_share.and_then(|helper_prep_share| {
vdaf.prepare_preprocess([leader_prep_share.clone(), helper_prep_share])
.context(
"couldn't preprocess leader & helper prepare shares into \
vdaf.prepare_shares_to_prepare_message(
aggregation_job.aggregation_parameter(),
[leader_prep_share.clone(), helper_prep_share],
)
.context(
"couldn't preprocess leader & helper prepare shares into \
prepare message",
)
)
});
match prep_msg {
Ok(prep_msg) => {
Expand Down
4 changes: 2 additions & 2 deletions aggregator/src/aggregator/taskprov_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ use prio::{
field::Field64,
flp::types::Count,
vdaf::{
prg::PrgSha3,
prio3::{Prio3, Prio3Count},
xof::XofShake128,
AggregateShare, OutputShare,
},
};
Expand All @@ -66,7 +66,7 @@ use trillium_testing::{
prelude::{post, put},
};

type TestVdaf = Prio3<Count<Field64>, PrgSha3, 16>;
type TestVdaf = Prio3<Count<Field64>, XofShake128, 16>;

pub struct TaskprovTestCase {
_ephemeral_datastore: EphemeralDatastore,
Expand Down
2 changes: 1 addition & 1 deletion collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ mod tests {
async fn successful_collect_prio3_histogram() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let vdaf = Prio3::new_histogram(2, 4).unwrap();
let vdaf = Prio3::new_histogram(2, 4, 2).unwrap();
let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &3);
let collector = setup_collector(&mut server, vdaf);

Expand Down
Loading