Skip to content

Commit

Permalink
Add per-task configuration for Daphne compatibility (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Dec 7, 2022
1 parent eee8e1e commit bd9a09f
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 40 deletions.
22 changes: 17 additions & 5 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ impl VdafOps {
report.task_id(),
report.metadata(),
report.public_share(),
task.input_share_aad_public_share_length_prefix(),
),
) {
Ok(leader_decrypted_input_share) => leader_decrypted_input_share,
Expand Down Expand Up @@ -1501,6 +1502,7 @@ impl VdafOps {
task.id(),
report_share.metadata(),
report_share.public_share(),
task.input_share_aad_public_share_length_prefix(),
),
)
.map_err(|error| {
Expand Down Expand Up @@ -3500,6 +3502,7 @@ mod tests {
task.id(),
&report_metadata,
&public_share.get_encoded(),
false,
);

let leader_ciphertext = hpke::seal(
Expand Down Expand Up @@ -4318,8 +4321,12 @@ mod tests {
);
let mut input_share_bytes = input_share.get_encoded();
input_share_bytes.push(0); // can no longer be decoded.
let aad =
associated_data_for_report_share(task.id(), &report_metadata_2, &encoded_public_share);
let aad = associated_data_for_report_share(
task.id(),
&report_metadata_2,
&encoded_public_share,
false,
);
let report_share_2 = generate_helper_report_share_for_plaintext(
report_metadata_2,
&hpke_key.0,
Expand Down Expand Up @@ -4415,7 +4422,8 @@ mod tests {
.unwrap(),
Vec::new(),
);
let aad = associated_data_for_report_share(task.id(), &report_metadata_6, &public_share_6);
let aad =
associated_data_for_report_share(task.id(), &report_metadata_6, &public_share_6, false);
let report_share_6 = generate_helper_report_share_for_plaintext(
report_metadata_6,
&hpke_key.0,
Expand Down Expand Up @@ -7940,8 +7948,12 @@ mod tests {
for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
{
let encoded_public_share = public_share.get_encoded();
let associated_data =
associated_data_for_report_share(task_id, report_metadata, &encoded_public_share);
let associated_data = associated_data_for_report_share(
task_id,
report_metadata,
&encoded_public_share,
false,
);
generate_helper_report_share_for_plaintext(
report_metadata.clone(),
cfg,
Expand Down
1 change: 1 addition & 0 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,7 @@ mod tests {
task_id,
report_metadata,
&public_share.get_encoded(),
false,
),
)
.unwrap();
Expand Down
16 changes: 12 additions & 4 deletions aggregator/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,9 @@ impl<C: Clock> Transaction<'_, C> {
.prepare_cached(
"INSERT INTO tasks (task_id, aggregator_role, aggregator_endpoints, query_type,
vdaf, max_batch_query_count, task_expiration, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)",
tolerable_clock_skew, collector_hpke_config,
input_share_aad_public_share_length_prefix)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)",
)
.await?;
self.tx
Expand All @@ -313,6 +314,8 @@ impl<C: Clock> Transaction<'_, C> {
/* tolerable_clock_skew */
&i64::try_from(task.tolerable_clock_skew().as_seconds())?,
/* collector_hpke_config */ &task.collector_hpke_config().get_encoded(),
/* input_share_aad_public_share_length_prefix */
&task.input_share_aad_public_share_length_prefix(),
],
)
.await?;
Expand Down Expand Up @@ -517,7 +520,8 @@ impl<C: Clock> Transaction<'_, C> {
.prepare_cached(
"SELECT aggregator_role, aggregator_endpoints, query_type, vdaf,
max_batch_query_count, task_expiration, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config
tolerable_clock_skew, collector_hpke_config,
input_share_aad_public_share_length_prefix
FROM tasks WHERE task_id = $1",
)
.await?;
Expand Down Expand Up @@ -594,7 +598,8 @@ impl<C: Clock> Transaction<'_, C> {
.prepare_cached(
"SELECT task_id, aggregator_role, aggregator_endpoints, query_type, vdaf,
max_batch_query_count, task_expiration, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config
tolerable_clock_skew, collector_hpke_config,
input_share_aad_public_share_length_prefix
FROM tasks",
)
.await?;
Expand Down Expand Up @@ -741,6 +746,8 @@ impl<C: Clock> Transaction<'_, C> {
let tolerable_clock_skew =
Duration::from_seconds(row.get_bigint_and_convert("tolerable_clock_skew")?);
let collector_hpke_config = HpkeConfig::get_decoded(row.get("collector_hpke_config"))?;
let input_share_aad_public_share_length_prefix =
row.get("input_share_aad_public_share_length_prefix");

// Aggregator authentication tokens.
let mut aggregator_auth_tokens = Vec::new();
Expand Down Expand Up @@ -827,6 +834,7 @@ impl<C: Clock> Transaction<'_, C> {
aggregator_auth_tokens,
collector_auth_tokens,
hpke_configs,
input_share_aad_public_share_length_prefix,
)?)
}

Expand Down
48 changes: 40 additions & 8 deletions aggregator/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ pub struct Task {
collector_auth_tokens: Vec<AuthenticationToken>,
/// HPKE configurations & private keys used by this aggregator to decrypt client reports.
hpke_keys: HashMap<HpkeConfigId, (HpkeConfig, HpkePrivateKey)>,
/// Configuration option to add a length prefix for the public share in the input share AAD.
input_share_aad_public_share_length_prefix: bool,
}

impl Task {
Expand All @@ -131,6 +133,7 @@ impl Task {
aggregator_auth_tokens: Vec<AuthenticationToken>,
collector_auth_tokens: Vec<AuthenticationToken>,
hpke_keys: I,
input_share_aad_public_share_length_prefix: bool,
) -> Result<Self, Error> {
// Ensure provided aggregator endpoints end with a slash, as we will be joining additional
// path segments into these endpoints & the Url::join implementation is persnickety about
Expand Down Expand Up @@ -161,6 +164,7 @@ impl Task {
aggregator_auth_tokens,
collector_auth_tokens,
hpke_keys,
input_share_aad_public_share_length_prefix,
};
task.validate()?;
Ok(task)
Expand Down Expand Up @@ -331,6 +335,12 @@ impl Task {
let secret_bytes = self.vdaf_verify_keys.first().unwrap();
VerifyKey::try_from(secret_bytes).map_err(|_| Error::AggregatorVerifyKeySize)
}

/// Fetch the configuration setting specifying whether an additional length prefix should be
/// added to the input share AAD, before the public share.
pub fn input_share_aad_public_share_length_prefix(&self) -> bool {
self.input_share_aad_public_share_length_prefix
}
}

fn fmt_vector_of_urls(urls: &Vec<Url>, f: &mut Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -360,6 +370,7 @@ struct SerializedTask {
aggregator_auth_tokens: Vec<String>, // in unpadded base64url
collector_auth_tokens: Vec<String>, // in unpadded base64url
hpke_keys: Vec<SerializedHpkeKeypair>, // in unpadded base64url
input_share_aad_public_share_length_prefix: bool,
}

impl Serialize for Task {
Expand Down Expand Up @@ -402,6 +413,8 @@ impl Serialize for Task {
aggregator_auth_tokens,
collector_auth_tokens,
hpke_keys,
input_share_aad_public_share_length_prefix: self
.input_share_aad_public_share_length_prefix,
}
.serialize(serializer)
}
Expand Down Expand Up @@ -482,6 +495,7 @@ impl<'de> Deserialize<'de> for Task {
aggregator_auth_tokens,
collector_auth_tokens,
hpke_keys,
serialized_task.input_share_aad_public_share_length_prefix,
)
.map_err(D::Error::custom)
}
Expand Down Expand Up @@ -635,6 +649,7 @@ pub mod test_util {
(aggregator_config_0, aggregator_private_key_0),
(aggregator_config_1, aggregator_private_key_1),
]),
false,
)
.unwrap(),
)
Expand Down Expand Up @@ -741,6 +756,17 @@ pub mod test_util {
})
}

/// Selects the input share AAD format.
pub fn with_input_share_aad_public_share_length_prefix(
self,
input_share_aad_public_share_length_prefix: bool,
) -> Self {
Self(Task {
input_share_aad_public_share_length_prefix,
..self.0
})
}

/// Consumes this task builder & produces a [`Task`] with the given specifications.
pub fn build(self) -> Task {
self.0.validate().unwrap();
Expand Down Expand Up @@ -772,14 +798,15 @@ mod tests {

#[test]
fn task_serialization() {
roundtrip_encoding(
TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Aes128Count,
Role::Leader,
)
.build(),
);
let mut task = TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Aes128Count,
Role::Leader,
)
.build();
roundtrip_encoding(task.clone());
task.input_share_aad_public_share_length_prefix = true;
roundtrip_encoding(task);
}

#[test]
Expand All @@ -804,6 +831,7 @@ mod tests {
Vec::from([generate_auth_token()]),
Vec::new(),
Vec::from([generate_test_hpke_config_and_private_key()]),
false,
)
.unwrap_err();

Expand All @@ -827,6 +855,7 @@ mod tests {
Vec::from([generate_auth_token()]),
Vec::from([generate_auth_token()]),
Vec::from([generate_test_hpke_config_and_private_key()]),
false,
)
.unwrap();

Expand All @@ -850,6 +879,7 @@ mod tests {
Vec::from([generate_auth_token()]),
Vec::new(),
Vec::from([generate_test_hpke_config_and_private_key()]),
false,
)
.unwrap();

Expand All @@ -873,6 +903,7 @@ mod tests {
Vec::from([generate_auth_token()]),
Vec::from([generate_auth_token()]),
Vec::from([generate_test_hpke_config_and_private_key()]),
false,
)
.unwrap_err();
}
Expand All @@ -898,6 +929,7 @@ mod tests {
Vec::from([generate_auth_token()]),
Vec::from([generate_auth_token()]),
Vec::from([generate_test_hpke_config_and_private_key()]),
false,
)
.unwrap();

Expand Down
18 changes: 16 additions & 2 deletions client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,24 @@ pub struct ClientParameters {
time_precision: Duration,
/// Parameters to use when retrying HTTP requests.
http_request_retry_parameters: ExponentialBackoff,
/// Configuration setting to add an additional length prefix to the input share AAD, before
/// the public share.
input_share_aad_public_share_length_prefix: bool,
}

impl ClientParameters {
/// Creates a new set of client task parameters.
pub fn new(task_id: TaskId, aggregator_endpoints: Vec<Url>, time_precision: Duration) -> Self {
pub fn new(
task_id: TaskId,
aggregator_endpoints: Vec<Url>,
time_precision: Duration,
input_share_aad_public_share_length_prefix: bool,
) -> Self {
Self::new_with_backoff(
task_id,
aggregator_endpoints,
time_precision,
input_share_aad_public_share_length_prefix,
http_request_exponential_backoff(),
)
}
Expand All @@ -83,6 +92,7 @@ impl ClientParameters {
task_id: TaskId,
mut aggregator_endpoints: Vec<Url>,
time_precision: Duration,
input_share_aad_public_share_length_prefix: bool,
http_request_retry_parameters: ExponentialBackoff,
) -> Self {
// Ensure provided aggregator endpoints end with a slash, as we will be joining additional
Expand All @@ -97,6 +107,7 @@ impl ClientParameters {
aggregator_endpoints,
time_precision,
http_request_retry_parameters,
input_share_aad_public_share_length_prefix,
}
}

Expand Down Expand Up @@ -222,6 +233,7 @@ where
&self.parameters.task_id,
&report_metadata,
&public_share,
self.parameters.input_share_aad_public_share_length_prefix,
);

let encrypted_input_shares: Vec<HpkeCiphertext> = [
Expand Down Expand Up @@ -305,6 +317,7 @@ mod tests {
random(),
Vec::from([server_url.clone(), server_url]),
Duration::from_seconds(1),
false,
test_http_request_exponential_backoff(),
),
vdaf_client,
Expand All @@ -324,6 +337,7 @@ mod tests {
"http://helper_endpoint".parse().unwrap(),
]),
Duration::from_seconds(1),
false,
);

assert_eq!(
Expand Down Expand Up @@ -424,7 +438,7 @@ mod tests {
install_test_trace_subscriber();

let client_parameters =
ClientParameters::new(random(), Vec::new(), Duration::from_seconds(0));
ClientParameters::new(random(), Vec::new(), Duration::from_seconds(0), false);
let client = Client::new(
client_parameters,
Prio3::new_aes128_count(2).unwrap(),
Expand Down
Loading

0 comments on commit bd9a09f

Please sign in to comment.