From 9719eb223be7e4b5828cfd33236706026e0e68ad Mon Sep 17 00:00:00 2001
From: Tim Geoghegan <timg@divviup.org>
Date: Tue, 23 Jan 2024 17:32:20 -0800
Subject: [PATCH] refactor test to use setup_cancel_aggregation_job_test

---
 .../src/aggregator/aggregation_job_driver.rs  | 161 +++++-------------
 1 file changed, 40 insertions(+), 121 deletions(-)

diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs
index e7661adc1..bf5588ae0 100644
--- a/aggregator/src/aggregator/aggregation_job_driver.rs
+++ b/aggregator/src/aggregator/aggregation_job_driver.rs
@@ -418,9 +418,7 @@ impl AggregationJobDriver {
                 Method::PUT,
                 task.aggregation_job_uri(aggregation_job.id())?
                     .ok_or_else(|| {
-                        Error::InvalidConfiguration(
-                            "task is not leader and has no aggregate share URI",
-                        )
+                        Error::InvalidConfiguration("task is leader and has no aggregate share URI")
                     })?,
                 AGGREGATION_JOB_ROUTE,
                 Some(RequestBody {
@@ -1052,11 +1050,11 @@ mod tests {
     use janus_messages::{
         problem_type::DapProblemType,
         query_type::{FixedSize, TimeInterval},
-        AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq,
-        AggregationJobResp, AggregationJobStep, Duration, Extension, ExtensionType, FixedSizeQuery,
-        HpkeConfig, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare,
-        PrepareContinue, PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query,
-        ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time,
+        AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp,
+        AggregationJobStep, Duration, Extension, ExtensionType, FixedSizeQuery, HpkeConfig,
+        InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareContinue,
+        PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query, ReportIdChecksum,
+        ReportMetadata, ReportShare, Role, TaskId, Time,
     };
     use mockito::ServerGuard;
     use prio::{
@@ -3740,7 +3738,10 @@ mod tests {
 
     struct CancelAggregationJobTestCase {
         task: AggregatorTask,
-        aggregation_job_id: AggregationJobId,
+        vdaf: Arc<Prio3Count>,
+        aggregation_job: AggregationJob<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>,
+        batch_identifier: Interval,
+        report_aggregation: ReportAggregation<VERIFY_KEY_LENGTH, Prio3Count>,
         _ephemeral_datastore: EphemeralDatastore,
         datastore: Arc<Datastore<MockClock>>,
         lease: Lease<AcquiredAggregationJob>,
@@ -3846,7 +3847,10 @@ mod tests {
 
         CancelAggregationJobTestCase {
             task,
-            aggregation_job_id,
+            vdaf,
+            batch_identifier,
+            aggregation_job,
+            report_aggregation,
             _ephemeral_datastore: ephemeral_datastore,
             datastore,
             lease,
@@ -3856,109 +3860,18 @@ mod tests {
 
     #[tokio::test]
     async fn cancel_aggregation_job() {
-        // Setup: insert a client report and add it to a new aggregation job.
-        install_test_trace_subscriber();
-        let clock = MockClock::default();
-        let ephemeral_datastore = ephemeral_datastore().await;
-        let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
-        let vdaf = Arc::new(Prio3::new_count(2).unwrap());
-        let mut mock_helper = mockito::Server::new_async().await;
-
-        let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
-            .with_helper_aggregator_endpoint(mock_helper.url().parse().unwrap())
-            .build()
-            .leader_view()
-            .unwrap();
-        let time = clock
-            .now()
-            .to_batch_interval_start(task.time_precision())
-            .unwrap();
-        let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
-        let report_metadata = ReportMetadata::new(random(), time);
-        let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();
-
-        let transcript = run_vdaf(
-            vdaf.as_ref(),
-            verify_key.as_bytes(),
-            &(),
-            report_metadata.id(),
-            &false,
-        );
-
-        let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
-        let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
-            *task.id(),
-            report_metadata,
-            helper_hpke_keypair.config(),
-            transcript.public_share,
-            Vec::new(),
-            &transcript.leader_input_share,
-            &transcript.helper_input_share,
-        );
-        let aggregation_job_id = random();
-
-        let aggregation_job = AggregationJob::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
-            *task.id(),
-            aggregation_job_id,
-            (),
-            (),
-            Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(),
-            AggregationJobState::InProgress,
-            AggregationJobStep::from(0),
-        );
-        let report_aggregation = ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new(
-            *task.id(),
-            aggregation_job_id,
-            *report.metadata().id(),
-            *report.metadata().time(),
-            0,
-            None,
-            ReportAggregationState::Start,
-        );
-
-        let lease = ds
-            .run_unnamed_tx(|tx| {
-                let (vdaf, task, report, aggregation_job, report_aggregation) = (
-                    vdaf.clone(),
-                    task.clone(),
-                    report.clone(),
-                    aggregation_job.clone(),
-                    report_aggregation.clone(),
-                );
-                Box::pin(async move {
-                    tx.put_aggregator_task(&task).await?;
-                    tx.put_client_report(vdaf.borrow(), &report).await?;
-                    tx.put_aggregation_job(&aggregation_job).await?;
-                    tx.put_report_aggregation(&report_aggregation).await?;
-
-                    tx.put_batch(&Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
-                        *task.id(),
-                        batch_identifier,
-                        (),
-                        BatchState::Open,
-                        1,
-                        Interval::from_time(report.metadata().time()).unwrap(),
-                    ))
-                    .await?;
-
-                    Ok(tx
-                        .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1)
-                        .await?
-                        .remove(0))
-                })
-            })
-            .await
-            .unwrap();
-        assert_eq!(lease.leased().task_id(), task.id());
-        assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id);
+        let mut test_case = setup_cancel_aggregation_job_test().await;
 
         // Run: create an aggregation job driver & cancel the aggregation job. Mock the helper to
         // verify that we instruct it to delete the aggregation job.
         // https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-09#section-4.5.2.2-20
-        let mocked_aggregation_job_delete = mock_helper
+        let mocked_aggregation_job_delete = test_case
+            .mock_helper
             .mock(
                 "DELETE",
-                task.aggregation_job_uri(&aggregation_job_id)
+                test_case
+                    .task
+                    .aggregation_job_uri(test_case.aggregation_job.id())
                     .unwrap()
                     .unwrap()
                     .path(),
@@ -3973,7 +3886,7 @@ mod tests {
             32,
         );
         aggregation_job_driver
-            .abandon_aggregation_job(Arc::clone(&ds), Arc::new(lease))
+            .abandon_aggregation_job(Arc::clone(&test_case.datastore), Arc::new(test_case.lease))
             .await
             .unwrap();
 
@@ -3982,26 +3895,32 @@ mod tests {
         // Verify: check that the datastore state is updated as expected (the aggregation job is
         // abandoned, the report aggregation is untouched) and sanity-check that the job can no
         // longer be acquired.
-        let want_aggregation_job = aggregation_job.with_state(AggregationJobState::Abandoned);
-        let want_report_aggregation = report_aggregation;
+        let want_aggregation_job = test_case
+            .aggregation_job
+            .with_state(AggregationJobState::Abandoned);
         let want_batch = Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
-            *task.id(),
-            batch_identifier,
+            *test_case.task.id(),
+            test_case.batch_identifier,
             (),
             BatchState::Open,
             0,
-            Interval::from_time(report.metadata().time()).unwrap(),
+            Interval::from_time(test_case.report_aggregation.report_metadata().time()).unwrap(),
         );
 
-        let (got_aggregation_job, got_report_aggregation, got_batch, got_leases) = ds
+        let (got_aggregation_job, got_report_aggregation, got_batch, got_leases) = test_case
+            .datastore
             .run_unnamed_tx(|tx| {
-                let (vdaf, task, report_id) =
-                    (Arc::clone(&vdaf), task.clone(), *report.metadata().id());
+                let (vdaf, task, report_id, aggregation_job) = (
+                    Arc::clone(&test_case.vdaf),
+                    test_case.task.clone(),
+                    *test_case.report_aggregation.report_metadata().id(),
+                    want_aggregation_job.clone(),
+                );
                 Box::pin(async move {
                     let aggregation_job = tx
                         .get_aggregation_job::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>(
                             task.id(),
-                            &aggregation_job_id,
+                            aggregation_job.id(),
                         )
                         .await?
                         .unwrap();
@@ -4010,14 +3929,14 @@ mod tests {
                             vdaf.as_ref(),
                             &Role::Leader,
                             task.id(),
-                            &aggregation_job_id,
+                            aggregation_job.id(),
                             aggregation_job.aggregation_parameter(),
                             &report_id,
                         )
                         .await?
                         .unwrap();
                     let batch = tx
-                        .get_batch(task.id(), &batch_identifier, &())
+                        .get_batch(task.id(), &test_case.batch_identifier, &())
                         .await?
                         .unwrap();
                     let leases = tx
@@ -4029,7 +3948,7 @@ mod tests {
             .await
             .unwrap();
         assert_eq!(want_aggregation_job, got_aggregation_job);
-        assert_eq!(want_report_aggregation, got_report_aggregation);
+        assert_eq!(test_case.report_aggregation, got_report_aggregation);
         assert_eq!(want_batch, got_batch);
         assert!(got_leases.is_empty());
     }
@@ -4049,7 +3968,7 @@ mod tests {
                 "DELETE",
                 test_case
                     .task
-                    .aggregation_job_uri(&test_case.aggregation_job_id)
+                    .aggregation_job_uri(test_case.aggregation_job.id())
                     .unwrap()
                     .unwrap()
                     .path(),