diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index a25ed9874..3487c0ba1 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -169,6 +169,7 @@ async fn setup_aggregate_init_test_for_vdaf< vdaf_instance, aggregation_param, measurement, + AuthenticationToken::Bearer(random()), ) .await; @@ -203,10 +204,13 @@ async fn setup_aggregate_init_test_without_sending_request< vdaf_instance: VdafInstance, aggregation_param: V::AggregationParam, measurement: V::Measurement, + auth_token: AuthenticationToken, ) -> AggregationJobInitTestCase { install_test_trace_subscriber(); - let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper).build(); + let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper) + .with_aggregator_auth_token(Some(auth_token)) + .build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); @@ -258,11 +262,12 @@ pub(crate) async fn put_aggregation_job( aggregation_job: &AggregationJobInitializeReq, handler: &impl Handler, ) -> TestConn { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); put(task.aggregation_job_uri(aggregation_job_id).unwrap().path()) - .with_request_header( - DAP_AUTH_HEADER, - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregationJobInitializeReq::::MEDIA_TYPE, @@ -279,14 +284,13 @@ async fn aggregation_job_init_authorization_dap_auth_token() { VdafInstance::Fake, dummy_vdaf::AggregationParam(0), (), + AuthenticationToken::DapAuth(random()), ) .await; - // Find a DapAuthToken among the task's aggregator auth tokens + let (auth_header, auth_value) = test_case .task - .aggregator_auth_tokens() - .iter() - .find(|auth| matches!(auth, AuthenticationToken::DapAuth(_))) + .aggregator_auth_token() .unwrap() .request_authentication(); @@ -317,6 +321,7 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu VdafInstance::Fake, dummy_vdaf::AggregationParam(0), (), + AuthenticationToken::Bearer(random()), ) .await; @@ -333,7 +338,8 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu DAP_AUTH_HEADER, test_case .task - .primary_aggregator_auth_token() + .aggregator_auth_token() + .unwrap() .as_ref() .to_owned(), ) @@ -480,19 +486,18 @@ async fn aggregation_job_init_wrong_query() { test_case.prepare_inits, ); + let (header, value) = test_case + .task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + let mut response = put(test_case .task .aggregation_job_uri(&random()) .unwrap() .path()) - .with_request_header( - DAP_AUTH_HEADER, - test_case - .task - .primary_aggregator_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregationJobInitializeReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 284b776cc..aff1a5e5c 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -302,11 +302,12 @@ pub mod test_util { request: &AggregationJobContinueReq, handler: &impl Handler, ) -> TestConn { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); post(task.aggregation_job_uri(aggregation_job_id).unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregationJobContinueReq::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 1fa4c414f..af0e8730a 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -417,7 +417,10 @@ impl AggregationJobDriver { AGGREGATION_JOB_ROUTE, AggregationJobInitializeReq::::MEDIA_TYPE, req, - task.primary_aggregator_auth_token(), + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| anyhow!("task has no aggregator auth token"))?, &self.http_request_duration_histogram, ) .await?; @@ -507,7 +510,10 @@ impl AggregationJobDriver { AGGREGATION_JOB_ROUTE, AggregationJobContinueReq::MEDIA_TYPE, req, - task.primary_aggregator_auth_token(), + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| anyhow!("task has no aggregator auth token"))?, &self.http_request_duration_histogram, ) .await?; @@ -950,7 +956,7 @@ mod tests { }, }; use rand::random; - use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration}; + use std::{borrow::Borrow, sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; #[tokio::test] @@ -997,7 +1003,7 @@ mod tests { &measurement, ); - let agg_auth_token = task.primary_aggregator_auth_token().clone(); + let agg_auth_token = task.aggregator_auth_token().unwrap().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -1115,6 +1121,7 @@ mod tests { ]); let mocked_aggregates = join_all(helper_responses.iter().map( |(req_method, req_content_type, resp_content_type, resp_body)| { + let (header, value) = agg_auth_token.request_authentication(); server .mock( req_method, @@ -1122,10 +1129,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), *req_content_type) .with_status(200) .with_header(CONTENT_TYPE.as_str(), resp_content_type) @@ -1293,7 +1297,7 @@ mod tests { &0, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::( *task.id(), @@ -1440,6 +1444,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -1447,10 +1452,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -1659,7 +1661,7 @@ mod tests { &measurement, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -1761,6 +1763,7 @@ mod tests { message: transcript.helper_prepare_transitions[0].message.clone(), }, )])); + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -1768,10 +1771,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -1912,7 +1912,7 @@ mod tests { &0, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::( *task.id(), @@ -2017,6 +2017,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -2024,10 +2025,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -2169,7 +2167,7 @@ mod tests { &measurement, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -2272,6 +2270,7 @@ mod tests { message: transcript.helper_prepare_transitions[0].message.clone(), }, )])); + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -2279,10 +2278,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -2436,7 +2432,7 @@ mod tests { &IdpfInput::from_bools(&[true]), ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -2584,6 +2580,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "POST", @@ -2591,10 +2588,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) .match_body(leader_request.get_encoded()) .with_status(200) @@ -2840,7 +2834,7 @@ mod tests { &IdpfInput::from_bools(&[true]), ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -2973,6 +2967,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "POST", @@ -2980,10 +2975,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) .match_body(leader_request.get_encoded()) .with_status(200) @@ -3366,7 +3358,7 @@ mod tests { ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let aggregation_job_id = random(); let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); @@ -3474,6 +3466,7 @@ mod tests { // Set up three error responses from our mock helper. These will cause errors in the // leader, because the response body is empty and cannot be decoded. + let (header, value) = agg_auth_token.request_authentication(); let failure_mock = server .mock( "PUT", @@ -3481,10 +3474,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -3503,10 +3493,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index c12a470a6..50b8bab7e 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -228,7 +228,10 @@ impl CollectionJobDriver { AGGREGATE_SHARES_ROUTE, AggregateShareReq::::MEDIA_TYPE, req, - task.primary_aggregator_auth_token(), + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| Error::InvalidConfiguration("no aggregator auth token in task"))?, &self.metrics.http_request_duration_histogram, ) .await?; @@ -558,7 +561,7 @@ mod tests { }; use prio::codec::{Decode, Encode}; use rand::random; - use std::{str, sync::Arc, time::Duration as StdDuration}; + use std::{sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; async fn setup_collection_job_test_case( @@ -716,7 +719,7 @@ mod tests { .with_time_precision(time_precision) .with_min_batch_size(10) .build(); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let batch_interval = Interval::new(clock.now(), Duration::from_seconds(2000)).unwrap(); let aggregation_param = AggregationParam(0); let report_timestamp = clock @@ -874,12 +877,10 @@ mod tests { ); // Simulate helper failing to service the aggregate share request. + let (header, value) = agg_auth_token.request_authentication(); let mocked_failed_aggregate_share = server .mock("POST", task.aggregate_shares_uri().unwrap().path()) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregateShareReq::::MEDIA_TYPE, @@ -935,12 +936,10 @@ mod tests { Vec::new(), )); + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_share = server .mock("POST", task.aggregate_shares_uri().unwrap().path()) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregateShareReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index cdb4d862d..a190200dd 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -91,7 +91,7 @@ impl CollectionJobTestCase { self.put_collection_job_with_auth_token( collection_job_id, request, - Some(self.task.primary_collector_auth_token()), + self.task.collector_auth_token(), ) .await } @@ -120,7 +120,7 @@ impl CollectionJobTestCase { ) -> TestConn { self.post_collection_job_with_auth_token( collection_job_id, - Some(self.task.primary_collector_auth_token()), + self.task.collector_auth_token(), ) .await } diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index f7fa63540..6581bf7a9 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -1435,11 +1435,14 @@ mod tests { async fn aggregate_wrong_agg_auth_token() { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; + let dap_auth_token = AuthenticationToken::DapAuth(random()); + let task = TaskBuilder::new( QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, ) + .with_aggregator_auth_token(Some(dap_auth_token.clone())) .build(); datastore.put_task(&task).await.unwrap(); @@ -1452,14 +1455,10 @@ mod tests { let wrong_token_value = random(); - // Send the right token, but the wrong format: we find a DapAuth token in the task's - // aggregator tokens and convert it to an equivalent Bearer token, which should be rejected. - let wrong_token_format = task - .aggregator_auth_tokens() - .iter() - .find(|token| matches!(token, AuthenticationToken::DapAuth(_))) - .map(|token| AuthenticationToken::new_bearer_token_from_bytes(token.as_ref()).unwrap()) - .unwrap(); + // Send the right token, but the wrong format: convert the DAP auth token to an equivalent + // Bearer token, which should be rejected. + let wrong_token_format = + AuthenticationToken::new_bearer_token_from_bytes(dap_auth_token.as_ref()).unwrap(); for auth_token in [Some(wrong_token_value), Some(wrong_token_format), None] { let mut test_conn = put(task @@ -4158,11 +4157,12 @@ mod tests { dummy_vdaf::AggregationParam::default().get_encoded(), ); + let (header, value) = task + .collector_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = put(task.collection_job_uri(&collection_job_id).unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_collector_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, CollectionReq::::MEDIA_TYPE, @@ -4221,7 +4221,7 @@ mod tests { .put_collection_job_with_auth_token( &collection_job_id, &req, - Some(test_case.task.primary_aggregator_auth_token()), + test_case.task.aggregator_auth_token(), ) .await; @@ -4298,7 +4298,7 @@ mod tests { let mut test_conn = test_case .post_collection_job_with_auth_token( &collection_job_id, - Some(test_case.task.primary_aggregator_auth_token()), + test_case.task.aggregator_auth_token(), ) .await; @@ -4526,18 +4526,16 @@ mod tests { let no_such_collection_job_id: CollectionJobId = random(); + let (header, value) = test_case + .task + .collector_auth_token() + .unwrap() + .request_authentication(); let test_conn = post(&format!( "/tasks/{}/collection_jobs/{no_such_collection_job_id}", test_case.task.id() )) - .with_request_header( - "DAP-Auth-Token", - test_case - .task - .primary_collector_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value) .run_async(&test_case.handler) .await; assert_eq!(test_conn.status(), Some(Status::NotFound)); @@ -4694,6 +4692,12 @@ mod tests { let collection_job_id: CollectionJobId = random(); + let (header, value) = test_case + .task + .collector_auth_token() + .unwrap() + .request_authentication(); + // Try to delete a collection job that doesn't exist let test_conn = delete( test_case @@ -4702,14 +4706,7 @@ mod tests { .unwrap() .path(), ) - .with_request_header( - "DAP-Auth-Token", - test_case - .task - .primary_collector_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value.clone()) .run_async(&test_case.handler) .await; assert_eq!(test_conn.status(), Some(Status::NotFound)); @@ -4734,14 +4731,7 @@ mod tests { .unwrap() .path(), ) - .with_request_header( - "DAP-Auth-Token", - test_case - .task - .primary_collector_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value) .run_async(&test_case.handler) .await; assert_eq!(test_conn.status(), Some(Status::NoContent)); @@ -4769,11 +4759,13 @@ mod tests { ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -4819,13 +4811,15 @@ mod tests { ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + // Test that a request for an invalid batch fails. (Specifically, the batch interval is too // small.) let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value.clone()) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -4847,10 +4841,7 @@ mod tests { // Test that a request for a too-old batch fails. let test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -4896,11 +4887,13 @@ mod tests { ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5081,11 +5074,12 @@ mod tests { 5, ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5134,11 +5128,12 @@ mod tests { ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(), ), ] { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5199,11 +5194,12 @@ mod tests { // Request the aggregate share multiple times. If the request parameters don't change, // then there is no query count violation and all requests should succeed. for iteration in 0..3 { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5266,11 +5262,12 @@ mod tests { 20, ReportIdChecksum::get_decoded(&[8 ^ 4 ^ 3 ^ 2; 32]).unwrap(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5317,11 +5314,12 @@ mod tests { ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(), ), ] { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index 3ebf692cb..a5ac05b41 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -718,7 +718,7 @@ mod tests { vdaf: !Prio3Sum bits: 2 role: Leader - vdaf_verify_keys: + vdaf_verify_key: max_batch_query_count: 1 task_expiration: 9000000000 min_batch_size: 10 @@ -731,8 +731,8 @@ mod tests { kdf_id: HkdfSha256 aead_id: Aes128Gcm public_key: 8lAqZ7OfNV2Gi_9cNE6J9WRmPbO-k1UPtu2Bztd0-yc - aggregator_auth_tokens: [] - collector_auth_tokens: [] + aggregator_auth_token: + collector_auth_token: hpke_keys: [] - leader_aggregator_endpoint: https://leader helper_aggregator_endpoint: https://helper @@ -740,7 +740,7 @@ mod tests { vdaf: !Prio3Sum bits: 2 role: Helper - vdaf_verify_keys: + vdaf_verify_key: max_batch_query_count: 1 task_expiration: 9000000000 min_batch_size: 10 @@ -753,8 +753,8 @@ mod tests { kdf_id: HkdfSha256 aead_id: Aes128Gcm public_key: 8lAqZ7OfNV2Gi_9cNE6J9WRmPbO-k1UPtu2Bztd0-yc - aggregator_auth_tokens: [] - collector_auth_tokens: [] + aggregator_auth_token: + collector_auth_token: hpke_keys: [] "#; @@ -800,8 +800,8 @@ mod tests { for task in &got_tasks { match task.role() { - Role::Leader => assert_eq!(task.collector_auth_tokens().len(), 1), - Role::Helper => assert!(task.collector_auth_tokens().is_empty()), + Role::Leader => assert!(task.collector_auth_token().is_some()), + Role::Helper => assert!(task.collector_auth_token().is_none()), role => panic!("unexpected role {role}"), } } diff --git a/aggregator_api/src/models.rs b/aggregator_api/src/models.rs index 7415a0e6c..d7e1625af 100644 --- a/aggregator_api/src/models.rs +++ b/aggregator_api/src/models.rs @@ -108,13 +108,12 @@ pub(crate) struct TaskResp { /// How much clock skew to allow between client and aggregator. Reports from /// farther than this duration into the future will be rejected. pub(crate) tolerable_clock_skew: Duration, - /// The authentication token for inter-aggregator communication in this task. - /// If `role` is Leader, this token is used by the aggregator to authenticate requests to - /// the Helper. If `role` is Helper, this token is used by the aggregator to authenticate - /// requests from the Leader. + /// The authentication token for inter-aggregator communication in this task. If `role` is + /// Helper, this token is used by the aggregator to authenticate requests from the Leader. Not + /// set if `role` is Leader.. // TODO(#1509): This field will have to change as Janus helpers will only store a salted // hash of aggregator auth tokens. - pub(crate) aggregator_auth_token: AuthenticationToken, + pub(crate) aggregator_auth_token: Option, /// The authentication token used by the task's Collector to authenticate to the Leader. /// `Some` if `role` is Leader, `None` otherwise. // TODO(#1509) This field will have to change as Janus leaders will only store a salted hash @@ -143,21 +142,6 @@ impl TryFrom<&Task> for TaskResp { } .clone(); - if task.aggregator_auth_tokens().len() != 1 { - return Err("illegal number of aggregator auth tokens in task"); - } - - let collector_auth_token = match task.role() { - Role::Leader => { - if task.collector_auth_tokens().len() != 1 { - return Err("illegal number of collector auth tokens in task"); - } - Some(task.primary_collector_auth_token().clone()) - } - Role::Helper => None, - _ => return Err("illegal aggregator role in task"), - }; - let mut aggregator_hpke_configs: Vec<_> = task .hpke_keys() .values() @@ -178,8 +162,8 @@ impl TryFrom<&Task> for TaskResp { min_batch_size: task.min_batch_size(), time_precision: *task.time_precision(), tolerable_clock_skew: *task.tolerable_clock_skew(), - aggregator_auth_token: task.primary_aggregator_auth_token().clone(), - collector_auth_token, + aggregator_auth_token: task.aggregator_auth_token().cloned(), + collector_auth_token: task.collector_auth_token().cloned(), collector_hpke_config: task .collector_hpke_config() .ok_or("collector_hpke_config is required")? diff --git a/aggregator_api/src/routes.rs b/aggregator_api/src/routes.rs index 5265aa6f5..0d93586be 100644 --- a/aggregator_api/src/routes.rs +++ b/aggregator_api/src/routes.rs @@ -121,7 +121,7 @@ pub(super) async fn post_task( let vdaf_verify_key = SecretBytes::new(vdaf_verify_key_bytes); - let (aggregator_auth_tokens, collector_auth_tokens) = match req.role { + let (aggregator_auth_token, collector_auth_token) = match req.role { Role::Leader => { let aggregator_auth_token = req.aggregator_auth_token.ok_or_else(|| { Error::BadRequest( @@ -129,7 +129,7 @@ pub(super) async fn post_task( .to_string(), ) })?; - (Vec::from([aggregator_auth_token]), Vec::from([random()])) + (Some(aggregator_auth_token), Some(random())) } Role::Helper => { @@ -140,7 +140,7 @@ pub(super) async fn post_task( )); } - (Vec::from([random()]), Vec::new()) + (Some(random()), None) } _ => unreachable!(), @@ -173,8 +173,8 @@ pub(super) async fn post_task( /* tolerable_clock_skew */ Duration::from_seconds(60), // 1 minute, /* collector_hpke_config */ req.collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keys, ) .map_err(|err| Error::BadRequest(format!("Error constructing task: {err}")))?, diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index 8e6c47d57..3b7e7d1b2 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -334,7 +334,8 @@ async fn post_task_helper_no_optional_fields() { assert_eq!(req.task_expiration.as_ref(), got_task.task_expiration()); assert_eq!(req.min_batch_size, got_task.min_batch_size()); assert_eq!(&req.time_precision, got_task.time_precision()); - assert_eq!(1, got_task.aggregator_auth_tokens().len()); + assert!(got_task.aggregator_auth_token().is_some()); + assert!(got_task.collector_auth_token().is_none()); assert_eq!( &req.collector_hpke_config, got_task.collector_hpke_config().unwrap() @@ -543,12 +544,11 @@ async fn post_task_leader_all_optional_fields() { &req.collector_hpke_config, got_task.collector_hpke_config().unwrap() ); - assert_eq!(1, got_task.aggregator_auth_tokens().len()); assert_eq!( aggregator_auth_token.as_ref(), - got_task.aggregator_auth_tokens()[0].as_ref() + got_task.aggregator_auth_token().unwrap().as_ref() ); - assert_eq!(1, got_task.collector_auth_tokens().len()); + assert!(got_task.collector_auth_token().is_some()); // ...and the response. assert_eq!(got_task_resp, TaskResp::try_from(&got_task).unwrap()); @@ -602,10 +602,7 @@ async fn get_task() { // Setup: write a task to the datastore. let (handler, _ephemeral_datastore, ds) = setup_api_test().await; - let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_auth_tokens(Vec::from([random()])) - .with_collector_auth_tokens(Vec::from([random()])) - .build(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader).build(); ds.run_tx(|tx| { let task = task.clone(); @@ -1758,14 +1755,14 @@ fn task_resp_serialization() { HpkeAeadId::Aes128Gcm, HpkePublicKey::from([0u8; 32].to_vec()), ), - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - "Y29sbGVjdG9yLWFiY2RlZjAw", - ) - .unwrap()]), - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - "Y29sbGVjdG9yLWFiY2RlZjAw", - ) - .unwrap()]), + Some( + AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw") + .unwrap(), + ), + Some( + AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw") + .unwrap(), + ), [(HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(13), @@ -1831,6 +1828,7 @@ fn task_resp_serialization() { Token::NewtypeStruct { name: "Duration" }, Token::U64(60), Token::Str("aggregator_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 480f8d7e7..7ecb63448 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -14,7 +14,7 @@ use crate::{ SecretBytes, }; use chrono::NaiveDateTime; -use futures::future::try_join_all; +use futures::future::{try_join_all, Either}; use janus_core::{ hpke::{HpkeKeypair, HpkePrivateKey}, task::{AuthenticationToken, VdafInstance}, @@ -584,81 +584,66 @@ impl Transaction<'_, C> { .await?, )?; - // Aggregator auth tokens. - let mut aggregator_auth_token_ords = Vec::new(); - let mut aggregator_auth_token_types = Vec::new(); - let mut aggregator_auth_tokens = Vec::new(); - for (ord, token) in task.aggregator_auth_tokens().iter().enumerate() { - let ord = i64::try_from(ord)?; - - let mut row_id = [0; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - + // Aggregator auth token. + let aggregator_auth_token_future = if let Some(token) = task.aggregator_auth_token() { let encrypted_aggregator_auth_token = self.crypter.encrypt( "task_aggregator_auth_tokens", - &row_id, + task.id().as_ref(), "token", token.as_ref(), )?; + let aggregator_auth_token_stmt = self + .prepare_cached( + "INSERT INTO task_aggregator_auth_tokens (task_id, type, token) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, $3)", + ) + .await?; - aggregator_auth_token_ords.push(ord); - aggregator_auth_token_types.push(AuthenticationTokenType::from(token)); - aggregator_auth_tokens.push(encrypted_aggregator_auth_token); - } - let stmt = self - .prepare_cached( - "INSERT INTO task_aggregator_auth_tokens (task_id, ord, type, token) - SELECT - (SELECT id FROM tasks WHERE task_id = $1), - * FROM UNNEST($2::BIGINT[], $3::AUTH_TOKEN_TYPE[], $4::BYTEA[])", - ) - .await?; - let aggregator_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id().as_ref(), - /* ords */ &aggregator_auth_token_ords, - /* token_types */ &aggregator_auth_token_types, - /* tokens */ &aggregator_auth_tokens, - ]; - let aggregator_auth_tokens_future = self.execute(&stmt, aggregator_auth_tokens_params); - - // Collector auth tokens. - let mut collector_auth_token_ords = Vec::new(); - let mut collector_auth_token_types = Vec::new(); - let mut collector_auth_tokens = Vec::new(); - for (ord, token) in task.collector_auth_tokens().iter().enumerate() { - let ord = i64::try_from(ord)?; - - let mut row_id = [0; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); + Either::Left(async move { + self.execute( + &aggregator_auth_token_stmt, + &[ + /* task_id */ &task.id().as_ref(), + /* token_type */ &AuthenticationTokenType::from(token), + /* token */ &encrypted_aggregator_auth_token, + ], + ) + .await + }) + } else { + // no-op future so we can unconditionally pass to `try_join`, below. + Either::Right(futures::future::ok(0)) + }; + // Collector auth token. + let collector_auth_token_future = if let Some(token) = task.collector_auth_token() { let encrypted_collector_auth_token = self.crypter.encrypt( "task_collector_auth_tokens", - &row_id, + task.id().as_ref(), "token", token.as_ref(), )?; - - collector_auth_token_ords.push(ord); - collector_auth_token_types.push(AuthenticationTokenType::from(token)); - collector_auth_tokens.push(encrypted_collector_auth_token); - } - let stmt = self - .prepare_cached( - "INSERT INTO task_collector_auth_tokens (task_id, ord, type, token) - SELECT - (SELECT id FROM tasks WHERE task_id = $1), - * FROM UNNEST($2::BIGINT[], $3::AUTH_TOKEN_TYPE[], $4::BYTEA[])", - ) - .await?; - let collector_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id().as_ref(), - /* ords */ &collector_auth_token_ords, - /* token_types */ &collector_auth_token_types, - /* tokens */ &collector_auth_tokens, - ]; - let collector_auth_tokens_future = self.execute(&stmt, collector_auth_tokens_params); + let collector_auth_token_stmt = self + .prepare_cached( + "INSERT INTO task_collector_auth_tokens (task_id, type, token) + VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, $3)", + ) + .await?; + Either::Left(async move { + self.execute( + &collector_auth_token_stmt, + &[ + /* task_id */ &task.id().as_ref(), + /* token_type */ &AuthenticationTokenType::from(token), + /* token */ &encrypted_collector_auth_token, + ], + ) + .await + }) + } else { + // no-op future so we can unconditionally pass to `try_join`, below. + Either::Right(futures::future::ok(0)) + }; // HPKE keys. let mut hpke_config_ids: Vec = Vec::new(); @@ -698,9 +683,9 @@ impl Transaction<'_, C> { let hpke_configs_future = self.execute(&stmt, hpke_configs_params); try_join!( - aggregator_auth_tokens_future, - collector_auth_tokens_future, - hpke_configs_future, + aggregator_auth_token_future, + collector_auth_token_future, + hpke_configs_future )?; Ok(()) @@ -738,19 +723,19 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT ord, type, token FROM task_aggregator_auth_tokens - WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1) ORDER BY ord ASC", + "SELECT type, token FROM task_aggregator_auth_tokens + WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1)", ) .await?; - let aggregator_auth_token_rows = self.query(&stmt, params); + let aggregator_auth_token_row = self.query_opt(&stmt, params); let stmt = self .prepare_cached( - "SELECT ord, type, token FROM task_collector_auth_tokens - WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1) ORDER BY ord ASC", + "SELECT type, token FROM task_collector_auth_tokens + WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1)", ) .await?; - let collector_auth_token_rows = self.query(&stmt, params); + let collector_auth_token_row = self.query_opt(&stmt, params); let stmt = self .prepare_cached( @@ -760,10 +745,10 @@ impl Transaction<'_, C> { .await?; let hpke_key_rows = self.query(&stmt, params); - let (task_row, aggregator_auth_token_rows, collector_auth_token_rows, hpke_key_rows) = try_join!( + let (task_row, aggregator_auth_token_row, collector_auth_token_row, hpke_key_rows) = try_join!( task_row, - aggregator_auth_token_rows, - collector_auth_token_rows, + aggregator_auth_token_row, + collector_auth_token_row, hpke_key_rows, )?; task_row @@ -771,8 +756,8 @@ impl Transaction<'_, C> { self.task_from_rows( task_id, &task_row, - &aggregator_auth_token_rows, - &collector_auth_token_rows, + aggregator_auth_token_row.as_ref(), + collector_auth_token_row.as_ref(), &hpke_key_rows, ) }) @@ -797,7 +782,7 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_aggregator_auth_tokens.task_id), - ord, type, token FROM task_aggregator_auth_tokens ORDER BY ord ASC", + type, token FROM task_aggregator_auth_tokens", ) .await?; let aggregator_auth_token_rows = self.query(&stmt, &[]); @@ -806,7 +791,7 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_collector_auth_tokens.task_id), - ord, type, token FROM task_collector_auth_tokens ORDER BY ord ASC", + type, token FROM task_collector_auth_tokens", ) .await?; let collector_auth_token_rows = self.query(&stmt, &[]); @@ -833,22 +818,16 @@ impl Transaction<'_, C> { task_row_by_id.push((task_id, row)); } - let mut aggregator_auth_token_rows_by_task_id: HashMap> = HashMap::new(); + let mut aggregator_auth_token_rows_by_task_id: HashMap = HashMap::new(); for row in aggregator_auth_token_rows { let task_id = TaskId::get_decoded(row.get("task_id"))?; - aggregator_auth_token_rows_by_task_id - .entry(task_id) - .or_default() - .push(row); + aggregator_auth_token_rows_by_task_id.insert(task_id, row); } - let mut collector_auth_token_rows_by_task_id: HashMap> = HashMap::new(); + let mut collector_auth_token_rows_by_task_id: HashMap = HashMap::new(); for row in collector_auth_token_rows { let task_id = TaskId::get_decoded(row.get("task_id"))?; - collector_auth_token_rows_by_task_id - .entry(task_id) - .or_default() - .push(row); + collector_auth_token_rows_by_task_id.insert(task_id, row); } let mut hpke_config_rows_by_task_id: HashMap> = HashMap::new(); @@ -866,12 +845,12 @@ impl Transaction<'_, C> { self.task_from_rows( &task_id, &row, - &aggregator_auth_token_rows_by_task_id + aggregator_auth_token_rows_by_task_id .remove(&task_id) - .unwrap_or_default(), - &collector_auth_token_rows_by_task_id + .as_ref(), + collector_auth_token_rows_by_task_id .remove(&task_id) - .unwrap_or_default(), + .as_ref(), &hpke_config_rows_by_task_id .remove(&task_id) .unwrap_or_default(), @@ -882,14 +861,12 @@ impl Transaction<'_, C> { /// Construct a [`Task`] from the contents of the provided (tasks) `Row`, /// `hpke_aggregator_auth_tokens` rows, and `task_hpke_keys` rows. - /// - /// agg_auth_token_rows must be sorted in ascending order by `ord`. fn task_from_rows( &self, task_id: &TaskId, row: &Row, - aggregator_auth_token_rows: &[Row], - collector_auth_token_rows: &[Row], + aggregator_auth_token_row: Option<&Row>, + collector_auth_token_row: Option<&Row>, hpke_key_rows: &[Row], ) -> Result { // Scalar task parameters. @@ -927,47 +904,33 @@ impl Transaction<'_, C> { ) .map(SecretBytes::new)?; - // Aggregator authentication tokens. - let mut aggregator_auth_tokens = Vec::new(); - for row in aggregator_auth_token_rows { - let ord: i64 = row.get("ord"); - let auth_token_type: AuthenticationTokenType = row.get("type"); - let encrypted_aggregator_auth_token: Vec = row.get("token"); + let aggregator_auth_token = if let Some(row) = aggregator_auth_token_row { + let auth_token_type: AuthenticationTokenType = row.try_get("type")?; + let encrypted_aggregator_auth_token: Vec = row.try_get("token")?; - let mut row_id = [0u8; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task_id.as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - - aggregator_auth_tokens.push(auth_token_type.as_authentication( - &self.crypter.decrypt( - "task_aggregator_auth_tokens", - &row_id, - "token", - &encrypted_aggregator_auth_token, - )?, - )?); - } + Some(auth_token_type.as_authentication(&self.crypter.decrypt( + "task_aggregator_auth_tokens", + task_id.as_ref(), + "token", + &encrypted_aggregator_auth_token, + )?)?) + } else { + None + }; - // Collector authentication tokens. - let mut collector_auth_tokens = Vec::new(); - for row in collector_auth_token_rows { - let ord: i64 = row.get("ord"); - let auth_token_type: AuthenticationTokenType = row.get("type"); - let encrypted_collector_auth_token: Vec = row.get("token"); + let collector_auth_token = if let Some(row) = collector_auth_token_row { + let auth_token_type: AuthenticationTokenType = row.try_get("type")?; + let encrypted_collector_auth_token: Vec = row.try_get("token")?; - let mut row_id = [0u8; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task_id.as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - - collector_auth_tokens.push(auth_token_type.as_authentication( - &self.crypter.decrypt( - "task_collector_auth_tokens", - &row_id, - "token", - &encrypted_collector_auth_token, - )?, - )?); - } + Some(auth_token_type.as_authentication(&self.crypter.decrypt( + "task_collector_auth_tokens", + task_id.as_ref(), + "token", + &encrypted_collector_auth_token, + )?)?) + } else { + None + }; // HPKE keys. let mut hpke_keypairs = Vec::new(); @@ -1005,8 +968,8 @@ impl Transaction<'_, C> { time_precision, tolerable_clock_skew, collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keypairs, ); // Trial validation through all known schemes. This is a workaround to avoid extending the diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 7e7318cc9..63d2fe180 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -129,10 +129,12 @@ pub struct Task { tolerable_clock_skew: Duration, /// HPKE configuration for the collector. collector_hpke_config: Option, - /// Tokens used to authenticate messages sent to or received from the other aggregator. - aggregator_auth_tokens: Vec, - /// Tokens used to authenticate messages sent to or received from the collector. - collector_auth_tokens: Vec, + /// Token used to authenticate messages sent to or received from the other aggregator. Only set + /// if the task was not created via taskprov. + aggregator_auth_token: Option, + /// Token used to authenticate messages sent to received from the collector. Only set if this + /// aggregator is the leader. + collector_auth_token: Option, /// HPKE configurations & private keys used by this aggregator to decrypt client reports. hpke_keys: HashMap, } @@ -155,8 +157,8 @@ impl Task { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: HpkeConfig, - aggregator_auth_tokens: Vec, - collector_auth_tokens: Vec, + aggregator_auth_token: Option, + collector_auth_token: Option, hpke_keys: I, ) -> Result { let task = Self::new_without_validation( @@ -174,8 +176,8 @@ impl Task { time_precision, tolerable_clock_skew, Some(collector_hpke_config), - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keys, ); task.validate()?; @@ -200,8 +202,8 @@ impl Task { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: Option, - aggregator_auth_tokens: Vec, - collector_auth_tokens: Vec, + aggregator_auth_token: Option, + collector_auth_token: Option, hpke_keys: I, ) -> Self { // Compute hpke_configs mapping cfg.id -> (cfg, key). @@ -228,8 +230,8 @@ impl Task { time_precision, tolerable_clock_skew, collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keys, } } @@ -264,13 +266,13 @@ impl Task { pub(crate) fn validate(&self) -> Result<(), Error> { self.validate_common()?; - if self.aggregator_auth_tokens.is_empty() { - return Err(Error::InvalidParameter("aggregator_auth_tokens")); + if self.aggregator_auth_token.is_none() { + return Err(Error::InvalidParameter("aggregator_auth_token")); } - if (self.role == Role::Leader) == (self.collector_auth_tokens.is_empty()) { + if (self.role == Role::Leader) == (self.collector_auth_token.is_none()) { // Collector auth tokens are allowed & required if and only if this task is in the // leader role. - return Err(Error::InvalidParameter("collector_auth_tokens")); + return Err(Error::InvalidParameter("collector_auth_token")); } if self.hpke_keys.is_empty() { return Err(Error::InvalidParameter("hpke_keys")); @@ -357,14 +359,14 @@ impl Task { self.collector_hpke_config.as_ref() } - /// Retrieves the aggregator authentication tokens associated with this task. - pub fn aggregator_auth_tokens(&self) -> &[AuthenticationToken] { - &self.aggregator_auth_tokens + /// Retrieves the aggregator authentication token associated with this task. + pub fn aggregator_auth_token(&self) -> Option<&AuthenticationToken> { + self.aggregator_auth_token.as_ref() } - /// Retrieves the collector authentication tokens associated with this task. - pub fn collector_auth_tokens(&self) -> &[AuthenticationToken] { - &self.collector_auth_tokens + /// Retrieves the collector authentication token associated with this task. + pub fn collector_auth_token(&self) -> Option<&AuthenticationToken> { + self.collector_auth_token.as_ref() } /// Retrieves the HPKE keys in use associated with this task. @@ -397,35 +399,22 @@ impl Task { } } - /// Returns the [`AuthenticationToken`] currently used by this aggregator to authenticate itself - /// to other aggregators. - pub fn primary_aggregator_auth_token(&self) -> &AuthenticationToken { - self.aggregator_auth_tokens.iter().next_back().unwrap() - } - /// Checks if the given aggregator authentication token is valid (i.e. matches with an /// authentication token recognized by this task). pub fn check_aggregator_auth_token(&self, auth_token: &AuthenticationToken) -> bool { - self.aggregator_auth_tokens - .iter() - .rev() - .any(|t| t == auth_token) - } - - /// Returns the [`AuthenticationToken`] currently used by the collector to authenticate itself - /// to the aggregators. - pub fn primary_collector_auth_token(&self) -> &AuthenticationToken { - // Unwrap safety: self.collector_auth_tokens is never empty - self.collector_auth_tokens.iter().next_back().unwrap() + match self.aggregator_auth_token { + Some(ref t) => t == auth_token, + None => false, + } } /// Checks if the given collector authentication token is valid (i.e. matches with an /// authentication token recognized by this task). pub fn check_collector_auth_token(&self, auth_token: &AuthenticationToken) -> bool { - self.collector_auth_tokens - .iter() - .rev() - .any(|t| t == auth_token) + match self.collector_auth_token { + Some(ref t) => t == auth_token, + None => false, + } } /// Returns the [`VerifyKey`] used by this aggregator to prepare report shares with other @@ -495,8 +484,8 @@ pub struct SerializedTask { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: HpkeConfig, - aggregator_auth_tokens: Vec, - collector_auth_tokens: Vec, + aggregator_auth_token: Option, + collector_auth_token: Option, hpke_keys: Vec, // uses unpadded base64url } @@ -532,12 +521,12 @@ impl SerializedTask { self.vdaf_verify_key = Some(URL_SAFE_NO_PAD.encode(vdaf_verify_key.as_ref())); } - if self.aggregator_auth_tokens.is_empty() { - self.aggregator_auth_tokens = Vec::from([random()]); + if self.aggregator_auth_token.is_none() { + self.aggregator_auth_token = Some(random()); } - if self.collector_auth_tokens.is_empty() && self.role == Role::Leader { - self.collector_auth_tokens = Vec::from([random()]); + if self.collector_auth_token.is_none() && self.role == Role::Leader { + self.collector_auth_token = Some(random()); } if self.hpke_keys.is_empty() { @@ -577,8 +566,8 @@ impl Serialize for Task { .collector_hpke_config() .expect("serializable tasks must have collector_hpke_config") .clone(), - aggregator_auth_tokens: self.aggregator_auth_tokens.clone(), - collector_auth_tokens: self.collector_auth_tokens.clone(), + aggregator_auth_token: self.aggregator_auth_token.clone(), + collector_auth_token: self.collector_auth_token.clone(), hpke_keys, } .serialize(serializer) @@ -614,8 +603,8 @@ impl TryFrom for Task { serialized_task.time_precision, serialized_task.tolerable_clock_skew, serialized_task.collector_hpke_config, - serialized_task.aggregator_auth_tokens, - serialized_task.collector_auth_tokens, + serialized_task.aggregator_auth_token, + serialized_task.collector_auth_token, serialized_task.hpke_keys, ) } @@ -689,10 +678,10 @@ pub mod test_util { .collect(), ); - let collector_auth_tokens = if role == Role::Leader { - Vec::from([random(), AuthenticationToken::DapAuth(random())]) + let collector_auth_token = if role == Role::Leader { + Some(random()) // Create an AuthenticationToken::Bearer by default } else { - Vec::new() + None }; Self( @@ -711,8 +700,8 @@ pub mod test_util { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random(), AuthenticationToken::DapAuth(random())]), - collector_auth_tokens, + Some(random()), // Create an AuthenticationToken::Bearer by default + collector_auth_token, Vec::from([aggregator_keypair_0, aggregator_keypair_1]), ) .unwrap(), @@ -795,24 +784,42 @@ pub mod test_util { }) } - /// Associates the eventual task with the given aggregator authentication tokens. - pub fn with_aggregator_auth_tokens( + /// Associates the eventual task with the given aggregator authentication token. + pub fn with_aggregator_auth_token( self, - aggregator_auth_tokens: Vec, + aggregator_auth_token: Option, ) -> Self { Self(Task { - aggregator_auth_tokens, + aggregator_auth_token, + ..self.0 + }) + } + + /// Associates the eventual task with a random [`AuthenticationToken::DapAuth`] aggregator + /// auth token. + pub fn with_dap_auth_aggregator_token(self) -> Self { + Self(Task { + aggregator_auth_token: Some(AuthenticationToken::DapAuth(random())), ..self.0 }) } - /// Sets the collector authentication tokens for the task. - pub fn with_collector_auth_tokens( + /// Associates the eventual task with the given collector authentication token. + pub fn with_collector_auth_token( self, - collector_auth_tokens: Vec, + collector_auth_token: Option, ) -> Self { Self(Task { - collector_auth_tokens, + collector_auth_token, + ..self.0 + }) + } + + /// Associates the eventual task with a random [`AuthenticationToken::DapAuth`] collector + /// auth token. + pub fn with_dap_auth_collector_token(self) -> Self { + Self(Task { + collector_auth_token: Some(AuthenticationToken::DapAuth(random())), ..self.0 }) } @@ -915,8 +922,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::new(), + Some(random()), + None, Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); @@ -937,8 +944,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::from([random()]), + Some(random()), + Some(random()), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -959,8 +966,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::new(), + Some(random()), + None, Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -981,8 +988,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::from([random()]), + Some(random()), + Some(random()), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); @@ -1005,8 +1012,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::from([random()]), + Some(random()), + Some(random()), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -1088,14 +1095,14 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - "YWdncmVnYXRvciB0b2tlbg", - ) - .unwrap()]), - Vec::from([AuthenticationToken::new_bearer_token_from_string( - "Y29sbGVjdG9yIHRva2Vu", - ) - .unwrap()]), + Some( + AuthenticationToken::new_dap_auth_token_from_string("YWdncmVnYXRvciB0b2tlbg") + .unwrap(), + ), + Some( + AuthenticationToken::new_bearer_token_from_string("Y29sbGVjdG9yIHRva2Vu") + .unwrap(), + ), [HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(255), @@ -1180,8 +1187,8 @@ mod tests { Token::Str("public_key"), Token::Str("Y29sbGVjdG9yIGhwa2UgcHVibGljIGtleQ"), Token::StructEnd, - Token::Str("aggregator_auth_tokens"), - Token::Seq { len: Some(1) }, + Token::Str("aggregator_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, @@ -1194,9 +1201,8 @@ mod tests { Token::Str("token"), Token::Str("YWdncmVnYXRvciB0b2tlbg"), Token::StructEnd, - Token::SeqEnd, - Token::Str("collector_auth_tokens"), - Token::Seq { len: Some(1) }, + Token::Str("collector_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, @@ -1209,7 +1215,6 @@ mod tests { Token::Str("token"), Token::Str("Y29sbGVjdG9yIHRva2Vu"), Token::StructEnd, - Token::SeqEnd, Token::Str("hpke_keys"), Token::Seq { len: Some(1) }, Token::Struct { @@ -1277,11 +1282,11 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::new_bearer_token_from_string( - "YWdncmVnYXRvciB0b2tlbg", - ) - .unwrap()]), - Vec::new(), + Some( + AuthenticationToken::new_bearer_token_from_string("YWdncmVnYXRvciB0b2tlbg") + .unwrap(), + ), + None, [HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(255), @@ -1378,8 +1383,8 @@ mod tests { Token::Str("public_key"), Token::Str("Y29sbGVjdG9yIGhwa2UgcHVibGljIGtleQ"), Token::StructEnd, - Token::Str("aggregator_auth_tokens"), - Token::Seq { len: Some(1) }, + Token::Str("aggregator_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, @@ -1392,10 +1397,8 @@ mod tests { Token::Str("token"), Token::Str("YWdncmVnYXRvciB0b2tlbg"), Token::StructEnd, - Token::SeqEnd, - Token::Str("collector_auth_tokens"), - Token::Seq { len: Some(0) }, - Token::SeqEnd, + Token::Str("collector_auth_token"), + Token::None, Token::Str("hpke_keys"), Token::Seq { len: Some(1) }, Token::Struct { diff --git a/aggregator_core/src/taskprov.rs b/aggregator_core/src/taskprov.rs index d4989553f..ce7c4cdee 100644 --- a/aggregator_core/src/taskprov.rs +++ b/aggregator_core/src/taskprov.rs @@ -302,8 +302,8 @@ impl Task { time_precision, tolerable_clock_skew, None, - Vec::new(), - Vec::new(), + None, + None, Vec::new(), )); task.validate()?; diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index 3834cfac5..a2cfd8eb5 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -97,11 +97,10 @@ CREATE INDEX task_id_index ON tasks(task_id); CREATE TABLE task_aggregator_auth_tokens( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only task_id BIGINT NOT NULL, -- task ID the token is associated with - ord BIGINT NOT NULL, -- a value used to specify the ordering of the authentication tokens type AUTH_TOKEN_TYPE NOT NULL DEFAULT 'BEARER', -- the type of the authentication token token BYTEA NOT NULL, -- bearer token used to authenticate messages to/from the other aggregator (encrypted) - CONSTRAINT task_aggregator_auth_tokens_unique_task_id_and_ord UNIQUE(task_id, ord), + CONSTRAINT task_aggregator_auth_tokens_unique_task_id UNIQUE(task_id), CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE ); @@ -109,11 +108,10 @@ CREATE TABLE task_aggregator_auth_tokens( CREATE TABLE task_collector_auth_tokens( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only task_id BIGINT NOT NULL, -- task ID the token is associated with - ord BIGINT NOT NULL, -- a value used to specify the ordering of the authentication tokens type AUTH_TOKEN_TYPE NOT NULL DEFAULT 'BEARER', -- the type of the authentication token token BYTEA NOT NULL, -- bearer token used to authenticate messages from the collector (encrypted) - CONSTRAINT task_collector_auth_tokens_unique_task_id_and_ord UNIQUE(task_id, ord), + CONSTRAINT task_collector_auth_tokens_unique_task_id UNIQUE(task_id), CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE ); diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index c4b49efba..88d94a05a 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -59,31 +59,29 @@ aead_id: Aes128Gcm public_key: 4qiv6IY5jrjCV3xbaQXULmPIpvoIml1oJmeXm-yOuAo - # Authentication tokens shared beteween the aggregators, and used to + # Authentication token shared beteween the aggregators, and used to # authenticate leader-to-helper requests. In the case of a leader-role task, - # the leader will include the first token in a header when making requests to - # the helper. In the case of a helper-role task, the helper will accept - # requests with any of the listed authentication tokens. + # the leader will include the token in a header when making requests to the + # helper. In the case of a helper-role task, the helper will accept requests + # requests with authentication tokens. # # Each token's `type` governs how it is inserted into HTTP requests if used by # the leader to authenticate a request to the helper. - aggregator_auth_tokens: + aggregator_auth_token: # DAP-Auth-Token values are encoded in unpadded base64url, and the decoded # value is sent in an HTTP header. For example, this token's value decodes # to "aggregator-235242f99406c4fd28b820c32eab0f68". - - type: "DapAuth" + type: "DapAuth" token: "YWdncmVnYXRvci0yMzUyNDJmOTk0MDZjNGZkMjhiODIwYzMyZWFiMGY2OA" - # Bearer token values are encoded in unpadded base64url. - - type: "Bearer" - token: "YWdncmVnYXRvci04NDc1NjkwZjJmYzQzMDBmYjE0NmJiMjk1NDIzNDk1NA" - # Authentication tokens shared between the leader and the collector, and used + # Authentication token shared between the leader and the collector, and used # to authenticate collector-to-leader requests. For leader tasks, this has the # same format as `aggregator_auth_tokens` above. For helper tasks, this will # be an empty list instead. + # Bearer token values are encoded in unpadded base64url. # This example decodes to "collector-abf5408e2b1601831625af3959106458". - collector_auth_tokens: - - type: "Bearer" + collector_auth_token: + type: "Bearer" token: "Y29sbGVjdG9yLWFiZjU0MDhlMmIxNjAxODMxNjI1YWYzOTU5MTA2NDU4" # This aggregator's HPKE keypairs. The first keypair's HPKE configuration will @@ -122,12 +120,12 @@ kdf_id: HkdfSha256 aead_id: Aes128Gcm public_key: KHRLcWgfWxli8cdOLPsgsZPttHXh0ho3vLVLrW-63lE - aggregator_auth_tokens: - - type: "Bearer" + aggregator_auth_token: + type: "Bearer" token: "YWdncmVnYXRvci1jZmE4NDMyZjdkMzllMjZiYjU3OGUzMzY5Mzk1MWQzNQ" - # Note that this task does not have any collector authentication tokens, since + # Note that this task does not have a collector authentication token, since # it is a helper role task. - collector_auth_tokens: [] + collector_auth_token: hpke_keys: - config: id: 37 diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index deced10af..644db5ae0 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -45,11 +45,14 @@ pub fn test_task_builders( Url::parse(&format!("http://helper-{endpoint_random_value}:8080/")).unwrap(), ) .with_min_batch_size(46) + // Force use of DAP-Auth-Tokens, as required by interop testing standard. + .with_dap_auth_aggregator_token() + .with_dap_auth_collector_token() .with_collector_hpke_config(collector_keypair.config().clone()); let helper_task = leader_task .clone() .with_role(Role::Helper) - .with_collector_auth_tokens(Vec::new()); + .with_collector_auth_token(None); let temporary_task = leader_task.clone().build(); let task_parameters = TaskParameters { task_id: *temporary_task.id(), @@ -60,7 +63,7 @@ pub fn test_task_builders( time_precision: *temporary_task.time_precision(), collector_hpke_config: collector_keypair.config().clone(), collector_private_key: collector_keypair.private_key().clone(), - collector_auth_token: temporary_task.primary_collector_auth_token().clone(), + collector_auth_token: temporary_task.collector_auth_token().unwrap().clone(), }; (task_parameters, leader_task, helper_task) diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 2fe30faac..b038a00bb 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -54,18 +54,16 @@ async fn handle_add_task( let collector_hpke_config = HpkeConfig::get_decoded(&collector_hpke_config_bytes) .context("could not parse collector HPKE configuration")?; - let collector_authentication_tokens = + let collector_authentication_token = match (request.role, request.collector_authentication_token) { (AggregatorRole::Leader, None) => { return Err(anyhow::anyhow!("collector authentication token is missing")) } - (AggregatorRole::Leader, Some(collector_authentication_token)) => { - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - collector_authentication_token, - ) - .context("invalid header value in \"collector_authentication_token\"")?]) - } - (AggregatorRole::Helper, _) => Vec::new(), + (AggregatorRole::Leader, Some(collector_authentication_token)) => Some( + AuthenticationToken::new_dap_auth_token_from_string(collector_authentication_token) + .context("invalid header value in \"collector_authentication_token\"")?, + ), + (AggregatorRole::Helper, _) => None, }; let hpke_keypair = keyring.lock().await.get_random_keypair(); @@ -103,8 +101,8 @@ async fn handle_add_task( // other aggregators running on the same host. Duration::from_seconds(1), collector_hpke_config, - Vec::from([leader_authentication_token]), - collector_authentication_tokens, + Some(leader_authentication_token), + collector_authentication_token, [hpke_keypair], ) .context("error constructing task")?; diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index c18513078..957a9440b 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -291,12 +291,12 @@ impl From for AggregatorAddTaskRequest { helper: task.helper_aggregator_endpoint().clone(), vdaf: task.vdaf().clone().into(), leader_authentication_token: String::from_utf8( - task.primary_aggregator_auth_token().as_ref().to_vec(), + task.aggregator_auth_token().unwrap().as_ref().to_vec(), ) .unwrap(), collector_authentication_token: if task.role() == &Role::Leader { Some( - String::from_utf8(task.primary_collector_auth_token().as_ref().to_vec()) + String::from_utf8(task.collector_auth_token().unwrap().as_ref().to_vec()) .unwrap(), ) } else {