diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index b2de59476..ca9ed6ec4 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -1895,6 +1895,326 @@ mod tests { assert_eq!(want_batch, got_batch); } + #[tokio::test] + async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { + // This is a regression test for https://github.com/divviup/janus/issues/2464. + const OLDEST_ALLOWED_REPORT_TIMESTAMP: Time = Time::from_seconds_since_epoch(1000); + const REPORT_EXPIRY_AGE: Duration = Duration::from_seconds(500); + const TIME_PRECISION: Duration = Duration::from_seconds(10); + + // Setup: insert an "old" and "new" client report, and add them to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .with_time_precision(TIME_PRECISION) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let gc_eligible_time = OLDEST_ALLOWED_REPORT_TIMESTAMP + .sub(&Duration::from_seconds(3 * TIME_PRECISION.as_seconds())) + .unwrap() + .to_batch_interval_start(&TIME_PRECISION) + .unwrap(); + let gc_eligible_batch_identifier = + TimeInterval::to_batch_identifier(&leader_task, &(), &gc_eligible_time).unwrap(); + let gc_eligible_report_metadata = ReportMetadata::new(random(), gc_eligible_time); + + let gc_uneligible_time = OLDEST_ALLOWED_REPORT_TIMESTAMP + .add(&Duration::from_seconds(3 * TIME_PRECISION.as_seconds())) + .unwrap() + .to_batch_interval_start(&TIME_PRECISION) + .unwrap(); + let gc_uneligible_batch_identifier = + TimeInterval::to_batch_identifier(&leader_task, &(), &gc_uneligible_time).unwrap(); + let gc_uneligible_report_metadata = ReportMetadata::new(random(), gc_uneligible_time); + + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + + let gc_eligible_transcript = run_vdaf( + vdaf.as_ref(), + verify_key.as_bytes(), + &(), + gc_eligible_report_metadata.id(), + &0, + ); + let gc_uneligible_transcript = run_vdaf( + vdaf.as_ref(), + verify_key.as_bytes(), + &(), + gc_uneligible_report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); + let gc_eligible_report = generate_report::( + *task.id(), + gc_eligible_report_metadata, + helper_hpke_keypair.config(), + gc_eligible_transcript.public_share.clone(), + Vec::new(), + &gc_eligible_transcript.leader_input_share, + &gc_eligible_transcript.helper_input_share, + ); + let gc_uneligible_report = generate_report::( + *task.id(), + gc_uneligible_report_metadata, + helper_hpke_keypair.config(), + gc_uneligible_transcript.public_share.clone(), + Vec::new(), + &gc_uneligible_transcript.leader_input_share, + &gc_uneligible_transcript.helper_input_share, + ); + + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let leader_task = leader_task.clone(); + let gc_eligible_report = gc_eligible_report.clone(); + let gc_uneligible_report = gc_uneligible_report.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&leader_task).await.unwrap(); + tx.put_client_report(vdaf.borrow(), &gc_eligible_report) + .await + .unwrap(); + tx.put_client_report(vdaf.borrow(), &gc_uneligible_report) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH, + TimeInterval, + Prio3Count, + >::new( + *leader_task.id(), + aggregation_job_id, + (), + (), + Interval::new( + gc_eligible_time, + gc_uneligible_time.difference(&gc_eligible_time).unwrap(), + ) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + tx.put_report_aggregation( + &ReportAggregation::::new( + *leader_task.id(), + aggregation_job_id, + *gc_eligible_report.metadata().id(), + *gc_eligible_report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) + .await + .unwrap(); + tx.put_report_aggregation( + &ReportAggregation::::new( + *leader_task.id(), + aggregation_job_id, + *gc_uneligible_report.metadata().id(), + *gc_uneligible_report.metadata().time(), + 1, + None, + ReportAggregationState::Start, + ), + ) + .await + .unwrap(); + + tx.put_batch(&Batch::::new( + *leader_task.id(), + gc_eligible_batch_identifier, + (), + BatchState::Closing, + 1, + Interval::from_time(&gc_eligible_time).unwrap(), + )) + .await + .unwrap(); + tx.put_batch(&Batch::::new( + *leader_task.id(), + gc_uneligible_batch_identifier, + (), + BatchState::Closing, + 1, + Interval::from_time(&gc_uneligible_time).unwrap(), + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Advance the clock to "enable" report expiry. + clock.advance(&REPORT_EXPIRY_AGE); + + // Setup: prepare mocked HTTP response. + let leader_request = AggregationJobInitializeReq::new( + ().get_encoded(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + gc_uneligible_report.metadata().clone(), + gc_uneligible_report.public_share().get_encoded(), + gc_uneligible_report.helper_encrypted_input_share().clone(), + ), + gc_uneligible_transcript.leader_prepare_transitions[0] + .message + .clone(), + )]), + ); + let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new( + *gc_uneligible_report.metadata().id(), + PrepareStepResult::Continue { + message: gc_uneligible_transcript.helper_prepare_transitions[0] + .message + .clone(), + }, + )])); + let (header, value) = agg_auth_token.request_authentication(); + let mocked_aggregate_init = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header(header, value.as_str()) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + &noop_meter(), + 32, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_init.assert_async().await; + + let want_aggregation_job = + AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new( + gc_eligible_time, + gc_uneligible_time.difference(&gc_eligible_time).unwrap(), + ) + .unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + ); + + let want_gc_eligible_report_aggregation = + ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *gc_eligible_report.metadata().id(), + *gc_eligible_report.metadata().time(), + 0, + None, + ReportAggregationState::Failed(PrepareError::ReportDropped), + ); + let want_uneligible_report_aggregation = + ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *gc_uneligible_report.metadata().id(), + *gc_uneligible_report.metadata().time(), + 1, + None, + ReportAggregationState::Finished, + ); + let want_report_aggregations = Vec::from([ + want_gc_eligible_report_aggregation, + want_uneligible_report_aggregation, + ]); + + let want_batch = Batch::::new( + *task.id(), + gc_uneligible_batch_identifier, + (), + BatchState::Closing, + 0, + Interval::from_time(&gc_uneligible_time).unwrap(), + ); + let want_batches = Vec::from([want_batch]); + + let (got_aggregation_job, got_report_aggregations, got_batches) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + ) + .await + .unwrap(); + let batches = tx.get_batches_for_task(task.id()).await.unwrap(); + Ok((aggregation_job, report_aggregations, batches)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregations, got_report_aggregations); + assert_eq!(want_batches, got_batches); + } + #[tokio::test] async fn step_fixed_size_aggregation_job_init_single_step() { // Setup: insert a client report and add it to a new aggregation job. diff --git a/aggregator/src/aggregator/aggregation_job_writer.rs b/aggregator/src/aggregator/aggregation_job_writer.rs index 5f9def79b..fb8559dfe 100644 --- a/aggregator/src/aggregator/aggregation_job_writer.rs +++ b/aggregator/src/aggregator/aggregation_job_writer.rs @@ -11,7 +11,7 @@ use janus_aggregator_core::{ }, Error, Transaction, }, - task::AggregatorTask, + task::{AggregatorTask, QueryType}, }; use janus_core::time::{Clock, IntervalExt}; use janus_messages::{AggregationJobId, Interval, PrepareError, ReportId}; @@ -22,6 +22,7 @@ use std::{ sync::{Arc, Mutex}, }; use tokio::try_join; +use tracing::{debug, error}; /// AggregationJobWriter contains the logic used to write aggregation jobs, both initially & /// on updates. It is used only by the Leader. @@ -300,7 +301,7 @@ impl (Operation::Update, batch), None => ( Operation::Put, @@ -332,15 +333,52 @@ impl Transaction<'_, C> { Ok((version, description)) } + /// Returns the clock used by this transaction. + pub fn clock(&self) -> &C { + self.clock + } + /// Writes a task into the datastore. #[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)] pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> {