Skip to content

Commit

Permalink
Task rewrite: TaskBuilder in aggregator pt. 1
Browse files Browse the repository at this point in the history
Adopts `NewTaskBuilder` and `AggregatorTask` across portions of the
test utilities and tests in the `janus_aggregator` module.

Part of #1524
  • Loading branch information
tgeoghegan committed Sep 29, 2023
1 parent 2fa444c commit a63a4a3
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 277 deletions.
40 changes: 19 additions & 21 deletions aggregator/src/aggregator/aggregate_init_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ use janus_aggregator_core::{
test_util::{ephemeral_datastore, EphemeralDatastore},
Datastore,
},
task::{test_util::TaskBuilder, QueryType, Task},
task::{
test_util::{NewTaskBuilder as TaskBuilder, Task},
AggregatorTask, QueryType,
},
test_util::noop_meter,
};
use janus_core::{
Expand All @@ -24,7 +27,7 @@ use janus_core::{
};
use janus_messages::{
query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp,
PartialBatchSelector, PrepareInit, PrepareStepResult, ReportMetadata, Role,
PartialBatchSelector, PrepareInit, PrepareStepResult, ReportMetadata,
};
use prio::{
codec::Encode,
Expand All @@ -46,7 +49,7 @@ where
V: vdaf::Vdaf,
{
clock: MockClock,
task: Task,
task: AggregatorTask,
vdaf: V,
aggregation_param: V::AggregationParam,
}
Expand All @@ -57,7 +60,7 @@ where
{
pub(super) fn new(
clock: MockClock,
task: Task,
task: AggregatorTask,
vdaf: V,
aggregation_param: V::AggregationParam,
) -> Self {
Expand Down Expand Up @@ -209,14 +212,15 @@ async fn setup_aggregate_init_test_without_sending_request<
) -> AggregationJobInitTestCase<VERIFY_KEY_SIZE, V> {
install_test_trace_subscriber();

let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance, Role::Helper)
.with_aggregator_auth_token(Some(auth_token))
let task = TaskBuilder::new(QueryType::TimeInterval, vdaf_instance)
.with_aggregator_auth_token(auth_token)
.build();
let helper_task = task.helper_view().unwrap();
let clock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);

datastore.put_task(&task).await.unwrap();
datastore.put_aggregator_task(&helper_task).await.unwrap();

let handler = aggregator_handler(
Arc::clone(&datastore),
Expand All @@ -227,8 +231,12 @@ async fn setup_aggregate_init_test_without_sending_request<
.await
.unwrap();

let prepare_init_generator =
PrepareInitGenerator::new(clock.clone(), task.clone(), vdaf, aggregation_param.clone());
let prepare_init_generator = PrepareInitGenerator::new(
clock.clone(),
helper_task.clone(),
vdaf,
aggregation_param.clone(),
);

let prepare_inits = Vec::from([
prepare_init_generator.next(&measurement).0,
Expand Down Expand Up @@ -263,10 +271,7 @@ pub(crate) async fn put_aggregation_job(
aggregation_job: &AggregationJobInitializeReq<TimeInterval>,
handler: &impl Handler,
) -> TestConn {
let (header, value) = task
.aggregator_auth_token()
.unwrap()
.request_authentication();
let (header, value) = task.aggregator_auth_token().request_authentication();
put(task.aggregation_job_uri(aggregation_job_id).unwrap().path())
.with_request_header(header, value)
.with_request_header(
Expand All @@ -292,7 +297,6 @@ async fn aggregation_job_init_authorization_dap_auth_token() {
let (auth_header, auth_value) = test_case
.task
.aggregator_auth_token()
.unwrap()
.request_authentication();

let response = put(test_case
Expand Down Expand Up @@ -337,12 +341,7 @@ async fn aggregation_job_init_malformed_authorization_header(#[case] header_valu
.with_request_header(KnownHeaderName::Authorization, header_value.to_string())
.with_request_header(
DAP_AUTH_HEADER,
test_case
.task
.aggregator_auth_token()
.unwrap()
.as_ref()
.to_owned(),
test_case.task.aggregator_auth_token().as_ref().to_owned(),
)
.with_request_header(
KnownHeaderName::ContentType,
Expand Down Expand Up @@ -490,7 +489,6 @@ async fn aggregation_job_init_wrong_query() {
let (header, value) = test_case
.task
.aggregator_auth_token()
.unwrap()
.request_authentication();

let mut response = put(test_case
Expand Down
27 changes: 12 additions & 15 deletions aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ impl VdafOps {
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
use crate::aggregator::http_handlers::test_util::{decode_response_body, take_problem_details};
use janus_aggregator_core::task::Task;
use janus_aggregator_core::task::test_util::Task;
use janus_messages::{AggregationJobContinueReq, AggregationJobId, AggregationJobResp};
use prio::codec::Encode;
use serde_json::json;
Expand All @@ -302,10 +302,7 @@ pub mod test_util {
request: &AggregationJobContinueReq,
handler: &impl Handler,
) -> TestConn {
let (header, value) = task
.aggregator_auth_token()
.unwrap()
.request_authentication();
let (header, value) = task.aggregator_auth_token().request_authentication();
post(task.aggregation_job_uri(aggregation_job_id).unwrap().path())
.with_request_header(header, value)
.with_request_header(
Expand Down Expand Up @@ -393,7 +390,10 @@ mod tests {
test_util::{ephemeral_datastore, EphemeralDatastore},
Datastore,
},
task::{test_util::TaskBuilder, QueryType, Task},
task::{
test_util::{NewTaskBuilder as TaskBuilder, Task},
QueryType,
},
test_util::noop_meter,
};
use janus_core::{
Expand Down Expand Up @@ -439,12 +439,9 @@ mod tests {
install_test_trace_subscriber();

let aggregation_job_id = random();
let task = TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Poplar1 { bits: 1 },
Role::Helper,
)
.build();
let task =
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Poplar1 { bits: 1 }).build();
let helper_task = task.helper_view().unwrap();
let clock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let meter = noop_meter();
Expand All @@ -456,7 +453,7 @@ mod tests {
.unwrap();
let prepare_init_generator = PrepareInitGenerator::new(
clock.clone(),
task.clone(),
helper_task.clone(),
Poplar1::new_shake128(1),
aggregation_param.clone(),
);
Expand All @@ -467,14 +464,14 @@ mod tests {
datastore
.run_tx(|tx| {
let (task, aggregation_param, prepare_init, transcript) = (
task.clone(),
helper_task.clone(),
aggregation_param.clone(),
prepare_init.clone(),
transcript.clone(),
);

Box::pin(async move {
tx.put_task(&task).await.unwrap();
tx.put_aggregator_task(&task).await.unwrap();
tx.put_report_share(task.id(), prepare_init.report_share())
.await
.unwrap();
Expand Down
34 changes: 15 additions & 19 deletions aggregator/src/aggregator/collection_job_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ use janus_aggregator_core::{
test_util::{ephemeral_datastore, EphemeralDatastore},
Datastore,
},
task::{test_util::TaskBuilder, QueryType, Task},
task::{
test_util::{NewTaskBuilder as TaskBuilder, Task},
QueryType,
},
test_util::noop_meter,
};
use janus_core::{
auth_tokens::AuthenticationToken,
hpke::{
self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo,
HpkeKeypair, Label,
},
hpke::{self, HpkeApplicationInfo, Label},
test_util::{
dummy_vdaf::{self, AggregationParam},
install_test_trace_subscriber,
Expand All @@ -51,7 +51,6 @@ use trillium_testing::{
pub(crate) struct CollectionJobTestCase {
pub(super) task: Task,
clock: MockClock,
pub(super) collector_hpke_keypair: HpkeKeypair,
pub(super) handler: Box<dyn Handler>,
pub(super) datastore: Arc<Datastore<MockClock>>,
_ephemeral_datastore: EphemeralDatastore,
Expand Down Expand Up @@ -92,7 +91,7 @@ impl CollectionJobTestCase {
self.put_collection_job_with_auth_token(
collection_job_id,
request,
self.task.collector_auth_token(),
Some(self.task.collector_auth_token()),
)
.await
}
Expand Down Expand Up @@ -121,7 +120,7 @@ impl CollectionJobTestCase {
) -> TestConn {
self.post_collection_job_with_auth_token(
collection_job_id,
self.task.collector_auth_token(),
Some(self.task.collector_auth_token()),
)
.await
}
Expand All @@ -133,15 +132,13 @@ pub(crate) async fn setup_collection_job_test_case(
) -> CollectionJobTestCase {
install_test_trace_subscriber();

let collector_hpke_keypair = generate_test_hpke_config_and_private_key();
let task = TaskBuilder::new(query_type, VdafInstance::Fake, role)
.with_collector_hpke_config(collector_hpke_keypair.config().clone())
.build();
let task = TaskBuilder::new(query_type, VdafInstance::Fake).build();
let role_task = task.view_for_role(role).unwrap();
let clock = MockClock::default();
let ephemeral_datastore = ephemeral_datastore().await;
let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);

datastore.put_task(&task).await.unwrap();
datastore.put_aggregator_task(&role_task).await.unwrap();

let handler = aggregator_handler(
Arc::clone(&datastore),
Expand All @@ -158,7 +155,6 @@ pub(crate) async fn setup_collection_job_test_case(
CollectionJobTestCase {
task,
clock,
collector_hpke_keypair,
handler: Box::new(handler),
datastore,
_ephemeral_datastore: ephemeral_datastore,
Expand Down Expand Up @@ -316,7 +312,7 @@ async fn collection_job_success_fixed_size() {
let batch_id = *collection_job.batch_identifier();

let encrypted_helper_aggregate_share = hpke::seal(
task.collector_hpke_config().unwrap(),
task.collector_hpke_keypair().config(),
&HpkeApplicationInfo::new(
&Label::AggregateShare,
&Role::Helper,
Expand Down Expand Up @@ -371,8 +367,8 @@ async fn collection_job_success_fixed_size() {
);

let decrypted_leader_aggregate_share = hpke::open(
test_case.task.collector_hpke_config().unwrap(),
test_case.collector_hpke_keypair.private_key(),
test_case.task.collector_hpke_keypair().config(),
test_case.task.collector_hpke_keypair().private_key(),
&HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector),
collect_resp.leader_encrypted_aggregate_share(),
&AggregateShareAad::new(
Expand All @@ -390,8 +386,8 @@ async fn collection_job_success_fixed_size() {
);

let decrypted_helper_aggregate_share = hpke::open(
test_case.task.collector_hpke_config().unwrap(),
test_case.collector_hpke_keypair.private_key(),
test_case.task.collector_hpke_keypair().config(),
test_case.task.collector_hpke_keypair().private_key(),
&HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector),
collect_resp.helper_encrypted_aggregate_share(),
&AggregateShareAad::new(
Expand Down
Loading

0 comments on commit a63a4a3

Please sign in to comment.