From e55cb258ca571955f0adfc796499ad312eaaade7 Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 28 Sep 2023 15:51:51 -0700 Subject: [PATCH 1/3] Task rewrite: adopt `AggregatorTask` in datastore Adopts `janus_aggregator_core::task::test_util::NewTaskBuilder` and `janus_aggregator_core::task::AggregatorTask` in the `janus_aggregator_core::datastore` module. Much as the previous change provides two kinds of `Task` structure, we now provide two sets of methods for reading and writing tasks: one that deals in the new `AggregatorTask` and the other which deals in the old `Task`. We add routines for converting between `task::Task` and `task::AggregatorTask` to make it easier for these two paths through the datastore to co-exist. This conversion is lossy because `AggregatorTask` only retains one of the aggregator endpoints, but this doesn't cause substantial problems in Janus, and we can live it transitionally. Finally, the SQL schema for tasks is changed so that only the peer aggregator's endpoint is stored. Part of #1524 --- aggregator/src/aggregator/http_handlers.rs | 2 +- aggregator/src/aggregator/taskprov_tests.rs | 4 +- aggregator_core/src/datastore.rs | 170 ++++--- aggregator_core/src/datastore/tests.rs | 533 ++++++++++---------- aggregator_core/src/task.rs | 149 ++++++ db/00000000000001_initial_schema.up.sql | 3 +- 6 files changed, 518 insertions(+), 343 deletions(-) diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 716a2355c..5c3c6b208 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -971,7 +971,7 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, VdafInstance::Prio3Count, - Role::Leader, + Role::Helper, ) .build(); let task_id = *task.id(); diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 772623eb7..39dcdb778 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -333,7 +333,9 @@ async fn taskprov_aggregate_init() { .state() .eq(&AggregationJobState::InProgress) ); - assert_eq!(test.task, got_task.unwrap()); + // TODO(#1524): This assertion temporarily just checks the task ID because of the lossy + // conversion between task::Task and task::AggregatorTask. + assert_eq!(test.task.id(), got_task.unwrap().id()); } #[tokio::test] diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 8ac957579..a6ae9ee9d 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -9,8 +9,8 @@ use self::models::{ }; use crate::{ query_type::{AccumulableQueryType, CollectableQueryType}, - task::{self, Task}, - taskprov::{self, PeerAggregator}, + task::{self, AggregatorTask, AggregatorTaskParameters, Task}, + taskprov::PeerAggregator, SecretBytes, }; use chrono::NaiveDateTime; @@ -306,6 +306,7 @@ impl Datastore { } /// Write a task into the datastore. + // TODO(#1524): remove this once everything has migrated to put_aggregator_task #[cfg(feature = "test-util")] #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] pub async fn put_task(&self, task: &Task) -> Result<(), Error> { @@ -315,6 +316,17 @@ impl Datastore { }) .await } + + /// Write a task into the datastore. + #[cfg(feature = "test-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] + pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> { + self.run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.put_aggregator_task(&task).await }) + }) + .await + } } fn check_error( @@ -525,20 +537,34 @@ impl Transaction<'_, C> { } /// Writes a task into the datastore. + // TODO(#1524): remove this once everything has migrated to put_aggregator_task #[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)] pub async fn put_task(&self, task: &Task) -> Result<(), Error> { + let aggregator_task = match task.role() { + Role::Leader => task.leader_view()?, + Role::Helper => task + .helper_view() + .or_else(|_| task.taskprov_helper_view())?, + _ => return Err(Error::InvalidParameter("role must be aggregator")), + }; + + self.put_aggregator_task(&aggregator_task).await + } + + /// Writes a task into the datastore. + #[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)] + pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> { // Main task insert. let stmt = self .prepare_cached( "INSERT INTO tasks ( - task_id, aggregator_role, leader_aggregator_endpoint, - helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, - task_expiration, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, + task_id, aggregator_role, peer_aggregator_endpoint, query_type, vdaf, + max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, + time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type, collector_auth_token) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 ) ON CONFLICT DO NOTHING", ) @@ -549,10 +575,8 @@ impl Transaction<'_, C> { &[ /* task_id */ &task.id().as_ref(), /* aggregator_role */ &AggregatorRole::from_role(*task.role())?, - /* leader_aggregator_endpoint */ - &task.leader_aggregator_endpoint().as_str(), - /* helper_aggregator_endpoint */ - &task.helper_aggregator_endpoint().as_str(), + /* peer_aggregator_endpoint */ + &task.peer_aggregator_endpoint().as_str(), /* query_type */ &Json(task.query_type()), /* vdaf */ &Json(task.vdaf()), /* max_batch_query_count */ @@ -574,9 +598,7 @@ impl Transaction<'_, C> { /* tolerable_clock_skew */ &i64::try_from(task.tolerable_clock_skew().as_seconds())?, /* collector_hpke_config */ - &task - .collector_hpke_config() - .map(|config| config.get_encoded()), + &task.collector_hpke_config().map(|cfg| cfg.get_encoded()), /* vdaf_verify_key */ &self.crypter.encrypt( "tasks", @@ -625,6 +647,7 @@ impl Transaction<'_, C> { let mut hpke_config_ids: Vec = Vec::new(); let mut hpke_configs: Vec> = Vec::new(); let mut hpke_private_keys: Vec> = Vec::new(); + for hpke_keypair in task.hpke_keys().values() { let mut row_id = [0u8; TaskId::LEN + size_of::()]; row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); @@ -677,16 +700,26 @@ impl Transaction<'_, C> { } /// Fetch the task parameters corresponing to the provided `task_id`. + // TODO(#1524): remove this once everything has migrated to get_aggregator_task #[tracing::instrument(skip(self), err)] pub async fn get_task(&self, task_id: &TaskId) -> Result, Error> { + Ok(self.get_aggregator_task(task_id).await?.map(Task::from)) + } + + /// Fetch the task parameters corresponing to the provided `task_id`. + #[tracing::instrument(skip(self), err)] + pub async fn get_aggregator_task( + &self, + task_id: &TaskId, + ) -> Result, Error> { let params: &[&(dyn ToSql + Sync)] = &[&task_id.as_ref()]; let stmt = self .prepare_cached( - "SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, - query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, - min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config, - vdaf_verify_key, aggregator_auth_token_type, aggregator_auth_token, - collector_auth_token_type, collector_auth_token + "SELECT aggregator_role, peer_aggregator_endpoint, query_type, vdaf, + max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, + time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, + aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type, + collector_auth_token FROM tasks WHERE task_id = $1", ) .await?; @@ -707,14 +740,25 @@ impl Transaction<'_, C> { } /// Fetch all the tasks in the database. + // TODO(#1524): remove this once everything has migrated to get_aggregator_tasks #[tracing::instrument(skip(self), err)] pub async fn get_tasks(&self) -> Result, Error> { + Ok(self + .get_aggregator_tasks() + .await? + .into_iter() + .map(Task::from) + .collect()) + } + + /// Fetch all the tasks in the database. + #[tracing::instrument(skip(self), err)] + pub async fn get_aggregator_tasks(&self) -> Result, Error> { let stmt = self .prepare_cached( - "SELECT task_id, aggregator_role, leader_aggregator_endpoint, - helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, - task_expiration, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, + "SELECT task_id, aggregator_role, peer_aggregator_endpoint, query_type, vdaf, + max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, + time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type, collector_auth_token FROM tasks", @@ -768,13 +812,10 @@ impl Transaction<'_, C> { task_id: &TaskId, row: &Row, hpke_key_rows: &[Row], - ) -> Result { + ) -> Result { // Scalar task parameters. let aggregator_role: AggregatorRole = row.get("aggregator_role"); - let leader_aggregator_endpoint = - row.get::<_, String>("leader_aggregator_endpoint").parse()?; - let helper_aggregator_endpoint = - row.get::<_, String>("helper_aggregator_endpoint").parse()?; + let peer_aggregator_endpoint = row.get::<_, String>("peer_aggregator_endpoint").parse()?; let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?; @@ -831,7 +872,7 @@ impl Transaction<'_, C> { .transpose()?; // HPKE keys. - let mut hpke_keypairs = Vec::new(); + let mut hpke_keys = Vec::new(); for row in hpke_key_rows { let config_id = u8::try_from(row.get::<_, i16>("config_id"))?; let config = HpkeConfig::get_decoded(row.get("config"))?; @@ -848,16 +889,47 @@ impl Transaction<'_, C> { &encrypted_private_key, )?); - hpke_keypairs.push(HpkeKeypair::new(config, private_key)); + hpke_keys.push(HpkeKeypair::new(config, private_key)); } - let task = Task::new_without_validation( + let aggregator_parameters = match ( + aggregator_role, + aggregator_auth_token, + collector_auth_token, + collector_hpke_config, + ) { + ( + AggregatorRole::Leader, + Some(aggregator_auth_token), + Some(collector_auth_token), + Some(collector_hpke_config), + ) => AggregatorTaskParameters::Leader { + aggregator_auth_token, + collector_auth_token, + collector_hpke_config, + }, + ( + AggregatorRole::Helper, + Some(aggregator_auth_token), + None, + Some(collector_hpke_config), + ) => AggregatorTaskParameters::Helper { + aggregator_auth_token, + collector_hpke_config, + }, + (AggregatorRole::Helper, None, None, None) => AggregatorTaskParameters::TaskProvHelper, + values => { + return Err(Error::DbState(format!( + "found task row with unexpected combination of values {values:?}", + ))); + } + }; + + Ok(AggregatorTask::new( *task_id, - leader_aggregator_endpoint, - helper_aggregator_endpoint, + peer_aggregator_endpoint, query_type, vdaf, - aggregator_role.as_role(), vdaf_verify_key, max_batch_query_count, task_expiration, @@ -865,33 +937,9 @@ impl Transaction<'_, C> { min_batch_size, time_precision, tolerable_clock_skew, - collector_hpke_config, - aggregator_auth_token, - collector_auth_token, - hpke_keypairs, - ); - // Trial validation through all known schemes. This is a workaround to avoid extending the - // schema to track the provenance of tasks. If we do end up implementing a task provenance - // column anyways, we can simplify this logic. - task.validate().or_else(|error| { - taskprov::Task(task.clone()) - .validate() - .map_err(|taskprov_error| { - error!( - %task_id, - %error, - %taskprov_error, - ?task, - "task has failed all available validation checks", - ); - // Choose some error to bubble up to the caller. Either way this error - // occurring is an indication of a bug, which we'll need to go into the - // logs for. - error - }) - })?; - - Ok(task) + hpke_keys, + aggregator_parameters, + )?) } /// Retrieves report & report aggregation metrics for a given task: either a tuple diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index e03ddd5ae..aea662d65 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -11,7 +11,7 @@ use crate::{ Crypter, Datastore, Error, Transaction, SUPPORTED_SCHEMA_VERSIONS, }, query_type::CollectableQueryType, - task::{self, test_util::TaskBuilder, Task}, + task::{self, test_util::NewTaskBuilder as TaskBuilder, AggregatorTask, Task}, taskprov::test_util::PeerAggregatorBuilder, test_util::noop_meter, }; @@ -143,9 +143,11 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { (VdafInstance::Poplar1 { bits: 8 }, Role::Helper), (VdafInstance::Poplar1 { bits: 64 }, Role::Helper), ] { - let task = TaskBuilder::new(task::QueryType::TimeInterval, vdaf, role) + let task = TaskBuilder::new(task::QueryType::TimeInterval, vdaf) .with_report_expiry_age(Some(Duration::from_seconds(3600))) - .build(); + .build() + .view_for_role(role) + .unwrap(); want_tasks.insert(*task.id(), task.clone()); let err = ds @@ -160,18 +162,18 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { let retrieved_task = ds .run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.get_task(task.id()).await }) + Box::pin(async move { tx.get_aggregator_task(task.id()).await }) }) .await .unwrap(); assert_eq!(None, retrieved_task); - ds.put_task(&task).await.unwrap(); + ds.put_aggregator_task(&task).await.unwrap(); let retrieved_task = ds .run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.get_task(task.id()).await }) + Box::pin(async move { tx.get_aggregator_task(task.id()).await }) }) .await .unwrap(); @@ -187,7 +189,7 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { let retrieved_task = ds .run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.get_task(task.id()).await }) + Box::pin(async move { tx.get_aggregator_task(task.id()).await }) }) .await .unwrap(); @@ -205,20 +207,20 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { // Rewrite & retrieve the task again, to test that the delete is "clean" in the sense // that it deletes all task-related data (& therefore does not conflict with a later // write to the same task_id). - ds.put_task(&task).await.unwrap(); + ds.put_aggregator_task(&task).await.unwrap(); let retrieved_task = ds .run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.get_task(task.id()).await }) + Box::pin(async move { tx.get_aggregator_task(task.id()).await }) }) .await .unwrap(); assert_eq!(Some(task), retrieved_task); } - let got_tasks: HashMap = ds - .run_tx(|tx| Box::pin(async move { tx.get_tasks().await })) + let got_tasks: HashMap = ds + .run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await })) .await .unwrap() .into_iter() @@ -233,14 +235,12 @@ async fn put_task_invalid_aggregator_auth_tokens(ephemeral_datastore: EphemeralD install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(); - ds.put_task(&task).await.unwrap(); + ds.put_aggregator_task(&task).await.unwrap(); for (auth_token, auth_token_type) in [("NULL", "'BEARER'"), ("'\\xDEADBEEF'::bytea", "NULL")] { ds.run_tx(|tx| { @@ -274,14 +274,12 @@ async fn put_task_invalid_collector_auth_tokens(ephemeral_datastore: EphemeralDa install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(); - ds.put_task(&task).await.unwrap(); + ds.put_aggregator_task(&task).await.unwrap(); for (auth_token, auth_token_type) in [("NULL", "'BEARER'"), ("'\\xDEADBEEF'::bytea", "NULL")] { ds.run_tx(|tx| { @@ -323,19 +321,16 @@ async fn get_task_metrics(ephemeral_datastore: EphemeralDatastore) { let task_id = ds .run_tx(|tx| { Box::pin(async move { - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let other_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let other_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); let reports: Vec<_> = iter::repeat_with(|| { LeaderStoredReport::new_dummy(*task.id(), OLDEST_ALLOWED_REPORT_TIMESTAMP) @@ -458,8 +453,8 @@ async fn get_task_metrics(ephemeral_datastore: EphemeralDatastore) { }) .collect(); - tx.put_task(&task).await?; - tx.put_task(&other_task).await?; + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&other_task).await?; try_join_all( reports .iter() @@ -524,20 +519,18 @@ async fn get_task_ids(ephemeral_datastore: EphemeralDatastore) { Box::pin(async move { const TOTAL_TASK_ID_COUNT: usize = 20; let tasks: Vec<_> = iter::repeat_with(|| { - TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build() + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap() }) .take(TOTAL_TASK_ID_COUNT) .collect(); - let mut task_ids: Vec<_> = tasks.iter().map(Task::id).cloned().collect(); + let mut task_ids: Vec<_> = tasks.iter().map(AggregatorTask::id).cloned().collect(); task_ids.sort(); - try_join_all(tasks.iter().map(|task| tx.put_task(task))).await?; + try_join_all(tasks.iter().map(|task| tx.put_aggregator_task(task))).await?; for (i, lower_bound) in iter::once(None) .chain(task_ids.iter().cloned().map(Some)) @@ -565,17 +558,15 @@ async fn roundtrip_report(ephemeral_datastore: EphemeralDatastore) { .difference(&OLDEST_ALLOWED_REPORT_TIMESTAMP) .unwrap(); - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(report_expiry_age)) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(report_expiry_age)) + .build() + .leader_view() + .unwrap(); ds.run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.put_task(&task).await }) + Box::pin(async move { tx.put_aggregator_task(&task).await }) }) .await .unwrap(); @@ -711,19 +702,15 @@ async fn get_unaggregated_client_report_ids_for_task(ephemeral_datastore: Epheme Duration::from_seconds(2), ) .unwrap(); - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let unrelated_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let unrelated_task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(); let first_unaggregated_report = LeaderStoredReport::new_dummy(*task.id(), OLDEST_ALLOWED_REPORT_TIMESTAMP); @@ -751,8 +738,8 @@ async fn get_unaggregated_client_report_ids_for_task(ephemeral_datastore: Epheme let unrelated_report = unrelated_report.clone(); Box::pin(async move { - tx.put_task(&task).await?; - tx.put_task(&unrelated_task).await?; + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&unrelated_task).await?; tx.put_client_report(&dummy_vdaf::Vdaf::new(), &first_unaggregated_report) .await?; @@ -879,25 +866,19 @@ async fn count_client_reports_for_interval(ephemeral_datastore: EphemeralDatasto let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let unrelated_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build(); - let no_reports_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let unrelated_task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); + let no_reports_task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); let expired_report_in_interval = LeaderStoredReport::new_dummy( *task.id(), @@ -934,9 +915,9 @@ async fn count_client_reports_for_interval(ephemeral_datastore: EphemeralDatasto let report_for_other_task = report_for_other_task.clone(); Box::pin(async move { - tx.put_task(&task).await?; - tx.put_task(&unrelated_task).await?; - tx.put_task(&no_reports_task).await?; + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&unrelated_task).await?; + tx.put_aggregator_task(&no_reports_task).await?; tx.put_client_report(&dummy_vdaf::Vdaf::new(), &expired_report_in_interval) .await?; @@ -1011,19 +992,21 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .leader_view() + .unwrap(); let unrelated_task = TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) - .build(); + .build() + .leader_view() + .unwrap(); // Set up state. let batch_id = ds @@ -1031,8 +1014,8 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto let (task, unrelated_task) = (task.clone(), unrelated_task.clone()); Box::pin(async move { - tx.put_task(&task).await?; - tx.put_task(&unrelated_task).await?; + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&unrelated_task).await?; // Create a batch for the first task containing two reports, which has started // aggregation twice with two different aggregation parameters. @@ -1190,12 +1173,10 @@ async fn roundtrip_report_share(ephemeral_datastore: EphemeralDatastore) { install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(); let report_share = ReportShare::new( ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), @@ -1212,7 +1193,7 @@ async fn roundtrip_report_share(ephemeral_datastore: EphemeralDatastore) { ds.run_tx(|tx| { let (task, report_share) = (task.clone(), report_share.clone()); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; tx.put_report_share(task.id(), &report_share).await?; Ok(()) @@ -1296,10 +1277,11 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) { batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .leader_view() + .unwrap(); let batch_id = random(); let leader_aggregation_job = AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( *task.id(), @@ -1327,7 +1309,7 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) { helper_aggregation_job.clone(), ); Box::pin(async move { - tx.put_task(&task).await.unwrap(); + tx.put_aggregator_task(&task).await.unwrap(); tx.put_aggregation_job(&leader_aggregation_job) .await .unwrap(); @@ -1507,13 +1489,11 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); const AGGREGATION_JOB_COUNT: usize = 10; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); let mut aggregation_job_ids: Vec<_> = thread_rng() .sample_iter(Standard) .take(AGGREGATION_JOB_COUNT) @@ -1525,7 +1505,7 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore Box::pin(async move { // Write a few aggregation jobs we expect to be able to retrieve with // acquire_incomplete_aggregation_jobs(). - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; try_join_all(aggregation_job_ids.into_iter().map(|aggregation_job_id| { let task_id = *task.id(); async move { @@ -1587,13 +1567,12 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore // Write an aggregation job for a task that we are taking on the helper role for. // We don't want to retrieve this one, either. - let helper_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Helper, - ) - .build(); - tx.put_task(&helper_task).await?; + let helper_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .helper_view() + .unwrap(); + tx.put_aggregator_task(&helper_task).await?; tx.put_aggregation_job( &AggregationJob::::new( *helper_task.id(), @@ -1892,9 +1871,10 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) - .build(); + .build() + .leader_view() + .unwrap(); let first_aggregation_job = AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( *task.id(), random(), @@ -1933,7 +1913,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) ds.run_tx(|tx| { let (task, want_agg_jobs) = (task.clone(), want_agg_jobs.clone()); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; for agg_job in want_agg_jobs { tx.put_aggregation_job(&agg_job).await.unwrap(); @@ -1947,10 +1927,11 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) - .build(); - tx.put_task(&unrelated_task).await?; + .build() + .leader_view() + .unwrap(); + tx.put_aggregator_task(&unrelated_task).await?; tx.put_aggregation_job(&AggregationJob::<0, FixedSize, dummy_vdaf::Vdaf>::new( *unrelated_task.id(), random(), @@ -2041,10 +2022,11 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::QueryType::TimeInterval, VdafInstance::Poplar1 { bits: 1 }, - role, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .view_for_role(role) + .unwrap(); let aggregation_job_id = random(); let report_id = random(); @@ -2053,7 +2035,7 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let (task, state, aggregation_param) = (task.clone(), state.clone(), aggregation_param.clone()); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, TimeInterval, @@ -2210,15 +2192,13 @@ async fn check_other_report_aggregation_exists(ephemeral_datastore: EphemeralDat let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Helper, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .helper_view() + .unwrap(); - ds.put_task(&task).await.unwrap(); + ds.put_aggregator_task(&task).await.unwrap(); let aggregation_job_id = random(); let report_id = random(); @@ -2427,10 +2407,11 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let task = TaskBuilder::new( task::QueryType::TimeInterval, VdafInstance::Poplar1 { bits: 1 }, - Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .helper_view() + .unwrap(); let aggregation_job_id = random(); let want_report_aggregations = ds @@ -2441,7 +2422,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme aggregation_param.clone(), ); Box::pin(async move { - tx.put_task(&task).await.unwrap(); + tx.put_aggregator_task(&task).await.unwrap(); tx.put_aggregation_job(&AggregationJob::< VERIFY_KEY_LENGTH, @@ -2598,13 +2579,11 @@ async fn get_collection_job(ephemeral_datastore: EphemeralDatastore) { let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); let first_batch_interval = Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(100)).unwrap(); let second_batch_interval = Interval::new( @@ -2620,7 +2599,7 @@ async fn get_collection_job(ephemeral_datastore: EphemeralDatastore) { .run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.put_task(&task).await.unwrap(); + tx.put_aggregator_task(&task).await.unwrap(); let first_collection_job = CollectionJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( *task.id(), @@ -2767,12 +2746,10 @@ async fn update_collection_jobs(ephemeral_datastore: EphemeralDatastore) { let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); let abandoned_batch_interval = Interval::new( Time::from_seconds_since_epoch(100), Duration::from_seconds(100), @@ -2787,7 +2764,7 @@ async fn update_collection_jobs(ephemeral_datastore: EphemeralDatastore) { ds.run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; let vdaf = dummy_vdaf::Vdaf::new(); let aggregation_param = AggregationParam(10); @@ -3003,10 +2980,12 @@ async fn setup_collection_job_acquire_test_case( let mut test_case = test_case.clone(); Box::pin(async move { for task_id in &test_case.task_ids { - tx.put_task( - &TaskBuilder::new(test_case.query_type, VdafInstance::Fake, Role::Leader) + tx.put_aggregator_task( + &TaskBuilder::new(test_case.query_type, VdafInstance::Fake) .with_id(*task_id) - .build(), + .build() + .leader_view() + .unwrap(), ) .await?; } @@ -3988,20 +3967,16 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera let ds = ephemeral_datastore.datastore(clock.clone()).await; let time_precision = Duration::from_seconds(100); - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_time_precision(time_precision) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let other_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_time_precision(time_precision) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let other_task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); let aggregate_share = AggregateShare(23); let aggregation_param = AggregationParam(12); @@ -4011,8 +3986,8 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera let other_task = other_task.clone(); Box::pin(async move { - tx.put_task(&task).await?; - tx.put_task(&other_task).await?; + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&other_task).await?; for when in [1000, 1100, 1200, 1300, 1400] { tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new( @@ -4203,7 +4178,7 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera _, >( tx, - &task, + &Task::from(task.clone()), &vdaf, &Interval::new( Time::from_seconds_since_epoch(1100), @@ -4248,7 +4223,7 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera _, >( tx, - &task, + &Task::from(task), &vdaf, &Interval::new( Time::from_seconds_since_epoch(1100), @@ -4292,7 +4267,7 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera _, >( tx, - &task, + &Task::from(task), &vdaf, &Interval::new( Time::from_seconds_since_epoch(1100), @@ -4326,10 +4301,11 @@ async fn roundtrip_batch_aggregation_fixed_size(ephemeral_datastore: EphemeralDa batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .leader_view() + .unwrap(); let batch_id = random(); let aggregate_share = AggregateShare(23); let aggregation_param = AggregationParam(12); @@ -4343,12 +4319,13 @@ async fn roundtrip_batch_aggregation_fixed_size(ephemeral_datastore: EphemeralDa batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) - .build(); + .build() + .leader_view() + .unwrap(); - tx.put_task(&task).await?; - tx.put_task(&other_task).await?; + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&other_task).await?; tx.put_batch(&Batch::<0, FixedSize, dummy_vdaf::Vdaf>::new( *task.id(), @@ -4542,14 +4519,12 @@ async fn roundtrip_aggregate_share_job_time_interval(ephemeral_datastore: Epheme let aggregate_share_job = ds .run_tx(|tx| { Box::pin(async move { - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Helper, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - tx.put_task(&task).await?; + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .helper_view() + .unwrap(); + tx.put_aggregator_task(&task).await?; tx.put_batch(&Batch::<0, TimeInterval, dummy_vdaf::Vdaf>::new( *task.id(), @@ -4729,11 +4704,12 @@ async fn roundtrip_aggregate_share_job_fixed_size(ephemeral_datastore: Ephemeral batch_time_window_size: None, }, VdafInstance::Fake, - Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - tx.put_task(&task).await?; + .build() + .helper_view() + .unwrap(); + tx.put_aggregator_task(&task).await?; let batch_id = random(); tx.put_batch(&Batch::<0, FixedSize, dummy_vdaf::Vdaf>::new( @@ -4877,11 +4853,12 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - tx.put_task(&task_1).await?; + .build() + .leader_view() + .unwrap(); + tx.put_aggregator_task(&task_1).await?; let batch_id_1 = random(); tx.put_batch(&Batch::<0, FixedSize, dummy_vdaf::Vdaf>::new( @@ -4903,11 +4880,12 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { batch_time_window_size: Some(batch_time_window_size), }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - tx.put_task(&task_2).await?; + .build() + .leader_view() + .unwrap(); + tx.put_aggregator_task(&task_2).await?; let batch_id_2 = random(); tx.put_batch(&Batch::<0, FixedSize, dummy_vdaf::Vdaf>::new( @@ -5209,18 +5187,19 @@ async fn roundtrip_batch(ephemeral_datastore: EphemeralDatastore) { ds.run_tx(|tx| { let want_batch = want_batch.clone(); Box::pin(async move { - tx.put_task( + tx.put_aggregator_task( &TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_id(*want_batch.task_id()) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(), + .build() + .leader_view() + .unwrap(), ) .await?; tx.put_batch(&want_batch).await?; @@ -5341,21 +5320,18 @@ async fn delete_expired_client_reports(ephemeral_datastore: EphemeralDatastore) let (task_id, new_report_id, other_task_id, other_task_report_id) = ds .run_tx(|tx| { Box::pin(async move { - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(report_expiry_age)) - .build(); - let other_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .build(); - tx.put_task(&task).await?; - tx.put_task(&other_task).await?; + let task = TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(report_expiry_age)) + .build() + .leader_view() + .unwrap(); + let other_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); + tx.put_aggregator_task(&task).await?; + tx.put_aggregator_task(&other_task).await?; let old_report = LeaderStoredReport::new_dummy( *task.id(), @@ -5510,44 +5486,44 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData ) = ds .run_tx(|tx| { Box::pin(async move { - let leader_time_interval_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let helper_time_interval_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Helper, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + let leader_time_interval_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let helper_time_interval_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .helper_view() + .unwrap(); let leader_fixed_size_task = TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .helper_view() + .unwrap(); let helper_fixed_size_task = TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, VdafInstance::Fake, - Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - tx.put_task(&leader_time_interval_task).await?; - tx.put_task(&helper_time_interval_task).await?; - tx.put_task(&leader_fixed_size_task).await?; - tx.put_task(&helper_fixed_size_task).await?; + .build() + .helper_view() + .unwrap(); + tx.put_aggregator_task(&leader_time_interval_task).await?; + tx.put_aggregator_task(&helper_time_interval_task).await?; + tx.put_aggregator_task(&leader_fixed_size_task).await?; + tx.put_aggregator_task(&helper_fixed_size_task).await?; let mut aggregation_job_ids = HashSet::new(); let mut all_report_ids = HashSet::new(); @@ -5900,7 +5876,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas // Setup. async fn write_collect_artifacts( tx: &Transaction<'_, MockClock>, - task: &Task, + task: &AggregatorTask, client_timestamps: &[Time], ) -> ( Option, // collection job ID @@ -6028,63 +6004,64 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas ) = ds .run_tx(|tx| { Box::pin(async move { - let leader_time_interval_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let helper_time_interval_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Helper, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + let leader_time_interval_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let helper_time_interval_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .helper_view() + .unwrap(); let leader_fixed_size_task = TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .leader_view() + .unwrap(); let helper_fixed_size_task = TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: None, }, VdafInstance::Fake, - Role::Helper, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + .build() + .helper_view() + .unwrap(); let leader_fixed_size_time_bucketed_task = TaskBuilder::new( task::QueryType::FixedSize { max_batch_size: 10, batch_time_window_size: Some(Duration::from_hours(24)?), }, VdafInstance::Fake, - Role::Leader, - ) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - let other_task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Fake, - Role::Leader, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); - tx.put_task(&leader_time_interval_task).await?; - tx.put_task(&helper_time_interval_task).await?; - tx.put_task(&leader_fixed_size_task).await?; - tx.put_task(&helper_fixed_size_task).await?; - tx.put_task(&leader_fixed_size_time_bucketed_task).await?; - tx.put_task(&other_task).await?; + .build() + .leader_view() + .unwrap(); + let other_task = + TaskBuilder::new(task::QueryType::TimeInterval, VdafInstance::Fake) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + tx.put_aggregator_task(&leader_time_interval_task).await?; + tx.put_aggregator_task(&helper_time_interval_task).await?; + tx.put_aggregator_task(&leader_fixed_size_task).await?; + tx.put_aggregator_task(&helper_fixed_size_task).await?; + tx.put_aggregator_task(&leader_fixed_size_time_bucketed_task) + .await?; + tx.put_aggregator_task(&other_task).await?; let mut collection_job_ids = HashSet::new(); let mut aggregate_share_job_ids = HashSet::new(); diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index ec781e517..6252e2712 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -479,6 +479,155 @@ impl Task { self.tasks_path() ))?) } + + /// Render the leader aggregator's view of this task. + pub fn leader_view(&self) -> Result { + AggregatorTask::new( + self.task_id, + self.helper_aggregator_endpoint.clone(), + self.query_type, + self.vdaf.clone(), + self.vdaf_verify_key.clone(), + self.max_batch_query_count, + self.task_expiration, + self.report_expiry_age, + self.min_batch_size, + self.time_precision, + self.tolerable_clock_skew, + self.hpke_keys.values().cloned().collect::>(), + AggregatorTaskParameters::Leader { + aggregator_auth_token: self + .aggregator_auth_token + .clone() + .ok_or_else(|| Error::InvalidParameter("no aggregator auth token in task"))?, + collector_auth_token: self + .collector_auth_token + .clone() + .ok_or_else(|| Error::InvalidParameter("no collector auth token in task"))?, + collector_hpke_config: self + .collector_hpke_config + .clone() + .ok_or_else(|| Error::InvalidParameter("no collector HPKE config in task"))?, + }, + ) + } + + /// Render the helper aggregator's view of this task. + pub fn helper_view(&self) -> Result { + AggregatorTask::new( + self.task_id, + self.helper_aggregator_endpoint.clone(), + self.query_type, + self.vdaf.clone(), + self.vdaf_verify_key.clone(), + self.max_batch_query_count, + self.task_expiration, + self.report_expiry_age, + self.min_batch_size, + self.time_precision, + self.tolerable_clock_skew, + self.hpke_keys.values().cloned().collect::>(), + AggregatorTaskParameters::Helper { + aggregator_auth_token: self + .aggregator_auth_token + .clone() + .ok_or_else(|| Error::InvalidParameter("no aggregator auth token in task"))?, + collector_hpke_config: self + .collector_hpke_config + .clone() + .ok_or_else(|| Error::InvalidParameter("no collector HPKE config in task"))?, + }, + ) + } + + /// Render a taskprov helper aggregator's view of this task. + pub fn taskprov_helper_view(&self) -> Result { + AggregatorTask::new( + self.task_id, + self.helper_aggregator_endpoint.clone(), + self.query_type, + self.vdaf.clone(), + self.vdaf_verify_key.clone(), + self.max_batch_query_count, + self.task_expiration, + self.report_expiry_age, + self.min_batch_size, + self.time_precision, + self.tolerable_clock_skew, + self.hpke_keys.values().cloned().collect::>(), + AggregatorTaskParameters::TaskProvHelper, + ) + } +} + +impl From for Task { + fn from(aggregator_task: AggregatorTask) -> Self { + // An `AggregatorTask` only contains the other aggregator's URL, which means we can't + // accurately set the own endpoint URL. However that value is never used, and in most cases + // will have been the fake value set in the aggregator API. + // unwrap safety: we know this URL is valid + let fake_aggregator_url = Url::parse("http://never-used.example.com").unwrap(); + let ( + role, + leader_aggregator_endpoint, + helper_aggregator_endpoint, + aggregator_auth_token, + collector_auth_token, + collector_hpke_config, + ) = match aggregator_task.aggregator_parameters() { + AggregatorTaskParameters::Leader { + aggregator_auth_token, + collector_auth_token, + collector_hpke_config, + } => ( + Role::Leader, + fake_aggregator_url, + aggregator_task.peer_aggregator_endpoint.clone(), + Some(aggregator_auth_token.clone()), + Some(collector_auth_token.clone()), + Some(collector_hpke_config.clone()), + ), + AggregatorTaskParameters::Helper { + aggregator_auth_token, + collector_hpke_config, + } => ( + Role::Helper, + aggregator_task.peer_aggregator_endpoint.clone(), + fake_aggregator_url, + Some(aggregator_auth_token.clone()), + None, + Some(collector_hpke_config.clone()), + ), + AggregatorTaskParameters::TaskProvHelper => ( + Role::Helper, + aggregator_task.peer_aggregator_endpoint.clone(), + fake_aggregator_url, + None, + None, + None, + ), + }; + + Self { + task_id: *aggregator_task.id(), + leader_aggregator_endpoint, + helper_aggregator_endpoint, + query_type: aggregator_task.query_type().clone(), + vdaf: aggregator_task.vdaf().clone(), + role, + vdaf_verify_key: aggregator_task.opaque_vdaf_verify_key().clone(), + max_batch_query_count: aggregator_task.max_batch_query_count(), + task_expiration: aggregator_task.task_expiration().cloned(), + report_expiry_age: aggregator_task.report_expiry_age().cloned(), + min_batch_size: aggregator_task.min_batch_size(), + time_precision: *aggregator_task.time_precision(), + tolerable_clock_skew: *aggregator_task.tolerable_clock_skew(), + collector_hpke_config, + aggregator_auth_token, + collector_auth_token, + hpke_keys: aggregator_task.hpke_keys().clone(), + } + } } /// Task parameters common to all views of a DAP task. diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index cffe33b79..dad14e3b4 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -78,8 +78,7 @@ CREATE TABLE tasks( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task - leader_aggregator_endpoint TEXT NOT NULL, -- Leader's API endpoint - helper_aggregator_endpoint TEXT NOT NULL, -- Helper's API endpoint + peer_aggregator_endpoint TEXT NOT NULL, -- peer aggregator's API endpoint query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected From 52595aba7fcef97a00135ab09f6e21b90173647e Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 28 Sep 2023 16:21:13 -0700 Subject: [PATCH 2/3] clippy --- aggregator_core/src/task.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 6252e2712..784fe129b 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -499,15 +499,15 @@ impl Task { aggregator_auth_token: self .aggregator_auth_token .clone() - .ok_or_else(|| Error::InvalidParameter("no aggregator auth token in task"))?, + .ok_or(Error::InvalidParameter("no aggregator auth token in task"))?, collector_auth_token: self .collector_auth_token .clone() - .ok_or_else(|| Error::InvalidParameter("no collector auth token in task"))?, + .ok_or(Error::InvalidParameter("no collector auth token in task"))?, collector_hpke_config: self .collector_hpke_config .clone() - .ok_or_else(|| Error::InvalidParameter("no collector HPKE config in task"))?, + .ok_or(Error::InvalidParameter("no collector HPKE config in task"))?, }, ) } @@ -531,11 +531,11 @@ impl Task { aggregator_auth_token: self .aggregator_auth_token .clone() - .ok_or_else(|| Error::InvalidParameter("no aggregator auth token in task"))?, + .ok_or(Error::InvalidParameter("no aggregator auth token in task"))?, collector_hpke_config: self .collector_hpke_config .clone() - .ok_or_else(|| Error::InvalidParameter("no collector HPKE config in task"))?, + .ok_or(Error::InvalidParameter("no collector HPKE config in task"))?, }, ) } @@ -612,7 +612,7 @@ impl From for Task { task_id: *aggregator_task.id(), leader_aggregator_endpoint, helper_aggregator_endpoint, - query_type: aggregator_task.query_type().clone(), + query_type: *aggregator_task.query_type(), vdaf: aggregator_task.vdaf().clone(), role, vdaf_verify_key: aggregator_task.opaque_vdaf_verify_key().clone(), From 0bf9101f40149a2870e2891a42d412d2083cfce2 Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 28 Sep 2023 17:35:09 -0700 Subject: [PATCH 3/3] improve/remove test workarounds --- aggregator/src/aggregator/taskprov_tests.rs | 6 +-- aggregator/src/bin/janus_cli.rs | 45 ++++++++++++++++++--- aggregator_core/src/task.rs | 17 +++++++- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 39dcdb778..62738ac3d 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -317,7 +317,7 @@ async fn taskprov_aggregate_init() { tx.get_aggregation_jobs_for_task::<16, FixedSize, TestVdaf>(&task_id) .await .unwrap(), - tx.get_task(&task_id).await.unwrap(), + tx.get_aggregator_task(&task_id).await.unwrap(), )) }) }) @@ -333,9 +333,7 @@ async fn taskprov_aggregate_init() { .state() .eq(&AggregationJobState::InProgress) ); - // TODO(#1524): This assertion temporarily just checks the task ID because of the lossy - // conversion between task::Task and task::AggregatorTask. - assert_eq!(test.task.id(), got_task.unwrap().id()); + assert_eq!(test.task.taskprov_helper_view().unwrap(), got_task.unwrap()); } #[tokio::test] diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index c8fa7fd44..3bcd86ee7 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -607,8 +607,26 @@ mod tests { .await .unwrap(), ); - assert_eq!(want_tasks, got_tasks); - assert_eq!(want_tasks, written_tasks); + assert_eq!( + want_tasks + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect::>(), + got_tasks + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect() + ); + assert_eq!( + want_tasks + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect::>(), + written_tasks + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect() + ); } #[tokio::test] @@ -703,11 +721,20 @@ mod tests { .unwrap(), ); let want_tasks = HashMap::from([ - (*replacement_task.id(), replacement_task), - (*tasks[1].id(), tasks[1].clone()), + ( + *replacement_task.id(), + replacement_task.view_for_role().unwrap(), + ), + (*tasks[1].id(), tasks[1].view_for_role().unwrap()), ]); - assert_eq!(want_tasks, got_tasks); + assert_eq!( + want_tasks, + got_tasks + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect() + ); } #[tokio::test] @@ -810,8 +837,14 @@ mod tests { } assert_eq!( - task_hashmap_from_slice(written_tasks), + task_hashmap_from_slice(written_tasks) + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect::>(), task_hashmap_from_slice(got_tasks) + .iter() + .map(|(k, v)| { (*k, v.view_for_role().unwrap()) }) + .collect() ); } diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 784fe129b..f77a6b4f6 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -516,7 +516,7 @@ impl Task { pub fn helper_view(&self) -> Result { AggregatorTask::new( self.task_id, - self.helper_aggregator_endpoint.clone(), + self.leader_aggregator_endpoint.clone(), self.query_type, self.vdaf.clone(), self.vdaf_verify_key.clone(), @@ -544,7 +544,7 @@ impl Task { pub fn taskprov_helper_view(&self) -> Result { AggregatorTask::new( self.task_id, - self.helper_aggregator_endpoint.clone(), + self.leader_aggregator_endpoint.clone(), self.query_type, self.vdaf.clone(), self.vdaf_verify_key.clone(), @@ -558,6 +558,19 @@ impl Task { AggregatorTaskParameters::TaskProvHelper, ) } + + /// Render the view of the specified aggregator of this task. + /// + /// # Errors + /// + /// Returns an error if `self.role` is not an aggregator role. + pub fn view_for_role(&self) -> Result { + match self.role { + Role::Leader => self.leader_view(), + Role::Helper => self.helper_view().or_else(|_| self.taskprov_helper_view()), + _ => Err(Error::InvalidParameter("role is not an aggregator")), + } + } } impl From for Task {