From 7231945f628f2ade6a53f8969fc42f9a17cef804 Mon Sep 17 00:00:00 2001 From: Ameer Ghani Date: Tue, 16 Jan 2024 14:34:18 -0500 Subject: [PATCH] Include batch details in batch mismatch error --- aggregator/src/aggregator/http_handlers.rs | 144 +++++++++++++-------- 1 file changed, 87 insertions(+), 57 deletions(-) diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index a6bfe22c0e..feacd5e924 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -107,7 +107,8 @@ impl Handler for Error { ), Error::BatchMismatch(inner) => conn.with_problem_document( &ProblemDocument::new_dap(DapProblemType::BatchMismatch) - .with_task_id(&inner.task_id), + .with_task_id(&inner.task_id) + .with_detail(&inner.to_string()), ), Error::BatchQueriedTooManyTimes(task_id, _) => conn.with_problem_document( &ProblemDocument::new_dap(DapProblemType::BatchQueriedTooManyTimes) @@ -682,7 +683,7 @@ mod tests { }, collection_job_tests::setup_collection_job_test_case, empty_batch_aggregations, - error::ReportRejectedReason, + error::{BatchMismatch, ReportRejectedReason}, http_handlers::{ aggregator_handler, aggregator_handler_with_aggregator, test_util::{decode_response_body, take_problem_details}, @@ -5090,6 +5091,25 @@ mod tests { ); // Put some batch aggregations in the DB. + let interval_1 = + Interval::new(Time::from_seconds_since_epoch(500), *task.time_precision()).unwrap(); + let interval_1_report_count = 5; + let interval_1_checksum = ReportIdChecksum::get_decoded(&[3; 32]).unwrap(); + + let interval_2 = + Interval::new(Time::from_seconds_since_epoch(1500), *task.time_precision()).unwrap(); + let interval_2_report_count = 5; + let interval_2_checksum = ReportIdChecksum::get_decoded(&[2; 32]).unwrap(); + + let interval_3 = + Interval::new(Time::from_seconds_since_epoch(2000), *task.time_precision()).unwrap(); + let interval_3_report_count = 5; + let interval_3_checksum = ReportIdChecksum::get_decoded(&[4; 32]).unwrap(); + + let interval_4 = + Interval::new(Time::from_seconds_since_epoch(2500), *task.time_precision()).unwrap(); + let interval_4_report_count = 5; + let interval_4_checksum = ReportIdChecksum::get_decoded(&[8; 32]).unwrap(); datastore .run_unnamed_tx(|tx| { let task = helper_task.clone(); @@ -5098,11 +5118,6 @@ mod tests { dummy_vdaf::AggregationParam(0), dummy_vdaf::AggregationParam(1), ] { - let interval_1 = Interval::new( - Time::from_seconds_since_epoch(500), - *task.time_precision(), - ) - .unwrap(); tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new( *task.id(), interval_1, @@ -5124,18 +5139,13 @@ mod tests { 0, BatchAggregationState::Aggregating, Some(dummy_vdaf::AggregateShare(64)), - 5, + interval_1_report_count, interval_1, - ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), + interval_1_checksum, )) .await .unwrap(); - let interval_2 = Interval::new( - Time::from_seconds_since_epoch(1500), - *task.time_precision(), - ) - .unwrap(); tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new( *task.id(), interval_2, @@ -5157,18 +5167,13 @@ mod tests { 0, BatchAggregationState::Aggregating, Some(dummy_vdaf::AggregateShare(128)), - 5, + interval_2_report_count, interval_2, - ReportIdChecksum::get_decoded(&[2; 32]).unwrap(), + interval_2_checksum, )) .await .unwrap(); - let interval_3 = Interval::new( - Time::from_seconds_since_epoch(2000), - *task.time_precision(), - ) - .unwrap(); tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new( *task.id(), interval_3, @@ -5190,18 +5195,13 @@ mod tests { 0, BatchAggregationState::Aggregating, Some(dummy_vdaf::AggregateShare(256)), - 5, + interval_3_report_count, interval_3, - ReportIdChecksum::get_decoded(&[4; 32]).unwrap(), + interval_3_checksum, )) .await .unwrap(); - let interval_4 = Interval::new( - Time::from_seconds_since_epoch(2500), - *task.time_precision(), - ) - .unwrap(); tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new( *task.id(), interval_4, @@ -5223,9 +5223,9 @@ mod tests { 0, BatchAggregationState::Aggregating, Some(dummy_vdaf::AggregateShare(512)), - 5, + interval_4_report_count, interval_4, - ReportIdChecksum::get_decoded(&[8; 32]).unwrap(), + interval_4_checksum, )) .await .unwrap(); @@ -5273,33 +5273,47 @@ mod tests { ); // Make requests that will fail because the checksum or report counts don't match. + struct MisalignedRequestTestCase { + name: &'static str, + request: AggregateShareReq, + expected_checksum: ReportIdChecksum, + expected_report_count: u64, + } for misaligned_request in [ - // Interval is big enough, but checksum doesn't match. - AggregateShareReq::new( - BatchSelector::new_time_interval( - Interval::new( - Time::from_seconds_since_epoch(0), - Duration::from_seconds(2000), - ) - .unwrap(), + MisalignedRequestTestCase { + name: "Interval is big enough but the checksums don't match", + request: AggregateShareReq::new( + BatchSelector::new_time_interval( + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(2000), + ) + .unwrap(), + ), + dummy_vdaf::AggregationParam(0).get_encoded(), + 10, + ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), ), - dummy_vdaf::AggregationParam(0).get_encoded(), - 10, - ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), - ), - // Interval is big enough, but report count doesn't match. - AggregateShareReq::new( - BatchSelector::new_time_interval( - Interval::new( - Time::from_seconds_since_epoch(2000), - Duration::from_seconds(2000), - ) - .unwrap(), + expected_checksum: interval_1_checksum.combined_with(&interval_2_checksum), + expected_report_count: interval_1_report_count + interval_2_report_count, + }, + MisalignedRequestTestCase { + name: "Interval is big enough but report count doesn't match", + request: AggregateShareReq::new( + BatchSelector::new_time_interval( + Interval::new( + Time::from_seconds_since_epoch(2000), + Duration::from_seconds(2000), + ) + .unwrap(), + ), + dummy_vdaf::AggregationParam(0).get_encoded(), + 20, + ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(), ), - dummy_vdaf::AggregationParam(0).get_encoded(), - 20, - ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(), - ), + expected_checksum: interval_3_checksum.combined_with(&interval_4_checksum), + expected_report_count: interval_3_report_count + interval_4_report_count, + }, ] { let (header, value) = task.aggregator_auth_token().request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) @@ -5308,11 +5322,24 @@ mod tests { KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, ) - .with_request_body(misaligned_request.get_encoded()) + .with_request_body(misaligned_request.request.get_encoded()) .run_async(&handler) .await; - assert_eq!(test_conn.status(), Some(Status::BadRequest)); + assert_eq!( + test_conn.status(), + Some(Status::BadRequest), + "{}", + misaligned_request.name + ); + + let expected_error = BatchMismatch { + task_id: *task.id(), + own_checksum: misaligned_request.expected_checksum, + own_report_count: misaligned_request.expected_report_count, + peer_checksum: *misaligned_request.request.checksum(), + peer_report_count: misaligned_request.request.report_count(), + }; assert_eq!( take_problem_details(&mut test_conn).await, json!({ @@ -5320,7 +5347,10 @@ mod tests { "type": "urn:ietf:params:ppm:dap:error:batchMismatch", "title": "Leader and helper disagree on reports aggregated in a batch.", "taskid": format!("{}", task.id()), - }) + "detail": expected_error.to_string(), + }), + "{}", + misaligned_request.name, ); }