Skip to content

Commit

Permalink
Add aggregation parameter to AggregateShareAad
Browse files Browse the repository at this point in the history
Relevant to #1669
  • Loading branch information
tgeoghegan committed Sep 14, 2023
1 parent 475edb3 commit 6e312e1
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 16 deletions.
9 changes: 7 additions & 2 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2624,6 +2624,7 @@ impl VdafOps {
&leader_aggregate_share.get_encoded(),
&AggregateShareAad::new(
*collection_job.task_id(),
collection_job.aggregation_parameter().get_encoded(),
BatchSelector::<Q>::new(collection_job.batch_identifier().clone()),
)
.get_encoded(),
Expand Down Expand Up @@ -2965,8 +2966,12 @@ impl VdafOps {
collector_hpke_config,
&HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector),
&aggregate_share_job.helper_aggregate_share().get_encoded(),
&AggregateShareAad::new(*task.id(), aggregate_share_req.batch_selector().clone())
.get_encoded(),
&AggregateShareAad::new(
*task.id(),
aggregate_share_job.aggregation_parameter().get_encoded(),
aggregate_share_req.batch_selector().clone(),
)
.get_encoded(),
)?;

Ok(AggregateShare::new(encrypted_aggregate_share))
Expand Down
6 changes: 5 additions & 1 deletion aggregator/src/aggregator/collection_job_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,10 @@ async fn collection_job_success_fixed_size() {
let vdaf = dummy_vdaf::Vdaf::new();
let leader_aggregate_share = dummy_vdaf::AggregateShare(0);
let helper_aggregate_share = dummy_vdaf::AggregateShare(1);
let aggregation_param = dummy_vdaf::AggregationParam::default();
let request = CollectionReq::new(
Query::new_fixed_size(FixedSizeQuery::CurrentBatch),
AggregationParam::default().get_encoded(),
aggregation_param.get_encoded(),
);

for _ in 0..2 {
Expand Down Expand Up @@ -323,6 +324,7 @@ async fn collection_job_success_fixed_size() {
&helper_aggregate_share_bytes,
&AggregateShareAad::new(
*task.id(),
aggregation_param.get_encoded(),
BatchSelector::new_fixed_size(batch_id),
)
.get_encoded(),
Expand Down Expand Up @@ -374,6 +376,7 @@ async fn collection_job_success_fixed_size() {
collect_resp.leader_encrypted_aggregate_share(),
&AggregateShareAad::new(
*test_case.task.id(),
aggregation_param.get_encoded(),
BatchSelector::new_fixed_size(batch_id),
)
.get_encoded(),
Expand All @@ -392,6 +395,7 @@ async fn collection_job_success_fixed_size() {
collect_resp.helper_encrypted_aggregate_share(),
&AggregateShareAad::new(
*test_case.task.id(),
aggregation_param.get_encoded(),
BatchSelector::new_fixed_size(batch_id),
)
.get_encoded(),
Expand Down
11 changes: 9 additions & 2 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4443,6 +4443,7 @@ mod tests {
&helper_aggregate_share_bytes,
&AggregateShareAad::new(
*task.id(),
aggregation_param.get_encoded(),
BatchSelector::new_time_interval(batch_interval),
)
.get_encoded(),
Expand Down Expand Up @@ -4492,6 +4493,7 @@ mod tests {
collect_resp.leader_encrypted_aggregate_share(),
&AggregateShareAad::new(
*test_case.task.id(),
aggregation_param.get_encoded(),
BatchSelector::new_time_interval(batch_interval),
)
.get_encoded(),
Expand All @@ -4510,6 +4512,7 @@ mod tests {
collect_resp.helper_encrypted_aggregate_share(),
&AggregateShareAad::new(
*test_case.task.id(),
aggregation_param.get_encoded(),
BatchSelector::new_time_interval(batch_interval),
)
.get_encoded(),
Expand Down Expand Up @@ -5235,8 +5238,12 @@ mod tests {
&Role::Collector,
),
aggregate_share_resp.encrypted_aggregate_share(),
&AggregateShareAad::new(*task.id(), request.batch_selector().clone())
.get_encoded(),
&AggregateShareAad::new(
*task.id(),
dummy_vdaf::AggregationParam(0).get_encoded(),
request.batch_selector().clone(),
)
.get_encoded(),
)
.unwrap();

Expand Down
8 changes: 7 additions & 1 deletion aggregator/src/aggregator/taskprov_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,12 @@ async fn taskprov_aggregate_share() {
test.collector_hpke_keypair.private_key(),
&HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector),
aggregate_share_resp.encrypted_aggregate_share(),
&AggregateShareAad::new(test.task_id, request.batch_selector().clone()).get_encoded(),
&AggregateShareAad::new(
test.task_id,
test.aggregation_param.get_encoded(),
request.batch_selector().clone(),
)
.get_encoded(),
)
.unwrap();
}
Expand Down Expand Up @@ -1114,6 +1119,7 @@ async fn end_to_end() {
aggregate_share_resp.encrypted_aggregate_share(),
&AggregateShareAad::new(
test.task_id,
test.aggregation_param.get_encoded(),
aggregate_share_request.batch_selector().clone(),
)
.get_encoded(),
Expand Down
1 change: 1 addition & 0 deletions aggregator_core/src/datastore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2580,6 +2580,7 @@ async fn get_collection_job(ephemeral_datastore: EphemeralDatastore) {
&[0, 1, 2, 3, 4, 5],
&AggregateShareAad::new(
*task.id(),
().get_encoded(),
BatchSelector::new_time_interval(first_batch_interval),
)
.get_encoded(),
Expand Down
24 changes: 16 additions & 8 deletions collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ impl<V: vdaf::Collector> Collector<V> {
encrypted_aggregate_share,
&AggregateShareAad::new(
self.parameters.task_id,
job.aggregation_parameter.get_encoded(),
BatchSelector::<Q>::new(Q::batch_identifier_for_collection(
&job.query,
&collect_response,
Expand Down Expand Up @@ -692,10 +693,12 @@ mod tests {
fn build_collect_response_time<const SEED_SIZE: usize, V: vdaf::Aggregator<SEED_SIZE, 16>>(
transcript: &VdafTranscript<SEED_SIZE, V>,
parameters: &CollectorParameters,
aggregation_parameter: &V::AggregationParam,
batch_interval: Interval,
) -> CollectionMessage<TimeInterval> {
let associated_data = AggregateShareAad::new(
parameters.task_id,
aggregation_parameter.get_encoded(),
BatchSelector::new_time_interval(batch_interval),
);
CollectionMessage::new(
Expand All @@ -722,10 +725,14 @@ mod tests {
fn build_collect_response_fixed<const SEED_SIZE: usize, V: vdaf::Aggregator<SEED_SIZE, 16>>(
transcript: &VdafTranscript<SEED_SIZE, V>,
parameters: &CollectorParameters,
aggregation_parameter: &V::AggregationParam,
batch_id: BatchId,
) -> CollectionMessage<FixedSize> {
let associated_data =
AggregateShareAad::new(parameters.task_id, BatchSelector::new_fixed_size(batch_id));
let associated_data = AggregateShareAad::new(
parameters.task_id,
aggregation_parameter.get_encoded(),
BatchSelector::new_fixed_size(batch_id),
);
CollectionMessage::new(
PartialBatchSelector::new_fixed_size(batch_id),
1,
Expand Down Expand Up @@ -793,7 +800,7 @@ mod tests {
)
.unwrap();
let collect_resp =
build_collect_response_time(&transcript, &collector.parameters, batch_interval);
build_collect_response_time(&transcript, &collector.parameters, &(), batch_interval);
let matcher = collection_uri_regex_matcher(&collector.parameters.task_id);

let mocked_collect_start_error = server
Expand Down Expand Up @@ -895,7 +902,7 @@ mod tests {
)
.unwrap();
let collect_resp =
build_collect_response_time(&transcript, &collector.parameters, batch_interval);
build_collect_response_time(&transcript, &collector.parameters, &(), batch_interval);
let matcher = collection_uri_regex_matcher(&collector.parameters.task_id);

let mocked_collect_start_success = server
Expand Down Expand Up @@ -965,7 +972,7 @@ mod tests {
)
.unwrap();
let collect_resp =
build_collect_response_time(&transcript, &collector.parameters, batch_interval);
build_collect_response_time(&transcript, &collector.parameters, &(), batch_interval);
let matcher = collection_uri_regex_matcher(&collector.parameters.task_id);

let mocked_collect_start_success = server
Expand Down Expand Up @@ -1045,7 +1052,7 @@ mod tests {
)
.unwrap();
let collect_resp =
build_collect_response_time(&transcript, &collector.parameters, batch_interval);
build_collect_response_time(&transcript, &collector.parameters, &(), batch_interval);
let matcher = collection_uri_regex_matcher(&collector.parameters.task_id);

let mocked_collect_start_success = server
Expand Down Expand Up @@ -1112,7 +1119,7 @@ mod tests {

let batch_id = random();
let collect_resp =
build_collect_response_fixed(&transcript, &collector.parameters, batch_id);
build_collect_response_fixed(&transcript, &collector.parameters, &(), batch_id);
let matcher = collection_uri_regex_matcher(&collector.parameters.task_id);

let mocked_collect_start_success = server
Expand Down Expand Up @@ -1198,7 +1205,7 @@ mod tests {
)
.unwrap();
let collect_resp =
build_collect_response_time(&transcript, &collector.parameters, batch_interval);
build_collect_response_time(&transcript, &collector.parameters, &(), batch_interval);
let matcher = collection_uri_regex_matcher(&collector.parameters.task_id);

let mocked_collect_start_success = server
Expand Down Expand Up @@ -1491,6 +1498,7 @@ mod tests {

let associated_data = AggregateShareAad::new(
collector.parameters.task_id,
().get_encoded(),
BatchSelector::new_time_interval(batch_interval),
);
let collect_resp = CollectionMessage::new(
Expand Down
35 changes: 33 additions & 2 deletions messages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1845,14 +1845,20 @@ impl Decode for InputShareAad {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AggregateShareAad<Q: QueryType> {
task_id: TaskId,
aggregation_parameter: Vec<u8>,
batch_selector: BatchSelector<Q>,
}

impl<Q: QueryType> AggregateShareAad<Q> {
/// Constructs a new aggregate share AAD.
pub fn new(task_id: TaskId, batch_selector: BatchSelector<Q>) -> Self {
pub fn new(
task_id: TaskId,
aggregation_parameter: Vec<u8>,
batch_selector: BatchSelector<Q>,
) -> Self {
Self {
task_id,
aggregation_parameter,
batch_selector,
}
}
Expand All @@ -1862,6 +1868,11 @@ impl<Q: QueryType> AggregateShareAad<Q> {
&self.task_id
}

/// Retrieves the aggregation parameter associated with this aggregate share AAD.
pub fn aggregation_parameter(&self) -> &[u8] {
&self.aggregation_parameter
}

/// Retrieves the batch selector associated with this aggregate share AAD.
pub fn batch_selector(&self) -> &BatchSelector<Q> {
&self.batch_selector
Expand All @@ -1871,21 +1882,29 @@ impl<Q: QueryType> AggregateShareAad<Q> {
impl<Q: QueryType> Encode for AggregateShareAad<Q> {
fn encode(&self, bytes: &mut Vec<u8>) {
self.task_id.encode(bytes);
encode_u32_items(bytes, &(), &self.aggregation_parameter);
self.batch_selector.encode(bytes);
}

fn encoded_len(&self) -> Option<usize> {
Some(self.task_id.encoded_len()? + self.batch_selector.encoded_len()?)
Some(
self.task_id.encoded_len()?
+ 4
+ self.aggregation_parameter.len()
+ self.batch_selector.encoded_len()?,
)
}
}

impl<Q: QueryType> Decode for AggregateShareAad<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let task_id = TaskId::decode(bytes)?;
let aggregation_parameter = decode_u32_items(&(), bytes)?;
let batch_selector = BatchSelector::decode(bytes)?;

Ok(Self {
task_id,
aggregation_parameter,
batch_selector,
})
}
Expand Down Expand Up @@ -4916,6 +4935,7 @@ mod tests {
roundtrip_encoding(&[(
AggregateShareAad::<TimeInterval> {
task_id: TaskId::from([12u8; 32]),
aggregation_parameter: Vec::from([0, 1, 2, 3]),
batch_selector: BatchSelector {
batch_identifier: Interval::new(
Time::from_seconds_since_epoch(54321),
Expand All @@ -4926,6 +4946,11 @@ mod tests {
},
concat!(
"0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C", // task_id
concat!(
// aggregation_parameter
"00000004", // length
"00010203", //opaque data
),
concat!(
// batch_selector
"01", // query_type
Expand All @@ -4942,12 +4967,18 @@ mod tests {
roundtrip_encoding(&[(
AggregateShareAad::<FixedSize> {
task_id: TaskId::from([u8::MIN; 32]),
aggregation_parameter: Vec::from([3, 2, 1, 0]),
batch_selector: BatchSelector {
batch_identifier: BatchId::from([7u8; 32]),
},
},
concat!(
"0000000000000000000000000000000000000000000000000000000000000000", // task_id
concat!(
// aggregation_parameter
"00000004", // length
"03020100", //opaque data
),
concat!(
// batch_selector
"02", // query_type
Expand Down

0 comments on commit 6e312e1

Please sign in to comment.