From c9fb9a3f843e657eaa7c3aa46cc1eceda89a3bbb Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Fri, 29 Sep 2023 18:52:39 -0700 Subject: [PATCH] Task rewrite: `janus_cli` (#2028) Adopts `AggregatorTask` in `janus_cli`. There's a functional change here: the YAML task accepted by `janus_cli` is now an `AggregatorTask` (via `SerializedAggregatorTask`), containing an aggregator's specific view of a task, meaning it contains just the peer aggregator's endpoint URL. In the near future, there will be more changes to the task YAML format so that for instance leader tasks only get the collector auth token _hash_ and helper tasks only get the aggregator auth token _hash_. Part of #1524 --- aggregator/src/bin/janus_cli.rs | 137 +++++++++++--------------------- aggregator_core/src/task.rs | 3 +- docs/samples/tasks.yaml | 8 +- 3 files changed, 53 insertions(+), 95 deletions(-) diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index 3bcd86ee7..e783e9977 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -9,7 +9,7 @@ use janus_aggregator::{ }; use janus_aggregator_core::{ datastore::{self, Datastore}, - task::{SerializedTask, Task}, + task::{AggregatorTask, SerializedAggregatorTask}, }; use janus_core::time::{Clock, RealClock}; use k8s_openapi::api::core::v1::Secret; @@ -155,9 +155,9 @@ async fn provision_tasks( tasks_file: &Path, generate_missing_parameters: bool, dry_run: bool, -) -> Result> { +) -> Result> { // Read tasks file. - let tasks: Vec = { + let tasks: Vec = { let task_file_contents = fs::read_to_string(tasks_file) .await .with_context(|| format!("couldn't read tasks file {tasks_file:?}"))?; @@ -165,14 +165,14 @@ async fn provision_tasks( .with_context(|| format!("couldn't parse tasks file {tasks_file:?}"))? }; - let tasks: Vec = tasks + let tasks: Vec = tasks .into_iter() .map(|mut task| { if generate_missing_parameters { task.generate_missing_fields(); } - Task::try_from(task) + AggregatorTask::try_from(task) }) .collect::>()?; @@ -201,7 +201,7 @@ async fn provision_tasks( err => err?, } - tx.put_task(task).await?; + tx.put_aggregator_task(task).await?; written_tasks.push(task.clone()); } @@ -464,7 +464,7 @@ mod tests { }; use janus_aggregator_core::{ datastore::{test_util::ephemeral_datastore, Datastore}, - task::{test_util::TaskBuilder, QueryType, Task}, + task::{test_util::NewTaskBuilder as TaskBuilder, AggregatorTask, QueryType}, }; use janus_core::{ test_util::{kubernetes, roundtrip_encoding}, @@ -555,15 +555,15 @@ mod tests { .unwrap_err(); } - fn task_hashmap_from_slice(tasks: Vec) -> HashMap { + fn task_hashmap_from_slice(tasks: Vec) -> HashMap { tasks.into_iter().map(|task| (*task.id(), task)).collect() } async fn run_provision_tasks_testcase( ds: &Datastore, - tasks: &[Task], + tasks: &[AggregatorTask], dry_run: bool, - ) -> Vec { + ) -> Vec { // Write tasks to a temporary file. let mut tasks_file = NamedTempFile::new().unwrap(); tasks_file @@ -583,18 +583,14 @@ mod tests { let ds = ephemeral_datastore.datastore(RealClock::default()).await; let tasks = Vec::from([ - TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(), - TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Sum { bits: 64 }, - Role::Helper, - ) - .build(), + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(), + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Sum { bits: 64 }) + .build() + .helper_view() + .unwrap(), ]); let written_tasks = run_provision_tasks_testcase(&ds, &tasks, false).await; @@ -603,30 +599,12 @@ mod tests { let want_tasks = task_hashmap_from_slice(tasks); let written_tasks = task_hashmap_from_slice(written_tasks); let got_tasks = task_hashmap_from_slice( - ds.run_tx(|tx| Box::pin(async move { tx.get_tasks().await })) + ds.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await })) .await .unwrap(), ); - assert_eq!( - want_tasks - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect::>(), - got_tasks - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect() - ); - assert_eq!( - want_tasks - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect::>(), - written_tasks - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect() - ); + assert_eq!(want_tasks, got_tasks); + assert_eq!(want_tasks, written_tasks); } #[tokio::test] @@ -634,12 +612,13 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = ephemeral_datastore.datastore(RealClock::default()).await; - let tasks = Vec::from([TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build()]); + let tasks = + Vec::from([ + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(), + ]); let written_tasks = run_provision_tasks_testcase(&ds, &tasks, true).await; @@ -647,7 +626,7 @@ mod tests { let written_tasks = task_hashmap_from_slice(written_tasks); assert_eq!(want_tasks, written_tasks); let got_tasks = task_hashmap_from_slice( - ds.run_tx(|tx| Box::pin(async move { tx.get_tasks().await })) + ds.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await })) .await .unwrap(), ); @@ -657,18 +636,14 @@ mod tests { #[tokio::test] async fn replace_task() { let tasks = Vec::from([ - TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(), - TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Sum { bits: 64 }, - Role::Helper, - ) - .build(), + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(), + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Sum { bits: 64 }) + .build() + .leader_view() + .unwrap(), ]); let ephemeral_datastore = ephemeral_datastore().await; @@ -693,10 +668,11 @@ mod tests { length: 4, chunk_length: 2, }, - Role::Leader, ) .with_id(*tasks[0].id()) - .build(); + .build() + .leader_view() + .unwrap(); let mut replacement_tasks_file = NamedTempFile::new().unwrap(); replacement_tasks_file @@ -716,25 +692,16 @@ mod tests { // Verify that the expected tasks were written. let got_tasks = task_hashmap_from_slice( - ds.run_tx(|tx| Box::pin(async move { tx.get_tasks().await })) + ds.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await })) .await .unwrap(), ); let want_tasks = HashMap::from([ - ( - *replacement_task.id(), - replacement_task.view_for_role().unwrap(), - ), - (*tasks[1].id(), tasks[1].view_for_role().unwrap()), + (*replacement_task.id(), replacement_task), + (*tasks[1].id(), tasks[1].clone()), ]); - assert_eq!( - want_tasks, - got_tasks - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect() - ); + assert_eq!(want_tasks, got_tasks); } #[tokio::test] @@ -742,8 +709,7 @@ mod tests { // YAML contains no task ID, VDAF verify keys, aggregator auth tokens, collector auth tokens // or HPKE keys. let serialized_task_yaml = r#" -- leader_aggregator_endpoint: https://leader - helper_aggregator_endpoint: https://helper +- peer_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 @@ -764,8 +730,7 @@ mod tests { aggregator_auth_token: collector_auth_token: hpke_keys: [] -- leader_aggregator_endpoint: https://leader - helper_aggregator_endpoint: https://helper +- peer_aggregator_endpoint: https://leader query_type: TimeInterval vdaf: !Prio3Sum bits: 2 @@ -822,7 +787,7 @@ mod tests { // Verify that the expected tasks were written. let got_tasks = ds - .run_tx(|tx| Box::pin(async move { tx.get_tasks().await })) + .run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await })) .await .unwrap(); @@ -837,14 +802,8 @@ mod tests { } assert_eq!( - task_hashmap_from_slice(written_tasks) - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect::>(), + task_hashmap_from_slice(written_tasks), task_hashmap_from_slice(got_tasks) - .iter() - .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) - .collect() ); } diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 8768c481d..d3a5b3e99 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -2257,7 +2257,8 @@ mod tests { #[test] fn deserialize_docs_sample_tasks() { - serde_yaml::from_str::>(include_str!("../../docs/samples/tasks.yaml")).unwrap(); + serde_yaml::from_str::>(include_str!("../../docs/samples/tasks.yaml")) + .unwrap(); } #[test] diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index 88d94a05a..44fe5ebff 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -6,9 +6,8 @@ # DAP's recommendation. task_id: "G9YKXjoEjfoU7M_fi_o2H0wmzavRb2sBFHeykeRhDMk" - # HTTPS endpoints of the leader and helper aggregators. - leader_aggregator_endpoint: "https://example.com/" - helper_aggregator_endpoint: "https://example.net/" + # HTTPS endpoint of the peer aggregator. + peer_aggregator_endpoint: "https://example.com/" # The DAP query type. See below for an example of a fixed-size task query_type: TimeInterval @@ -98,8 +97,7 @@ private_key: wFRYwiypcHC-mkGP1u3XQgIvtnlkQlUfZjgtM_zRsnI - task_id: "D-hCKPuqL2oTf7ZVRVyMP5VGt43EAEA8q34mDf6p1JE" - leader_aggregator_endpoint: "https://example.org/" - helper_aggregator_endpoint: "https://example.com/" + peer_aggregator_endpoint: "https://example.org/" # For tasks using the fixed size query type, an additional `max_batch_size` # parameter must be provided. query_type: !FixedSize