diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 2899685f3..d5a9c5f5a 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -264,33 +264,51 @@ impl Aggregator { &self, task_id_base64: Option<&[u8]>, ) -> Result { - match task_id_base64 { - Some(task_id_base64) => { - let task_id_bytes = URL_SAFE_NO_PAD - .decode(task_id_base64) - .map_err(|_| Error::UnrecognizedMessage(None, "task_id"))?; - let task_id = TaskId::get_decoded(&task_id_bytes) - .map_err(|_| Error::UnrecognizedMessage(None, "task_id"))?; - let task_aggregator = self - .task_aggregator_for(&task_id) - .await? - .ok_or(Error::UnrecognizedTask(task_id))?; - Ok(task_aggregator.handle_hpke_config()) + // If we're running in taskprov mode, unconditionally provide the global keys and ignore + // the task_id parameter. + if self.cfg.taskprov_config.enabled { + let configs = self.global_hpke_keypairs.configs(); + if configs.is_empty() { + Err(Error::Internal( + "this server is missing its global HPKE config".into(), + )) + } else { + Ok(HpkeConfigList::new(configs.to_vec())) } - None => { - let configs = self.global_hpke_keypairs.configs(); - if configs.is_empty() { - if self.cfg.taskprov_config.enabled { - // A global HPKE configuration is only _required_ when taskprov - // is enabled. - Err(Error::Internal( - "this server is missing its global HPKE config".into(), - )) - } else { + } else { + // Otherwise, try to get the task-specific key. + match task_id_base64 { + Some(task_id_base64) => { + let task_id_bytes = URL_SAFE_NO_PAD + .decode(task_id_base64) + .map_err(|_| Error::UnrecognizedMessage(None, "task_id"))?; + let task_id = TaskId::get_decoded(&task_id_bytes) + .map_err(|_| Error::UnrecognizedMessage(None, "task_id"))?; + let task_aggregator = self + .task_aggregator_for(&task_id) + .await? + .ok_or(Error::UnrecognizedTask(task_id))?; + + match task_aggregator.handle_hpke_config() { + Some(hpke_config_list) => Ok(hpke_config_list), + // Assuming something hasn't gone horribly wrong with the database, this + // should only happen in the case where the system has been moved from taskprov + // mode to non-taskprov mode. Thus there's still taskprov tasks in the database. + // This isn't a supported use case, so the operator needs to delete these tasks + // or move the system back into taskprov mode. + None => Err(Error::Internal("task has no HPKE configs".to_string())), + } + } + // No task ID present, try to fall back to a global config. + None => { + let configs = self.global_hpke_keypairs.configs(); + if configs.is_empty() { + // This server isn't configured to provide global HPKE keys, the client + // should have given us a task ID. Err(Error::MissingTaskId) + } else { + Ok(HpkeConfigList::new(configs.to_vec())) } - } else { - Ok(HpkeConfigList::new(configs.to_vec())) } } } @@ -850,19 +868,18 @@ impl TaskAggregator { }) } - fn handle_hpke_config(&self) -> HpkeConfigList { + fn handle_hpke_config(&self) -> Option { // TODO(#239): consider deciding a better way to determine "primary" (e.g. most-recent) HPKE // config/key -- right now it's the one with the maximal config ID, but that will run into // trouble if we ever need to wrap-around, which we may since config IDs are effectively a u8. - HpkeConfigList::new(Vec::from([self + Some(HpkeConfigList::new(Vec::from([self .task .hpke_keys() .iter() - .max_by_key(|(&id, _)| id) - .unwrap() + .max_by_key(|(&id, _)| id)? .1 .config() - .clone()])) + .clone()]))) } async fn handle_upload( diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 22c8b6457..4d3f24457 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -632,23 +632,26 @@ pub mod test_util { #[cfg(test)] mod tests { - use crate::aggregator::{ - aggregate_init_tests::{put_aggregation_job, setup_aggregate_init_test}, - aggregation_job_continue::test_util::{ - post_aggregation_job_and_decode, post_aggregation_job_expecting_error, - }, - collection_job_tests::setup_collection_job_test_case, - empty_batch_aggregations, - http_handlers::{ - aggregator_handler, aggregator_handler_with_aggregator, - test_util::{take_problem_details, take_response_body}, - }, - tests::{ - create_report, create_report_custom, default_aggregator_config, - generate_helper_report_share, generate_helper_report_share_for_plaintext, - BATCH_AGGREGATION_SHARD_COUNT, + use crate::{ + aggregator::{ + aggregate_init_tests::{put_aggregation_job, setup_aggregate_init_test}, + aggregation_job_continue::test_util::{ + post_aggregation_job_and_decode, post_aggregation_job_expecting_error, + }, + collection_job_tests::setup_collection_job_test_case, + empty_batch_aggregations, + http_handlers::{ + aggregator_handler, aggregator_handler_with_aggregator, + test_util::{take_problem_details, take_response_body}, + }, + tests::{ + create_report, create_report_custom, default_aggregator_config, + generate_helper_report_share, generate_helper_report_share_for_plaintext, + BATCH_AGGREGATION_SHARD_COUNT, + }, + Config, }, - Config, + config::TaskprovConfig, }; use assert_matches::assert_matches; use futures::future::try_join_all; @@ -663,6 +666,7 @@ mod tests { }, query_type::{AccumulableQueryType, CollectableQueryType}, task::{test_util::TaskBuilder, QueryType, VerifyKey}, + taskprov, test_util::noop_meter, }; use janus_core::{ @@ -945,6 +949,86 @@ mod tests { assert_eq!(test_conn.status(), Some(Status::BadRequest)); } + #[tokio::test] + async fn global_hpke_config_with_taskprov() { + install_test_trace_subscriber(); + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + + // Insert an HPKE config, i.e. start the application with a keypair already + // in the database. + let first_hpke_keypair = generate_test_hpke_config_and_private_key_with_id(1); + datastore + .run_tx(|tx| { + let keypair = first_hpke_keypair.clone(); + Box::pin(async move { + tx.put_global_hpke_keypair(&keypair).await?; + tx.set_global_hpke_keypair_state(keypair.config().id(), &HpkeKeyState::Active) + .await?; + Ok(()) + }) + }) + .await + .unwrap(); + + // Insert a taskprov task. This task won't have its task-specific HPKE key. + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Count, + Role::Leader, + ) + .build(); + let task_id = *task.id(); + let task = taskprov::Task::new( + task_id, + task.aggregator_endpoints().to_vec(), + *task.query_type(), + task.vdaf().clone(), + *task.role(), + task.vdaf_verify_keys().to_vec(), + task.max_batch_query_count(), + task.task_expiration().cloned(), + task.report_expiry_age().cloned(), + task.min_batch_size(), + *task.time_precision(), + *task.tolerable_clock_skew(), + ) + .unwrap(); + datastore.put_task(&task.into()).await.unwrap(); + + let cfg = Config { + taskprov_config: TaskprovConfig { enabled: true }, + ..Default::default() + }; + + let aggregator = Arc::new( + crate::aggregator::Aggregator::new( + datastore.clone(), + clock.clone(), + &noop_meter(), + cfg, + ) + .await + .unwrap(), + ); + let handler = aggregator_handler_with_aggregator(aggregator.clone(), &noop_meter()) + .await + .unwrap(); + + let mut test_conn = get(&format!("/hpke_config?task_id={}", task_id)) + .run_async(&handler) + .await; + assert_eq!(test_conn.status(), Some(Status::Ok)); + let bytes = take_response_body(&mut test_conn).await; + let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!( + hpke_config_list.hpke_configs(), + &[first_hpke_keypair.config().clone()] + ); + check_hpke_config_is_usable(&hpke_config_list, &first_hpke_keypair); + } + fn check_hpke_config_is_usable(hpke_config_list: &HpkeConfigList, hpke_keypair: &HpkeKeypair) { let application_info = HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader);