From 9719eb223be7e4b5828cfd33236706026e0e68ad Mon Sep 17 00:00:00 2001 From: Tim Geoghegan <timg@divviup.org> Date: Tue, 23 Jan 2024 17:32:20 -0800 Subject: [PATCH] refactor test to use setup_cancel_aggregation_job_test --- .../src/aggregator/aggregation_job_driver.rs | 161 +++++------------- 1 file changed, 40 insertions(+), 121 deletions(-) diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index e7661adc1..bf5588ae0 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -418,9 +418,7 @@ impl AggregationJobDriver { Method::PUT, task.aggregation_job_uri(aggregation_job.id())? .ok_or_else(|| { - Error::InvalidConfiguration( - "task is not leader and has no aggregate share URI", - ) + Error::InvalidConfiguration("task is leader and has no aggregate share URI") })?, AGGREGATION_JOB_ROUTE, Some(RequestBody { @@ -1052,11 +1050,11 @@ mod tests { use janus_messages::{ problem_type::DapProblemType, query_type::{FixedSize, TimeInterval}, - AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, - AggregationJobResp, AggregationJobStep, Duration, Extension, ExtensionType, FixedSizeQuery, - HpkeConfig, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, - PrepareContinue, PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query, - ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, + AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, + AggregationJobStep, Duration, Extension, ExtensionType, FixedSizeQuery, HpkeConfig, + InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareContinue, + PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query, ReportIdChecksum, + ReportMetadata, ReportShare, Role, TaskId, Time, }; use mockito::ServerGuard; use prio::{ @@ -3740,7 +3738,10 @@ mod tests { struct CancelAggregationJobTestCase { task: AggregatorTask, - aggregation_job_id: AggregationJobId, + vdaf: Arc<Prio3Count>, + aggregation_job: AggregationJob<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>, + batch_identifier: Interval, + report_aggregation: ReportAggregation<VERIFY_KEY_LENGTH, Prio3Count>, _ephemeral_datastore: EphemeralDatastore, datastore: Arc<Datastore<MockClock>>, lease: Lease<AcquiredAggregationJob>, @@ -3846,7 +3847,10 @@ mod tests { CancelAggregationJobTestCase { task, - aggregation_job_id, + vdaf, + batch_identifier, + aggregation_job, + report_aggregation, _ephemeral_datastore: ephemeral_datastore, datastore, lease, @@ -3856,109 +3860,18 @@ mod tests { #[tokio::test] async fn cancel_aggregation_job() { - // Setup: insert a client report and add it to a new aggregation job. - install_test_trace_subscriber(); - let clock = MockClock::default(); - let ephemeral_datastore = ephemeral_datastore().await; - let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let mut mock_helper = mockito::Server::new_async().await; - - let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) - .with_helper_aggregator_endpoint(mock_helper.url().parse().unwrap()) - .build() - .leader_view() - .unwrap(); - let time = clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(); - let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); - let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap(); - - let transcript = run_vdaf( - vdaf.as_ref(), - verify_key.as_bytes(), - &(), - report_metadata.id(), - &false, - ); - - let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>( - *task.id(), - report_metadata, - helper_hpke_keypair.config(), - transcript.public_share, - Vec::new(), - &transcript.leader_input_share, - &transcript.helper_input_share, - ); - let aggregation_job_id = random(); - - let aggregation_job = AggregationJob::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ); - let report_aggregation = ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - ); - - let lease = ds - .run_unnamed_tx(|tx| { - let (vdaf, task, report, aggregation_job, report_aggregation) = ( - vdaf.clone(), - task.clone(), - report.clone(), - aggregation_job.clone(), - report_aggregation.clone(), - ); - Box::pin(async move { - tx.put_aggregator_task(&task).await?; - tx.put_client_report(vdaf.borrow(), &report).await?; - tx.put_aggregation_job(&aggregation_job).await?; - tx.put_report_aggregation(&report_aggregation).await?; - - tx.put_batch(&Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new( - *task.id(), - batch_identifier, - (), - BatchState::Open, - 1, - Interval::from_time(report.metadata().time()).unwrap(), - )) - .await?; - - Ok(tx - .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) - .await? - .remove(0)) - }) - }) - .await - .unwrap(); - assert_eq!(lease.leased().task_id(), task.id()); - assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + let mut test_case = setup_cancel_aggregation_job_test().await; // Run: create an aggregation job driver & cancel the aggregation job. Mock the helper to // verify that we instruct it to delete the aggregation job. // https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-09#section-4.5.2.2-20 - let mocked_aggregation_job_delete = mock_helper + let mocked_aggregation_job_delete = test_case + .mock_helper .mock( "DELETE", - task.aggregation_job_uri(&aggregation_job_id) + test_case + .task + .aggregation_job_uri(test_case.aggregation_job.id()) .unwrap() .unwrap() .path(), @@ -3973,7 +3886,7 @@ mod tests { 32, ); aggregation_job_driver - .abandon_aggregation_job(Arc::clone(&ds), Arc::new(lease)) + .abandon_aggregation_job(Arc::clone(&test_case.datastore), Arc::new(test_case.lease)) .await .unwrap(); @@ -3982,26 +3895,32 @@ mod tests { // Verify: check that the datastore state is updated as expected (the aggregation job is // abandoned, the report aggregation is untouched) and sanity-check that the job can no // longer be acquired. - let want_aggregation_job = aggregation_job.with_state(AggregationJobState::Abandoned); - let want_report_aggregation = report_aggregation; + let want_aggregation_job = test_case + .aggregation_job + .with_state(AggregationJobState::Abandoned); let want_batch = Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new( - *task.id(), - batch_identifier, + *test_case.task.id(), + test_case.batch_identifier, (), BatchState::Open, 0, - Interval::from_time(report.metadata().time()).unwrap(), + Interval::from_time(test_case.report_aggregation.report_metadata().time()).unwrap(), ); - let (got_aggregation_job, got_report_aggregation, got_batch, got_leases) = ds + let (got_aggregation_job, got_report_aggregation, got_batch, got_leases) = test_case + .datastore .run_unnamed_tx(|tx| { - let (vdaf, task, report_id) = - (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); + let (vdaf, task, report_id, aggregation_job) = ( + Arc::clone(&test_case.vdaf), + test_case.task.clone(), + *test_case.report_aggregation.report_metadata().id(), + want_aggregation_job.clone(), + ); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>( task.id(), - &aggregation_job_id, + aggregation_job.id(), ) .await? .unwrap(); @@ -4010,14 +3929,14 @@ mod tests { vdaf.as_ref(), &Role::Leader, task.id(), - &aggregation_job_id, + aggregation_job.id(), aggregation_job.aggregation_parameter(), &report_id, ) .await? .unwrap(); let batch = tx - .get_batch(task.id(), &batch_identifier, &()) + .get_batch(task.id(), &test_case.batch_identifier, &()) .await? .unwrap(); let leases = tx @@ -4029,7 +3948,7 @@ mod tests { .await .unwrap(); assert_eq!(want_aggregation_job, got_aggregation_job); - assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(test_case.report_aggregation, got_report_aggregation); assert_eq!(want_batch, got_batch); assert!(got_leases.is_empty()); } @@ -4049,7 +3968,7 @@ mod tests { "DELETE", test_case .task - .aggregation_job_uri(&test_case.aggregation_job_id) + .aggregation_job_uri(test_case.aggregation_job.id()) .unwrap() .unwrap() .path(),