Skip to content

Commit

Permalink
Fix writing aggregation jobs touching GC'ed batches. (#2467)
Browse files Browse the repository at this point in the history
This issue should only exist in the time-interval query type, as
fixed-size is arranged such that aggregation jobs touching a given batch
must be GC'ed before the batch. I include a guard to ensure that the new
codepath is only taken in the expected case of an already-GC'ed batch
for a time-interval query, as otherwise we might drop batch writes if we
fell into it unexpectedly.
  • Loading branch information
branlwyd authored Jan 10, 2024
1 parent 16a0fbc commit f257d1e
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 6 deletions.
320 changes: 320 additions & 0 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,326 @@ mod tests {
assert_eq!(want_batch, got_batch);
}

#[tokio::test]
async fn step_time_interval_aggregation_job_init_partially_garbage_collected() {
// This is a regression test for https://github.com/divviup/janus/issues/2464.
const OLDEST_ALLOWED_REPORT_TIMESTAMP: Time = Time::from_seconds_since_epoch(1000);
const REPORT_EXPIRY_AGE: Duration = Duration::from_seconds(500);
const TIME_PRECISION: Duration = Duration::from_seconds(10);

// Setup: insert an "old" and "new" client report, and add them to a new aggregation job.
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP);
let ephemeral_datastore = ephemeral_datastore().await;
let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
let vdaf = Arc::new(Prio3::new_count(2).unwrap());

let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.with_helper_aggregator_endpoint(server.url().parse().unwrap())
.with_report_expiry_age(Some(REPORT_EXPIRY_AGE))
.with_time_precision(TIME_PRECISION)
.build();

let leader_task = task.leader_view().unwrap();

let gc_eligible_time = OLDEST_ALLOWED_REPORT_TIMESTAMP
.sub(&Duration::from_seconds(3 * TIME_PRECISION.as_seconds()))
.unwrap()
.to_batch_interval_start(&TIME_PRECISION)
.unwrap();
let gc_eligible_batch_identifier =
TimeInterval::to_batch_identifier(&leader_task, &(), &gc_eligible_time).unwrap();
let gc_eligible_report_metadata = ReportMetadata::new(random(), gc_eligible_time);

let gc_uneligible_time = OLDEST_ALLOWED_REPORT_TIMESTAMP
.add(&Duration::from_seconds(3 * TIME_PRECISION.as_seconds()))
.unwrap()
.to_batch_interval_start(&TIME_PRECISION)
.unwrap();
let gc_uneligible_batch_identifier =
TimeInterval::to_batch_identifier(&leader_task, &(), &gc_uneligible_time).unwrap();
let gc_uneligible_report_metadata = ReportMetadata::new(random(), gc_uneligible_time);

let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();

let gc_eligible_transcript = run_vdaf(
vdaf.as_ref(),
verify_key.as_bytes(),
&(),
gc_eligible_report_metadata.id(),
&0,
);
let gc_uneligible_transcript = run_vdaf(
vdaf.as_ref(),
verify_key.as_bytes(),
&(),
gc_uneligible_report_metadata.id(),
&0,
);

let agg_auth_token = task.aggregator_auth_token();
let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let gc_eligible_report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
gc_eligible_report_metadata,
helper_hpke_keypair.config(),
gc_eligible_transcript.public_share.clone(),
Vec::new(),
&gc_eligible_transcript.leader_input_share,
&gc_eligible_transcript.helper_input_share,
);
let gc_uneligible_report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
gc_uneligible_report_metadata,
helper_hpke_keypair.config(),
gc_uneligible_transcript.public_share.clone(),
Vec::new(),
&gc_uneligible_transcript.leader_input_share,
&gc_uneligible_transcript.helper_input_share,
);

let aggregation_job_id = random();

let lease = ds
.run_unnamed_tx(|tx| {
let vdaf = Arc::clone(&vdaf);
let leader_task = leader_task.clone();
let gc_eligible_report = gc_eligible_report.clone();
let gc_uneligible_report = gc_uneligible_report.clone();

Box::pin(async move {
tx.put_aggregator_task(&leader_task).await.unwrap();
tx.put_client_report(vdaf.borrow(), &gc_eligible_report)
.await
.unwrap();
tx.put_client_report(vdaf.borrow(), &gc_uneligible_report)
.await
.unwrap();

tx.put_aggregation_job(&AggregationJob::<
VERIFY_KEY_LENGTH,
TimeInterval,
Prio3Count,
>::new(
*leader_task.id(),
aggregation_job_id,
(),
(),
Interval::new(
gc_eligible_time,
gc_uneligible_time.difference(&gc_eligible_time).unwrap(),
)
.unwrap(),
AggregationJobState::InProgress,
AggregationJobStep::from(0),
))
.await
.unwrap();
tx.put_report_aggregation(
&ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new(
*leader_task.id(),
aggregation_job_id,
*gc_eligible_report.metadata().id(),
*gc_eligible_report.metadata().time(),
0,
None,
ReportAggregationState::Start,
),
)
.await
.unwrap();
tx.put_report_aggregation(
&ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new(
*leader_task.id(),
aggregation_job_id,
*gc_uneligible_report.metadata().id(),
*gc_uneligible_report.metadata().time(),
1,
None,
ReportAggregationState::Start,
),
)
.await
.unwrap();

tx.put_batch(&Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*leader_task.id(),
gc_eligible_batch_identifier,
(),
BatchState::Closing,
1,
Interval::from_time(&gc_eligible_time).unwrap(),
))
.await
.unwrap();
tx.put_batch(&Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*leader_task.id(),
gc_uneligible_batch_identifier,
(),
BatchState::Closing,
1,
Interval::from_time(&gc_uneligible_time).unwrap(),
))
.await
.unwrap();

Ok(tx
.acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1)
.await
.unwrap()
.remove(0))
})
})
.await
.unwrap();
assert_eq!(lease.leased().task_id(), task.id());
assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id);

// Advance the clock to "enable" report expiry.
clock.advance(&REPORT_EXPIRY_AGE);

// Setup: prepare mocked HTTP response.
let leader_request = AggregationJobInitializeReq::new(
().get_encoded(),
PartialBatchSelector::new_time_interval(),
Vec::from([PrepareInit::new(
ReportShare::new(
gc_uneligible_report.metadata().clone(),
gc_uneligible_report.public_share().get_encoded(),
gc_uneligible_report.helper_encrypted_input_share().clone(),
),
gc_uneligible_transcript.leader_prepare_transitions[0]
.message
.clone(),
)]),
);
let helper_response = AggregationJobResp::new(Vec::from([PrepareResp::new(
*gc_uneligible_report.metadata().id(),
PrepareStepResult::Continue {
message: gc_uneligible_transcript.helper_prepare_transitions[0]
.message
.clone(),
},
)]));
let (header, value) = agg_auth_token.request_authentication();
let mocked_aggregate_init = server
.mock(
"PUT",
task.aggregation_job_uri(&aggregation_job_id)
.unwrap()
.path(),
)
.match_header(header, value.as_str())
.match_header(
CONTENT_TYPE.as_str(),
AggregationJobInitializeReq::<TimeInterval>::MEDIA_TYPE,
)
.match_body(leader_request.get_encoded())
.with_status(200)
.with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE)
.with_body(helper_response.get_encoded())
.create_async()
.await;

// Run: create an aggregation job driver & try to step the aggregation we've created.
let aggregation_job_driver = AggregationJobDriver::new(
reqwest::Client::builder().build().unwrap(),
&noop_meter(),
32,
);
aggregation_job_driver
.step_aggregation_job(ds.clone(), Arc::new(lease))
.await
.unwrap();

// Verify.
mocked_aggregate_init.assert_async().await;

let want_aggregation_job =
AggregationJob::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*task.id(),
aggregation_job_id,
(),
(),
Interval::new(
gc_eligible_time,
gc_uneligible_time.difference(&gc_eligible_time).unwrap(),
)
.unwrap(),
AggregationJobState::Finished,
AggregationJobStep::from(1),
);

let want_gc_eligible_report_aggregation =
ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new(
*task.id(),
aggregation_job_id,
*gc_eligible_report.metadata().id(),
*gc_eligible_report.metadata().time(),
0,
None,
ReportAggregationState::Failed(PrepareError::ReportDropped),
);
let want_uneligible_report_aggregation =
ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new(
*task.id(),
aggregation_job_id,
*gc_uneligible_report.metadata().id(),
*gc_uneligible_report.metadata().time(),
1,
None,
ReportAggregationState::Finished,
);
let want_report_aggregations = Vec::from([
want_gc_eligible_report_aggregation,
want_uneligible_report_aggregation,
]);

let want_batch = Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*task.id(),
gc_uneligible_batch_identifier,
(),
BatchState::Closing,
0,
Interval::from_time(&gc_uneligible_time).unwrap(),
);
let want_batches = Vec::from([want_batch]);

let (got_aggregation_job, got_report_aggregations, got_batches) = ds
.run_unnamed_tx(|tx| {
let vdaf = Arc::clone(&vdaf);
let task = task.clone();
Box::pin(async move {
let aggregation_job = tx
.get_aggregation_job::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>(
task.id(),
&aggregation_job_id,
)
.await
.unwrap()
.unwrap();
let report_aggregations = tx
.get_report_aggregations_for_aggregation_job(
vdaf.as_ref(),
&Role::Leader,
task.id(),
&aggregation_job_id,
)
.await
.unwrap();
let batches = tx.get_batches_for_task(task.id()).await.unwrap();
Ok((aggregation_job, report_aggregations, batches))
})
})
.await
.unwrap();

assert_eq!(want_aggregation_job, got_aggregation_job);
assert_eq!(want_report_aggregations, got_report_aggregations);
assert_eq!(want_batches, got_batches);
}

#[tokio::test]
async fn step_fixed_size_aggregation_job_init_single_step() {
// Setup: insert a client report and add it to a new aggregation job.
Expand Down
Loading

0 comments on commit f257d1e

Please sign in to comment.