Skip to content

Commit

Permalink
improve/remove test workarounds
Browse files Browse the repository at this point in the history
  • Loading branch information
tgeoghegan committed Sep 29, 2023
1 parent 52595ab commit 0bf9101
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
6 changes: 2 additions & 4 deletions aggregator/src/aggregator/taskprov_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ async fn taskprov_aggregate_init() {
tx.get_aggregation_jobs_for_task::<16, FixedSize, TestVdaf>(&task_id)
.await
.unwrap(),
tx.get_task(&task_id).await.unwrap(),
tx.get_aggregator_task(&task_id).await.unwrap(),
))
})
})
Expand All @@ -333,9 +333,7 @@ async fn taskprov_aggregate_init() {
.state()
.eq(&AggregationJobState::InProgress)
);
// TODO(#1524): This assertion temporarily just checks the task ID because of the lossy
// conversion between task::Task and task::AggregatorTask.
assert_eq!(test.task.id(), got_task.unwrap().id());
assert_eq!(test.task.taskprov_helper_view().unwrap(), got_task.unwrap());
}

#[tokio::test]
Expand Down
45 changes: 39 additions & 6 deletions aggregator/src/bin/janus_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,26 @@ mod tests {
.await
.unwrap(),
);
assert_eq!(want_tasks, got_tasks);
assert_eq!(want_tasks, written_tasks);
assert_eq!(
want_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
got_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
assert_eq!(
want_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
written_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

#[tokio::test]
Expand Down Expand Up @@ -703,11 +721,20 @@ mod tests {
.unwrap(),
);
let want_tasks = HashMap::from([
(*replacement_task.id(), replacement_task),
(*tasks[1].id(), tasks[1].clone()),
(
*replacement_task.id(),
replacement_task.view_for_role().unwrap(),
),
(*tasks[1].id(), tasks[1].view_for_role().unwrap()),
]);

assert_eq!(want_tasks, got_tasks);
assert_eq!(
want_tasks,
got_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

#[tokio::test]
Expand Down Expand Up @@ -810,8 +837,14 @@ mod tests {
}

assert_eq!(
task_hashmap_from_slice(written_tasks),
task_hashmap_from_slice(written_tasks)
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
task_hashmap_from_slice(got_tasks)
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

Expand Down
17 changes: 15 additions & 2 deletions aggregator_core/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ impl Task {
pub fn helper_view(&self) -> Result<AggregatorTask, Error> {
AggregatorTask::new(
self.task_id,
self.helper_aggregator_endpoint.clone(),
self.leader_aggregator_endpoint.clone(),
self.query_type,
self.vdaf.clone(),
self.vdaf_verify_key.clone(),
Expand Down Expand Up @@ -544,7 +544,7 @@ impl Task {
pub fn taskprov_helper_view(&self) -> Result<AggregatorTask, Error> {
AggregatorTask::new(
self.task_id,
self.helper_aggregator_endpoint.clone(),
self.leader_aggregator_endpoint.clone(),
self.query_type,
self.vdaf.clone(),
self.vdaf_verify_key.clone(),
Expand All @@ -558,6 +558,19 @@ impl Task {
AggregatorTaskParameters::TaskProvHelper,
)
}

/// Render the view of the specified aggregator of this task.
///
/// # Errors
///
/// Returns an error if `self.role` is not an aggregator role.
pub fn view_for_role(&self) -> Result<AggregatorTask, Error> {
match self.role {
Role::Leader => self.leader_view(),
Role::Helper => self.helper_view().or_else(|_| self.taskprov_helper_view()),
_ => Err(Error::InvalidParameter("role is not an aggregator")),
}
}
}

impl From<AggregatorTask> for Task {
Expand Down

0 comments on commit 0bf9101

Please sign in to comment.