Skip to content

Commit

Permalink
refactor test to use setup_cancel_aggregation_job_test
Browse files Browse the repository at this point in the history
  • Loading branch information
tgeoghegan committed Jan 24, 2024
1 parent 7a9b9c0 commit 9719eb2
Showing 1 changed file with 40 additions and 121 deletions.
161 changes: 40 additions & 121 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,7 @@ impl AggregationJobDriver {
Method::PUT,
task.aggregation_job_uri(aggregation_job.id())?
.ok_or_else(|| {
Error::InvalidConfiguration(
"task is not leader and has no aggregate share URI",
)
Error::InvalidConfiguration("task is leader and has no aggregate share URI")
})?,
AGGREGATION_JOB_ROUTE,
Some(RequestBody {
Expand Down Expand Up @@ -1052,11 +1050,11 @@ mod tests {
use janus_messages::{
problem_type::DapProblemType,
query_type::{FixedSize, TimeInterval},
AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq,
AggregationJobResp, AggregationJobStep, Duration, Extension, ExtensionType, FixedSizeQuery,
HpkeConfig, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare,
PrepareContinue, PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query,
ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time,
AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp,
AggregationJobStep, Duration, Extension, ExtensionType, FixedSizeQuery, HpkeConfig,
InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareContinue,
PrepareError, PrepareInit, PrepareResp, PrepareStepResult, Query, ReportIdChecksum,
ReportMetadata, ReportShare, Role, TaskId, Time,
};
use mockito::ServerGuard;
use prio::{
Expand Down Expand Up @@ -3740,7 +3738,10 @@ mod tests {

struct CancelAggregationJobTestCase {
task: AggregatorTask,
aggregation_job_id: AggregationJobId,
vdaf: Arc<Prio3Count>,
aggregation_job: AggregationJob<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>,
batch_identifier: Interval,
report_aggregation: ReportAggregation<VERIFY_KEY_LENGTH, Prio3Count>,
_ephemeral_datastore: EphemeralDatastore,
datastore: Arc<Datastore<MockClock>>,
lease: Lease<AcquiredAggregationJob>,
Expand Down Expand Up @@ -3846,7 +3847,10 @@ mod tests {

CancelAggregationJobTestCase {
task,
aggregation_job_id,
vdaf,
batch_identifier,
aggregation_job,
report_aggregation,
_ephemeral_datastore: ephemeral_datastore,
datastore,
lease,
Expand All @@ -3856,109 +3860,18 @@ mod tests {

#[tokio::test]
async fn cancel_aggregation_job() {
// Setup: insert a client report and add it to a new aggregation job.
install_test_trace_subscriber();
let clock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
let vdaf = Arc::new(Prio3::new_count(2).unwrap());
let mut mock_helper = mockito::Server::new_async().await;

let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.with_helper_aggregator_endpoint(mock_helper.url().parse().unwrap())
.build()
.leader_view()
.unwrap();
let time = clock
.now()
.to_batch_interval_start(task.time_precision())
.unwrap();
let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap();
let report_metadata = ReportMetadata::new(random(), time);
let verify_key: VerifyKey<VERIFY_KEY_LENGTH> = task.vdaf_verify_key().unwrap();

let transcript = run_vdaf(
vdaf.as_ref(),
verify_key.as_bytes(),
&(),
report_metadata.id(),
&false,
);

let helper_hpke_keypair = generate_test_hpke_config_and_private_key();
let report = generate_report::<VERIFY_KEY_LENGTH, Prio3Count>(
*task.id(),
report_metadata,
helper_hpke_keypair.config(),
transcript.public_share,
Vec::new(),
&transcript.leader_input_share,
&transcript.helper_input_share,
);
let aggregation_job_id = random();

let aggregation_job = AggregationJob::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*task.id(),
aggregation_job_id,
(),
(),
Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(),
AggregationJobState::InProgress,
AggregationJobStep::from(0),
);
let report_aggregation = ReportAggregation::<VERIFY_KEY_LENGTH, Prio3Count>::new(
*task.id(),
aggregation_job_id,
*report.metadata().id(),
*report.metadata().time(),
0,
None,
ReportAggregationState::Start,
);

let lease = ds
.run_unnamed_tx(|tx| {
let (vdaf, task, report, aggregation_job, report_aggregation) = (
vdaf.clone(),
task.clone(),
report.clone(),
aggregation_job.clone(),
report_aggregation.clone(),
);
Box::pin(async move {
tx.put_aggregator_task(&task).await?;
tx.put_client_report(vdaf.borrow(), &report).await?;
tx.put_aggregation_job(&aggregation_job).await?;
tx.put_report_aggregation(&report_aggregation).await?;

tx.put_batch(&Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*task.id(),
batch_identifier,
(),
BatchState::Open,
1,
Interval::from_time(report.metadata().time()).unwrap(),
))
.await?;

Ok(tx
.acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1)
.await?
.remove(0))
})
})
.await
.unwrap();
assert_eq!(lease.leased().task_id(), task.id());
assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id);
let mut test_case = setup_cancel_aggregation_job_test().await;

// Run: create an aggregation job driver & cancel the aggregation job. Mock the helper to
// verify that we instruct it to delete the aggregation job.
// https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-09#section-4.5.2.2-20
let mocked_aggregation_job_delete = mock_helper
let mocked_aggregation_job_delete = test_case
.mock_helper
.mock(
"DELETE",
task.aggregation_job_uri(&aggregation_job_id)
test_case
.task
.aggregation_job_uri(test_case.aggregation_job.id())
.unwrap()
.unwrap()
.path(),
Expand All @@ -3973,7 +3886,7 @@ mod tests {
32,
);
aggregation_job_driver
.abandon_aggregation_job(Arc::clone(&ds), Arc::new(lease))
.abandon_aggregation_job(Arc::clone(&test_case.datastore), Arc::new(test_case.lease))
.await
.unwrap();

Expand All @@ -3982,26 +3895,32 @@ mod tests {
// Verify: check that the datastore state is updated as expected (the aggregation job is
// abandoned, the report aggregation is untouched) and sanity-check that the job can no
// longer be acquired.
let want_aggregation_job = aggregation_job.with_state(AggregationJobState::Abandoned);
let want_report_aggregation = report_aggregation;
let want_aggregation_job = test_case
.aggregation_job
.with_state(AggregationJobState::Abandoned);
let want_batch = Batch::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>::new(
*task.id(),
batch_identifier,
*test_case.task.id(),
test_case.batch_identifier,
(),
BatchState::Open,
0,
Interval::from_time(report.metadata().time()).unwrap(),
Interval::from_time(test_case.report_aggregation.report_metadata().time()).unwrap(),
);

let (got_aggregation_job, got_report_aggregation, got_batch, got_leases) = ds
let (got_aggregation_job, got_report_aggregation, got_batch, got_leases) = test_case
.datastore
.run_unnamed_tx(|tx| {
let (vdaf, task, report_id) =
(Arc::clone(&vdaf), task.clone(), *report.metadata().id());
let (vdaf, task, report_id, aggregation_job) = (
Arc::clone(&test_case.vdaf),
test_case.task.clone(),
*test_case.report_aggregation.report_metadata().id(),
want_aggregation_job.clone(),
);
Box::pin(async move {
let aggregation_job = tx
.get_aggregation_job::<VERIFY_KEY_LENGTH, TimeInterval, Prio3Count>(
task.id(),
&aggregation_job_id,
aggregation_job.id(),
)
.await?
.unwrap();
Expand All @@ -4010,14 +3929,14 @@ mod tests {
vdaf.as_ref(),
&Role::Leader,
task.id(),
&aggregation_job_id,
aggregation_job.id(),
aggregation_job.aggregation_parameter(),
&report_id,
)
.await?
.unwrap();
let batch = tx
.get_batch(task.id(), &batch_identifier, &())
.get_batch(task.id(), &test_case.batch_identifier, &())
.await?
.unwrap();
let leases = tx
Expand All @@ -4029,7 +3948,7 @@ mod tests {
.await
.unwrap();
assert_eq!(want_aggregation_job, got_aggregation_job);
assert_eq!(want_report_aggregation, got_report_aggregation);
assert_eq!(test_case.report_aggregation, got_report_aggregation);
assert_eq!(want_batch, got_batch);
assert!(got_leases.is_empty());
}
Expand All @@ -4049,7 +3968,7 @@ mod tests {
"DELETE",
test_case
.task
.aggregation_job_uri(&test_case.aggregation_job_id)
.aggregation_job_uri(test_case.aggregation_job.id())
.unwrap()
.unwrap()
.path(),
Expand Down

0 comments on commit 9719eb2

Please sign in to comment.