Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task rewrite: adopt AggregatorTask in datastore #2017

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ mod tests {
let task = TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Count,
Role::Leader,
Role::Helper,
)
.build();
let task_id = *task.id();
Expand Down
4 changes: 2 additions & 2 deletions aggregator/src/aggregator/taskprov_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ async fn taskprov_aggregate_init() {
tx.get_aggregation_jobs_for_task::<16, FixedSize, TestVdaf>(&task_id)
.await
.unwrap(),
tx.get_task(&task_id).await.unwrap(),
tx.get_aggregator_task(&task_id).await.unwrap(),
))
})
})
Expand All @@ -333,7 +333,7 @@ async fn taskprov_aggregate_init() {
.state()
.eq(&AggregationJobState::InProgress)
);
assert_eq!(test.task, got_task.unwrap());
assert_eq!(test.task.taskprov_helper_view().unwrap(), got_task.unwrap());
}

#[tokio::test]
Expand Down
45 changes: 39 additions & 6 deletions aggregator/src/bin/janus_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,26 @@ mod tests {
.await
.unwrap(),
);
assert_eq!(want_tasks, got_tasks);
assert_eq!(want_tasks, written_tasks);
assert_eq!(
want_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
got_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
assert_eq!(
want_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
written_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

#[tokio::test]
Expand Down Expand Up @@ -703,11 +721,20 @@ mod tests {
.unwrap(),
);
let want_tasks = HashMap::from([
(*replacement_task.id(), replacement_task),
(*tasks[1].id(), tasks[1].clone()),
(
*replacement_task.id(),
replacement_task.view_for_role().unwrap(),
),
(*tasks[1].id(), tasks[1].view_for_role().unwrap()),
]);

assert_eq!(want_tasks, got_tasks);
assert_eq!(
want_tasks,
got_tasks
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

#[tokio::test]
Expand Down Expand Up @@ -810,8 +837,14 @@ mod tests {
}

assert_eq!(
task_hashmap_from_slice(written_tasks),
task_hashmap_from_slice(written_tasks)
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect::<HashMap<_, _>>(),
task_hashmap_from_slice(got_tasks)
.iter()
.map(|(k, v)| { (*k, v.view_for_role().unwrap()) })
.collect()
);
}

Expand Down
170 changes: 109 additions & 61 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use self::models::{
};
use crate::{
query_type::{AccumulableQueryType, CollectableQueryType},
task::{self, Task},
taskprov::{self, PeerAggregator},
task::{self, AggregatorTask, AggregatorTaskParameters, Task},
taskprov::PeerAggregator,
SecretBytes,
};
use chrono::NaiveDateTime;
Expand Down Expand Up @@ -306,6 +306,7 @@ impl<C: Clock> Datastore<C> {
}

/// Write a task into the datastore.
// TODO(#1524): remove this once everything has migrated to put_aggregator_task
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub async fn put_task(&self, task: &Task) -> Result<(), Error> {
Expand All @@ -315,6 +316,17 @@ impl<C: Clock> Datastore<C> {
})
.await
}

/// Write a task into the datastore.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> {
self.run_tx(|tx| {
let task = task.clone();
Box::pin(async move { tx.put_aggregator_task(&task).await })
})
.await
}
}

fn check_error<T>(
Expand Down Expand Up @@ -525,20 +537,34 @@ impl<C: Clock> Transaction<'_, C> {
}

/// Writes a task into the datastore.
// TODO(#1524): remove this once everything has migrated to put_aggregator_task
#[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)]
pub async fn put_task(&self, task: &Task) -> Result<(), Error> {
let aggregator_task = match task.role() {
Role::Leader => task.leader_view()?,
Role::Helper => task
.helper_view()
.or_else(|_| task.taskprov_helper_view())?,
_ => return Err(Error::InvalidParameter("role must be aggregator")),
};

self.put_aggregator_task(&aggregator_task).await
}

/// Writes a task into the datastore.
#[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)]
pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> {
// Main task insert.
let stmt = self
.prepare_cached(
"INSERT INTO tasks (
task_id, aggregator_role, leader_aggregator_endpoint,
helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count,
task_expiration, report_expiry_age, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
task_id, aggregator_role, peer_aggregator_endpoint, query_type, vdaf,
max_batch_query_count, task_expiration, report_expiry_age, min_batch_size,
time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type,
collector_auth_token)
VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
)
ON CONFLICT DO NOTHING",
)
Expand All @@ -549,10 +575,8 @@ impl<C: Clock> Transaction<'_, C> {
&[
/* task_id */ &task.id().as_ref(),
/* aggregator_role */ &AggregatorRole::from_role(*task.role())?,
/* leader_aggregator_endpoint */
&task.leader_aggregator_endpoint().as_str(),
/* helper_aggregator_endpoint */
&task.helper_aggregator_endpoint().as_str(),
/* peer_aggregator_endpoint */
&task.peer_aggregator_endpoint().as_str(),
/* query_type */ &Json(task.query_type()),
/* vdaf */ &Json(task.vdaf()),
/* max_batch_query_count */
Expand All @@ -574,9 +598,7 @@ impl<C: Clock> Transaction<'_, C> {
/* tolerable_clock_skew */
&i64::try_from(task.tolerable_clock_skew().as_seconds())?,
/* collector_hpke_config */
&task
.collector_hpke_config()
.map(|config| config.get_encoded()),
&task.collector_hpke_config().map(|cfg| cfg.get_encoded()),
/* vdaf_verify_key */
&self.crypter.encrypt(
"tasks",
Expand Down Expand Up @@ -625,6 +647,7 @@ impl<C: Clock> Transaction<'_, C> {
let mut hpke_config_ids: Vec<i16> = Vec::new();
let mut hpke_configs: Vec<Vec<u8>> = Vec::new();
let mut hpke_private_keys: Vec<Vec<u8>> = Vec::new();

for hpke_keypair in task.hpke_keys().values() {
let mut row_id = [0u8; TaskId::LEN + size_of::<u8>()];
row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref());
Expand Down Expand Up @@ -677,16 +700,26 @@ impl<C: Clock> Transaction<'_, C> {
}

/// Fetch the task parameters corresponing to the provided `task_id`.
// TODO(#1524): remove this once everything has migrated to get_aggregator_task
#[tracing::instrument(skip(self), err)]
pub async fn get_task(&self, task_id: &TaskId) -> Result<Option<Task>, Error> {
Ok(self.get_aggregator_task(task_id).await?.map(Task::from))
}

/// Fetch the task parameters corresponing to the provided `task_id`.
#[tracing::instrument(skip(self), err)]
pub async fn get_aggregator_task(
&self,
task_id: &TaskId,
) -> Result<Option<AggregatorTask>, Error> {
let params: &[&(dyn ToSql + Sync)] = &[&task_id.as_ref()];
let stmt = self
.prepare_cached(
"SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint,
query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age,
min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config,
vdaf_verify_key, aggregator_auth_token_type, aggregator_auth_token,
collector_auth_token_type, collector_auth_token
"SELECT aggregator_role, peer_aggregator_endpoint, query_type, vdaf,
max_batch_query_count, task_expiration, report_expiry_age, min_batch_size,
time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type,
collector_auth_token
FROM tasks WHERE task_id = $1",
)
.await?;
Expand All @@ -707,14 +740,25 @@ impl<C: Clock> Transaction<'_, C> {
}

/// Fetch all the tasks in the database.
// TODO(#1524): remove this once everything has migrated to get_aggregator_tasks
#[tracing::instrument(skip(self), err)]
pub async fn get_tasks(&self) -> Result<Vec<Task>, Error> {
Ok(self
.get_aggregator_tasks()
.await?
.into_iter()
.map(Task::from)
.collect())
}

/// Fetch all the tasks in the database.
#[tracing::instrument(skip(self), err)]
pub async fn get_aggregator_tasks(&self) -> Result<Vec<AggregatorTask>, Error> {
let stmt = self
.prepare_cached(
"SELECT task_id, aggregator_role, leader_aggregator_endpoint,
helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count,
task_expiration, report_expiry_age, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
"SELECT task_id, aggregator_role, peer_aggregator_endpoint, query_type, vdaf,
max_batch_query_count, task_expiration, report_expiry_age, min_batch_size,
time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type,
collector_auth_token
FROM tasks",
Expand Down Expand Up @@ -768,13 +812,10 @@ impl<C: Clock> Transaction<'_, C> {
task_id: &TaskId,
row: &Row,
hpke_key_rows: &[Row],
) -> Result<Task, Error> {
) -> Result<AggregatorTask, Error> {
// Scalar task parameters.
let aggregator_role: AggregatorRole = row.get("aggregator_role");
let leader_aggregator_endpoint =
row.get::<_, String>("leader_aggregator_endpoint").parse()?;
let helper_aggregator_endpoint =
row.get::<_, String>("helper_aggregator_endpoint").parse()?;
let peer_aggregator_endpoint = row.get::<_, String>("peer_aggregator_endpoint").parse()?;
let query_type = row.try_get::<_, Json<task::QueryType>>("query_type")?.0;
let vdaf = row.try_get::<_, Json<VdafInstance>>("vdaf")?.0;
let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?;
Expand Down Expand Up @@ -831,7 +872,7 @@ impl<C: Clock> Transaction<'_, C> {
.transpose()?;

// HPKE keys.
let mut hpke_keypairs = Vec::new();
let mut hpke_keys = Vec::new();
for row in hpke_key_rows {
let config_id = u8::try_from(row.get::<_, i16>("config_id"))?;
let config = HpkeConfig::get_decoded(row.get("config"))?;
Expand All @@ -848,50 +889,57 @@ impl<C: Clock> Transaction<'_, C> {
&encrypted_private_key,
)?);

hpke_keypairs.push(HpkeKeypair::new(config, private_key));
hpke_keys.push(HpkeKeypair::new(config, private_key));
}

let task = Task::new_without_validation(
let aggregator_parameters = match (
aggregator_role,
aggregator_auth_token,
collector_auth_token,
collector_hpke_config,
) {
(
AggregatorRole::Leader,
Some(aggregator_auth_token),
Some(collector_auth_token),
Some(collector_hpke_config),
) => AggregatorTaskParameters::Leader {
aggregator_auth_token,
collector_auth_token,
collector_hpke_config,
},
(
AggregatorRole::Helper,
Some(aggregator_auth_token),
None,
Some(collector_hpke_config),
) => AggregatorTaskParameters::Helper {
aggregator_auth_token,
collector_hpke_config,
},
(AggregatorRole::Helper, None, None, None) => AggregatorTaskParameters::TaskProvHelper,
values => {
return Err(Error::DbState(format!(
"found task row with unexpected combination of values {values:?}",
)));
}
};

Ok(AggregatorTask::new(
*task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
peer_aggregator_endpoint,
query_type,
vdaf,
aggregator_role.as_role(),
vdaf_verify_key,
max_batch_query_count,
task_expiration,
report_expiry_age,
min_batch_size,
time_precision,
tolerable_clock_skew,
collector_hpke_config,
aggregator_auth_token,
collector_auth_token,
hpke_keypairs,
);
// Trial validation through all known schemes. This is a workaround to avoid extending the
// schema to track the provenance of tasks. If we do end up implementing a task provenance
// column anyways, we can simplify this logic.
task.validate().or_else(|error| {
taskprov::Task(task.clone())
.validate()
.map_err(|taskprov_error| {
error!(
%task_id,
%error,
%taskprov_error,
?task,
"task has failed all available validation checks",
);
// Choose some error to bubble up to the caller. Either way this error
// occurring is an indication of a bug, which we'll need to go into the
// logs for.
error
})
})?;

Ok(task)
hpke_keys,
aggregator_parameters,
)?)
}

/// Retrieves report & report aggregation metrics for a given task: either a tuple
Expand Down
Loading
Loading