Skip to content

Commit

Permalink
Don't wait on bad reports
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga committed Jan 24, 2024
1 parent c1b34d8 commit 48a9e0d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 36 deletions.
60 changes: 45 additions & 15 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ pub struct Aggregator<C: Clock> {
}

/// Config represents a configuration for an Aggregator.
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Config {
/// Defines the maximum size of a batch of uploaded reports which will be written in a single
/// transaction.
Expand Down Expand Up @@ -1411,7 +1411,7 @@ impl VdafOps {
let report_time = *report.metadata().time();
async move {
let rejection = ReportRejection::new(*task.id(), report_id, report_time, reason);
report_writer.write_report(Err(rejection)).await?;
report_writer.write_rejection(rejection).await;
Ok::<_, Arc<Error>>(Arc::new(Error::ReportRejected(rejection)))
}
};
Expand Down Expand Up @@ -1562,9 +1562,9 @@ impl VdafOps {
);

report_writer
.write_report(Ok(Box::new(WritableReport::<SEED_SIZE, Q, A>::new(
.write_report(Box::new(WritableReport::<SEED_SIZE, Q, A>::new(
vdaf, report,
))))
)))
.await
}
}
Expand Down Expand Up @@ -3305,6 +3305,7 @@ mod tests {
};
use rand::random;
use std::{collections::HashSet, iter, sync::Arc, time::Duration as StdDuration};
use tokio::time::sleep;

pub(super) fn create_report_custom(
task: &AggregatorTask,
Expand Down Expand Up @@ -3526,8 +3527,9 @@ mod tests {
async fn upload_wrong_hpke_config_id() {
install_test_trace_subscriber();

let config = default_aggregator_config();
let (_, aggregator, clock, task, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;
let leader_task = task.leader_view().unwrap();
let report = create_report(&leader_task, clock.now());

Expand Down Expand Up @@ -3563,6 +3565,9 @@ mod tests {
})
});

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand Down Expand Up @@ -3620,9 +3625,9 @@ mod tests {
#[tokio::test]
async fn upload_report_in_the_future_past_clock_skew() {
install_test_trace_subscriber();

let config = default_aggregator_config();
let (_, aggregator, clock, task, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;
let report = create_report(
&task.leader_view().unwrap(),
clock
Expand All @@ -3644,6 +3649,9 @@ mod tests {
assert_matches!(rejection.reason(), ReportRejectionReason::TooEarly);
});

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand All @@ -3661,8 +3669,9 @@ mod tests {
async fn upload_report_for_collected_batch() {
install_test_trace_subscriber();

let config = default_aggregator_config();
let (_, aggregator, clock, task, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;
let report = create_report(&task.leader_view().unwrap(), clock.now());

// Insert a collection job for the batch interval including our report.
Expand Down Expand Up @@ -3712,6 +3721,9 @@ mod tests {
}
);

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand Down Expand Up @@ -3805,9 +3817,9 @@ mod tests {
#[tokio::test]
async fn upload_report_task_expired() {
install_test_trace_subscriber();

let config = default_aggregator_config();
let (_, aggregator, clock, _, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;

let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.with_task_expiration(Some(clock.now()))
Expand Down Expand Up @@ -3835,6 +3847,9 @@ mod tests {
}
);

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand All @@ -3851,9 +3866,9 @@ mod tests {
#[tokio::test]
async fn upload_report_report_expired() {
install_test_trace_subscriber();

let config = default_aggregator_config();
let (_, aggregator, clock, _, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;

let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count)
.with_report_expiry_age(Some(Duration::from_seconds(60)))
Expand Down Expand Up @@ -3882,6 +3897,9 @@ mod tests {
}
);

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand All @@ -3898,8 +3916,9 @@ mod tests {
#[tokio::test]
async fn upload_report_faulty_encryption() {
install_test_trace_subscriber();
let config = default_aggregator_config();
let (_, aggregator, clock, task, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;

let task = task.leader_view().unwrap();

Expand Down Expand Up @@ -3928,6 +3947,9 @@ mod tests {
}
);

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand All @@ -3944,8 +3966,9 @@ mod tests {
#[tokio::test]
async fn upload_report_public_share_decode_failure() {
install_test_trace_subscriber();
let config = default_aggregator_config();
let (_, aggregator, clock, task, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;

let task = task.leader_view().unwrap();

Expand Down Expand Up @@ -3973,6 +3996,9 @@ mod tests {
}
);

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand All @@ -3989,8 +4015,9 @@ mod tests {
#[tokio::test]
async fn upload_report_leader_input_share_decode_failure() {
install_test_trace_subscriber();
let config = default_aggregator_config();
let (_, aggregator, clock, task, datastore, _ephemeral_datastore) =
setup_upload_test(default_aggregator_config()).await;
setup_upload_test(config).await;

let task = task.leader_view().unwrap();

Expand Down Expand Up @@ -4032,6 +4059,9 @@ mod tests {
}
);

// Wait out the batch write delay so the report status can be flushed to the DB.
sleep(config.max_upload_batch_write_delay * 2).await;

let got_counters = datastore
.run_unnamed_tx(|tx| {
let task_id = *task.id();
Expand Down
63 changes: 43 additions & 20 deletions aggregator/src/aggregator/report_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ type ReportResult<C> = Result<Box<dyn ReportWriter<C>>, ReportRejection>;

type ResultSender = oneshot::Sender<Result<(), Arc<Error>>>;

type ReportWriteBatcherSender<C> = mpsc::Sender<(ReportResult<C>, ResultSender)>;
type ReportWriteBatcherReceiver<C> = mpsc::Receiver<(ReportResult<C>, ResultSender)>;
type ReportWriteBatcherSender<C> = mpsc::Sender<(ReportResult<C>, Option<ResultSender>)>;
type ReportWriteBatcherReceiver<C> = mpsc::Receiver<(ReportResult<C>, Option<ResultSender>)>;

pub struct ReportWriteBatcher<C: Clock> {
report_tx: ReportWriteBatcherSender<C>,
Expand Down Expand Up @@ -54,15 +54,36 @@ impl<C: Clock> ReportWriteBatcher<C> {
Self { report_tx }
}

pub async fn write_report(&self, report_result: ReportResult<C>) -> Result<(), Arc<Error>> {
/// Save a [`ReportRejection`] to the database.
///
/// This function does not wait for the result of the batch write, because we do not want
/// clients to retry bad reports, even due to server error.
pub async fn write_rejection(&self, report_rejection: ReportRejection) {
// Unwrap safety: report_rx is not dropped until ReportWriteBatcher is dropped.
self.report_tx
.send((Err(report_rejection), None))
.await
.unwrap();
}

/// Save a report to the database.
///
/// This function waits for and returns the result of the batch write.
pub async fn write_report(
&self,
report_writer: Box<dyn ReportWriter<C>>,
) -> Result<(), Arc<Error>> {
// Send report to be written.
// Unwrap safety: report_rx is not dropped until ReportWriteBatcher is dropped.
let (rslt_tx, rslt_rx) = oneshot::channel();
self.report_tx.send((report_result, rslt_tx)).await.unwrap();
let (result_tx, result_rx) = oneshot::channel();
self.report_tx
.send((Ok(report_writer), Some(result_tx)))
.await
.unwrap();

// Await the result of writing the report.
// Unwrap safety: rslt_tx is always sent on before being dropped, and is never closed.
rslt_rx.await.unwrap()
result_rx.await.unwrap()
}

#[tracing::instrument(name = "ReportWriteBatcher::run_upload_batcher", skip(ds, report_rx))]
Expand Down Expand Up @@ -121,7 +142,7 @@ impl<C: Clock> ReportWriteBatcher<C> {
async fn write_batch(
ds: Arc<Datastore<C>>,
counter_shard_count: u64,
mut report_results: Vec<(ReportResult<C>, ResultSender)>,
mut report_results: Vec<(ReportResult<C>, Option<ResultSender>)>,
) {
let ord = thread_rng().gen_range(0..counter_shard_count);

Expand All @@ -133,10 +154,10 @@ impl<C: Clock> ReportWriteBatcher<C> {
});

// Run all report writes concurrently.
let (report_results, result_senders): (Vec<ReportResult<C>>, Vec<ResultSender>) =
let (report_results, result_senders): (Vec<ReportResult<C>>, Vec<Option<ResultSender>>) =
report_results.into_iter().unzip();
let report_results = Arc::new(report_results);
let rslts = ds
let results = ds
.run_tx("upload", |tx| {
let report_results = Arc::clone(&report_results);
Box::pin(async move {
Expand All @@ -161,28 +182,30 @@ impl<C: Clock> ReportWriteBatcher<C> {
})
.await;

match rslts {
Ok(rslts) => {
match results {
Ok(results) => {
// Individual, per-request results.
assert_eq!(result_senders.len(), rslts.len()); // sanity check: should be guaranteed.
for (rslt_tx, rslt) in result_senders.into_iter().zip(rslts.into_iter()) {
if rslt_tx.send(rslt.map_err(Arc::new)).is_err() {
debug!(
"ReportWriter couldn't send result to requester (request cancelled?)"
);
assert_eq!(result_senders.len(), results.len()); // sanity check: should be guaranteed.
for (result_tx, result) in result_senders.into_iter().zip(results.into_iter()) {
if let Some(result_tx) = result_tx {
if result_tx.send(result.map_err(Arc::new)).is_err() {
debug!(
"ReportWriter couldn't send result to requester (request cancelled?)"
);
}
}
}
}
Err(err) => {
// Total-transaction failures are given to all waiting report uploaders.
let err = Arc::new(Error::from(err));
for rslt_tx in result_senders.into_iter() {
if rslt_tx.send(Err(Arc::clone(&err))).is_err() {
result_senders.into_iter().flatten().for_each(|result_tx| {
if result_tx.send(Err(Arc::clone(&err))).is_err() {
debug!(
"ReportWriter couldn't send result to requester (request cancelled?)"
);
};
}
})
}
};
}
Expand Down
2 changes: 1 addition & 1 deletion aggregator/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn format_database_url(url: &Url, fmt: &mut std::fmt::Formatter) -> Result<(), s
/// options are implementation-specific.
///
/// [spec]: https://datatracker.ietf.org/doc/draft-wang-ppm-dap-taskprov/
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct TaskprovConfig {
/// Whether to enable the extension or not. Enabling this changes the behavior
/// of the aggregator consistent with the taskprov [specification][spec].
Expand Down

0 comments on commit 48a9e0d

Please sign in to comment.