Skip to content

Commit

Permalink
Include batch details in batch mismatch error (#2483)
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga authored Jan 23, 2024
1 parent 050df75 commit 2b64b2a
Showing 1 changed file with 87 additions and 57 deletions.
144 changes: 87 additions & 57 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ impl Handler for Error {
),
Error::BatchMismatch(inner) => conn.with_problem_document(
&ProblemDocument::new_dap(DapProblemType::BatchMismatch)
.with_task_id(&inner.task_id),
.with_task_id(&inner.task_id)
.with_detail(&inner.to_string()),
),
Error::BatchQueriedTooManyTimes(task_id, _) => conn.with_problem_document(
&ProblemDocument::new_dap(DapProblemType::BatchQueriedTooManyTimes)
Expand Down Expand Up @@ -687,7 +688,7 @@ mod tests {
},
collection_job_tests::setup_collection_job_test_case,
empty_batch_aggregations,
error::ReportRejectedReason,
error::{BatchMismatch, ReportRejectedReason},
http_handlers::{
aggregator_handler, aggregator_handler_with_aggregator,
test_util::{decode_response_body, take_problem_details},
Expand Down Expand Up @@ -5127,6 +5128,25 @@ mod tests {
);

// Put some batch aggregations in the DB.
let interval_1 =
Interval::new(Time::from_seconds_since_epoch(500), *task.time_precision()).unwrap();
let interval_1_report_count = 5;
let interval_1_checksum = ReportIdChecksum::get_decoded(&[3; 32]).unwrap();

let interval_2 =
Interval::new(Time::from_seconds_since_epoch(1500), *task.time_precision()).unwrap();
let interval_2_report_count = 5;
let interval_2_checksum = ReportIdChecksum::get_decoded(&[2; 32]).unwrap();

let interval_3 =
Interval::new(Time::from_seconds_since_epoch(2000), *task.time_precision()).unwrap();
let interval_3_report_count = 5;
let interval_3_checksum = ReportIdChecksum::get_decoded(&[4; 32]).unwrap();

let interval_4 =
Interval::new(Time::from_seconds_since_epoch(2500), *task.time_precision()).unwrap();
let interval_4_report_count = 5;
let interval_4_checksum = ReportIdChecksum::get_decoded(&[8; 32]).unwrap();
datastore
.run_unnamed_tx(|tx| {
let task = helper_task.clone();
Expand All @@ -5135,11 +5155,6 @@ mod tests {
dummy_vdaf::AggregationParam(0),
dummy_vdaf::AggregationParam(1),
] {
let interval_1 = Interval::new(
Time::from_seconds_since_epoch(500),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_1,
Expand All @@ -5161,18 +5176,13 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(64)),
5,
interval_1_report_count,
interval_1,
ReportIdChecksum::get_decoded(&[3; 32]).unwrap(),
interval_1_checksum,
))
.await
.unwrap();

let interval_2 = Interval::new(
Time::from_seconds_since_epoch(1500),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_2,
Expand All @@ -5194,18 +5204,13 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(128)),
5,
interval_2_report_count,
interval_2,
ReportIdChecksum::get_decoded(&[2; 32]).unwrap(),
interval_2_checksum,
))
.await
.unwrap();

let interval_3 = Interval::new(
Time::from_seconds_since_epoch(2000),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_3,
Expand All @@ -5227,18 +5232,13 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(256)),
5,
interval_3_report_count,
interval_3,
ReportIdChecksum::get_decoded(&[4; 32]).unwrap(),
interval_3_checksum,
))
.await
.unwrap();

let interval_4 = Interval::new(
Time::from_seconds_since_epoch(2500),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_4,
Expand All @@ -5260,9 +5260,9 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(512)),
5,
interval_4_report_count,
interval_4,
ReportIdChecksum::get_decoded(&[8; 32]).unwrap(),
interval_4_checksum,
))
.await
.unwrap();
Expand Down Expand Up @@ -5310,33 +5310,47 @@ mod tests {
);

// Make requests that will fail because the checksum or report counts don't match.
struct MisalignedRequestTestCase<Q: janus_messages::query_type::QueryType> {
name: &'static str,
request: AggregateShareReq<Q>,
expected_checksum: ReportIdChecksum,
expected_report_count: u64,
}
for misaligned_request in [
// Interval is big enough, but checksum doesn't match.
AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(0),
Duration::from_seconds(2000),
)
.unwrap(),
MisalignedRequestTestCase {
name: "Interval is big enough but the checksums don't match",
request: AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(0),
Duration::from_seconds(2000),
)
.unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded().unwrap(),
10,
ReportIdChecksum::get_decoded(&[3; 32]).unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded().unwrap(),
10,
ReportIdChecksum::get_decoded(&[3; 32]).unwrap(),
),
// Interval is big enough, but report count doesn't match.
AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(2000),
Duration::from_seconds(2000),
)
.unwrap(),
expected_checksum: interval_1_checksum.combined_with(&interval_2_checksum),
expected_report_count: interval_1_report_count + interval_2_report_count,
},
MisalignedRequestTestCase {
name: "Interval is big enough but report count doesn't match",
request: AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(2000),
Duration::from_seconds(2000),
)
.unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded().unwrap(),
20,
ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded().unwrap(),
20,
ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(),
),
expected_checksum: interval_3_checksum.combined_with(&interval_4_checksum),
expected_report_count: interval_3_report_count + interval_4_report_count,
},
] {
let (header, value) = task.aggregator_auth_token().request_authentication();
let mut test_conn = post(task.aggregate_shares_uri().unwrap().path())
Expand All @@ -5345,19 +5359,35 @@ mod tests {
KnownHeaderName::ContentType,
AggregateShareReq::<TimeInterval>::MEDIA_TYPE,
)
.with_request_body(misaligned_request.get_encoded().unwrap())
.with_request_body(misaligned_request.request.get_encoded().unwrap())
.run_async(&handler)
.await;

assert_eq!(test_conn.status(), Some(Status::BadRequest));
assert_eq!(
test_conn.status(),
Some(Status::BadRequest),
"{}",
misaligned_request.name
);

let expected_error = BatchMismatch {
task_id: *task.id(),
own_checksum: misaligned_request.expected_checksum,
own_report_count: misaligned_request.expected_report_count,
peer_checksum: *misaligned_request.request.checksum(),
peer_report_count: misaligned_request.request.report_count(),
};
assert_eq!(
take_problem_details(&mut test_conn).await,
json!({
"status": Status::BadRequest as u16,
"type": "urn:ietf:params:ppm:dap:error:batchMismatch",
"title": "Leader and helper disagree on reports aggregated in a batch.",
"taskid": format!("{}", task.id()),
})
"detail": expected_error.to_string(),
}),
"{}",
misaligned_request.name,
);
}

Expand Down

0 comments on commit 2b64b2a

Please sign in to comment.