Skip to content

Commit

Permalink
Task rewrite: janus_cli (#2028)
Browse files Browse the repository at this point in the history
Adopts `AggregatorTask` in `janus_cli`. There's a functional change
here: the YAML task accepted by `janus_cli` is now an `AggregatorTask`
(via `SerializedAggregatorTask`), containing an aggregator's specific
view of a task, meaning it contains just the peer aggregator's endpoint
URL. In the near future, there will be more changes to the task YAML
format so that for instance leader tasks only get the collector auth
token _hash_ and helper tasks only get the aggregator auth token _hash_.

Part of #1524
  • Loading branch information
tgeoghegan authored Sep 30, 2023
1 parent a992268 commit c9fb9a3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 95 deletions.
137 changes: 48 additions & 89 deletions aggregator/src/bin/janus_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use janus_aggregator::{
};
use janus_aggregator_core::{
datastore::{self, Datastore},
task::{SerializedTask, Task},
task::{AggregatorTask, SerializedAggregatorTask},
};
use janus_core::time::{Clock, RealClock};
use k8s_openapi::api::core::v1::Secret;
Expand Down Expand Up @@ -155,24 +155,24 @@ async fn provision_tasks<C: Clock>(
tasks_file: &Path,
generate_missing_parameters: bool,
dry_run: bool,
) -> Result<Vec<Task>> {
) -> Result<Vec<AggregatorTask>> {
// Read tasks file.
let tasks: Vec<SerializedTask> = {
let tasks: Vec<SerializedAggregatorTask> = {
let task_file_contents = fs::read_to_string(tasks_file)
.await
.with_context(|| format!("couldn't read tasks file {tasks_file:?}"))?;
serde_yaml::from_str(&task_file_contents)
.with_context(|| format!("couldn't parse tasks file {tasks_file:?}"))?
};

let tasks: Vec<Task> = tasks
let tasks: Vec<AggregatorTask> = tasks
.into_iter()
.map(|mut task| {
if generate_missing_parameters {
task.generate_missing_fields();
}

Task::try_from(task)
AggregatorTask::try_from(task)
})
.collect::<Result<_, _>>()?;

Expand Down Expand Up @@ -201,7 +201,7 @@ async fn provision_tasks<C: Clock>(
err => err?,
}

tx.put_task(task).await?;
tx.put_aggregator_task(task).await?;

written_tasks.push(task.clone());
}
Expand Down Expand Up @@ -464,7 +464,7 @@ mod tests {
};
use janus_aggregator_core::{
datastore::{test_util::ephemeral_datastore, Datastore},
task::{test_util::TaskBuilder, QueryType, Task},
task::{test_util::NewTaskBuilder as TaskBuilder, AggregatorTask, QueryType},
};
use janus_core::{
test_util::{kubernetes, roundtrip_encoding},
Expand Down Expand Up @@ -555,15 +555,15 @@ mod tests {
.unwrap_err();
}

fn task_hashmap_from_slice(tasks: Vec<Task>) -> HashMap<TaskId, Task> {
fn task_hashmap_from_slice(tasks: Vec<AggregatorTask>) -> HashMap<TaskId, AggregatorTask> {
tasks.into_iter().map(|task| (*task.id(), task)).collect()
}

async fn run_provision_tasks_testcase(
ds: &Datastore<RealClock>,
tasks: &[Task],
tasks: &[AggregatorTask],
dry_run: bool,
) -> Vec<Task> {
) -> Vec<AggregatorTask> {
// Write tasks to a temporary file.
let mut tasks_file = NamedTempFile::new().unwrap();
tasks_file
Expand All @@ -583,18 +583,14 @@ mod tests {
let ds = ephemeral_datastore.datastore(RealClock::default()).await;

let tasks = Vec::from([
TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Count,
Role::Leader,
)
.build(),
TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Sum { bits: 64 },
Role::Helper,
)
.build(),
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.build()
.leader_view()
.unwrap(),
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Sum { bits: 64 })
.build()
.helper_view()
.unwrap(),
]);

let written_tasks = run_provision_tasks_testcase(&ds, &tasks, false).await;
Expand All @@ -603,51 +599,34 @@ mod tests {
let want_tasks = task_hashmap_from_slice(tasks);
let written_tasks = task_hashmap_from_slice(written_tasks);
let got_tasks = task_hashmap_from_slice(
ds.run_tx(|tx| Box::pin(async move { tx.get_tasks().await }))
ds.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await }))
.await
.unwrap(),
);
assert_eq!(
want_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
got_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
assert_eq!(
want_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
written_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
assert_eq!(want_tasks, got_tasks);
assert_eq!(want_tasks, written_tasks);
}

#[tokio::test]
async fn provision_task_dry_run() {
let ephemeral_datastore = ephemeral_datastore().await;
let ds = ephemeral_datastore.datastore(RealClock::default()).await;

let tasks = Vec::from([TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Count,
Role::Leader,
)
.build()]);
let tasks =
Vec::from([
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.build()
.leader_view()
.unwrap(),
]);

let written_tasks = run_provision_tasks_testcase(&ds, &tasks, true).await;

let want_tasks = task_hashmap_from_slice(tasks);
let written_tasks = task_hashmap_from_slice(written_tasks);
assert_eq!(want_tasks, written_tasks);
let got_tasks = task_hashmap_from_slice(
ds.run_tx(|tx| Box::pin(async move { tx.get_tasks().await }))
ds.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await }))
.await
.unwrap(),
);
Expand All @@ -657,18 +636,14 @@ mod tests {
#[tokio::test]
async fn replace_task() {
let tasks = Vec::from([
TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Count,
Role::Leader,
)
.build(),
TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Sum { bits: 64 },
Role::Helper,
)
.build(),
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.build()
.leader_view()
.unwrap(),
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Sum { bits: 64 })
.build()
.leader_view()
.unwrap(),
]);

let ephemeral_datastore = ephemeral_datastore().await;
Expand All @@ -693,10 +668,11 @@ mod tests {
length: 4,
chunk_length: 2,
},
Role::Leader,
)
.with_id(*tasks[0].id())
.build();
.build()
.leader_view()
.unwrap();

let mut replacement_tasks_file = NamedTempFile::new().unwrap();
replacement_tasks_file
Expand All @@ -716,34 +692,24 @@ mod tests {

// Verify that the expected tasks were written.
let got_tasks = task_hashmap_from_slice(
ds.run_tx(|tx| Box::pin(async move { tx.get_tasks().await }))
ds.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await }))
.await
.unwrap(),
);
let want_tasks = HashMap::from([
(
*replacement_task.id(),
replacement_task.view_for_role().unwrap(),
),
(*tasks[1].id(), tasks[1].view_for_role().unwrap()),
(*replacement_task.id(), replacement_task),
(*tasks[1].id(), tasks[1].clone()),
]);

assert_eq!(
want_tasks,
got_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
assert_eq!(want_tasks, got_tasks);
}

#[tokio::test]
async fn provision_task_with_generated_values() {
// YAML contains no task ID, VDAF verify keys, aggregator auth tokens, collector auth tokens
// or HPKE keys.
let serialized_task_yaml = r#"
- leader_aggregator_endpoint: https://leader
helper_aggregator_endpoint: https://helper
- peer_aggregator_endpoint: https://helper
query_type: TimeInterval
vdaf: !Prio3Sum
bits: 2
Expand All @@ -764,8 +730,7 @@ mod tests {
aggregator_auth_token:
collector_auth_token:
hpke_keys: []
- leader_aggregator_endpoint: https://leader
helper_aggregator_endpoint: https://helper
- peer_aggregator_endpoint: https://leader
query_type: TimeInterval
vdaf: !Prio3Sum
bits: 2
Expand Down Expand Up @@ -822,7 +787,7 @@ mod tests {

// Verify that the expected tasks were written.
let got_tasks = ds
.run_tx(|tx| Box::pin(async move { tx.get_tasks().await }))
.run_tx(|tx| Box::pin(async move { tx.get_aggregator_tasks().await }))
.await
.unwrap();

Expand All @@ -837,14 +802,8 @@ mod tests {
}

assert_eq!(
task_hashmap_from_slice(written_tasks)
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
task_hashmap_from_slice(written_tasks),
task_hashmap_from_slice(got_tasks)
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

Expand Down
3 changes: 2 additions & 1 deletion aggregator_core/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2257,7 +2257,8 @@ mod tests {

#[test]
fn deserialize_docs_sample_tasks() {
serde_yaml::from_str::<Vec<Task>>(include_str!("../../docs/samples/tasks.yaml")).unwrap();
serde_yaml::from_str::<Vec<AggregatorTask>>(include_str!("../../docs/samples/tasks.yaml"))
.unwrap();
}

#[test]
Expand Down
8 changes: 3 additions & 5 deletions docs/samples/tasks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
# DAP's recommendation.
task_id: "G9YKXjoEjfoU7M_fi_o2H0wmzavRb2sBFHeykeRhDMk"

# HTTPS endpoints of the leader and helper aggregators.
leader_aggregator_endpoint: "https://example.com/"
helper_aggregator_endpoint: "https://example.net/"
# HTTPS endpoint of the peer aggregator.
peer_aggregator_endpoint: "https://example.com/"

# The DAP query type. See below for an example of a fixed-size task
query_type: TimeInterval
Expand Down Expand Up @@ -98,8 +97,7 @@
private_key: wFRYwiypcHC-mkGP1u3XQgIvtnlkQlUfZjgtM_zRsnI

- task_id: "D-hCKPuqL2oTf7ZVRVyMP5VGt43EAEA8q34mDf6p1JE"
leader_aggregator_endpoint: "https://example.org/"
helper_aggregator_endpoint: "https://example.com/"
peer_aggregator_endpoint: "https://example.org/"
# For tasks using the fixed size query type, an additional `max_batch_size`
# parameter must be provided.
query_type: !FixedSize
Expand Down

0 comments on commit c9fb9a3

Please sign in to comment.