Skip to content

Commit

Permalink
Task rewrite: adopt AggregatorTask in taskprov (#2019)
Browse files Browse the repository at this point in the history
Adopts `janus_aggregator_core::task::AggregatorTask` in the `taskprov`
module and code that uses it in `janus_aggregator`.

Of particular interest is that we do away with
`janus_aggregator_core::taskprov::Task` and instead represent taskprov
tasks as a `janus_aggregator_core::task::Task` with
`AggregatorTaskParameters::TaskProvHelper`. Hopefully we'll be able to
further unify handling of taskprov and regular tasks in the future.

Part of #1524
  • Loading branch information
tgeoghegan authored Sep 29, 2023
1 parent 41e2677 commit b2fc735
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 137 deletions.
23 changes: 7 additions & 16 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use janus_aggregator_core::{
},
query_type::AccumulableQueryType,
task::{self, AggregatorTask, Task, VerifyKey},
taskprov::{self, PeerAggregator},
taskprov::PeerAggregator,
};
#[cfg(feature = "test-util")]
use janus_core::test_util::dummy_vdaf;
Expand Down Expand Up @@ -645,7 +645,7 @@ impl<C: Clock> Aggregator<C> {
task_config: &TaskConfig,
aggregator_auth_token: Option<&AuthenticationToken>,
) -> Result<(), Error> {
let (peer_aggregator, leader_url, helper_url) = self
let (peer_aggregator, leader_url, _) = self
.taskprov_authorize_request(peer_role, task_id, task_config, aggregator_auth_token)
.await?;

Expand All @@ -661,38 +661,29 @@ impl<C: Clock> Aggregator<C> {
Error::InvalidTask(*task_id, OptOutReason::InvalidParameter(err.to_string()))
})?;

let our_role = match peer_role {
Role::Leader => Role::Helper,
Role::Helper => Role::Leader,
_ => {
return Err(Error::Internal(
"role should have only been Helper or Leader".to_string(),
))
}
};

let vdaf_verify_key = peer_aggregator.derive_vdaf_verify_key(task_id, &vdaf_instance);

let task = taskprov::Task::new(
let task = AggregatorTask::new(
*task_id,
leader_url,
helper_url,
task_config.query_config().query().try_into()?,
vdaf_instance,
our_role,
vdaf_verify_key,
task_config.query_config().max_batch_query_count() as u64,
Some(*task_config.task_expiration()),
peer_aggregator.report_expiry_age().cloned(),
task_config.query_config().min_batch_size() as u64,
*task_config.query_config().time_precision(),
*peer_aggregator.tolerable_clock_skew(),
// Taskprov task has no per-task HPKE keys
[],
task::AggregatorTaskParameters::TaskProvHelper,
)
.map_err(|err| Error::InvalidTask(*task_id, OptOutReason::TaskParameters(err)))?;
self.datastore
.run_tx_with_name("taskprov_put_task", |tx| {
let task = task.clone();
Box::pin(async move { tx.put_task(&task.into()).await })
Box::pin(async move { tx.put_aggregator_task(&task).await })
})
.await
.or_else(|error| -> Result<(), Error> {
Expand Down
38 changes: 11 additions & 27 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,10 @@ mod tests {
Datastore,
},
query_type::{AccumulableQueryType, CollectableQueryType},
task::{test_util::TaskBuilder, QueryType, VerifyKey},
taskprov,
task::{
test_util::{NewTaskBuilder, TaskBuilder},
QueryType, VerifyKey,
},
test_util::noop_meter,
};
use janus_core::{
Expand Down Expand Up @@ -968,30 +970,12 @@ mod tests {
.unwrap();

// Insert a taskprov task. This task won't have its task-specific HPKE key.
let task = TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Count,
Role::Helper,
)
.build();
let task_id = *task.id();
let task = taskprov::Task::new(
task_id,
task.leader_aggregator_endpoint().clone(),
task.helper_aggregator_endpoint().clone(),
*task.query_type(),
task.vdaf().clone(),
*task.role(),
task.opaque_vdaf_verify_key().clone(),
task.max_batch_query_count(),
task.task_expiration().cloned(),
task.report_expiry_age().cloned(),
task.min_batch_size(),
*task.time_precision(),
*task.tolerable_clock_skew(),
)
.unwrap();
datastore.put_task(&task.into()).await.unwrap();
let task = NewTaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count).build();
let taskprov_helper_task = task.taskprov_helper_view().unwrap();
datastore
.put_aggregator_task(&taskprov_helper_task)
.await
.unwrap();

let cfg = Config {
taskprov_config: TaskprovConfig { enabled: true },
Expand All @@ -1012,7 +996,7 @@ mod tests {
.await
.unwrap();

let mut test_conn = get(&format!("/hpke_config?task_id={}", task_id))
let mut test_conn = get(&format!("/hpke_config?task_id={}", task.id()))
.run_async(&handler)
.await;
assert_eq!(test_conn.status(), Some(Status::Ok));
Expand Down
41 changes: 23 additions & 18 deletions aggregator/src/aggregator/taskprov_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use janus_aggregator_core::{
test_util::{ephemeral_datastore, EphemeralDatastore},
Datastore,
},
task::{QueryType, Task},
task::{
test_util::{NewTaskBuilder as TaskBuilder, Task},
QueryType,
},
taskprov::{test_util::PeerAggregatorBuilder, PeerAggregator},
test_util::noop_meter,
};
Expand Down Expand Up @@ -64,6 +67,7 @@ use trillium_testing::{
assert_headers,
prelude::{post, put},
};
use url::Url;

type TestVdaf = Poplar1<XofShake128, 16>;

Expand Down Expand Up @@ -167,31 +171,30 @@ async fn setup_taskprov_test() -> TaskprovTestCase {
.unwrap();
let measurement = IdpfInput::from_bools(&[true]);

let task = janus_aggregator_core::taskprov::Task::new(
task_id,
url::Url::parse("https://leader.example.com/").unwrap(),
url::Url::parse("https://helper.example.com/").unwrap(),
let task = TaskBuilder::new(
QueryType::FixedSize {
max_batch_size: max_batch_size as u64,
batch_time_window_size: None,
},
vdaf_instance,
Role::Helper,
vdaf_verify_key.clone(),
max_batch_query_count as u64,
Some(task_expiration),
peer_aggregator.report_expiry_age().copied(),
min_batch_size as u64,
Duration::from_seconds(1),
Duration::from_seconds(1),
)
.unwrap();
.with_id(task_id)
.with_leader_aggregator_endpoint(Url::parse("https://leader.example.com/").unwrap())
.with_helper_aggregator_endpoint(Url::parse("https://helper.example.com/").unwrap())
.with_vdaf_verify_key(vdaf_verify_key.clone())
.with_max_batch_query_count(max_batch_query_count as u64)
.with_task_expiration(Some(task_expiration))
.with_report_expiry_age(peer_aggregator.report_expiry_age().copied())
.with_min_batch_size(min_batch_size as u64)
.with_time_precision(Duration::from_seconds(1))
.with_tolerable_clock_skew(Duration::from_seconds(1))
.build();

let report_metadata = ReportMetadata::new(
random(),
clock
.now()
.to_batch_interval_start(task.task().time_precision())
.to_batch_interval_start(task.time_precision())
.unwrap(),
);
let transcript = run_vdaf(
Expand All @@ -217,7 +220,7 @@ async fn setup_taskprov_test() -> TaskprovTestCase {
datastore,
handler: Box::new(handler),
peer_aggregator,
task: task.into(),
task,
task_config,
task_id,
report_metadata,
Expand Down Expand Up @@ -740,7 +743,8 @@ async fn taskprov_aggregate_continue() {

Box::pin(async move {
// Aggregate continue is only possible if the task has already been inserted.
tx.put_task(&task).await?;
tx.put_aggregator_task(&task.taskprov_helper_view().unwrap())
.await?;

tx.put_report_share(task.id(), &report_share).await?;

Expand Down Expand Up @@ -883,7 +887,8 @@ async fn taskprov_aggregate_share() {
let transcript = test.transcript.clone();

Box::pin(async move {
tx.put_task(&task).await?;
tx.put_aggregator_task(&task.taskprov_helper_view().unwrap())
.await?;

tx.put_batch(&Batch::<16, FixedSize, TestVdaf>::new(
*task.id(),
Expand Down
78 changes: 2 additions & 76 deletions aggregator_core/src/taskprov.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use crate::{
task::{self, Error, QueryType},
SecretBytes,
};
use crate::{task::Error, SecretBytes};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use derivative::Derivative;
use janus_core::{auth_tokens::AuthenticationToken, vdaf::VdafInstance};
use janus_messages::{Duration, HpkeConfig, Role, TaskId, Time};
use janus_messages::{Duration, HpkeConfig, Role, TaskId};
use rand::{distributions::Standard, prelude::Distribution};
use ring::hkdf::{KeyType, Salt, HKDF_SHA256};
use serde::{
Expand Down Expand Up @@ -266,77 +263,6 @@ impl KeyType for VdafVerifyKeyLength {
}
}

/// Newtype for [`task::Task`], which omits certain fields that aren't required for taskprov tasks.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Task(pub(super) task::Task);

impl Task {
#[allow(clippy::too_many_arguments)]
pub fn new(
task_id: TaskId,
leader_aggregator_endpoint: Url,
helper_aggregator_endpoint: Url,
query_type: QueryType,
vdaf: VdafInstance,
role: Role,
vdaf_verify_key: SecretBytes,
max_batch_query_count: u64,
task_expiration: Option<Time>,
report_expiry_age: Option<Duration>,
min_batch_size: u64,
time_precision: Duration,
tolerable_clock_skew: Duration,
) -> Result<Self, Error> {
let task = Self(task::Task::new_without_validation(
task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
query_type,
vdaf,
role,
vdaf_verify_key,
max_batch_query_count,
task_expiration,
report_expiry_age,
min_batch_size,
time_precision,
tolerable_clock_skew,
None,
None,
None,
Vec::new(),
));
task.validate()?;
Ok(task)
}

pub(super) fn validate(&self) -> Result<(), Error> {
self.0.validate_common()?;
if let QueryType::FixedSize {
batch_time_window_size,
..
} = self.0.query_type()
{
if batch_time_window_size.is_some() {
return Err(Error::InvalidParameter(
"batch_time_window_size is not supported for taskprov",
));
}
}
Ok(())
}

pub fn task(&self) -> &task::Task {
&self.0
}
}

impl From<Task> for task::Task {
fn from(value: Task) -> Self {
value.0
}
}

#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
Expand Down

0 comments on commit b2fc735

Please sign in to comment.