diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index dc783963c..6fdbe510a 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -140,7 +140,7 @@ pub struct Aggregator { /// Report writer, with support for batching. report_writer: Arc>, /// Cache of task aggregators. - task_aggregators: Mutex>>>, + task_aggregators: TaskAggregatorCache, /// Metrics. metrics: AggregatorMetrics, @@ -151,6 +151,9 @@ pub struct Aggregator { peer_aggregators: PeerAggregatorCache, } +type TaskAggregatorCache = + SyncMutex>>>>>>; + #[derive(Clone)] struct AggregatorMetrics { /// Counter tracking the number of failed decryptions while handling the @@ -271,7 +274,7 @@ impl Aggregator { clock, cfg, report_writer, - task_aggregators: Mutex::new(HashMap::new()), + task_aggregators: SyncMutex::new(HashMap::new()), metrics: AggregatorMetrics { upload_decrypt_failure_counter, upload_decode_failure_counter, @@ -676,45 +679,55 @@ impl Aggregator { &self, task_id: &TaskId, ) -> Result>>, Error> { - // TODO(#238): don't cache forever (decide on & implement some cache eviction policy). - // This is important both to avoid ever-growing resource usage, and to allow aggregators to + // TODO(#238): don't cache forever (decide on & implement some cache eviction policy). This + // is important both to avoid ever-growing resource usage, and to allow aggregators to // notice when a task changes (e.g. due to key rotation). - // Fast path: grab an existing task aggregator if one exists for this task. - { - let task_aggs = self.task_aggregators.lock().await; - if let Some(task_agg) = task_aggs.get(task_id) { - return Ok(Some(Arc::clone(task_agg))); - } - } - - // TODO(#1639): not holding the lock while querying means that multiple tokio::tasks could - // enter this section and redundantly query the database. This could be costly at high QPS. + // Step one: grab the existing entry for this task, if one exists. If there is no existing + // entry, write a new (empty) entry. + let cache_entry = { + // Unwrap safety: mutex poisoning. + let mut task_aggs = self.task_aggregators.lock().unwrap(); + Arc::clone( + task_aggs + .entry(*task_id) + .or_insert_with(|| Arc::new(Mutex::default())), + ) + }; - // Slow path: retrieve task, create a task aggregator, store it to the cache, then return it. - let task_opt = self - .datastore - .run_tx("task_aggregator_get_task", |tx| { - let task_id = *task_id; - Box::pin(async move { tx.get_aggregator_task(&task_id).await }) - }) - .await?; - match task_opt { - Some(task) => { - let task_agg = - Arc::new(TaskAggregator::new(task, Arc::clone(&self.report_writer))?); - { - let mut task_aggs = self.task_aggregators.lock().await; - Ok(Some(Arc::clone( - task_aggs.entry(*task_id).or_insert(task_agg), - ))) - } + // Step two: if the entry is empty, fill it via a database query. Concurrent callers + // requesting the same task will contend over this lock while awaiting the result of the DB + // query, ensuring that in the common case only a single query will be made for each task. + let task_aggregator = { + let mut cache_entry = cache_entry.lock().await; + if cache_entry.is_none() { + *cache_entry = self + .datastore + .run_tx("task_aggregator_get_task", |tx| { + let task_id = *task_id; + Box::pin(async move { tx.get_aggregator_task(&task_id).await }) + }) + .await? + .map(|task| TaskAggregator::new(task, Arc::clone(&self.report_writer))) + .transpose()? + .map(Arc::new); } - // Avoid caching None, in case a previously non-existent task is provisioned while the - // system is live. Note that for #238, if we're improving this cache to indeed cache - // None, we must provide some mechanism for taskprov tasks to force a cache refresh. - None => Ok(None), + cache_entry.as_ref().map(Arc::clone) + }; + + // If the task doesn't exist, remove the task entry from the cache to avoid caching a + // negative result. Then return the result. + // + // TODO(#238): once cache eviction is implemented, we can likely remove this step. We only + // need to do this to avoid trivial DoS via a requestor spraying many nonexistent task IDs. + // However, we need to consider the taskprov case, where an aggregator can add a task and + // expect it to be immediately visible. + if task_aggregator.is_none() { + // Unwrap safety: mutex poisoning. + let mut task_aggs = self.task_aggregators.lock().unwrap(); + task_aggs.remove(task_id); } + Ok(task_aggregator) } /// Opts in or out of a taskprov task.