From 06667dbeb93d4870bb257b7ddf9a19b37a7e5da5 Mon Sep 17 00:00:00 2001 From: David Cook Date: Fri, 22 Sep 2023 15:26:03 -0500 Subject: [PATCH] Handle chunk length VDAF parameters (#517) * Add tests for JSON representation of VDAF enum * Add handling of chunk_length VDAF parameters * Update OpenAPI schema * Propagate changes to client and CLI --- Cargo.lock | 20 ++- Cargo.toml | 1 + cli/src/tasks.rs | 26 ++-- client/src/task.rs | 26 +++- documentation/openapi.yml | 11 +- src/clients/aggregator_client/api_types.rs | 105 ++++++++++---- src/entity/task/new_task.rs | 11 +- src/entity/task/vdaf.rs | 155 ++++++++++++++++++++- src/entity/task/vdaf/tests/serde.rs | 89 ++++++++++++ src/routes/tasks.rs | 4 +- tests/new_task.rs | 18 +-- tests/vdaf.rs | 46 ++++-- 12 files changed, 442 insertions(+), 70 deletions(-) create mode 100644 src/entity/task/vdaf/tests/serde.rs diff --git a/Cargo.lock b/Cargo.lock index ed13f253..8722dd52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1113,6 +1113,7 @@ dependencies = [ "oauth2", "opentelemetry", "opentelemetry-prometheus", + "prio 0.15.2", "querystrong", "rand", "regex", @@ -1900,7 +1901,7 @@ dependencies = [ "derivative", "hex", "num_enum", - "prio", + "prio 0.12.2", "rand", "serde", "thiserror", @@ -2549,6 +2550,23 @@ dependencies = [ "thiserror", ] +[[package]] +name = "prio" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06023d4cf59f8c136ac9a11affa42d25f1ecf251a5065585be0b9d7a07c01217" +dependencies = [ + "aes", + "byteorder", + "ctr", + "getrandom", + "rand_core 0.6.4", + "serde", + "sha3", + "subtle", + "thiserror", +] + [[package]] name = "proc-macro-crate" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 7400422e..707a63cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ opentelemetry = { version = "0.19.0", features = ["metrics"] } opentelemetry-prometheus = { version = "0.12.0", features = [ "prometheus-encoding", ] } +prio = "0.15.2" querystrong = "0.3.0" rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } diff --git a/cli/src/tasks.rs b/cli/src/tasks.rs index 088d189b..dc87a324 100644 --- a/cli/src/tasks.rs +++ b/cli/src/tasks.rs @@ -43,6 +43,8 @@ pub enum TaskAction { length: Option, #[arg(long, required_if_eq_any([("vdaf", "sum"), ("vdaf", "sum_vec")]))] bits: Option, + #[arg(long)] + chunk_length: Option, }, /// rename a task @@ -70,26 +72,32 @@ impl TaskAction { vdaf, min_batch_size, max_batch_size, + time_precision, hpke_config_id, categorical_buckets, continuous_buckets, length, bits, - time_precision, + chunk_length, } => { let vdaf = match vdaf { VdafName::Count => Vdaf::Count, VdafName::Histogram => { match (length, categorical_buckets, continuous_buckets) { - (Some(length), None, None) => { - Vdaf::Histogram(Histogram::Length { length }) - } + (Some(length), None, None) => Vdaf::Histogram(Histogram::Length { + length, + chunk_length, + }), (None, Some(buckets), None) => { - Vdaf::Histogram(Histogram::Categorical { buckets }) - } - (None, None, Some(buckets)) => { - Vdaf::Histogram(Histogram::Continuous { buckets }) + Vdaf::Histogram(Histogram::Categorical { + buckets, + chunk_length, + }) } + (None, None, Some(buckets)) => Vdaf::Histogram(Histogram::Continuous { + buckets, + chunk_length, + }), (None, None, None) => { return Err(Error::Other("continuous-buckets, categorical-buckets, or length are required for histogram vdaf".into())); } @@ -103,10 +111,12 @@ impl TaskAction { }, VdafName::CountVec => Vdaf::CountVec { length: length.unwrap(), + chunk_length, }, VdafName::SumVec => Vdaf::SumVec { bits: bits.unwrap(), length: length.unwrap(), + chunk_length, }, }; diff --git a/client/src/task.rs b/client/src/task.rs index 53755d46..4ebcf9ab 100644 --- a/client/src/task.rs +++ b/client/src/task.rs @@ -49,16 +49,32 @@ pub enum Vdaf { Sum { bits: u8 }, #[serde(rename = "count_vec")] - CountVec { length: u64 }, + CountVec { + length: u64, + chunk_length: Option, + }, #[serde(rename = "sum_vec")] - SumVec { bits: u8, length: u64 }, + SumVec { + bits: u8, + length: u64, + chunk_length: Option, + }, } #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] #[serde(untagged)] pub enum Histogram { - Categorical { buckets: Vec }, - Continuous { buckets: Vec }, - Length { length: u64 }, + Categorical { + buckets: Vec, + chunk_length: Option, + }, + Continuous { + buckets: Vec, + chunk_length: Option, + }, + Length { + length: u64, + chunk_length: Option, + }, } diff --git a/documentation/openapi.yml b/documentation/openapi.yml index 005fa379..c526a5c8 100644 --- a/documentation/openapi.yml +++ b/documentation/openapi.yml @@ -752,8 +752,15 @@ components: bits: type: number buckets: - type: array - item: number + oneOf: + - type: array + items: + type: number + - type: array + items: + type: string + chunk_length: + type: number required: [type] ApiToken: type: object diff --git a/src/clients/aggregator_client/api_types.rs b/src/clients/aggregator_client/api_types.rs index 92d8fc63..5f73c5b4 100644 --- a/src/clients/aggregator_client/api_types.rs +++ b/src/clients/aggregator_client/api_types.rs @@ -20,10 +20,19 @@ pub use janus_messages::{ #[non_exhaustive] pub enum AggregatorVdaf { Prio3Count, - Prio3Sum { bits: u8 }, + Prio3Sum { + bits: u8, + }, Prio3Histogram(HistogramType), - Prio3CountVec { length: u64 }, - Prio3SumVec { bits: u8, length: u64 }, + Prio3CountVec { + length: u64, + chunk_length: Option, + }, + Prio3SumVec { + bits: u8, + length: u64, + chunk_length: Option, + }, } impl PartialEq for AggregatorVdaf { @@ -38,29 +47,50 @@ impl PartialEq for Vdaf { (Vdaf::Count, AggregatorVdaf::Prio3Count) => true, ( Vdaf::Histogram(histogram), - AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { length }), - ) => histogram.length() == *length, + AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { + length, + chunk_length, + }), + ) => histogram.length() == *length && histogram.chunk_length() == *chunk_length, ( - Vdaf::Histogram(Histogram::Continuous(ContinuousBuckets { buckets: Some(lhs) })), - AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets: rhs }), - ) => lhs == rhs, + Vdaf::Histogram(Histogram::Continuous(ContinuousBuckets { + buckets: Some(lhs_buckets), + chunk_length: lhs_chunk_length, + })), + AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { + buckets: rhs_buckets, + chunk_length: rhs_chunk_length, + }), + ) => lhs_buckets == rhs_buckets && lhs_chunk_length == rhs_chunk_length, (Vdaf::Sum(Sum { bits: Some(lhs) }), AggregatorVdaf::Prio3Sum { bits: rhs }) => { lhs == rhs } ( - Vdaf::CountVec(CountVec { length: Some(lhs) }), - AggregatorVdaf::Prio3CountVec { length: rhs }, - ) => lhs == rhs, + Vdaf::CountVec(CountVec { + length: Some(lhs_length), + chunk_length: lhs_chunk_length, + }), + AggregatorVdaf::Prio3CountVec { + length: rhs_length, + chunk_length: rhs_chunk_length, + }, + ) => lhs_length == rhs_length && lhs_chunk_length == rhs_chunk_length, ( Vdaf::SumVec(SumVec { bits: Some(lhs_bits), length: Some(lhs_length), + chunk_length: lhs_chunk_length, }), AggregatorVdaf::Prio3SumVec { bits: rhs_bits, length: rhs_length, + chunk_length: rhs_chunk_length, }, - ) => lhs_bits == rhs_bits && lhs_length == rhs_length, + ) => { + lhs_bits == rhs_bits + && lhs_length == rhs_length + && lhs_chunk_length == rhs_chunk_length + } _ => false, } } @@ -69,8 +99,14 @@ impl PartialEq for Vdaf { #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[serde(untagged)] pub enum HistogramType { - Opaque { length: u64 }, - Buckets { buckets: Vec }, + Opaque { + length: u64, + chunk_length: Option, + }, + Buckets { + buckets: Vec, + chunk_length: Option, + }, } impl From for Vdaf { @@ -78,20 +114,35 @@ impl From for Vdaf { match value { AggregatorVdaf::Prio3Count => Self::Count, AggregatorVdaf::Prio3Sum { bits } => Self::Sum(Sum { bits: Some(bits) }), - AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets }) => { - Self::Histogram(Histogram::Continuous(ContinuousBuckets { - buckets: Some(buckets), - })) - } - AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { length }) => { - Self::Histogram(Histogram::Opaque(BucketLength { length })) - } - AggregatorVdaf::Prio3CountVec { length } => Self::CountVec(CountVec { + AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { + buckets, + chunk_length, + }) => Self::Histogram(Histogram::Continuous(ContinuousBuckets { + buckets: Some(buckets), + chunk_length, + })), + AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { + length, + chunk_length, + }) => Self::Histogram(Histogram::Opaque(BucketLength { + length, + chunk_length, + })), + AggregatorVdaf::Prio3CountVec { + length, + chunk_length, + } => Self::CountVec(CountVec { length: Some(length), + chunk_length, }), - AggregatorVdaf::Prio3SumVec { bits, length } => Self::SumVec(SumVec { + AggregatorVdaf::Prio3SumVec { + bits, + length, + chunk_length, + } => Self::SumVec(SumVec { length: Some(length), bits: Some(bits), + chunk_length, }), } } @@ -283,7 +334,8 @@ mod test { }, "vdaf": { "Prio3CountVec": { - "length": 5 + "length": 5, + "chunk_length": null } }, "role": "Leader", @@ -320,7 +372,8 @@ mod test { }, "vdaf": { "Prio3CountVec": { - "length": 5 + "length": 5, + "chunk_length": null } }, "role": "Leader", diff --git a/src/entity/task/new_task.rs b/src/entity/task/new_task.rs index 1e90e6b4..56ef247e 100644 --- a/src/entity/task/new_task.rs +++ b/src/entity/task/new_task.rs @@ -229,6 +229,12 @@ impl NewTask { Some(aggregator_vdaf) } + fn populate_chunk_length(&mut self, protocol: &Protocol) { + if let Some(vdaf) = &mut self.vdaf { + vdaf.populate_chunk_length(protocol); + } + } + fn validate_query_type_is_supported( &self, leader: &Aggregator, @@ -241,8 +247,8 @@ impl NewTask { } } - pub async fn validate( - &self, + pub async fn normalize_and_validate( + &mut self, account: Account, db: &impl ConnectionTrait, ) -> Result { @@ -253,6 +259,7 @@ impl NewTask { let aggregator_vdaf = if let Some((leader, helper, protocol)) = aggregators.as_ref() { self.validate_query_type_is_supported(leader, helper, &mut errors); + self.populate_chunk_length(protocol); self.validate_vdaf_is_supported(leader, helper, protocol, &mut errors) } else { None diff --git a/src/entity/task/vdaf.rs b/src/entity/task/vdaf.rs index b5afad72..36aa8e02 100644 --- a/src/entity/task/vdaf.rs +++ b/src/entity/task/vdaf.rs @@ -2,6 +2,7 @@ use crate::{ clients::aggregator_client::api_types::{AggregatorVdaf, HistogramType}, entity::{aggregator::VdafName, Protocol}, }; +use prio::vdaf::prio3::optimal_chunk_length; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, hash::Hash}; use validator::{Validate, ValidationError, ValidationErrors}; @@ -19,35 +20,64 @@ impl Histogram { match self { Histogram::Categorical(CategoricalBuckets { buckets: Some(buckets), + .. }) => buckets.len() as u64, Histogram::Continuous(ContinuousBuckets { buckets: Some(buckets), + .. }) => buckets.len() as u64 + 1, - Histogram::Opaque(BucketLength { length }) => *length, + Histogram::Opaque(BucketLength { length, .. }) => *length, _ => 0, } } + pub fn chunk_length(&self) -> Option { + match self { + Histogram::Categorical(CategoricalBuckets { chunk_length, .. }) + | Histogram::Continuous(ContinuousBuckets { chunk_length, .. }) + | Histogram::Opaque(BucketLength { chunk_length, .. }) => *chunk_length, + } + } + fn representation_for_protocol( &self, protocol: &Protocol, ) -> Result { match (protocol, self) { (Protocol::Dap07, histogram) => { - Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { - length: histogram.length(), - })) + if let Some(chunk_length) = histogram.chunk_length() { + Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { + length: histogram.length(), + chunk_length: Some(chunk_length), + })) + } else { + panic!("chunk_length was not populated"); + } } ( Protocol::Dap04, Self::Continuous(ContinuousBuckets { buckets: Some(buckets), + chunk_length: None, }), ) => Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets: buckets.clone(), + chunk_length: None, })), + ( + Protocol::Dap04, + Self::Continuous(ContinuousBuckets { + buckets: _, + chunk_length: Some(_), + }), + ) => { + let mut errors = ValidationErrors::new(); + errors.add("chunk_length", ValidationError::new("not-allowed")); + Err(errors) + } + (Protocol::Dap04, Self::Categorical(_)) => { let mut errors = ValidationErrors::new(); errors.add("buckets", ValidationError::new("must-be-numbers")); @@ -67,18 +97,27 @@ impl Histogram { pub struct ContinuousBuckets { #[validate(required, length(min = 1), custom = "increasing", custom = "unique")] pub buckets: Option>, + + #[validate(range(min = 1))] + pub chunk_length: Option, } #[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] pub struct CategoricalBuckets { #[validate(required, length(min = 1), custom = "unique")] pub buckets: Option>, + + #[validate(range(min = 1))] + pub chunk_length: Option, } #[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq, Copy)] pub struct BucketLength { #[validate(range(min = 1))] pub length: u64, + + #[validate(range(min = 1))] + pub chunk_length: Option, } fn unique(buckets: &[T]) -> Result<(), ValidationError> { @@ -114,6 +153,9 @@ pub struct Sum { pub struct CountVec { #[validate(required)] pub length: Option, + + #[validate(range(min = 1))] + pub chunk_length: Option, } #[derive(Serialize, Deserialize, Validate, Debug, Clone, Copy, Eq, PartialEq)] @@ -123,6 +165,9 @@ pub struct SumVec { #[validate(required)] pub length: Option, + + #[validate(range(min = 1))] + pub chunk_length: Option, } #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] @@ -170,16 +215,105 @@ impl Vdaf { Self::SumVec(SumVec { length: Some(length), bits: Some(bits), + chunk_length, }) => Ok(AggregatorVdaf::Prio3SumVec { bits: *bits, length: *length, + chunk_length: *chunk_length, }), Self::CountVec(CountVec { length: Some(length), - }) => Ok(AggregatorVdaf::Prio3CountVec { length: *length }), + chunk_length, + }) => Ok(AggregatorVdaf::Prio3CountVec { + length: *length, + chunk_length: *chunk_length, + }), _ => Err(ValidationErrors::new()), } } + + pub fn populate_chunk_length(&mut self, protocol: &Protocol) { + match (self, protocol) { + // Chunk length was already populated, don't change it. + ( + Self::Histogram(Histogram::Continuous(ContinuousBuckets { + chunk_length: Some(_), + .. + })), + _, + ) + | ( + Self::Histogram(Histogram::Opaque(BucketLength { + chunk_length: Some(_), + .. + })), + _, + ) + | ( + Self::Histogram(Histogram::Categorical(CategoricalBuckets { + chunk_length: Some(_), + .. + })), + _, + ) + | ( + Self::CountVec(CountVec { + chunk_length: Some(_), + .. + }), + _, + ) + | ( + Self::SumVec(SumVec { + chunk_length: Some(_), + .. + }), + _, + ) => {} + + // Select a chunk length if the protocol version needs it and it isn't set yet. + (Self::Histogram(histogram), Protocol::Dap07) => { + let length = histogram.length(); + match histogram { + Histogram::Opaque(BucketLength { chunk_length, .. }) + | Histogram::Categorical(CategoricalBuckets { chunk_length, .. }) + | Histogram::Continuous(ContinuousBuckets { chunk_length, .. }) => { + *chunk_length = Some(optimal_chunk_length(length as usize) as u64) + } + } + } + + ( + Self::CountVec(CountVec { + length: Some(length), + chunk_length: chunk_length @ None, + }), + Protocol::Dap07, + ) => *chunk_length = Some(optimal_chunk_length(*length as usize) as u64), + + ( + Self::SumVec(SumVec { + bits: Some(bits), + length: Some(length), + chunk_length: chunk_length @ None, + }), + Protocol::Dap07, + ) => { + *chunk_length = Some(optimal_chunk_length(*bits as usize * *length as usize) as u64) + } + + // Invalid, missing parameters, do nothing. + (Self::CountVec(CountVec { length: None, .. }), Protocol::Dap07) + | (Self::SumVec(SumVec { bits: None, .. }), Protocol::Dap07) + | (Self::SumVec(SumVec { length: None, .. }), Protocol::Dap07) => {} + + // Chunk length is not applicable, either due to VDAF choice or protocol version. + (Self::Count, _) + | (Self::Sum { .. }, _) + | (Self::Unrecognized, _) + | (_, Protocol::Dap04) => {} + } + } } impl Validate for Vdaf { @@ -206,10 +340,13 @@ mod tests { use super::*; use crate::test::assert_errors; + mod serde; + #[test] fn validate_continuous_histogram() { assert!(ContinuousBuckets { - buckets: Some(vec![0, 1, 2]) + buckets: Some(vec![0, 1, 2]), + chunk_length: None, } .validate() .is_ok()); @@ -217,6 +354,7 @@ mod tests { assert_errors( ContinuousBuckets { buckets: Some(vec![0, 2, 1]), + chunk_length: None, }, "buckets", &["sorted"], @@ -225,6 +363,7 @@ mod tests { assert_errors( ContinuousBuckets { buckets: Some(vec![0, 0, 2]), + chunk_length: None, }, "buckets", &["unique"], @@ -234,7 +373,8 @@ mod tests { #[test] fn validate_categorical_histogram() { assert!(CategoricalBuckets { - buckets: Some(vec!["a".into(), "b".into()]) + buckets: Some(vec!["a".into(), "b".into()]), + chunk_length: None, } .validate() .is_ok()); @@ -242,6 +382,7 @@ mod tests { assert_errors( CategoricalBuckets { buckets: Some(vec!["a".into(), "a".into()]), + chunk_length: None, }, "buckets", &["unique"], diff --git a/src/entity/task/vdaf/tests/serde.rs b/src/entity/task/vdaf/tests/serde.rs new file mode 100644 index 00000000..e75024e1 --- /dev/null +++ b/src/entity/task/vdaf/tests/serde.rs @@ -0,0 +1,89 @@ +use crate::entity::task::vdaf::{ + BucketLength, CategoricalBuckets, ContinuousBuckets, CountVec, Histogram, Sum, SumVec, Vdaf, +}; + +#[test] +fn json_vdaf() { + for (serialized, vdaf) in [ + (r#"{"type":"count"}"#, Vdaf::Count), + ( + r#"{"type":"histogram","buckets":["A","B"]}"#, + Vdaf::Histogram(Histogram::Categorical(CategoricalBuckets { + buckets: Some(Vec::from(["A".to_owned(), "B".to_owned()])), + chunk_length: None, + })), + ), + ( + r#"{"type":"histogram","buckets":["A","B"],"chunk_length":2}"#, + Vdaf::Histogram(Histogram::Categorical(CategoricalBuckets { + buckets: Some(Vec::from(["A".to_owned(), "B".to_owned()])), + chunk_length: Some(2), + })), + ), + ( + r#"{"type":"histogram","buckets":[1,10,100]}"#, + Vdaf::Histogram(Histogram::Continuous(ContinuousBuckets { + buckets: Some(Vec::from([1, 10, 100])), + chunk_length: None, + })), + ), + ( + r#"{"type":"histogram","buckets":[1,10,100],"chunk_length":2}"#, + Vdaf::Histogram(Histogram::Continuous(ContinuousBuckets { + buckets: Some(Vec::from([1, 10, 100])), + chunk_length: Some(2), + })), + ), + ( + r#"{"type":"histogram","length":5}"#, + Vdaf::Histogram(Histogram::Opaque(BucketLength { + length: 5, + chunk_length: None, + })), + ), + ( + r#"{"type":"histogram","length":5,"chunk_length":2}"#, + Vdaf::Histogram(Histogram::Opaque(BucketLength { + length: 5, + chunk_length: Some(2), + })), + ), + ( + r#"{"type":"sum","bits":8}"#, + Vdaf::Sum(Sum { bits: Some(8) }), + ), + ( + r#"{"type":"count_vec","length":5}"#, + Vdaf::CountVec(CountVec { + length: Some(5), + chunk_length: None, + }), + ), + ( + r#"{"type":"count_vec","length":5,"chunk_length":2}"#, + Vdaf::CountVec(CountVec { + length: Some(5), + chunk_length: Some(2), + }), + ), + ( + r#"{"type":"sum_vec","bits":8,"length":10}"#, + Vdaf::SumVec(SumVec { + bits: Some(8), + length: Some(10), + chunk_length: None, + }), + ), + ( + r#"{"type":"sum_vec","bits":8,"length":10,"chunk_length":12}"#, + Vdaf::SumVec(SumVec { + bits: Some(8), + length: Some(10), + chunk_length: Some(12), + }), + ), + (r#"{"type":"wrong"}"#, Vdaf::Unrecognized), + ] { + assert_eq!(serde_json::from_str::(serialized).unwrap(), vdaf); + } +} diff --git a/src/routes/tasks.rs b/src/routes/tasks.rs index 810822ef..1fd5c45c 100644 --- a/src/routes/tasks.rs +++ b/src/routes/tasks.rs @@ -47,10 +47,10 @@ impl FromConn for Task { type CreateArgs = (Account, Json, State, Db); pub async fn create( conn: &mut Conn, - (account, task, State(client), db): CreateArgs, + (account, mut task, State(client), db): CreateArgs, ) -> Result { let crypter = conn.state().unwrap(); - task.validate(account, &db) + task.normalize_and_validate(account, &db) .await? .provision(client, crypter) .await? diff --git a/tests/new_task.rs b/tests/new_task.rs index ac5f686b..7b79ffb8 100644 --- a/tests/new_task.rs +++ b/tests/new_task.rs @@ -1,11 +1,11 @@ use divviup_api::entity::aggregator::Role; use test_support::{assert_eq, test, *}; -pub async fn assert_errors(app: &DivviupApi, new_task: &NewTask, field: &str, codes: &[&str]) { +pub async fn assert_errors(app: &DivviupApi, new_task: &mut NewTask, field: &str, codes: &[&str]) { let account = fixtures::account(app).await; assert_eq!( new_task - .validate(account, app.db()) + .normalize_and_validate(account, app.db()) .await .unwrap_err() .field_errors() @@ -20,7 +20,7 @@ pub async fn assert_errors(app: &DivviupApi, new_task: &NewTask, field: &str, co async fn batch_size(app: DivviupApi) -> TestResult { assert_errors( &app, - &NewTask { + &mut NewTask { min_batch_size: Some(100), max_batch_size: Some(50), ..Default::default() @@ -32,7 +32,7 @@ async fn batch_size(app: DivviupApi) -> TestResult { assert_errors( &app, - &NewTask { + &mut NewTask { min_batch_size: Some(100), max_batch_size: Some(50), ..Default::default() @@ -58,7 +58,7 @@ async fn aggregator_roles(app: DivviupApi) -> TestResult { assert_errors( &app, - &NewTask { + &mut NewTask { leader_aggregator_id: Some(helper.id.to_string()), helper_aggregator_id: Some(either.id.to_string()), ..Default::default() @@ -70,7 +70,7 @@ async fn aggregator_roles(app: DivviupApi) -> TestResult { assert_errors( &app, - &NewTask { + &mut NewTask { helper_aggregator_id: Some(leader.id.to_string()), leader_aggregator_id: Some(either.id.to_string()), ..Default::default() @@ -80,13 +80,13 @@ async fn aggregator_roles(app: DivviupApi) -> TestResult { ) .await; - let ok_aggregators = NewTask { + let mut ok_aggregators = NewTask { helper_aggregator_id: Some(helper.id.to_string()), leader_aggregator_id: Some(leader.id.to_string()), ..Default::default() }; - assert_errors(&app, &ok_aggregators, "helper_aggregator_id", &[]).await; - assert_errors(&app, &ok_aggregators, "leader_aggregator_id", &[]).await; + assert_errors(&app, &mut ok_aggregators, "helper_aggregator_id", &[]).await; + assert_errors(&app, &mut ok_aggregators, "leader_aggregator_id", &[]).await; Ok(()) } diff --git a/tests/vdaf.rs b/tests/vdaf.rs index 14323ccd..e33e3c28 100644 --- a/tests/vdaf.rs +++ b/tests/vdaf.rs @@ -1,4 +1,4 @@ -use divviup_api::entity::task::vdaf::Vdaf; +use divviup_api::entity::task::vdaf::{BucketLength, CategoricalBuckets, Histogram, Vdaf}; use test_support::{assert_eq, test, *}; #[test] pub fn histogram_representations() { @@ -9,19 +9,29 @@ pub fn histogram_representations() { Err(json!({"buckets": [{"code": "must-be-numbers", "message": null, "params": {}}]})), ), ( - json!({"type": "histogram", "buckets": ["a", "b", "c"]}), + json!({"type": "histogram", "buckets": ["a", "b", "c"], "chunk_length": 1}), + Protocol::Dap04, + Err(json!({"buckets": [{"code": "must-be-numbers", "message": null, "params": {}}]})), + ), + ( + json!({"type": "histogram", "buckets": ["a", "b", "c"], "chunk_length": 1}), Protocol::Dap07, - Ok(json!({"Prio3Histogram": {"length": 3}})), + Ok(json!({"Prio3Histogram": {"length": 3, "chunk_length": 1}})), ), ( json!({"type": "histogram", "buckets": [1, 2, 3]}), Protocol::Dap04, - Ok(json!({"Prio3Histogram": {"buckets": [1, 2, 3]}})), + Ok(json!({"Prio3Histogram": {"buckets": [1, 2, 3], "chunk_length": null}})), ), ( - json!({"type": "histogram", "buckets": [1, 2, 3]}), + json!({"type": "histogram", "buckets": [1, 2, 3], "chunk_length": 2}), + Protocol::Dap04, + Err(json!({"chunk_length": [{"code": "not-allowed", "message": null, "params": {}}]})), + ), + ( + json!({"type": "histogram", "buckets": [1, 2, 3], "chunk_length": 2}), Protocol::Dap07, - Ok(json!({"Prio3Histogram": {"length": 4}})), + Ok(json!({"Prio3Histogram": {"length": 4, "chunk_length": 2}})), ), ( json!({"type": "histogram", "length": 3}), @@ -29,9 +39,9 @@ pub fn histogram_representations() { Err(json!({"buckets": [{"code": "required", "message": null, "params":{}}]})), ), ( - json!({"type": "histogram", "length": 3}), + json!({"type": "histogram", "length": 3, "chunk_length": 1}), Protocol::Dap07, - Ok(json!({"Prio3Histogram": {"length": 3}})), + Ok(json!({"Prio3Histogram": {"length": 3, "chunk_length": 1}})), ), ]; @@ -46,3 +56,23 @@ pub fn histogram_representations() { ); } } + +#[test] +#[should_panic] +fn histogram_representation_dap_07_no_chunk_length_1() { + let _ = Vdaf::Histogram(Histogram::Categorical(CategoricalBuckets { + buckets: Some(Vec::from(["a".to_owned(), "b".to_owned(), "c".to_owned()])), + chunk_length: None, + })) + .representation_for_protocol(&Protocol::Dap07); +} + +#[test] +#[should_panic] +fn histogram_representation_dap_07_no_chunk_length_2() { + let _ = Vdaf::Histogram(Histogram::Opaque(BucketLength { + length: 3, + chunk_length: None, + })) + .representation_for_protocol(&Protocol::Dap07); +}