Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tasks have only one aggregator, collector token #1973

Merged
merged 4 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions aggregator/src/aggregator/aggregate_init_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async fn setup_aggregate_init_test_for_vdaf<
vdaf_instance,
aggregation_param,
measurement,
AuthenticationToken::Bearer(random()),
)
.await;

Expand Down Expand Up @@ -203,10 +204,13 @@ async fn setup_aggregate_init_test_without_sending_request<
vdaf_instance: VdafInstance,
aggregation_param: V::AggregationParam,
measurement: V::Measurement,
auth_token: AuthenticationToken,
) -> AggregationJobInitTestCase<VERIFY_KEY_SIZE, V> {
install_test_trace_subscriber();

let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper).build();
let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper)
.with_aggregator_auth_token(Some(auth_token))
.build();
let clock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
Expand Down Expand Up @@ -258,11 +262,12 @@ pub(crate) async fn put_aggregation_job(
aggregation_job: &AggregationJobInitializeReq<TimeInterval>,
handler: &impl Handler,
) -> TestConn {
let (header, value) = task
.aggregator_auth_token()
.unwrap()
.request_authentication();
put(task.aggregation_job_uri(aggregation_job_id).unwrap().path())
.with_request_header(
DAP_AUTH_HEADER,
task.primary_aggregator_auth_token().as_ref().to_owned(),
)
.with_request_header(header, value)
.with_request_header(
KnownHeaderName::ContentType,
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand All @@ -279,14 +284,13 @@ async fn aggregation_job_init_authorization_dap_auth_token() {
VdafInstance::Fake,
dummy_vdaf::AggregationParam(0),
(),
AuthenticationToken::DapAuth(random()),
)
.await;
// Find a DapAuthToken among the task's aggregator auth tokens

let (auth_header, auth_value) = test_case
.task
.aggregator_auth_tokens()
.iter()
.find(|auth| matches!(auth, AuthenticationToken::DapAuth(_)))
.aggregator_auth_token()
.unwrap()
.request_authentication();

Expand Down Expand Up @@ -317,6 +321,7 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu
VdafInstance::Fake,
dummy_vdaf::AggregationParam(0),
(),
AuthenticationToken::Bearer(random()),
)
.await;

Expand All @@ -333,7 +338,8 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu
DAP_AUTH_HEADER,
test_case
.task
.primary_aggregator_auth_token()
.aggregator_auth_token()
.unwrap()
.as_ref()
.to_owned(),
)
Expand Down Expand Up @@ -480,19 +486,18 @@ async fn aggregation_job_init_wrong_query() {
test_case.prepare_inits,
);

let (header, value) = test_case
.task
.aggregator_auth_token()
.unwrap()
.request_authentication();

let mut response = put(test_case
.task
.aggregation_job_uri(&random())
.unwrap()
.path())
.with_request_header(
DAP_AUTH_HEADER,
test_case
.task
.primary_aggregator_auth_token()
.as_ref()
.to_owned(),
)
.with_request_header(header, value)
.with_request_header(
KnownHeaderName::ContentType,
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down
9 changes: 5 additions & 4 deletions aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,12 @@ pub mod test_util {
request: &AggregationJobContinueReq,
handler: &impl Handler,
) -> TestConn {
let (header, value) = task
.aggregator_auth_token()
.unwrap()
.request_authentication();
post(task.aggregation_job_uri(aggregation_job_id).unwrap().path())
.with_request_header(
"DAP-Auth-Token",
task.primary_aggregator_auth_token().as_ref().to_owned(),
)
.with_request_header(header, value)
.with_request_header(
KnownHeaderName::ContentType,
AggregationJobContinueReq::MEDIA_TYPE,
Expand Down
81 changes: 34 additions & 47 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ impl AggregationJobDriver {
AGGREGATION_JOB_ROUTE,
AggregationJobInitializeReq::<Q>::MEDIA_TYPE,
req,
task.primary_aggregator_auth_token(),
// The only way a task wouldn't have an aggregator auth token in it is in the taskprov
// case, and Janus never acts as the leader with taskprov enabled.
task.aggregator_auth_token()
.ok_or_else(|| anyhow!("task has no aggregator auth token"))?,
&self.http_request_duration_histogram,
)
.await?;
Expand Down Expand Up @@ -507,7 +510,10 @@ impl AggregationJobDriver {
AGGREGATION_JOB_ROUTE,
AggregationJobContinueReq::MEDIA_TYPE,
req,
task.primary_aggregator_auth_token(),
// The only way a task wouldn't have an aggregator auth token in it is in the taskprov
// case, and Janus never acts as the leader with taskprov enabled.
task.aggregator_auth_token()
.ok_or_else(|| anyhow!("task has no aggregator auth token"))?,
&self.http_request_duration_histogram,
)
.await?;
Expand Down Expand Up @@ -950,7 +956,7 @@ mod tests {
},
};
use rand::random;
use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration};
use std::{borrow::Borrow, sync::Arc, time::Duration as StdDuration};
use trillium_tokio::Stopper;

#[tokio::test]
Expand Down Expand Up @@ -997,7 +1003,7 @@ mod tests {
&measurement,
);

let agg_auth_token = task.primary_aggregator_auth_token().clone();
let agg_auth_token = task.aggregator_auth_token().unwrap().clone();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -1115,17 +1121,15 @@ mod tests {
]);
let mocked_aggregates = join_all(helper_responses.iter().map(
|(req_method, req_content_type, resp_content_type, resp_body)| {
let (header, value) = agg_auth_token.request_authentication();
server
.mock(
req_method,
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(CONTENT_TYPE.as_str(), *req_content_type)
.with_status(200)
.with_header(CONTENT_TYPE.as_str(), resp_content_type)
Expand Down Expand Up @@ -1293,7 +1297,7 @@ mod tests {
&0,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
Expand Down Expand Up @@ -1440,17 +1444,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down Expand Up @@ -1659,7 +1661,7 @@ mod tests {
&measurement,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -1761,17 +1763,15 @@ mod tests {
message: transcript.helper_prepare_transitions[0].message.clone(),
},
)]));
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down Expand Up @@ -1912,7 +1912,7 @@ mod tests {
&0,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
Expand Down Expand Up @@ -2017,17 +2017,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<FixedSize>::MEDIA_TYPE,
Expand Down Expand Up @@ -2169,7 +2167,7 @@ mod tests {
&measurement,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -2272,17 +2270,15 @@ mod tests {
message: transcript.helper_prepare_transitions[0].message.clone(),
},
)]));
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<FixedSize>::MEDIA_TYPE,
Expand Down Expand Up @@ -2436,7 +2432,7 @@ mod tests {
&IdpfInput::from_bools(&[true]),
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -2584,17 +2580,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"POST",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE)
.match_body(leader_request.get_encoded())
.with_status(200)
Expand Down Expand Up @@ -2840,7 +2834,7 @@ mod tests {
&IdpfInput::from_bools(&[true]),
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -2973,17 +2967,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"POST",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE)
.match_body(leader_request.get_encoded())
.with_status(200)
Expand Down Expand Up @@ -3366,7 +3358,7 @@ mod tests {
)
.with_helper_aggregator_endpoint(server.url().parse().unwrap())
.build();
let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let aggregation_job_id = random();
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();

Expand Down Expand Up @@ -3474,17 +3466,15 @@ mod tests {

// Set up three error responses from our mock helper. These will cause errors in the
// leader, because the response body is empty and cannot be decoded.
let (header, value) = agg_auth_token.request_authentication();
let failure_mock = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand All @@ -3503,10 +3493,7 @@ mod tests {
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down
Loading