Skip to content

Commit

Permalink
Include batch details in batch mismatch error
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga committed Jan 12, 2024
1 parent 61d5efb commit ca17875
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 63 deletions.
151 changes: 94 additions & 57 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use prio::codec::Encode;
use ring::digest::{digest, SHA256};
use routefinder::Captures;
use serde::Deserialize;
use serde_json::json;
use std::{borrow::Cow, time::Duration as StdDuration};
use std::{io::Cursor, sync::Arc};
use tracing::warn;
Expand Down Expand Up @@ -83,7 +84,9 @@ impl Handler for Error {
"job. The job is no longer collectable. Contact the server operators for ",
"assistance."
))
.with_collection_job_id(collection_job_id),
.with_additional_fields(json!({
"collection_job_id": collection_job_id.to_string(),
})),
),
Error::UnrecognizedCollectionJob(_) => conn.with_status(Status::NotFound),
Error::OutdatedHpkeConfig(task_id, _) => conn.with_problem_document(
Expand All @@ -107,7 +110,13 @@ impl Handler for Error {
),
Error::BatchMismatch(inner) => conn.with_problem_document(
&ProblemDocument::new_dap(DapProblemType::BatchMismatch)
.with_task_id(&inner.task_id),
.with_task_id(&inner.task_id)
.with_additional_fields(json!({
"own_checksum": inner.own_checksum.to_string(),
"own_report_count": inner.own_report_count,
"peer_checksum": inner.peer_checksum.to_string(),
"peer_report_count": inner.peer_report_count,
})),
),
Error::BatchQueriedTooManyTimes(task_id, _) => conn.with_problem_document(
&ProblemDocument::new_dap(DapProblemType::BatchQueriedTooManyTimes)
Expand Down Expand Up @@ -5090,6 +5099,25 @@ mod tests {
);

// Put some batch aggregations in the DB.
let interval_1 =
Interval::new(Time::from_seconds_since_epoch(500), *task.time_precision()).unwrap();
let interval_1_report_count = 5;
let interval_1_checksum = ReportIdChecksum::get_decoded(&[3; 32]).unwrap();

let interval_2 =
Interval::new(Time::from_seconds_since_epoch(1500), *task.time_precision()).unwrap();
let interval_2_report_count = 5;
let interval_2_checksum = ReportIdChecksum::get_decoded(&[2; 32]).unwrap();

let interval_3 =
Interval::new(Time::from_seconds_since_epoch(2000), *task.time_precision()).unwrap();
let interval_3_report_count = 5;
let interval_3_checksum = ReportIdChecksum::get_decoded(&[4; 32]).unwrap();

let interval_4 =
Interval::new(Time::from_seconds_since_epoch(2500), *task.time_precision()).unwrap();
let interval_4_report_count = 5;
let interval_4_checksum = ReportIdChecksum::get_decoded(&[8; 32]).unwrap();
datastore
.run_unnamed_tx(|tx| {
let task = helper_task.clone();
Expand All @@ -5098,11 +5126,6 @@ mod tests {
dummy_vdaf::AggregationParam(0),
dummy_vdaf::AggregationParam(1),
] {
let interval_1 = Interval::new(
Time::from_seconds_since_epoch(500),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_1,
Expand All @@ -5124,18 +5147,13 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(64)),
5,
interval_1_report_count,
interval_1,
ReportIdChecksum::get_decoded(&[3; 32]).unwrap(),
interval_1_checksum,
))
.await
.unwrap();

let interval_2 = Interval::new(
Time::from_seconds_since_epoch(1500),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_2,
Expand All @@ -5157,18 +5175,13 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(128)),
5,
interval_2_report_count,
interval_2,
ReportIdChecksum::get_decoded(&[2; 32]).unwrap(),
interval_2_checksum,
))
.await
.unwrap();

let interval_3 = Interval::new(
Time::from_seconds_since_epoch(2000),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_3,
Expand All @@ -5190,18 +5203,13 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(256)),
5,
interval_3_report_count,
interval_3,
ReportIdChecksum::get_decoded(&[4; 32]).unwrap(),
interval_3_checksum,
))
.await
.unwrap();

let interval_4 = Interval::new(
Time::from_seconds_since_epoch(2500),
*task.time_precision(),
)
.unwrap();
tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new(
*task.id(),
interval_4,
Expand All @@ -5223,9 +5231,9 @@ mod tests {
0,
BatchAggregationState::Aggregating,
Some(dummy_vdaf::AggregateShare(512)),
5,
interval_4_report_count,
interval_4,
ReportIdChecksum::get_decoded(&[8; 32]).unwrap(),
interval_4_checksum,
))
.await
.unwrap();
Expand Down Expand Up @@ -5273,33 +5281,51 @@ mod tests {
);

// Make requests that will fail because the checksum or report counts don't match.
struct MisalignedRequestTestCase<Q: janus_messages::query_type::QueryType> {
name: &'static str,
request: AggregateShareReq<Q>,
expected_checksum: String,
expected_report_count: u64,
}
for misaligned_request in [
// Interval is big enough, but checksum doesn't match.
AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(0),
Duration::from_seconds(2000),
)
.unwrap(),
MisalignedRequestTestCase {
name: "Interval is big enough but the checksums don't match",
request: AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(0),
Duration::from_seconds(2000),
)
.unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded(),
10,
ReportIdChecksum::get_decoded(&[3; 32]).unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded(),
10,
ReportIdChecksum::get_decoded(&[3; 32]).unwrap(),
),
// Interval is big enough, but report count doesn't match.
AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(2000),
Duration::from_seconds(2000),
)
.unwrap(),
expected_checksum: interval_1_checksum
.combined_with(&interval_2_checksum)
.to_string(),
expected_report_count: interval_1_report_count + interval_2_report_count,
},
MisalignedRequestTestCase {
name: "Interval is big enough but report count doesn't match",
request: AggregateShareReq::new(
BatchSelector::new_time_interval(
Interval::new(
Time::from_seconds_since_epoch(2000),
Duration::from_seconds(2000),
)
.unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded(),
20,
ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(),
),
dummy_vdaf::AggregationParam(0).get_encoded(),
20,
ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(),
),
expected_checksum: interval_3_checksum
.combined_with(&interval_4_checksum)
.to_string(),
expected_report_count: interval_3_report_count + interval_4_report_count,
},
] {
let (header, value) = task.aggregator_auth_token().request_authentication();
let mut test_conn = post(task.aggregate_shares_uri().unwrap().path())
Expand All @@ -5308,19 +5334,30 @@ mod tests {
KnownHeaderName::ContentType,
AggregateShareReq::<TimeInterval>::MEDIA_TYPE,
)
.with_request_body(misaligned_request.get_encoded())
.with_request_body(misaligned_request.request.get_encoded())
.run_async(&handler)
.await;

assert_eq!(test_conn.status(), Some(Status::BadRequest));
assert_eq!(
test_conn.status(),
Some(Status::BadRequest),
"{}",
misaligned_request.name
);
assert_eq!(
take_problem_details(&mut test_conn).await,
json!({
"status": Status::BadRequest as u16,
"type": "urn:ietf:params:ppm:dap:error:batchMismatch",
"title": "Leader and helper disagree on reports aggregated in a batch.",
"taskid": format!("{}", task.id()),
})
"own_checksum": misaligned_request.expected_checksum,
"own_report_count": misaligned_request.expected_report_count,
"peer_checksum": misaligned_request.request.checksum().to_string(),
"peer_report_count": misaligned_request.request.report_count(),
}),
"{}",
misaligned_request.name,
);
}

Expand Down
36 changes: 30 additions & 6 deletions aggregator/src/aggregator/problem_details.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use janus_messages::{problem_type::DapProblemType, CollectionJobId, TaskId};
use janus_messages::{problem_type::DapProblemType, TaskId};
use serde::Serialize;
use serde_json::Value;
use trillium::{Conn, KnownHeaderName, Status};
use trillium_api::ApiConnExt;

Expand Down Expand Up @@ -36,8 +37,8 @@ pub struct ProblemDocument<'a> {
taskid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
detail: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
collection_job_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", flatten)]
additional_fields: Option<Value>,
}

impl<'a> ProblemDocument<'a> {
Expand All @@ -54,7 +55,7 @@ impl<'a> ProblemDocument<'a> {
status: status.into(),
taskid: None,
detail: None,
collection_job_id: None,
additional_fields: None,
}
}

Expand All @@ -81,9 +82,9 @@ impl<'a> ProblemDocument<'a> {
}
}

pub fn with_collection_job_id(self, collection_job_id: &CollectionJobId) -> Self {
pub fn with_additional_fields(self, fields: Value) -> Self {
Self {
collection_job_id: Some(collection_job_id.to_string()),
additional_fields: Some(fields),
..self
}
}
Expand Down Expand Up @@ -124,9 +125,13 @@ mod tests {
use opentelemetry::metrics::Unit;
use rand::random;
use reqwest::Client;
use serde_json::{json, ser::to_string};
use std::{borrow::Cow, sync::Arc};
use trillium::Status;
use trillium_testing::prelude::post;

use super::ProblemDocument;

#[test]
fn dap_problem_type_round_trip() {
for problem_type in [
Expand All @@ -150,6 +155,25 @@ mod tests {
assert_matches!("".parse::<DapProblemType>(), Err(DapProblemTypeParseError));
}

#[test]
fn problem_details_additional_fields() {
// Without additional fields.
let problem_details = ProblemDocument::new("foo", "bar", Status::Ok);

assert_eq!(
to_string(&problem_details).unwrap(),
"{\"type\":\"foo\",\"title\":\"bar\",\"status\":200}"
);

// With additional_fields.
let problem_details =
problem_details.with_additional_fields(json!({"foo": "bar", "baz": 100}));
assert_eq!(
to_string(&problem_details).unwrap(),
"{\"type\":\"foo\",\"title\":\"bar\",\"status\":200,\"baz\":100,\"foo\":\"bar\"}"
);
}

#[tokio::test]
async fn problem_details_round_trip() {
let request_histogram = noop_meter()
Expand Down

0 comments on commit ca17875

Please sign in to comment.