Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protocol-determined vdaf representation #419

Merged
merged 5 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions cli/src/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{CliResult, DetermineAccountId, Output};
use crate::{CliResult, DetermineAccountId, Error, Output};
use clap::Subcommand;
use divviup_client::{DivviupClient, NewTask, Uuid, Vdaf};
use divviup_client::{DivviupClient, Histogram, NewTask, Uuid, Vdaf};
use humantime::{Duration, Timestamp};
use time::{OffsetDateTime, UtcOffset};

Expand Down Expand Up @@ -38,8 +38,10 @@ pub enum TaskAction {
time_precision: Duration,
#[arg(long)]
hpke_config_id: Uuid,
#[arg(long, required_if_eq("vdaf", "histogram"), value_delimiter = ',')]
buckets: Option<Vec<u64>>,
#[arg(long, value_delimiter = ',')]
categorical_buckets: Option<Vec<String>>,
#[arg(long, value_delimiter = ',')]
continuous_buckets: Option<Vec<u64>>,
#[arg(long, required_if_eq_any([("vdaf", "count_vec"), ("vdaf", "sum_vec")]))]
length: Option<u64>,
#[arg(long, required_if_eq_any([("vdaf", "sum"), ("vdaf", "sum_vec")]))]
Expand Down Expand Up @@ -73,16 +75,33 @@ impl TaskAction {
max_batch_size,
expiration,
hpke_config_id,
buckets,
categorical_buckets,
continuous_buckets,
length,
bits,
time_precision,
} => {
let vdaf = match vdaf {
VdafName::Count => Vdaf::Count,
VdafName::Histogram => Vdaf::Histogram {
buckets: buckets.unwrap(),
},
VdafName::Histogram => {
match (length, categorical_buckets, continuous_buckets) {
(Some(length), None, None) => {
Vdaf::Histogram(Histogram::Length { length })
}
(None, Some(buckets), None) => {
Vdaf::Histogram(Histogram::Categorical { buckets })
}
(None, None, Some(buckets)) => {
Vdaf::Histogram(Histogram::Continuous { buckets })
}
(None, None, None) => {
return Err(Error::Other("continuous-buckets, categorical-buckets, or length are required for histogram vdaf".into()));
}
_ => {
return Err(Error::Other("continuous-buckets, categorical-buckets, and length are mutually exclusive".into()));
}
}
}
VdafName::Sum => Vdaf::Sum {
bits: bits.unwrap(),
},
Expand Down
2 changes: 2 additions & 0 deletions client/src/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::Protocol;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use url::Url;
Expand Down Expand Up @@ -29,6 +30,7 @@ pub struct Aggregator {
pub is_first_party: bool,
pub vdafs: Vec<String>,
pub query_types: Vec<String>,
pub protocol: Protocol,
}

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
Expand Down
4 changes: 3 additions & 1 deletion client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod aggregator;
mod api_token;
mod hpke_configs;
mod membership;
mod protocol;
mod task;
mod validation_errors;

Expand All @@ -34,7 +35,8 @@ pub use janus_messages::{
HpkeConfig as HpkeConfigContents, HpkePublicKey,
};
pub use membership::Membership;
pub use task::{NewTask, Task, Vdaf};
pub use protocol::Protocol;
pub use task::{Histogram, NewTask, Task, Vdaf};
pub use time::OffsetDateTime;
pub use trillium_client;
pub use trillium_client::Client;
Expand Down
42 changes: 42 additions & 0 deletions client/src/protocol.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt::{self, Display, Formatter},
str::FromStr,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Protocol {
#[serde(rename = "DAP-04")]
Dap04,
#[serde(rename = "DAP-05")]
Dap05,
}

impl AsRef<str> for Protocol {
fn as_ref(&self) -> &str {
match self {
Self::Dap04 => "DAP-04",
Self::Dap05 => "DAP-05",
}
}
}

#[derive(Debug)]
pub struct UnrecognizedProtocol(String);
impl Display for UnrecognizedProtocol {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{} was not a recognized protocol", self.0))
}
}
impl Error for UnrecognizedProtocol {}
impl FromStr for Protocol {
type Err = UnrecognizedProtocol;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match &*s.to_lowercase() {
"dap-04" => Ok(Self::Dap04),
"dap-05" => Ok(Self::Dap05),
unrecognized => Err(UnrecognizedProtocol(unrecognized.to_string())),
}
}
}
10 changes: 9 additions & 1 deletion client/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub enum Vdaf {
Count,

#[serde(rename = "histogram")]
Histogram { buckets: Vec<u64> },
Histogram(Histogram),

#[serde(rename = "sum")]
Sum { bits: u8 },
Expand All @@ -56,3 +56,11 @@ pub enum Vdaf {
#[serde(rename = "sum_vec")]
SumVec { bits: u8, length: u64 },
}

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(untagged)]
pub enum Histogram {
Categorical { buckets: Vec<String> },
Continuous { buckets: Vec<u64> },
Length { length: u64 },
}
2 changes: 2 additions & 0 deletions migration/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod m20230703_201332_add_additional_fields_to_api_tokens;
mod m20230725_220134_add_vdafs_and_query_types_to_aggregators;
mod m20230731_181722_rename_aggregator_bearer_token;
mod m20230808_204859_create_hpke_config;
mod m20230817_192017_add_protocol_to_aggregators;

pub struct Migrator;

Expand All @@ -41,6 +42,7 @@ impl MigratorTrait for Migrator {
Box::new(m20230725_220134_add_vdafs_and_query_types_to_aggregators::Migration),
Box::new(m20230731_181722_rename_aggregator_bearer_token::Migration),
Box::new(m20230808_204859_create_hpke_config::Migration),
Box::new(m20230817_192017_add_protocol_to_aggregators::Migration),
]
}
}
52 changes: 52 additions & 0 deletions migration/src/m20230817_192017_add_protocol_to_aggregators.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use sea_orm_migration::prelude::*;

#[derive(DeriveMigrationName)]
pub struct Migration;

#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, db: &SchemaManager) -> Result<(), DbErr> {
db.alter_table(
TableAlterStatement::new()
.table(Aggregator::Table)
.add_column(ColumnDef::new(Aggregator::Protocol).string().null())
.to_owned(),
)
.await?;

db.exec_stmt(
Query::update()
.table(Aggregator::Table)
.value(Aggregator::Protocol, "DAP-04")
.to_owned(),
)
.await?;

db.alter_table(
TableAlterStatement::new()
.table(Aggregator::Table)
.modify_column(ColumnDef::new(Aggregator::Protocol).not_null())
.to_owned(),
)
.await?;

Ok(())
}

async fn down(&self, db: &SchemaManager) -> Result<(), DbErr> {
db.alter_table(
TableAlterStatement::new()
.table(Aggregator::Table)
.drop_column(Aggregator::Protocol)
.to_owned(),
)
.await?;
Ok(())
}
}

#[derive(DeriveIden)]
enum Aggregator {
Table,
Protocol,
}
9 changes: 5 additions & 4 deletions src/api_mocks/aggregator_api.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::random_chars;
use crate::clients::aggregator_client::api_types::{
AggregatorApiConfig, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId,
HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, TaskMetrics,
TaskResponse, VdafInstance,
AggregatorApiConfig, AggregatorVdaf, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId,
HpkeKemId, HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds,
TaskMetrics, TaskResponse,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use querystrong::QueryStrong;
Expand Down Expand Up @@ -30,6 +30,7 @@ pub fn mock() -> impl Handler {
role: random(),
vdafs: Default::default(),
query_types: Default::default(),
protocol: random(),
}),
)
.post("/tasks", api(post_task))
Expand Down Expand Up @@ -70,7 +71,7 @@ async fn get_task(conn: &mut Conn, (): ()) -> Json<TaskResponse> {
task_id: task_id.parse().unwrap(),
peer_aggregator_endpoint: "https://_".parse().unwrap(),
query_type: QueryType::TimeInterval,
vdaf: VdafInstance::Prio3Count,
vdaf: AggregatorVdaf::Prio3Count,
role: Role::Leader,
vdaf_verify_key: random_chars(10),
max_batch_query_count: 100,
Expand Down
Loading