From aca9f14d123a844983fa334dcad8f3ff5821c9df Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Fri, 22 Sep 2023 11:10:32 -0700 Subject: [PATCH] Tasks have only one aggregator, collector token (#1973) * Tasks have only one aggregator, collector token We've decided that each task has exactly one aggregator and collector auth token, and rotating those means rotating the task. This commit updates the representations of auth tokens in the database and in memory (`aggregator_core::task::Task`) accordingly. In particular, the tables `task_aggregator_auth_tokens` and `task_collector_auth_tokens` tables are removed and folded into table `tasks`. Several tests relied upon tasks constructed in tests having two aggregator auth tokens, one of each supported type. Those tests are fixed to explicitly construct tasks using `DAP-Auth-Token` tokens where necessary. Part of #1524, #1521 --- .../src/aggregator/aggregate_init_tests.rs | 41 +-- .../aggregator/aggregation_job_continue.rs | 9 +- .../src/aggregator/aggregation_job_driver.rs | 81 ++--- .../src/aggregator/collection_job_driver.rs | 21 +- .../src/aggregator/collection_job_tests.rs | 4 +- aggregator/src/aggregator/http_handlers.rs | 146 +++++---- aggregator/src/bin/janus_cli.rs | 16 +- aggregator_api/src/models.rs | 28 +- aggregator_api/src/routes.rs | 10 +- aggregator_api/src/tests.rs | 30 +- aggregator_core/src/datastore.rs | 290 +++++------------- aggregator_core/src/datastore/tests.rs | 82 +++++ aggregator_core/src/task.rs | 205 +++++++------ aggregator_core/src/taskprov.rs | 4 +- db/00000000000001_initial_schema.down.sql | 2 - db/00000000000001_initial_schema.up.sql | 40 +-- docs/samples/tasks.yaml | 30 +- integration_tests/tests/common/mod.rs | 7 +- .../src/bin/janus_interop_aggregator.rs | 18 +- interop_binaries/src/lib.rs | 4 +- 20 files changed, 486 insertions(+), 582 deletions(-) diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index a25ed9874..3487c0ba1 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -169,6 +169,7 @@ async fn setup_aggregate_init_test_for_vdaf< vdaf_instance, aggregation_param, measurement, + AuthenticationToken::Bearer(random()), ) .await; @@ -203,10 +204,13 @@ async fn setup_aggregate_init_test_without_sending_request< vdaf_instance: VdafInstance, aggregation_param: V::AggregationParam, measurement: V::Measurement, + auth_token: AuthenticationToken, ) -> AggregationJobInitTestCase { install_test_trace_subscriber(); - let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper).build(); + let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper) + .with_aggregator_auth_token(Some(auth_token)) + .build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); @@ -258,11 +262,12 @@ pub(crate) async fn put_aggregation_job( aggregation_job: &AggregationJobInitializeReq, handler: &impl Handler, ) -> TestConn { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); put(task.aggregation_job_uri(aggregation_job_id).unwrap().path()) - .with_request_header( - DAP_AUTH_HEADER, - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregationJobInitializeReq::::MEDIA_TYPE, @@ -279,14 +284,13 @@ async fn aggregation_job_init_authorization_dap_auth_token() { VdafInstance::Fake, dummy_vdaf::AggregationParam(0), (), + AuthenticationToken::DapAuth(random()), ) .await; - // Find a DapAuthToken among the task's aggregator auth tokens + let (auth_header, auth_value) = test_case .task - .aggregator_auth_tokens() - .iter() - .find(|auth| matches!(auth, AuthenticationToken::DapAuth(_))) + .aggregator_auth_token() .unwrap() .request_authentication(); @@ -317,6 +321,7 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu VdafInstance::Fake, dummy_vdaf::AggregationParam(0), (), + AuthenticationToken::Bearer(random()), ) .await; @@ -333,7 +338,8 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu DAP_AUTH_HEADER, test_case .task - .primary_aggregator_auth_token() + .aggregator_auth_token() + .unwrap() .as_ref() .to_owned(), ) @@ -480,19 +486,18 @@ async fn aggregation_job_init_wrong_query() { test_case.prepare_inits, ); + let (header, value) = test_case + .task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + let mut response = put(test_case .task .aggregation_job_uri(&random()) .unwrap() .path()) - .with_request_header( - DAP_AUTH_HEADER, - test_case - .task - .primary_aggregator_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregationJobInitializeReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 284b776cc..aff1a5e5c 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -302,11 +302,12 @@ pub mod test_util { request: &AggregationJobContinueReq, handler: &impl Handler, ) -> TestConn { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); post(task.aggregation_job_uri(aggregation_job_id).unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregationJobContinueReq::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 1fa4c414f..af0e8730a 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -417,7 +417,10 @@ impl AggregationJobDriver { AGGREGATION_JOB_ROUTE, AggregationJobInitializeReq::::MEDIA_TYPE, req, - task.primary_aggregator_auth_token(), + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| anyhow!("task has no aggregator auth token"))?, &self.http_request_duration_histogram, ) .await?; @@ -507,7 +510,10 @@ impl AggregationJobDriver { AGGREGATION_JOB_ROUTE, AggregationJobContinueReq::MEDIA_TYPE, req, - task.primary_aggregator_auth_token(), + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| anyhow!("task has no aggregator auth token"))?, &self.http_request_duration_histogram, ) .await?; @@ -950,7 +956,7 @@ mod tests { }, }; use rand::random; - use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration}; + use std::{borrow::Borrow, sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; #[tokio::test] @@ -997,7 +1003,7 @@ mod tests { &measurement, ); - let agg_auth_token = task.primary_aggregator_auth_token().clone(); + let agg_auth_token = task.aggregator_auth_token().unwrap().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -1115,6 +1121,7 @@ mod tests { ]); let mocked_aggregates = join_all(helper_responses.iter().map( |(req_method, req_content_type, resp_content_type, resp_body)| { + let (header, value) = agg_auth_token.request_authentication(); server .mock( req_method, @@ -1122,10 +1129,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), *req_content_type) .with_status(200) .with_header(CONTENT_TYPE.as_str(), resp_content_type) @@ -1293,7 +1297,7 @@ mod tests { &0, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::( *task.id(), @@ -1440,6 +1444,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -1447,10 +1452,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -1659,7 +1661,7 @@ mod tests { &measurement, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -1761,6 +1763,7 @@ mod tests { message: transcript.helper_prepare_transitions[0].message.clone(), }, )])); + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -1768,10 +1771,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -1912,7 +1912,7 @@ mod tests { &0, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::( *task.id(), @@ -2017,6 +2017,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unauthorizedRequest\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -2024,10 +2025,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -2169,7 +2167,7 @@ mod tests { &measurement, ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -2272,6 +2270,7 @@ mod tests { message: transcript.helper_prepare_transitions[0].message.clone(), }, )])); + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "PUT", @@ -2279,10 +2278,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -2436,7 +2432,7 @@ mod tests { &IdpfInput::from_bools(&[true]), ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -2584,6 +2580,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "POST", @@ -2591,10 +2588,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) .match_body(leader_request.get_encoded()) .with_status(200) @@ -2840,7 +2834,7 @@ mod tests { &IdpfInput::from_bools(&[true]), ); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); let report = generate_report::>( *task.id(), @@ -2973,6 +2967,7 @@ mod tests { .with_body("{\"type\": \"urn:ietf:params:ppm:dap:error:unrecognizedTask\"}") .create_async() .await; + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_success = server .mock( "POST", @@ -2980,10 +2975,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) .match_body(leader_request.get_encoded()) .with_status(200) @@ -3366,7 +3358,7 @@ mod tests { ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let aggregation_job_id = random(); let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); @@ -3474,6 +3466,7 @@ mod tests { // Set up three error responses from our mock helper. These will cause errors in the // leader, because the response body is empty and cannot be decoded. + let (header, value) = agg_auth_token.request_authentication(); let failure_mock = server .mock( "PUT", @@ -3481,10 +3474,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, @@ -3503,10 +3493,7 @@ mod tests { .unwrap() .path(), ) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregationJobInitializeReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index c12a470a6..50b8bab7e 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -228,7 +228,10 @@ impl CollectionJobDriver { AGGREGATE_SHARES_ROUTE, AggregateShareReq::::MEDIA_TYPE, req, - task.primary_aggregator_auth_token(), + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| Error::InvalidConfiguration("no aggregator auth token in task"))?, &self.metrics.http_request_duration_histogram, ) .await?; @@ -558,7 +561,7 @@ mod tests { }; use prio::codec::{Decode, Encode}; use rand::random; - use std::{str, sync::Arc, time::Duration as StdDuration}; + use std::{sync::Arc, time::Duration as StdDuration}; use trillium_tokio::Stopper; async fn setup_collection_job_test_case( @@ -716,7 +719,7 @@ mod tests { .with_time_precision(time_precision) .with_min_batch_size(10) .build(); - let agg_auth_token = task.primary_aggregator_auth_token(); + let agg_auth_token = task.aggregator_auth_token().unwrap(); let batch_interval = Interval::new(clock.now(), Duration::from_seconds(2000)).unwrap(); let aggregation_param = AggregationParam(0); let report_timestamp = clock @@ -874,12 +877,10 @@ mod tests { ); // Simulate helper failing to service the aggregate share request. + let (header, value) = agg_auth_token.request_authentication(); let mocked_failed_aggregate_share = server .mock("POST", task.aggregate_shares_uri().unwrap().path()) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregateShareReq::::MEDIA_TYPE, @@ -935,12 +936,10 @@ mod tests { Vec::new(), )); + let (header, value) = agg_auth_token.request_authentication(); let mocked_aggregate_share = server .mock("POST", task.aggregate_shares_uri().unwrap().path()) - .match_header( - "DAP-Auth-Token", - str::from_utf8(agg_auth_token.as_ref()).unwrap(), - ) + .match_header(header, value.as_str()) .match_header( CONTENT_TYPE.as_str(), AggregateShareReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index cdb4d862d..a190200dd 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -91,7 +91,7 @@ impl CollectionJobTestCase { self.put_collection_job_with_auth_token( collection_job_id, request, - Some(self.task.primary_collector_auth_token()), + self.task.collector_auth_token(), ) .await } @@ -120,7 +120,7 @@ impl CollectionJobTestCase { ) -> TestConn { self.post_collection_job_with_auth_token( collection_job_id, - Some(self.task.primary_collector_auth_token()), + self.task.collector_auth_token(), ) .await } diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index f7fa63540..6581bf7a9 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -1435,11 +1435,14 @@ mod tests { async fn aggregate_wrong_agg_auth_token() { let (_, _ephemeral_datastore, datastore, handler) = setup_http_handler_test().await; + let dap_auth_token = AuthenticationToken::DapAuth(random()); + let task = TaskBuilder::new( QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, ) + .with_aggregator_auth_token(Some(dap_auth_token.clone())) .build(); datastore.put_task(&task).await.unwrap(); @@ -1452,14 +1455,10 @@ mod tests { let wrong_token_value = random(); - // Send the right token, but the wrong format: we find a DapAuth token in the task's - // aggregator tokens and convert it to an equivalent Bearer token, which should be rejected. - let wrong_token_format = task - .aggregator_auth_tokens() - .iter() - .find(|token| matches!(token, AuthenticationToken::DapAuth(_))) - .map(|token| AuthenticationToken::new_bearer_token_from_bytes(token.as_ref()).unwrap()) - .unwrap(); + // Send the right token, but the wrong format: convert the DAP auth token to an equivalent + // Bearer token, which should be rejected. + let wrong_token_format = + AuthenticationToken::new_bearer_token_from_bytes(dap_auth_token.as_ref()).unwrap(); for auth_token in [Some(wrong_token_value), Some(wrong_token_format), None] { let mut test_conn = put(task @@ -4158,11 +4157,12 @@ mod tests { dummy_vdaf::AggregationParam::default().get_encoded(), ); + let (header, value) = task + .collector_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = put(task.collection_job_uri(&collection_job_id).unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_collector_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, CollectionReq::::MEDIA_TYPE, @@ -4221,7 +4221,7 @@ mod tests { .put_collection_job_with_auth_token( &collection_job_id, &req, - Some(test_case.task.primary_aggregator_auth_token()), + test_case.task.aggregator_auth_token(), ) .await; @@ -4298,7 +4298,7 @@ mod tests { let mut test_conn = test_case .post_collection_job_with_auth_token( &collection_job_id, - Some(test_case.task.primary_aggregator_auth_token()), + test_case.task.aggregator_auth_token(), ) .await; @@ -4526,18 +4526,16 @@ mod tests { let no_such_collection_job_id: CollectionJobId = random(); + let (header, value) = test_case + .task + .collector_auth_token() + .unwrap() + .request_authentication(); let test_conn = post(&format!( "/tasks/{}/collection_jobs/{no_such_collection_job_id}", test_case.task.id() )) - .with_request_header( - "DAP-Auth-Token", - test_case - .task - .primary_collector_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value) .run_async(&test_case.handler) .await; assert_eq!(test_conn.status(), Some(Status::NotFound)); @@ -4694,6 +4692,12 @@ mod tests { let collection_job_id: CollectionJobId = random(); + let (header, value) = test_case + .task + .collector_auth_token() + .unwrap() + .request_authentication(); + // Try to delete a collection job that doesn't exist let test_conn = delete( test_case @@ -4702,14 +4706,7 @@ mod tests { .unwrap() .path(), ) - .with_request_header( - "DAP-Auth-Token", - test_case - .task - .primary_collector_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value.clone()) .run_async(&test_case.handler) .await; assert_eq!(test_conn.status(), Some(Status::NotFound)); @@ -4734,14 +4731,7 @@ mod tests { .unwrap() .path(), ) - .with_request_header( - "DAP-Auth-Token", - test_case - .task - .primary_collector_auth_token() - .as_ref() - .to_owned(), - ) + .with_request_header(header, value) .run_async(&test_case.handler) .await; assert_eq!(test_conn.status(), Some(Status::NoContent)); @@ -4769,11 +4759,13 @@ mod tests { ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -4819,13 +4811,15 @@ mod tests { ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + // Test that a request for an invalid batch fails. (Specifically, the batch interval is too // small.) let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value.clone()) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -4847,10 +4841,7 @@ mod tests { // Test that a request for a too-old batch fails. let test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -4896,11 +4887,13 @@ mod tests { ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); + let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5081,11 +5074,12 @@ mod tests { 5, ReportIdChecksum::default(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5134,11 +5128,12 @@ mod tests { ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(), ), ] { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5199,11 +5194,12 @@ mod tests { // Request the aggregate share multiple times. If the request parameters don't change, // then there is no query count violation and all requests should succeed. for iteration in 0..3 { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5266,11 +5262,12 @@ mod tests { 20, ReportIdChecksum::get_decoded(&[8 ^ 4 ^ 3 ^ 2; 32]).unwrap(), ); + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, @@ -5317,11 +5314,12 @@ mod tests { ReportIdChecksum::get_decoded(&[4 ^ 8; 32]).unwrap(), ), ] { + let (header, value) = task + .aggregator_auth_token() + .unwrap() + .request_authentication(); let mut test_conn = post(task.aggregate_shares_uri().unwrap().path()) - .with_request_header( - "DAP-Auth-Token", - task.primary_aggregator_auth_token().as_ref().to_owned(), - ) + .with_request_header(header, value) .with_request_header( KnownHeaderName::ContentType, AggregateShareReq::::MEDIA_TYPE, diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index 3ebf692cb..a5ac05b41 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -718,7 +718,7 @@ mod tests { vdaf: !Prio3Sum bits: 2 role: Leader - vdaf_verify_keys: + vdaf_verify_key: max_batch_query_count: 1 task_expiration: 9000000000 min_batch_size: 10 @@ -731,8 +731,8 @@ mod tests { kdf_id: HkdfSha256 aead_id: Aes128Gcm public_key: 8lAqZ7OfNV2Gi_9cNE6J9WRmPbO-k1UPtu2Bztd0-yc - aggregator_auth_tokens: [] - collector_auth_tokens: [] + aggregator_auth_token: + collector_auth_token: hpke_keys: [] - leader_aggregator_endpoint: https://leader helper_aggregator_endpoint: https://helper @@ -740,7 +740,7 @@ mod tests { vdaf: !Prio3Sum bits: 2 role: Helper - vdaf_verify_keys: + vdaf_verify_key: max_batch_query_count: 1 task_expiration: 9000000000 min_batch_size: 10 @@ -753,8 +753,8 @@ mod tests { kdf_id: HkdfSha256 aead_id: Aes128Gcm public_key: 8lAqZ7OfNV2Gi_9cNE6J9WRmPbO-k1UPtu2Bztd0-yc - aggregator_auth_tokens: [] - collector_auth_tokens: [] + aggregator_auth_token: + collector_auth_token: hpke_keys: [] "#; @@ -800,8 +800,8 @@ mod tests { for task in &got_tasks { match task.role() { - Role::Leader => assert_eq!(task.collector_auth_tokens().len(), 1), - Role::Helper => assert!(task.collector_auth_tokens().is_empty()), + Role::Leader => assert!(task.collector_auth_token().is_some()), + Role::Helper => assert!(task.collector_auth_token().is_none()), role => panic!("unexpected role {role}"), } } diff --git a/aggregator_api/src/models.rs b/aggregator_api/src/models.rs index 7415a0e6c..d7e1625af 100644 --- a/aggregator_api/src/models.rs +++ b/aggregator_api/src/models.rs @@ -108,13 +108,12 @@ pub(crate) struct TaskResp { /// How much clock skew to allow between client and aggregator. Reports from /// farther than this duration into the future will be rejected. pub(crate) tolerable_clock_skew: Duration, - /// The authentication token for inter-aggregator communication in this task. - /// If `role` is Leader, this token is used by the aggregator to authenticate requests to - /// the Helper. If `role` is Helper, this token is used by the aggregator to authenticate - /// requests from the Leader. + /// The authentication token for inter-aggregator communication in this task. If `role` is + /// Helper, this token is used by the aggregator to authenticate requests from the Leader. Not + /// set if `role` is Leader.. // TODO(#1509): This field will have to change as Janus helpers will only store a salted // hash of aggregator auth tokens. - pub(crate) aggregator_auth_token: AuthenticationToken, + pub(crate) aggregator_auth_token: Option, /// The authentication token used by the task's Collector to authenticate to the Leader. /// `Some` if `role` is Leader, `None` otherwise. // TODO(#1509) This field will have to change as Janus leaders will only store a salted hash @@ -143,21 +142,6 @@ impl TryFrom<&Task> for TaskResp { } .clone(); - if task.aggregator_auth_tokens().len() != 1 { - return Err("illegal number of aggregator auth tokens in task"); - } - - let collector_auth_token = match task.role() { - Role::Leader => { - if task.collector_auth_tokens().len() != 1 { - return Err("illegal number of collector auth tokens in task"); - } - Some(task.primary_collector_auth_token().clone()) - } - Role::Helper => None, - _ => return Err("illegal aggregator role in task"), - }; - let mut aggregator_hpke_configs: Vec<_> = task .hpke_keys() .values() @@ -178,8 +162,8 @@ impl TryFrom<&Task> for TaskResp { min_batch_size: task.min_batch_size(), time_precision: *task.time_precision(), tolerable_clock_skew: *task.tolerable_clock_skew(), - aggregator_auth_token: task.primary_aggregator_auth_token().clone(), - collector_auth_token, + aggregator_auth_token: task.aggregator_auth_token().cloned(), + collector_auth_token: task.collector_auth_token().cloned(), collector_hpke_config: task .collector_hpke_config() .ok_or("collector_hpke_config is required")? diff --git a/aggregator_api/src/routes.rs b/aggregator_api/src/routes.rs index 5265aa6f5..0d93586be 100644 --- a/aggregator_api/src/routes.rs +++ b/aggregator_api/src/routes.rs @@ -121,7 +121,7 @@ pub(super) async fn post_task( let vdaf_verify_key = SecretBytes::new(vdaf_verify_key_bytes); - let (aggregator_auth_tokens, collector_auth_tokens) = match req.role { + let (aggregator_auth_token, collector_auth_token) = match req.role { Role::Leader => { let aggregator_auth_token = req.aggregator_auth_token.ok_or_else(|| { Error::BadRequest( @@ -129,7 +129,7 @@ pub(super) async fn post_task( .to_string(), ) })?; - (Vec::from([aggregator_auth_token]), Vec::from([random()])) + (Some(aggregator_auth_token), Some(random())) } Role::Helper => { @@ -140,7 +140,7 @@ pub(super) async fn post_task( )); } - (Vec::from([random()]), Vec::new()) + (Some(random()), None) } _ => unreachable!(), @@ -173,8 +173,8 @@ pub(super) async fn post_task( /* tolerable_clock_skew */ Duration::from_seconds(60), // 1 minute, /* collector_hpke_config */ req.collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keys, ) .map_err(|err| Error::BadRequest(format!("Error constructing task: {err}")))?, diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index 8e6c47d57..3b7e7d1b2 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -334,7 +334,8 @@ async fn post_task_helper_no_optional_fields() { assert_eq!(req.task_expiration.as_ref(), got_task.task_expiration()); assert_eq!(req.min_batch_size, got_task.min_batch_size()); assert_eq!(&req.time_precision, got_task.time_precision()); - assert_eq!(1, got_task.aggregator_auth_tokens().len()); + assert!(got_task.aggregator_auth_token().is_some()); + assert!(got_task.collector_auth_token().is_none()); assert_eq!( &req.collector_hpke_config, got_task.collector_hpke_config().unwrap() @@ -543,12 +544,11 @@ async fn post_task_leader_all_optional_fields() { &req.collector_hpke_config, got_task.collector_hpke_config().unwrap() ); - assert_eq!(1, got_task.aggregator_auth_tokens().len()); assert_eq!( aggregator_auth_token.as_ref(), - got_task.aggregator_auth_tokens()[0].as_ref() + got_task.aggregator_auth_token().unwrap().as_ref() ); - assert_eq!(1, got_task.collector_auth_tokens().len()); + assert!(got_task.collector_auth_token().is_some()); // ...and the response. assert_eq!(got_task_resp, TaskResp::try_from(&got_task).unwrap()); @@ -602,10 +602,7 @@ async fn get_task() { // Setup: write a task to the datastore. let (handler, _ephemeral_datastore, ds) = setup_api_test().await; - let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_auth_tokens(Vec::from([random()])) - .with_collector_auth_tokens(Vec::from([random()])) - .build(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader).build(); ds.run_tx(|tx| { let task = task.clone(); @@ -1758,14 +1755,14 @@ fn task_resp_serialization() { HpkeAeadId::Aes128Gcm, HpkePublicKey::from([0u8; 32].to_vec()), ), - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - "Y29sbGVjdG9yLWFiY2RlZjAw", - ) - .unwrap()]), - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - "Y29sbGVjdG9yLWFiY2RlZjAw", - ) - .unwrap()]), + Some( + AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw") + .unwrap(), + ), + Some( + AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw") + .unwrap(), + ), [(HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(13), @@ -1831,6 +1828,7 @@ fn task_resp_serialization() { Token::NewtypeStruct { name: "Duration" }, Token::U64(60), Token::Str("aggregator_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 480f8d7e7..25b9eecf5 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -533,8 +533,12 @@ impl Transaction<'_, C> { task_id, aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, + aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type, + collector_auth_token) + VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 + ) ON CONFLICT DO NOTHING", ) .await?; @@ -579,87 +583,43 @@ impl Transaction<'_, C> { "vdaf_verify_key", task.opaque_vdaf_verify_key().as_ref(), )?, + /* aggregator_auth_token_type */ + &task + .aggregator_auth_token() + .map(AuthenticationTokenType::from), + /* aggregator_auth_token */ + &task + .aggregator_auth_token() + .map(|token| { + self.crypter.encrypt( + "tasks", + task.id().as_ref(), + "aggregator_auth_token", + token.as_ref(), + ) + }) + .transpose()?, + /* collector_auth_token_type */ + &task + .collector_auth_token() + .map(AuthenticationTokenType::from), + /* collector_auth_token */ + &task + .collector_auth_token() + .map(|token| { + self.crypter.encrypt( + "tasks", + task.id().as_ref(), + "collector_auth_token", + token.as_ref(), + ) + }) + .transpose()?, ], ) .await?, )?; - // Aggregator auth tokens. - let mut aggregator_auth_token_ords = Vec::new(); - let mut aggregator_auth_token_types = Vec::new(); - let mut aggregator_auth_tokens = Vec::new(); - for (ord, token) in task.aggregator_auth_tokens().iter().enumerate() { - let ord = i64::try_from(ord)?; - - let mut row_id = [0; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - - let encrypted_aggregator_auth_token = self.crypter.encrypt( - "task_aggregator_auth_tokens", - &row_id, - "token", - token.as_ref(), - )?; - - aggregator_auth_token_ords.push(ord); - aggregator_auth_token_types.push(AuthenticationTokenType::from(token)); - aggregator_auth_tokens.push(encrypted_aggregator_auth_token); - } - let stmt = self - .prepare_cached( - "INSERT INTO task_aggregator_auth_tokens (task_id, ord, type, token) - SELECT - (SELECT id FROM tasks WHERE task_id = $1), - * FROM UNNEST($2::BIGINT[], $3::AUTH_TOKEN_TYPE[], $4::BYTEA[])", - ) - .await?; - let aggregator_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id().as_ref(), - /* ords */ &aggregator_auth_token_ords, - /* token_types */ &aggregator_auth_token_types, - /* tokens */ &aggregator_auth_tokens, - ]; - let aggregator_auth_tokens_future = self.execute(&stmt, aggregator_auth_tokens_params); - - // Collector auth tokens. - let mut collector_auth_token_ords = Vec::new(); - let mut collector_auth_token_types = Vec::new(); - let mut collector_auth_tokens = Vec::new(); - for (ord, token) in task.collector_auth_tokens().iter().enumerate() { - let ord = i64::try_from(ord)?; - - let mut row_id = [0; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - - let encrypted_collector_auth_token = self.crypter.encrypt( - "task_collector_auth_tokens", - &row_id, - "token", - token.as_ref(), - )?; - - collector_auth_token_ords.push(ord); - collector_auth_token_types.push(AuthenticationTokenType::from(token)); - collector_auth_tokens.push(encrypted_collector_auth_token); - } - let stmt = self - .prepare_cached( - "INSERT INTO task_collector_auth_tokens (task_id, ord, type, token) - SELECT - (SELECT id FROM tasks WHERE task_id = $1), - * FROM UNNEST($2::BIGINT[], $3::AUTH_TOKEN_TYPE[], $4::BYTEA[])", - ) - .await?; - let collector_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id().as_ref(), - /* ords */ &collector_auth_token_ords, - /* token_types */ &collector_auth_token_types, - /* tokens */ &collector_auth_tokens, - ]; - let collector_auth_tokens_future = self.execute(&stmt, collector_auth_tokens_params); - // HPKE keys. let mut hpke_config_ids: Vec = Vec::new(); let mut hpke_configs: Vec> = Vec::new(); @@ -695,13 +655,7 @@ impl Transaction<'_, C> { /* configs */ &hpke_configs, /* private_keys */ &hpke_private_keys, ]; - let hpke_configs_future = self.execute(&stmt, hpke_configs_params); - - try_join!( - aggregator_auth_tokens_future, - collector_auth_tokens_future, - hpke_configs_future, - )?; + self.execute(&stmt, hpke_configs_params).await?; Ok(()) } @@ -730,28 +684,13 @@ impl Transaction<'_, C> { "SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config, - vdaf_verify_key + vdaf_verify_key, aggregator_auth_token_type, aggregator_auth_token, + collector_auth_token_type, collector_auth_token FROM tasks WHERE task_id = $1", ) .await?; let task_row = self.query_opt(&stmt, params); - let stmt = self - .prepare_cached( - "SELECT ord, type, token FROM task_aggregator_auth_tokens - WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1) ORDER BY ord ASC", - ) - .await?; - let aggregator_auth_token_rows = self.query(&stmt, params); - - let stmt = self - .prepare_cached( - "SELECT ord, type, token FROM task_collector_auth_tokens - WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1) ORDER BY ord ASC", - ) - .await?; - let collector_auth_token_rows = self.query(&stmt, params); - let stmt = self .prepare_cached( "SELECT config_id, config, private_key FROM task_hpke_keys @@ -760,22 +699,9 @@ impl Transaction<'_, C> { .await?; let hpke_key_rows = self.query(&stmt, params); - let (task_row, aggregator_auth_token_rows, collector_auth_token_rows, hpke_key_rows) = try_join!( - task_row, - aggregator_auth_token_rows, - collector_auth_token_rows, - hpke_key_rows, - )?; + let (task_row, hpke_key_rows) = try_join!(task_row, hpke_key_rows,)?; task_row - .map(|task_row| { - self.task_from_rows( - task_id, - &task_row, - &aggregator_auth_token_rows, - &collector_auth_token_rows, - &hpke_key_rows, - ) - }) + .map(|task_row| self.task_from_rows(task_id, &task_row, &hpke_key_rows)) .transpose() } @@ -787,30 +713,14 @@ impl Transaction<'_, C> { "SELECT task_id, aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key + tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, + aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type, + collector_auth_token FROM tasks", ) .await?; let task_rows = self.query(&stmt, &[]); - let stmt = self - .prepare_cached( - "SELECT (SELECT tasks.task_id FROM tasks - WHERE tasks.id = task_aggregator_auth_tokens.task_id), - ord, type, token FROM task_aggregator_auth_tokens ORDER BY ord ASC", - ) - .await?; - let aggregator_auth_token_rows = self.query(&stmt, &[]); - - let stmt = self - .prepare_cached( - "SELECT (SELECT tasks.task_id FROM tasks - WHERE tasks.id = task_collector_auth_tokens.task_id), - ord, type, token FROM task_collector_auth_tokens ORDER BY ord ASC", - ) - .await?; - let collector_auth_token_rows = self.query(&stmt, &[]); - let stmt = self .prepare_cached( "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_hpke_keys.task_id), @@ -819,13 +729,7 @@ impl Transaction<'_, C> { .await?; let hpke_config_rows = self.query(&stmt, &[]); - let (task_rows, aggregator_auth_token_rows, collector_auth_token_rows, hpke_config_rows) = - try_join!( - task_rows, - aggregator_auth_token_rows, - collector_auth_token_rows, - hpke_config_rows, - )?; + let (task_rows, hpke_config_rows) = try_join!(task_rows, hpke_config_rows,)?; let mut task_row_by_id = Vec::new(); for row in task_rows { @@ -833,24 +737,6 @@ impl Transaction<'_, C> { task_row_by_id.push((task_id, row)); } - let mut aggregator_auth_token_rows_by_task_id: HashMap> = HashMap::new(); - for row in aggregator_auth_token_rows { - let task_id = TaskId::get_decoded(row.get("task_id"))?; - aggregator_auth_token_rows_by_task_id - .entry(task_id) - .or_default() - .push(row); - } - - let mut collector_auth_token_rows_by_task_id: HashMap> = HashMap::new(); - for row in collector_auth_token_rows { - let task_id = TaskId::get_decoded(row.get("task_id"))?; - collector_auth_token_rows_by_task_id - .entry(task_id) - .or_default() - .push(row); - } - let mut hpke_config_rows_by_task_id: HashMap> = HashMap::new(); for row in hpke_config_rows { let task_id = TaskId::get_decoded(row.get("task_id"))?; @@ -866,12 +752,6 @@ impl Transaction<'_, C> { self.task_from_rows( &task_id, &row, - &aggregator_auth_token_rows_by_task_id - .remove(&task_id) - .unwrap_or_default(), - &collector_auth_token_rows_by_task_id - .remove(&task_id) - .unwrap_or_default(), &hpke_config_rows_by_task_id .remove(&task_id) .unwrap_or_default(), @@ -880,16 +760,12 @@ impl Transaction<'_, C> { .collect::>() } - /// Construct a [`Task`] from the contents of the provided (tasks) `Row`, - /// `hpke_aggregator_auth_tokens` rows, and `task_hpke_keys` rows. - /// - /// agg_auth_token_rows must be sorted in ascending order by `ord`. + /// Construct a [`Task`] from the contents of the provided (tasks) `Row` and + /// `task_hpke_keys` rows. fn task_from_rows( &self, task_id: &TaskId, row: &Row, - aggregator_auth_token_rows: &[Row], - collector_auth_token_rows: &[Row], hpke_key_rows: &[Row], ) -> Result { // Scalar task parameters. @@ -927,47 +803,31 @@ impl Transaction<'_, C> { ) .map(SecretBytes::new)?; - // Aggregator authentication tokens. - let mut aggregator_auth_tokens = Vec::new(); - for row in aggregator_auth_token_rows { - let ord: i64 = row.get("ord"); - let auth_token_type: AuthenticationTokenType = row.get("type"); - let encrypted_aggregator_auth_token: Vec = row.get("token"); - - let mut row_id = [0u8; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task_id.as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - - aggregator_auth_tokens.push(auth_token_type.as_authentication( - &self.crypter.decrypt( - "task_aggregator_auth_tokens", - &row_id, - "token", - &encrypted_aggregator_auth_token, - )?, - )?); - } - - // Collector authentication tokens. - let mut collector_auth_tokens = Vec::new(); - for row in collector_auth_token_rows { - let ord: i64 = row.get("ord"); - let auth_token_type: AuthenticationTokenType = row.get("type"); - let encrypted_collector_auth_token: Vec = row.get("token"); + let aggregator_auth_token = row + .try_get::<_, Option>>("aggregator_auth_token")? + .zip(row.try_get::<_, Option>("aggregator_auth_token_type")?) + .map(|(encrypted_token, token_type)| { + token_type.as_authentication(&self.crypter.decrypt( + "tasks", + task_id.as_ref(), + "aggregator_auth_token", + &encrypted_token, + )?) + }) + .transpose()?; - let mut row_id = [0u8; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task_id.as_ref()); - row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - - collector_auth_tokens.push(auth_token_type.as_authentication( - &self.crypter.decrypt( - "task_collector_auth_tokens", - &row_id, - "token", - &encrypted_collector_auth_token, - )?, - )?); - } + let collector_auth_token = row + .try_get::<_, Option>>("collector_auth_token")? + .zip(row.try_get::<_, Option>("collector_auth_token_type")?) + .map(|(encrypted_token, token_type)| { + token_type.as_authentication(&self.crypter.decrypt( + "tasks", + task_id.as_ref(), + "collector_auth_token", + &encrypted_token, + )?) + }) + .transpose()?; // HPKE keys. let mut hpke_keypairs = Vec::new(); @@ -1005,8 +865,8 @@ impl Transaction<'_, C> { time_precision, tolerable_clock_skew, collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keypairs, ); // Trial validation through all known schemes. This is a workaround to avoid extending the diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 675df0701..3779ba223 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -203,6 +203,88 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { assert_eq!(want_tasks, got_tasks); } +#[rstest_reuse::apply(schema_versions_template)] +#[tokio::test] +async fn put_task_invalid_aggregator_auth_tokens(ephemeral_datastore: EphemeralDatastore) { + install_test_trace_subscriber(); + let ds = ephemeral_datastore.datastore(MockClock::default()).await; + + let task = TaskBuilder::new( + task::QueryType::TimeInterval, + VdafInstance::Prio3Count, + Role::Leader, + ) + .build(); + + ds.put_task(&task).await.unwrap(); + + for (auth_token, auth_token_type) in [("NULL", "'BEARER'"), ("'\\xDEADBEEF'::bytea", "NULL")] { + ds.run_tx(|tx| { + Box::pin(async move { + let err = tx + .query_one( + &format!( + "UPDATE tasks SET aggregator_auth_token = {auth_token}, + aggregator_auth_token_type = {auth_token_type}" + ), + &[], + ) + .await + .unwrap_err(); + + assert_eq!( + err.as_db_error().unwrap().constraint().unwrap(), + "aggregator_auth_token_null" + ); + Ok(()) + }) + }) + .await + .unwrap(); + } +} + +#[rstest_reuse::apply(schema_versions_template)] +#[tokio::test] +async fn put_task_invalid_collector_auth_tokens(ephemeral_datastore: EphemeralDatastore) { + install_test_trace_subscriber(); + let ds = ephemeral_datastore.datastore(MockClock::default()).await; + + let task = TaskBuilder::new( + task::QueryType::TimeInterval, + VdafInstance::Prio3Count, + Role::Leader, + ) + .build(); + + ds.put_task(&task).await.unwrap(); + + for (auth_token, auth_token_type) in [("NULL", "'BEARER'"), ("'\\xDEADBEEF'::bytea", "NULL")] { + ds.run_tx(|tx| { + Box::pin(async move { + let err = tx + .query_one( + &format!( + "UPDATE tasks SET collector_auth_token = {auth_token}, + collector_auth_token_type = {auth_token_type}" + ), + &[], + ) + .await + .unwrap_err(); + + assert_eq!( + err.as_db_error().unwrap().constraint().unwrap(), + "collector_auth_token_null" + ); + Ok(()) + }) + }) + .await + .unwrap(); + } +} + #[rstest_reuse::apply(schema_versions_template)] #[tokio::test] async fn get_task_metrics(ephemeral_datastore: EphemeralDatastore) { diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 7e7318cc9..77755ad9c 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -129,10 +129,12 @@ pub struct Task { tolerable_clock_skew: Duration, /// HPKE configuration for the collector. collector_hpke_config: Option, - /// Tokens used to authenticate messages sent to or received from the other aggregator. - aggregator_auth_tokens: Vec, - /// Tokens used to authenticate messages sent to or received from the collector. - collector_auth_tokens: Vec, + /// Token used to authenticate messages sent to or received from the other aggregator. Only set + /// if the task was not created via taskprov. + aggregator_auth_token: Option, + /// Token used to authenticate messages sent to or received from the collector. Only set if this + /// aggregator is the leader. + collector_auth_token: Option, /// HPKE configurations & private keys used by this aggregator to decrypt client reports. hpke_keys: HashMap, } @@ -155,8 +157,8 @@ impl Task { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: HpkeConfig, - aggregator_auth_tokens: Vec, - collector_auth_tokens: Vec, + aggregator_auth_token: Option, + collector_auth_token: Option, hpke_keys: I, ) -> Result { let task = Self::new_without_validation( @@ -174,8 +176,8 @@ impl Task { time_precision, tolerable_clock_skew, Some(collector_hpke_config), - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keys, ); task.validate()?; @@ -200,8 +202,8 @@ impl Task { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: Option, - aggregator_auth_tokens: Vec, - collector_auth_tokens: Vec, + aggregator_auth_token: Option, + collector_auth_token: Option, hpke_keys: I, ) -> Self { // Compute hpke_configs mapping cfg.id -> (cfg, key). @@ -228,8 +230,8 @@ impl Task { time_precision, tolerable_clock_skew, collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_token, + collector_auth_token, hpke_keys, } } @@ -264,13 +266,13 @@ impl Task { pub(crate) fn validate(&self) -> Result<(), Error> { self.validate_common()?; - if self.aggregator_auth_tokens.is_empty() { - return Err(Error::InvalidParameter("aggregator_auth_tokens")); + if self.aggregator_auth_token.is_none() { + return Err(Error::InvalidParameter("aggregator_auth_token")); } - if (self.role == Role::Leader) == (self.collector_auth_tokens.is_empty()) { + if (self.role == Role::Leader) == (self.collector_auth_token.is_none()) { // Collector auth tokens are allowed & required if and only if this task is in the // leader role. - return Err(Error::InvalidParameter("collector_auth_tokens")); + return Err(Error::InvalidParameter("collector_auth_token")); } if self.hpke_keys.is_empty() { return Err(Error::InvalidParameter("hpke_keys")); @@ -357,14 +359,14 @@ impl Task { self.collector_hpke_config.as_ref() } - /// Retrieves the aggregator authentication tokens associated with this task. - pub fn aggregator_auth_tokens(&self) -> &[AuthenticationToken] { - &self.aggregator_auth_tokens + /// Retrieves the aggregator authentication token associated with this task. + pub fn aggregator_auth_token(&self) -> Option<&AuthenticationToken> { + self.aggregator_auth_token.as_ref() } - /// Retrieves the collector authentication tokens associated with this task. - pub fn collector_auth_tokens(&self) -> &[AuthenticationToken] { - &self.collector_auth_tokens + /// Retrieves the collector authentication token associated with this task. + pub fn collector_auth_token(&self) -> Option<&AuthenticationToken> { + self.collector_auth_token.as_ref() } /// Retrieves the HPKE keys in use associated with this task. @@ -397,35 +399,22 @@ impl Task { } } - /// Returns the [`AuthenticationToken`] currently used by this aggregator to authenticate itself - /// to other aggregators. - pub fn primary_aggregator_auth_token(&self) -> &AuthenticationToken { - self.aggregator_auth_tokens.iter().next_back().unwrap() - } - /// Checks if the given aggregator authentication token is valid (i.e. matches with an /// authentication token recognized by this task). pub fn check_aggregator_auth_token(&self, auth_token: &AuthenticationToken) -> bool { - self.aggregator_auth_tokens - .iter() - .rev() - .any(|t| t == auth_token) - } - - /// Returns the [`AuthenticationToken`] currently used by the collector to authenticate itself - /// to the aggregators. - pub fn primary_collector_auth_token(&self) -> &AuthenticationToken { - // Unwrap safety: self.collector_auth_tokens is never empty - self.collector_auth_tokens.iter().next_back().unwrap() + match self.aggregator_auth_token { + Some(ref t) => t == auth_token, + None => false, + } } /// Checks if the given collector authentication token is valid (i.e. matches with an /// authentication token recognized by this task). pub fn check_collector_auth_token(&self, auth_token: &AuthenticationToken) -> bool { - self.collector_auth_tokens - .iter() - .rev() - .any(|t| t == auth_token) + match self.collector_auth_token { + Some(ref t) => t == auth_token, + None => false, + } } /// Returns the [`VerifyKey`] used by this aggregator to prepare report shares with other @@ -495,8 +484,8 @@ pub struct SerializedTask { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: HpkeConfig, - aggregator_auth_tokens: Vec, - collector_auth_tokens: Vec, + aggregator_auth_token: Option, + collector_auth_token: Option, hpke_keys: Vec, // uses unpadded base64url } @@ -532,12 +521,12 @@ impl SerializedTask { self.vdaf_verify_key = Some(URL_SAFE_NO_PAD.encode(vdaf_verify_key.as_ref())); } - if self.aggregator_auth_tokens.is_empty() { - self.aggregator_auth_tokens = Vec::from([random()]); + if self.aggregator_auth_token.is_none() { + self.aggregator_auth_token = Some(random()); } - if self.collector_auth_tokens.is_empty() && self.role == Role::Leader { - self.collector_auth_tokens = Vec::from([random()]); + if self.collector_auth_token.is_none() && self.role == Role::Leader { + self.collector_auth_token = Some(random()); } if self.hpke_keys.is_empty() { @@ -577,8 +566,8 @@ impl Serialize for Task { .collector_hpke_config() .expect("serializable tasks must have collector_hpke_config") .clone(), - aggregator_auth_tokens: self.aggregator_auth_tokens.clone(), - collector_auth_tokens: self.collector_auth_tokens.clone(), + aggregator_auth_token: self.aggregator_auth_token.clone(), + collector_auth_token: self.collector_auth_token.clone(), hpke_keys, } .serialize(serializer) @@ -614,8 +603,8 @@ impl TryFrom for Task { serialized_task.time_precision, serialized_task.tolerable_clock_skew, serialized_task.collector_hpke_config, - serialized_task.aggregator_auth_tokens, - serialized_task.collector_auth_tokens, + serialized_task.aggregator_auth_token, + serialized_task.collector_auth_token, serialized_task.hpke_keys, ) } @@ -689,10 +678,10 @@ pub mod test_util { .collect(), ); - let collector_auth_tokens = if role == Role::Leader { - Vec::from([random(), AuthenticationToken::DapAuth(random())]) + let collector_auth_token = if role == Role::Leader { + Some(random()) // Create an AuthenticationToken::Bearer by default } else { - Vec::new() + None }; Self( @@ -711,8 +700,8 @@ pub mod test_util { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random(), AuthenticationToken::DapAuth(random())]), - collector_auth_tokens, + Some(random()), // Create an AuthenticationToken::Bearer by default + collector_auth_token, Vec::from([aggregator_keypair_0, aggregator_keypair_1]), ) .unwrap(), @@ -795,24 +784,42 @@ pub mod test_util { }) } - /// Associates the eventual task with the given aggregator authentication tokens. - pub fn with_aggregator_auth_tokens( + /// Associates the eventual task with the given aggregator authentication token. + pub fn with_aggregator_auth_token( self, - aggregator_auth_tokens: Vec, + aggregator_auth_token: Option, ) -> Self { Self(Task { - aggregator_auth_tokens, + aggregator_auth_token, + ..self.0 + }) + } + + /// Associates the eventual task with a random [`AuthenticationToken::DapAuth`] aggregator + /// auth token. + pub fn with_dap_auth_aggregator_token(self) -> Self { + Self(Task { + aggregator_auth_token: Some(AuthenticationToken::DapAuth(random())), ..self.0 }) } - /// Sets the collector authentication tokens for the task. - pub fn with_collector_auth_tokens( + /// Associates the eventual task with the given collector authentication token. + pub fn with_collector_auth_token( self, - collector_auth_tokens: Vec, + collector_auth_token: Option, ) -> Self { Self(Task { - collector_auth_tokens, + collector_auth_token, + ..self.0 + }) + } + + /// Associates the eventual task with a random [`AuthenticationToken::DapAuth`] collector + /// auth token. + pub fn with_dap_auth_collector_token(self) -> Self { + Self(Task { + collector_auth_token: Some(AuthenticationToken::DapAuth(random())), ..self.0 }) } @@ -915,8 +922,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::new(), + Some(random()), + None, Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); @@ -937,8 +944,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::from([random()]), + Some(random()), + Some(random()), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -959,8 +966,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::new(), + Some(random()), + None, Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -981,8 +988,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::from([random()]), + Some(random()), + Some(random()), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); @@ -1005,8 +1012,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random()]), - Vec::from([random()]), + Some(random()), + Some(random()), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -1088,14 +1095,14 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - "YWdncmVnYXRvciB0b2tlbg", - ) - .unwrap()]), - Vec::from([AuthenticationToken::new_bearer_token_from_string( - "Y29sbGVjdG9yIHRva2Vu", - ) - .unwrap()]), + Some( + AuthenticationToken::new_dap_auth_token_from_string("YWdncmVnYXRvciB0b2tlbg") + .unwrap(), + ), + Some( + AuthenticationToken::new_bearer_token_from_string("Y29sbGVjdG9yIHRva2Vu") + .unwrap(), + ), [HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(255), @@ -1180,8 +1187,8 @@ mod tests { Token::Str("public_key"), Token::Str("Y29sbGVjdG9yIGhwa2UgcHVibGljIGtleQ"), Token::StructEnd, - Token::Str("aggregator_auth_tokens"), - Token::Seq { len: Some(1) }, + Token::Str("aggregator_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, @@ -1194,9 +1201,8 @@ mod tests { Token::Str("token"), Token::Str("YWdncmVnYXRvciB0b2tlbg"), Token::StructEnd, - Token::SeqEnd, - Token::Str("collector_auth_tokens"), - Token::Seq { len: Some(1) }, + Token::Str("collector_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, @@ -1209,7 +1215,6 @@ mod tests { Token::Str("token"), Token::Str("Y29sbGVjdG9yIHRva2Vu"), Token::StructEnd, - Token::SeqEnd, Token::Str("hpke_keys"), Token::Seq { len: Some(1) }, Token::Struct { @@ -1277,11 +1282,11 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::new_bearer_token_from_string( - "YWdncmVnYXRvciB0b2tlbg", - ) - .unwrap()]), - Vec::new(), + Some( + AuthenticationToken::new_bearer_token_from_string("YWdncmVnYXRvciB0b2tlbg") + .unwrap(), + ), + None, [HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(255), @@ -1378,8 +1383,8 @@ mod tests { Token::Str("public_key"), Token::Str("Y29sbGVjdG9yIGhwa2UgcHVibGljIGtleQ"), Token::StructEnd, - Token::Str("aggregator_auth_tokens"), - Token::Seq { len: Some(1) }, + Token::Str("aggregator_auth_token"), + Token::Some, Token::Struct { name: "AuthenticationToken", len: 2, @@ -1392,10 +1397,8 @@ mod tests { Token::Str("token"), Token::Str("YWdncmVnYXRvciB0b2tlbg"), Token::StructEnd, - Token::SeqEnd, - Token::Str("collector_auth_tokens"), - Token::Seq { len: Some(0) }, - Token::SeqEnd, + Token::Str("collector_auth_token"), + Token::None, Token::Str("hpke_keys"), Token::Seq { len: Some(1) }, Token::Struct { diff --git a/aggregator_core/src/taskprov.rs b/aggregator_core/src/taskprov.rs index d4989553f..ce7c4cdee 100644 --- a/aggregator_core/src/taskprov.rs +++ b/aggregator_core/src/taskprov.rs @@ -302,8 +302,8 @@ impl Task { time_precision, tolerable_clock_skew, None, - Vec::new(), - Vec::new(), + None, + None, Vec::new(), )); task.validate()?; diff --git a/db/00000000000001_initial_schema.down.sql b/db/00000000000001_initial_schema.down.sql index 9419a7822..6a1156840 100644 --- a/db/00000000000001_initial_schema.down.sql +++ b/db/00000000000001_initial_schema.down.sql @@ -24,8 +24,6 @@ DROP INDEX client_reports_task_and_timestamp_index CASCADE; DROP INDEX client_reports_task_and_timestamp_unaggregated_index CASCADE; DROP TABLE client_reports CASCADE; DROP TABLE task_hpke_keys CASCADE; -DROP TABLE task_collector_auth_tokens CASCADE; -DROP TABLE task_aggregator_auth_tokens CASCADE; DROP INDEX task_id_index CASCADE; DROP TABLE tasks CASCADE; DROP TABLE taskprov_aggregator_auth_tokens; diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index 3834cfac5..cffe33b79 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -89,34 +89,24 @@ CREATE TABLE tasks( time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds collector_hpke_config BYTEA, -- the HPKE config of the collector (encoded HpkeConfig message) - vdaf_verify_key BYTEA NOT NULL -- the VDAF verification key (encrypted) + vdaf_verify_key BYTEA NOT NULL, -- the VDAF verification key (encrypted) + + -- Authentication token used to authenticate messages to/from the other aggregator. + -- These columns are NULL if the task was provisioned by taskprov. + aggregator_auth_token_type AUTH_TOKEN_TYPE, -- the type of the authentication token + aggregator_auth_token BYTEA, -- encrypted bearer token + -- The aggregator_auth_token columns must either both be NULL or both be non-NULL + CONSTRAINT aggregator_auth_token_null CHECK ((aggregator_auth_token_type IS NULL) = (aggregator_auth_token IS NULL)), + + -- Authentication token used to authenticate messages to the leader from the collector. These + -- columns are NULL if the task was provisioned by taskprov or if the task's role is helper. + collector_auth_token_type AUTH_TOKEN_TYPE, -- the type of the authentication token + collector_auth_token BYTEA, -- encrypted bearer token + -- The collector_auth_token columns must either both be NULL or both be non-NULL + CONSTRAINT collector_auth_token_null CHECK ((collector_auth_token_type IS NULL) = (collector_auth_token IS NULL)) ); CREATE INDEX task_id_index ON tasks(task_id); --- The aggregator authentication tokens used by a given task. -CREATE TABLE task_aggregator_auth_tokens( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BIGINT NOT NULL, -- task ID the token is associated with - ord BIGINT NOT NULL, -- a value used to specify the ordering of the authentication tokens - type AUTH_TOKEN_TYPE NOT NULL DEFAULT 'BEARER', -- the type of the authentication token - token BYTEA NOT NULL, -- bearer token used to authenticate messages to/from the other aggregator (encrypted) - - CONSTRAINT task_aggregator_auth_tokens_unique_task_id_and_ord UNIQUE(task_id, ord), - CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE -); - --- The collector authentication tokens used by a given task. -CREATE TABLE task_collector_auth_tokens( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BIGINT NOT NULL, -- task ID the token is associated with - ord BIGINT NOT NULL, -- a value used to specify the ordering of the authentication tokens - type AUTH_TOKEN_TYPE NOT NULL DEFAULT 'BEARER', -- the type of the authentication token - token BYTEA NOT NULL, -- bearer token used to authenticate messages from the collector (encrypted) - - CONSTRAINT task_collector_auth_tokens_unique_task_id_and_ord UNIQUE(task_id, ord), - CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE -); - -- The HPKE public keys (aka configs) and private keys used by a given task. CREATE TABLE task_hpke_keys( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index c4b49efba..88d94a05a 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -59,31 +59,29 @@ aead_id: Aes128Gcm public_key: 4qiv6IY5jrjCV3xbaQXULmPIpvoIml1oJmeXm-yOuAo - # Authentication tokens shared beteween the aggregators, and used to + # Authentication token shared beteween the aggregators, and used to # authenticate leader-to-helper requests. In the case of a leader-role task, - # the leader will include the first token in a header when making requests to - # the helper. In the case of a helper-role task, the helper will accept - # requests with any of the listed authentication tokens. + # the leader will include the token in a header when making requests to the + # helper. In the case of a helper-role task, the helper will accept requests + # requests with authentication tokens. # # Each token's `type` governs how it is inserted into HTTP requests if used by # the leader to authenticate a request to the helper. - aggregator_auth_tokens: + aggregator_auth_token: # DAP-Auth-Token values are encoded in unpadded base64url, and the decoded # value is sent in an HTTP header. For example, this token's value decodes # to "aggregator-235242f99406c4fd28b820c32eab0f68". - - type: "DapAuth" + type: "DapAuth" token: "YWdncmVnYXRvci0yMzUyNDJmOTk0MDZjNGZkMjhiODIwYzMyZWFiMGY2OA" - # Bearer token values are encoded in unpadded base64url. - - type: "Bearer" - token: "YWdncmVnYXRvci04NDc1NjkwZjJmYzQzMDBmYjE0NmJiMjk1NDIzNDk1NA" - # Authentication tokens shared between the leader and the collector, and used + # Authentication token shared between the leader and the collector, and used # to authenticate collector-to-leader requests. For leader tasks, this has the # same format as `aggregator_auth_tokens` above. For helper tasks, this will # be an empty list instead. + # Bearer token values are encoded in unpadded base64url. # This example decodes to "collector-abf5408e2b1601831625af3959106458". - collector_auth_tokens: - - type: "Bearer" + collector_auth_token: + type: "Bearer" token: "Y29sbGVjdG9yLWFiZjU0MDhlMmIxNjAxODMxNjI1YWYzOTU5MTA2NDU4" # This aggregator's HPKE keypairs. The first keypair's HPKE configuration will @@ -122,12 +120,12 @@ kdf_id: HkdfSha256 aead_id: Aes128Gcm public_key: KHRLcWgfWxli8cdOLPsgsZPttHXh0ho3vLVLrW-63lE - aggregator_auth_tokens: - - type: "Bearer" + aggregator_auth_token: + type: "Bearer" token: "YWdncmVnYXRvci1jZmE4NDMyZjdkMzllMjZiYjU3OGUzMzY5Mzk1MWQzNQ" - # Note that this task does not have any collector authentication tokens, since + # Note that this task does not have a collector authentication token, since # it is a helper role task. - collector_auth_tokens: [] + collector_auth_token: hpke_keys: - config: id: 37 diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index deced10af..644db5ae0 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -45,11 +45,14 @@ pub fn test_task_builders( Url::parse(&format!("http://helper-{endpoint_random_value}:8080/")).unwrap(), ) .with_min_batch_size(46) + // Force use of DAP-Auth-Tokens, as required by interop testing standard. + .with_dap_auth_aggregator_token() + .with_dap_auth_collector_token() .with_collector_hpke_config(collector_keypair.config().clone()); let helper_task = leader_task .clone() .with_role(Role::Helper) - .with_collector_auth_tokens(Vec::new()); + .with_collector_auth_token(None); let temporary_task = leader_task.clone().build(); let task_parameters = TaskParameters { task_id: *temporary_task.id(), @@ -60,7 +63,7 @@ pub fn test_task_builders( time_precision: *temporary_task.time_precision(), collector_hpke_config: collector_keypair.config().clone(), collector_private_key: collector_keypair.private_key().clone(), - collector_auth_token: temporary_task.primary_collector_auth_token().clone(), + collector_auth_token: temporary_task.collector_auth_token().unwrap().clone(), }; (task_parameters, leader_task, helper_task) diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 2fe30faac..b038a00bb 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -54,18 +54,16 @@ async fn handle_add_task( let collector_hpke_config = HpkeConfig::get_decoded(&collector_hpke_config_bytes) .context("could not parse collector HPKE configuration")?; - let collector_authentication_tokens = + let collector_authentication_token = match (request.role, request.collector_authentication_token) { (AggregatorRole::Leader, None) => { return Err(anyhow::anyhow!("collector authentication token is missing")) } - (AggregatorRole::Leader, Some(collector_authentication_token)) => { - Vec::from([AuthenticationToken::new_dap_auth_token_from_string( - collector_authentication_token, - ) - .context("invalid header value in \"collector_authentication_token\"")?]) - } - (AggregatorRole::Helper, _) => Vec::new(), + (AggregatorRole::Leader, Some(collector_authentication_token)) => Some( + AuthenticationToken::new_dap_auth_token_from_string(collector_authentication_token) + .context("invalid header value in \"collector_authentication_token\"")?, + ), + (AggregatorRole::Helper, _) => None, }; let hpke_keypair = keyring.lock().await.get_random_keypair(); @@ -103,8 +101,8 @@ async fn handle_add_task( // other aggregators running on the same host. Duration::from_seconds(1), collector_hpke_config, - Vec::from([leader_authentication_token]), - collector_authentication_tokens, + Some(leader_authentication_token), + collector_authentication_token, [hpke_keypair], ) .context("error constructing task")?; diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index c18513078..957a9440b 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -291,12 +291,12 @@ impl From for AggregatorAddTaskRequest { helper: task.helper_aggregator_endpoint().clone(), vdaf: task.vdaf().clone().into(), leader_authentication_token: String::from_utf8( - task.primary_aggregator_auth_token().as_ref().to_vec(), + task.aggregator_auth_token().unwrap().as_ref().to_vec(), ) .unwrap(), collector_authentication_token: if task.role() == &Role::Leader { Some( - String::from_utf8(task.primary_collector_auth_token().as_ref().to_vec()) + String::from_utf8(task.collector_auth_token().unwrap().as_ref().to_vec()) .unwrap(), ) } else {