From 61b466884b6882775f66bf7f3522d5f3398da18c Mon Sep 17 00:00:00 2001 From: Ameer Ghani Date: Thu, 21 Sep 2023 16:49:21 -0400 Subject: [PATCH] aggregator auth tokens --- aggregator/src/aggregator.rs | 53 ++++- aggregator/src/aggregator/http_handlers.rs | 52 ++++- aggregator/src/aggregator/taskprov_tests.rs | 211 ++----------------- aggregator/src/bin/aggregator.rs | 25 ++- docs/samples/advanced_config/aggregator.yaml | 10 + docs/samples/basic_config/aggregator.yaml | 9 + 6 files changed, 148 insertions(+), 212 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index eb43de03f..0cc1d68be 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -193,8 +193,11 @@ pub struct Config { /// New tasks will have this tolerable clock skew. pub tolerable_clock_skew: Duration, - /// Defines the key used to deterministically derive the VDAF verify key for new tasks. + /// Defines the key used to derive the VDAF verify key for new tasks. pub verify_key_init: VerifyKeyInit, + + /// Authentication tokens used for requests from the leader. + pub auth_tokens: Vec, } // subscriber-01 only: the config now has mandatory fields, so default only makes sense as a helper @@ -218,6 +221,7 @@ impl Default for Config { report_expiry_age: None, tolerable_clock_skew: Duration::from_minutes(60).unwrap(), verify_key_init: random(), + auth_tokens: Vec::new(), } } } @@ -322,8 +326,13 @@ impl Aggregator { task_aggregator } None if taskprov_task_config.is_some() => { - self.taskprov_opt_in(&Role::Leader, task_id, taskprov_task_config.unwrap()) - .await?; + self.taskprov_opt_in( + &Role::Leader, + task_id, + taskprov_task_config.unwrap(), + auth_token.as_ref(), + ) + .await?; // Retry fetching the aggregator, since the last function would have just inserted // its task. @@ -366,8 +375,13 @@ impl Aggregator { } if taskprov_task_config.is_some() { - self.taskprov_authorize_request(&Role::Leader, task_id, taskprov_task_config.unwrap()) - .await?; + self.taskprov_authorize_request( + &Role::Leader, + task_id, + taskprov_task_config.unwrap(), + auth_token.as_ref(), + ) + .await?; } else if !auth_token .map(|t| task_aggregator.task.check_aggregator_auth_token(&t)) .unwrap_or(false) @@ -492,8 +506,13 @@ impl Aggregator { // Authorize the request and retrieve the collector's HPKE config. If this is a taskprov task, we // have to use the peer aggregator's collector config rather than the main task. let collector_hpke_config = if taskprov_task_config.is_some() { - self.taskprov_authorize_request(&Role::Leader, task_id, taskprov_task_config.unwrap()) - .await?; + self.taskprov_authorize_request( + &Role::Leader, + task_id, + taskprov_task_config.unwrap(), + auth_token.as_ref(), + ) + .await?; &self.cfg.collector_hpke_config } else { if !auth_token @@ -574,8 +593,9 @@ impl Aggregator { peer_role: &Role, task_id: &TaskId, task_config: &TaskConfig, + aggregator_auth_token: Option<&AuthenticationToken>, ) -> Result<(), Error> { - self.taskprov_authorize_request(peer_role, task_id, task_config) + self.taskprov_authorize_request(peer_role, task_id, task_config, aggregator_auth_token) .await?; let aggregator_urls = task_config @@ -583,6 +603,12 @@ impl Aggregator { .iter() .map(|url| url.try_into()) .collect::, _>>()?; + if aggregator_urls.len() < 2 { + return Err(Error::UnrecognizedMessage( + Some(*task_id), + "taskprov configuration is missing one or both aggregators", + )); + } // TODO(#1647): Check whether task config parameters are acceptable for privacy and // availability of the system. @@ -662,7 +688,18 @@ impl Aggregator { peer_role: &Role, task_id: &TaskId, task_config: &TaskConfig, + aggregator_auth_token: Option<&AuthenticationToken>, ) -> Result<(), Error> { + let request_token = aggregator_auth_token.ok_or(Error::UnauthorizedRequest(*task_id))?; + if !self + .cfg + .auth_tokens + .iter() + .any(|token| token == request_token) + { + return Err(Error::UnauthorizedRequest(*task_id)); + } + if self.clock.now() > *task_config.task_expiration() { return Err(Error::InvalidTask(*task_id, OptOutReason::TaskExpired)); } diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 16ebc7762..c9cf46290 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -245,10 +245,10 @@ async fn aggregator_handler_with_aggregator( "hpke_config", hpke_config_cors_preflight, ) - .post("upload", instrumented(api(upload::))) - .with_route(trillium::Method::Options, "upload", upload_cors_preflight) .post("aggregate", instrumented(api(aggregate::))) .post("collect", instrumented(api(collect_post::))) + .post("aggregate_share", instrumented(api(aggregate_share::))) + // TODO(#1728): remove these unnecessary routes, subscriber-01 is helper-only. .get( "collect/:task_id/:collection_job_id", instrumented(api(collect_get::)), @@ -257,11 +257,57 @@ async fn aggregator_handler_with_aggregator( "collect/:task_id/:collection_job_id", instrumented(api(collect_delete::)), ) - .post("aggregate_share", instrumented(api(aggregate_share::))), + .post("upload", instrumented(api(upload::))) + .with_route(trillium::Method::Options, "upload", upload_cors_preflight), StatusCounter::new(meter), )) } +// pub fn authenticated(handler: H) -> impl Handler { +// AuthenticatedHandler(handler, PhantomData::) +// } + +// #[derive(Handler)] +// struct AuthenticatedHandler(#[handler(except = [run])] H, PhantomData); + +// impl AuthenticatedHandler { +// async fn run(&self, mut conn: Conn) -> Conn { +// let aggregator: Arc> = conn.take_state().unwrap(); + +// let request_auth = { +// let bearer_token = match extract_bearer_token(&conn) { +// Ok(bearer_token) => bearer_token, +// Err(_) => todo!("bad request"), +// }; + +// match bearer_token { +// Some(bearer_token) => bearer_token, +// None => match conn.request_headers().get(DAP_AUTH_HEADER) { +// Some(dap_auth) => { +// match AuthenticationToken::new_dap_auth_token_from_bytes(dap_auth.as_ref()) +// { +// Ok(dap_auth) => dap_auth, +// Err(_) => todo!("bad request"), +// } +// } +// None => todo!("unauthorized"), +// }, +// } +// }; + +// if aggregator +// .cfg +// .auth_tokens +// .iter() +// .any(|token| *token == request_auth) +// { +// self.0.run(conn).await +// } else { +// conn.with_status(Status::Unauthorized).halt() +// } +// } +// } + /// API handler for the "/hpke_config" GET endpoint. async fn hpke_config( conn: &mut Conn, diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 56a335c90..d828872a2 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -26,7 +26,7 @@ use janus_core::{ HpkeApplicationInfo, HpkeKeypair, Label, }, report_id::ReportIdChecksumExt, - task::PRIO3_VERIFY_KEY_LENGTH, + task::{AuthenticationToken, PRIO3_VERIFY_KEY_LENGTH}, taskprov::TASKPROV_HEADER, test_util::{install_test_trace_subscriber, run_vdaf, VdafTranscript}, time::{Clock, DurationExt, MockClock, TimeExt}, @@ -71,6 +71,7 @@ pub struct TaskprovTestCase { task: Task, task_config: TaskConfig, task_id: TaskId, + aggregator_auth_token: AuthenticationToken, } async fn setup_taskprov_test() -> TaskprovTestCase { @@ -82,6 +83,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { let global_hpke_key = generate_test_hpke_config_and_private_key(); let collector_hpke_keypair = generate_test_hpke_config_and_private_key(); + let aggregator_auth_token: AuthenticationToken = random(); datastore .run_tx(|tx| { @@ -94,9 +96,12 @@ async fn setup_taskprov_test() -> TaskprovTestCase { .await .unwrap(); + let tolerable_clock_skew = Duration::from_seconds(60); let config = Config { collector_hpke_config: collector_hpke_keypair.config().clone(), verify_key_init: random(), + auth_tokens: vec![aggregator_auth_token.clone()], + tolerable_clock_skew, ..Default::default() }; @@ -161,7 +166,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { config.report_expiry_age.clone(), min_batch_size as u64, Duration::from_seconds(1), - Duration::from_seconds(1), + tolerable_clock_skew, ) .unwrap(); @@ -200,6 +205,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase { report_metadata, transcript, report_share, + aggregator_auth_token, } } @@ -218,10 +224,7 @@ async fn taskprov_aggregate_init() { Vec::from([test.report_share.clone()]), ); - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let auth = test.aggregator_auth_token.request_authentication(); let mut test_conn = post(test.task.aggregation_job_uri().unwrap().path()) .with_request_header(auth.0, "Bearer invalid_token") @@ -317,10 +320,7 @@ async fn taskprov_opt_out_task_expired() { Vec::from([test.report_share.clone()]), ); - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let auth = test.aggregator_auth_token.request_authentication(); // Advance clock past task expiry. test.clock.advance(&Duration::from_hours(48).unwrap()); @@ -390,10 +390,7 @@ async fn taskprov_opt_out_mismatched_task_id() { ) .unwrap(); - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let auth = test.aggregator_auth_token.request_authentication(); let mut test_conn = post( test @@ -470,10 +467,7 @@ async fn taskprov_opt_out_missing_aggregator() { Vec::from([test.report_share.clone()]), ); - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let auth = test.aggregator_auth_token.request_authentication(); let mut test_conn = post( test @@ -507,172 +501,6 @@ async fn taskprov_opt_out_missing_aggregator() { ); } -#[tokio::test] -async fn taskprov_opt_out_peer_aggregator_wrong_role() { - let test = setup_taskprov_test().await; - - let batch_id = random(); - let aggregation_job_id: AggregationJobId = random(); - - let task_expiration = test - .clock - .now() - .add(&Duration::from_hours(24).unwrap()) - .unwrap(); - let another_task_config = TaskConfig::new( - Vec::from("foobar".as_bytes()), - // Attempt to configure leader as a helper. - Vec::from([ - "https://helper.example.com/".as_bytes().try_into().unwrap(), - "https://leader.example.com/".as_bytes().try_into().unwrap(), - ]), - QueryConfig::new( - Duration::from_seconds(1), - 100, - 100, - TaskprovQuery::FixedSize { - max_batch_size: 100, - }, - ), - task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Aes128Count).unwrap(), - ) - .unwrap(); - let another_task_config_encoded = another_task_config.get_encoded(); - let another_task_id: TaskId = digest(&SHA256, &another_task_config_encoded) - .as_ref() - .try_into() - .unwrap(); - - let request = AggregateInitializeReq::new( - another_task_id, - aggregation_job_id, - ().get_encoded(), - PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), - ); - - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); - - let mut test_conn = post( - test - // Use the test case task's ID. - .task - .aggregation_job_uri() - .unwrap() - .path(), - ) - .with_request_header(auth.0, auth.1) - .with_request_header( - KnownHeaderName::ContentType, - AggregateInitializeReq::::MEDIA_TYPE, - ) - .with_request_header( - TASKPROV_HEADER, - URL_SAFE_NO_PAD.encode(another_task_config_encoded), - ) - .with_request_body(request.get_encoded()) - .run_async(&test.handler) - .await; - assert_eq!(test_conn.status(), Some(Status::BadRequest)); - assert_eq!( - take_problem_details(&mut test_conn).await, - json!({ - "status": Status::BadRequest as u16, - "type": "urn:ietf:params:ppm:dap:error:invalidTask", - "title": "Aggregator has opted out of the indicated task.", - "taskid": format!("{}", another_task_id - ), - }) - ); -} - -#[tokio::test] -async fn taskprov_opt_out_peer_aggregator_does_not_exist() { - let test = setup_taskprov_test().await; - - let batch_id = random(); - let aggregation_job_id: AggregationJobId = random(); - - let task_expiration = test - .clock - .now() - .add(&Duration::from_hours(24).unwrap()) - .unwrap(); - let another_task_config = TaskConfig::new( - Vec::from("foobar".as_bytes()), - Vec::from([ - // Some non-existent aggregator. - "https://foobar.example.com/".as_bytes().try_into().unwrap(), - "https://leader.example.com/".as_bytes().try_into().unwrap(), - ]), - QueryConfig::new( - Duration::from_seconds(1), - 100, - 100, - TaskprovQuery::FixedSize { - max_batch_size: 100, - }, - ), - task_expiration, - VdafConfig::new(DpConfig::new(DpMechanism::None), VdafType::Prio3Aes128Count).unwrap(), - ) - .unwrap(); - let another_task_config_encoded = another_task_config.get_encoded(); - let another_task_id: TaskId = digest(&SHA256, &another_task_config_encoded) - .as_ref() - .try_into() - .unwrap(); - - let request = AggregateInitializeReq::new( - another_task_id, - aggregation_job_id, - ().get_encoded(), - PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([test.report_share.clone()]), - ); - - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); - - let mut test_conn = post( - test - // Use the test case task's ID. - .task - .aggregation_job_uri() - .unwrap() - .path(), - ) - .with_request_header(auth.0, auth.1) - .with_request_header( - KnownHeaderName::ContentType, - AggregateInitializeReq::::MEDIA_TYPE, - ) - .with_request_header( - TASKPROV_HEADER, - URL_SAFE_NO_PAD.encode(another_task_config_encoded), - ) - .with_request_body(request.get_encoded()) - .run_async(&test.handler) - .await; - assert_eq!(test_conn.status(), Some(Status::BadRequest)); - assert_eq!( - take_problem_details(&mut test_conn).await, - json!({ - "status": Status::BadRequest as u16, - "type": "urn:ietf:params:ppm:dap:error:invalidTask", - "title": "Aggregator has opted out of the indicated task.", - "taskid": format!("{}", another_task_id - ), - }) - ); -} - #[tokio::test] async fn taskprov_aggregate_continue() { let test = setup_taskprov_test().await; @@ -748,10 +576,7 @@ async fn taskprov_aggregate_continue() { )]), ); - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let auth = test.aggregator_auth_token.request_authentication(); // Attempt using the wrong credentials, should reject. let mut test_conn = post(test.task.aggregation_job_uri().unwrap().path()) @@ -869,10 +694,7 @@ async fn taskprov_aggregate_share() { ReportIdChecksum::get_decoded(&[3; 32]).unwrap(), ); - let auth = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let auth = test.aggregator_auth_token.request_authentication(); // Attempt using the wrong credentials, should reject. let mut test_conn = post(test.task.aggregate_shares_uri().unwrap().path()) @@ -938,10 +760,7 @@ async fn taskprov_aggregate_share() { #[tokio::test] async fn end_to_end() { let test = setup_taskprov_test().await; - let (auth_header_name, auth_header_value) = test - .peer_aggregator - .primary_aggregator_auth_token() - .request_authentication(); + let (auth_header_name, auth_header_value) = test.aggregator_auth_token.request_authentication(); let batch_id = random(); let aggregation_job_id = random(); diff --git a/aggregator/src/bin/aggregator.rs b/aggregator/src/bin/aggregator.rs index f96c72895..55bbc2605 100644 --- a/aggregator/src/bin/aggregator.rs +++ b/aggregator/src/bin/aggregator.rs @@ -45,12 +45,23 @@ async fn main() -> Result<()> { .response_headers() .context("failed to parse response headers")?; + // inahga: refactor auth token logic + let auth_tokens = options + .aggregator_api_auth_tokens + .iter() + .filter(|token| !token.is_empty()) + .map(|token| { + AuthenticationToken::new_bearer_token_from_string(token) + .context("invalid aggregator auth token") + }) + .collect::>>()?; + let mut handlers = ( aggregator_handler( Arc::clone(&datastore), clock, &meter, - config.aggregator_config(options.verify_key_init), + config.aggregator_config(options.verify_key_init, auth_tokens), ) .await?, None, @@ -393,7 +404,11 @@ impl Config { .collect() } - fn aggregator_config(&self, verify_key_init: VerifyKeyInit) -> aggregator::Config { + fn aggregator_config( + &self, + verify_key_init: VerifyKeyInit, + auth_tokens: Vec, + ) -> aggregator::Config { aggregator::Config { max_upload_batch_size: self.max_upload_batch_size, max_upload_batch_write_delay: Duration::from_millis( @@ -408,6 +423,7 @@ impl Config { report_expiry_age: self.taskprov_config.report_expiry_age, tolerable_clock_skew: self.taskprov_config.tolerable_clock_skew, verify_key_init, + auth_tokens, } } } @@ -441,7 +457,6 @@ mod tests { use janus_aggregator_core::taskprov::VerifyKeyInit; use janus_core::{ hpke::test_util::generate_test_hpke_config_and_private_key, test_util::roundtrip_encoding, - time::DurationExt, }; use janus_messages::{ HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, @@ -703,7 +718,7 @@ mod tests { "# ) .unwrap() - .aggregator_config(verify_key_init.clone()), + .aggregator_config(verify_key_init.clone(), Vec::new()), aggregator::Config { max_upload_batch_size: 100, max_upload_batch_write_delay: Duration::from_millis(250), @@ -720,7 +735,7 @@ mod tests { ), ), report_expiry_age: None, - tolerable_clock_skew: janus_messages::Duration::from_minutes(60).unwrap(), + tolerable_clock_skew: janus_messages::Duration::from_seconds(60), verify_key_init, ..Default::default() } diff --git a/docs/samples/advanced_config/aggregator.yaml b/docs/samples/advanced_config/aggregator.yaml index 1b06f69ed..69fd2e5ca 100644 --- a/docs/samples/advanced_config/aggregator.yaml +++ b/docs/samples/advanced_config/aggregator.yaml @@ -104,3 +104,13 @@ garbage_collection: # The maximum number of collection jobs (& related artifacts), per task, to delete in a single run # of the garbage collector. collection_limit: 50 + +taskprov_config: + collector_hpke_config: + id: 183 + kem_id: X25519HkdfSha256 + kdf_id: HkdfSha256 + aead_id: Aes128Gcm + public_key: 4qiv6IY5jrjCV3xbaQXULmPIpvoIml1oJmeXm-yOuAo + tolerable_clock_skew: 60 + report_expiry_age: 2592000 diff --git a/docs/samples/basic_config/aggregator.yaml b/docs/samples/basic_config/aggregator.yaml index 37c5932d4..8e6368982 100644 --- a/docs/samples/basic_config/aggregator.yaml +++ b/docs/samples/basic_config/aggregator.yaml @@ -21,3 +21,12 @@ max_upload_batch_write_delay_ms: 250 # Number of sharded database records per batch aggregation. Must not be greater # than the equivalent setting in the collection job driver. (required) batch_aggregation_shard_count: 32 + +taskprov_config: + collector_hpke_config: + id: 183 + kem_id: X25519HkdfSha256 + kdf_id: HkdfSha256 + aead_id: Aes128Gcm + public_key: 4qiv6IY5jrjCV3xbaQXULmPIpvoIml1oJmeXm-yOuAo + tolerable_clock_skew: 60