Skip to content

Commit

Permalink
Tasks have only one aggregator, collector token
Browse files Browse the repository at this point in the history
We've decided that each task has exactly one aggregator and collector
auth token, and rotating those means rotating the task. This commit
updates the representations of auth tokens in the database and in memory
(`aggregator_core::task::Task`) accordingly.

Several tests relied upon tasks constructed in tests having two
aggregator auth tokens, one of each supported type. Those tests are
fixed to explicitly construct tasks using `DAP-Auth-Token` tokens where
necessary.

Given that there is now a single aggregator or collector auth token per
task, we could remove the `task_aggregator_auth_tokens` and
`task_collector_auth_tokens` tables and instead add columns
`aggregator_auth_token`, `aggregator_auth_token_type`,
`collector_auth_token` and `collector_auth_token_type` to `tasks`.

I chose not to do this because validating the correctness of those
columns would be tricky. We couldn't make them all `NOT NULL`, because a
helper task doesn't have a collector auth token and a task provisioned
via taskprov won't have either token (instead it'll use tokens from the
`taskprov_*_auth_tokens` tables). So we would have to write constraints
on the `tasks` table to ensure pairs of columns are either `NULL` or
`NOT NULL` in tandem, which I think are expressed more clearly in the
independent `task_*_auth_tokens` tables. Plus, if we ever _do_ decide to
support auth token rotation, it'll be easier to do with these tables in
place.

Part of #1524, #1521
  • Loading branch information
tgeoghegan committed Sep 21, 2023
1 parent 9cb721a commit a647d83
Show file tree
Hide file tree
Showing 18 changed files with 417 additions and 482 deletions.
41 changes: 23 additions & 18 deletions aggregator/src/aggregator/aggregate_init_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async fn setup_aggregate_init_test_for_vdaf<
vdaf_instance,
aggregation_param,
measurement,
AuthenticationToken::Bearer(random()),
)
.await;

Expand Down Expand Up @@ -203,10 +204,13 @@ async fn setup_aggregate_init_test_without_sending_request<
vdaf_instance: VdafInstance,
aggregation_param: V::AggregationParam,
measurement: V::Measurement,
auth_token: AuthenticationToken,
) -> AggregationJobInitTestCase<VERIFY_KEY_SIZE, V> {
install_test_trace_subscriber();

let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper).build();
let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper)
.with_aggregator_auth_token(Some(auth_token))
.build();
let clock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
Expand Down Expand Up @@ -258,11 +262,12 @@ pub(crate) async fn put_aggregation_job(
aggregation_job: &AggregationJobInitializeReq<TimeInterval>,
handler: &impl Handler,
) -> TestConn {
let (header, value) = task
.aggregator_auth_token()
.unwrap()
.request_authentication();
put(task.aggregation_job_uri(aggregation_job_id).unwrap().path())
.with_request_header(
DAP_AUTH_HEADER,
task.primary_aggregator_auth_token().as_ref().to_owned(),
)
.with_request_header(header, value)
.with_request_header(
KnownHeaderName::ContentType,
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand All @@ -279,14 +284,13 @@ async fn aggregation_job_init_authorization_dap_auth_token() {
VdafInstance::Fake,
dummy_vdaf::AggregationParam(0),
(),
AuthenticationToken::DapAuth(random()),
)
.await;
// Find a DapAuthToken among the task's aggregator auth tokens

let (auth_header, auth_value) = test_case
.task
.aggregator_auth_tokens()
.iter()
.find(|auth| matches!(auth, AuthenticationToken::DapAuth(_)))
.aggregator_auth_token()
.unwrap()
.request_authentication();

Expand Down Expand Up @@ -317,6 +321,7 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu
VdafInstance::Fake,
dummy_vdaf::AggregationParam(0),
(),
AuthenticationToken::Bearer(random()),
)
.await;

Expand All @@ -333,7 +338,8 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu
DAP_AUTH_HEADER,
test_case
.task
.primary_aggregator_auth_token()
.aggregator_auth_token()
.unwrap()
.as_ref()
.to_owned(),
)
Expand Down Expand Up @@ -480,19 +486,18 @@ async fn aggregation_job_init_wrong_query() {
test_case.prepare_inits,
);

let (header, value) = test_case
.task
.aggregator_auth_token()
.unwrap()
.request_authentication();

let mut response = put(test_case
.task
.aggregation_job_uri(&random())
.unwrap()
.path())
.with_request_header(
DAP_AUTH_HEADER,
test_case
.task
.primary_aggregator_auth_token()
.as_ref()
.to_owned(),
)
.with_request_header(header, value)
.with_request_header(
KnownHeaderName::ContentType,
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down
9 changes: 5 additions & 4 deletions aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,12 @@ pub mod test_util {
request: &AggregationJobContinueReq,
handler: &impl Handler,
) -> TestConn {
let (header, value) = task
.aggregator_auth_token()
.unwrap()
.request_authentication();
post(task.aggregation_job_uri(aggregation_job_id).unwrap().path())
.with_request_header(
"DAP-Auth-Token",
task.primary_aggregator_auth_token().as_ref().to_owned(),
)
.with_request_header(header, value)
.with_request_header(
KnownHeaderName::ContentType,
AggregationJobContinueReq::MEDIA_TYPE,
Expand Down
81 changes: 34 additions & 47 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ impl AggregationJobDriver {
AGGREGATION_JOB_ROUTE,
AggregationJobInitializeReq::<Q>::MEDIA_TYPE,
req,
task.primary_aggregator_auth_token(),
// The only way a task wouldn't have an aggregator auth token in it is in the taskprov
// case, and Janus never acts as the leader with taskprov enabled.
task.aggregator_auth_token()
.ok_or_else(|| anyhow!("task has no aggregator auth token"))?,
&self.http_request_duration_histogram,
)
.await?;
Expand Down Expand Up @@ -507,7 +510,10 @@ impl AggregationJobDriver {
AGGREGATION_JOB_ROUTE,
AggregationJobContinueReq::MEDIA_TYPE,
req,
task.primary_aggregator_auth_token(),
// The only way a task wouldn't have an aggregator auth token in it is in the taskprov
// case, and Janus never acts as the leader with taskprov enabled.
task.aggregator_auth_token()
.ok_or_else(|| anyhow!("task has no aggregator auth token"))?,
&self.http_request_duration_histogram,
)
.await?;
Expand Down Expand Up @@ -950,7 +956,7 @@ mod tests {
},
};
use rand::random;
use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration};
use std::{borrow::Borrow, sync::Arc, time::Duration as StdDuration};
use trillium_tokio::Stopper;

#[tokio::test]
Expand Down Expand Up @@ -997,7 +1003,7 @@ mod tests {
&measurement,
);

let agg_auth_token = task.primary_aggregator_auth_token().clone();
let agg_auth_token = task.aggregator_auth_token().unwrap().clone();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -1115,17 +1121,15 @@ mod tests {
]);
let mocked_aggregates = join_all(helper_responses.iter().map(
|(req_method, req_content_type, resp_content_type, resp_body)| {
let (header, value) = agg_auth_token.request_authentication();
server
.mock(
req_method,
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(CONTENT_TYPE.as_str(), *req_content_type)
.with_status(200)
.with_header(CONTENT_TYPE.as_str(), resp_content_type)
Expand Down Expand Up @@ -1293,7 +1297,7 @@ mod tests {
&0,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
Expand Down Expand Up @@ -1440,17 +1444,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down Expand Up @@ -1659,7 +1661,7 @@ mod tests {
&measurement,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -1761,17 +1763,15 @@ mod tests {
message: transcript.helper_prepare_transitions[0].message.clone(),
},
)]));
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down Expand Up @@ -1912,7 +1912,7 @@ mod tests {
&0,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
Expand Down Expand Up @@ -2017,17 +2017,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<FixedSize>::MEDIA_TYPE,
Expand Down Expand Up @@ -2169,7 +2167,7 @@ mod tests {
&measurement,
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -2272,17 +2270,15 @@ mod tests {
message: transcript.helper_prepare_transitions[0].message.clone(),
},
)]));
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<FixedSize>::MEDIA_TYPE,
Expand Down Expand Up @@ -2436,7 +2432,7 @@ mod tests {
&IdpfInput::from_bools(&[true]),
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -2584,17 +2580,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"POST",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE)
.match_body(leader_request.get_encoded())
.with_status(200)
Expand Down Expand Up @@ -2840,7 +2834,7 @@ mod tests {
&IdpfInput::from_bools(&[true]),
);

let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Poplar1<XofShake128, 16>>(
*task.id(),
Expand Down Expand Up @@ -2973,17 +2967,15 @@ mod tests {
.with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}")
.create_async()
.await;
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_success = server
.mock(
"POST",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE)
.match_body(leader_request.get_encoded())
.with_status(200)
Expand Down Expand Up @@ -3366,7 +3358,7 @@ mod tests {
)
.with_helper_aggregator_endpoint(server.url().parse().unwrap())
.build();
let agg_auth_token = task.primary_aggregator_auth_token();
let agg_auth_token = task.aggregator_auth_token().unwrap();
let aggregation_job_id = random();
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();

Expand Down Expand Up @@ -3474,17 +3466,15 @@ mod tests {

// Set up three error responses from our mock helper. These will cause errors in the
// leader, because the response body is empty and cannot be decoded.
let (header, value) = agg_auth_token.request_authentication();
let failure_mock = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand All @@ -3503,10 +3493,7 @@ mod tests {
.unwrap()
.path(),
)
.match_header(
"DAP-Auth-Token",
str::from_utf8(agg_auth_token.as_ref()).unwrap(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
Expand Down
Loading

0 comments on commit a647d83

Please sign in to comment.