Skip to content

Commit

Permalink
Build query instead of brute forcing each possible one
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga committed Jan 23, 2024
1 parent 1137832 commit c1b34d8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 57 deletions.
72 changes: 15 additions & 57 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4769,63 +4769,21 @@ impl<C: Clock> Transaction<'_, C> {
ord: u64,
incrementor: &TaskUploadIncrementor,
) -> Result<(), Error> {
// Brute force each possible query. We cannot parameterize column names in prepared
// statements and we want to avoid the hazards of string interpolation into SQL.
let stmt = self
.prepare_cached(match incrementor {
TaskUploadIncrementor::IntervalCollected => {
"INSERT INTO task_upload_counters (task_id, ord, interval_collected)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET interval_collected = task_upload_counters.interval_collected + 1"
}
TaskUploadIncrementor::ReportDecodeFailure => {
"INSERT INTO task_upload_counters (task_id, ord, report_decode_failure)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET report_decode_failure = task_upload_counters.report_decode_failure + 1"
}
TaskUploadIncrementor::ReportDecryptFailure => {
"INSERT INTO task_upload_counters (task_id, ord, report_decrypt_failure)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET report_decrypt_failure = task_upload_counters.report_decrypt_failure + 1"
}
TaskUploadIncrementor::ReportExpired => {
"INSERT INTO task_upload_counters (task_id, ord, report_expired)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET report_expired = task_upload_counters.report_expired + 1"
}
TaskUploadIncrementor::ReportOutdatedKey => {
"INSERT INTO task_upload_counters (task_id, ord, report_outdated_key)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET report_outdated_key = task_upload_counters.report_outdated_key + 1"
}
TaskUploadIncrementor::ReportSuccess => {
"INSERT INTO task_upload_counters (task_id, ord, report_success)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET report_success = task_upload_counters.report_success + 1"
}
TaskUploadIncrementor::ReportTooEarly => {
"INSERT INTO task_upload_counters (task_id, ord, report_too_early)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET report_too_early = task_upload_counters.report_too_early + 1"
}
TaskUploadIncrementor::TaskExpired => {
"INSERT INTO task_upload_counters (task_id, ord, task_expired)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET task_expired = task_upload_counters.task_expired + 1"
}
})
.await?;
let params: &[&(dyn ToSql + Sync)] = &[task_id.as_ref(), &i64::try_from(ord)?];

check_single_row_mutation(self.execute(&stmt, params).await?)
// SQL injection safety: The possible inputs of TaskUploadIncrementor and the resulting
// .column() are constrained to values that are known to be safe for interpolation. The
// calling function cannot supply arbitrary strings.
let column = incrementor.column();
let stmt = format!(
"INSERT INTO task_upload_counters (task_id, ord, {column})
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, 1)
ON CONFLICT (task_id, ord) DO UPDATE
SET {column} = task_upload_counters.{column} + 1"
);
let stmt = self.prepare_cached(&stmt).await?;
check_single_row_mutation(
self.execute(&stmt, &[task_id.as_ref(), &i64::try_from(ord)?])
.await?,
)
}
}

Expand Down
15 changes: 15 additions & 0 deletions aggregator_core/src/datastore/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1915,3 +1915,18 @@ pub enum TaskUploadIncrementor {
/// A report was submitted to the task after the task's expiry.
TaskExpired,
}

impl TaskUploadIncrementor {
pub(crate) fn column(&self) -> &'static str {
match self {
TaskUploadIncrementor::IntervalCollected => "interval_collected",
TaskUploadIncrementor::ReportDecodeFailure => "report_decode_failure",
TaskUploadIncrementor::ReportDecryptFailure => "report_decrypt_failure",
TaskUploadIncrementor::ReportExpired => "report_expired",
TaskUploadIncrementor::ReportOutdatedKey => "report_outdated_key",
TaskUploadIncrementor::ReportSuccess => "report_success",
TaskUploadIncrementor::ReportTooEarly => "report_too_early",
TaskUploadIncrementor::TaskExpired => "task_expired",
}
}
}

0 comments on commit c1b34d8

Please sign in to comment.