From daa0e0718fa2e9150663977f382239a242d7a065 Mon Sep 17 00:00:00 2001 From: Brett Hoerner Date: Thu, 18 Jan 2024 14:16:55 -0700 Subject: [PATCH] Dequeue multiple items at a time --- hook-common/src/pgqueue.rs | 426 +++++++++++++++++++++++++++-------- hook-janitor/src/webhooks.rs | 14 +- hook-worker/src/config.rs | 3 + hook-worker/src/main.rs | 1 + hook-worker/src/worker.rs | 149 +++++++----- 5 files changed, 436 insertions(+), 157 deletions(-) diff --git a/hook-common/src/pgqueue.rs b/hook-common/src/pgqueue.rs index 4c22612..29b8d6f 100644 --- a/hook-common/src/pgqueue.rs +++ b/hook-common/src/pgqueue.rs @@ -1,14 +1,17 @@ //! # PgQueue //! //! A job queue implementation backed by a PostgreSQL table. -use std::str::FromStr; +use std::ops::DerefMut; use std::time; +use std::{str::FromStr, sync::Arc}; use async_trait::async_trait; use chrono; use serde; use sqlx::postgres::{PgPool, PgPoolOptions}; use thiserror::Error; +use tokio::sync::Mutex; +use tracing::error; /// Enumeration of errors for operations with PgQueue. /// Errors that can originate from sqlx and are wrapped by us to provide additional context. @@ -24,6 +27,8 @@ pub enum PgQueueError { ParseJobStatusError(String), #[error("{0} is not a valid HttpMethod")] ParseHttpMethodError(String), + #[error("transaction was already closed")] + TransactionAlreadyClosedError, } #[derive(Error, Debug)] @@ -36,6 +41,8 @@ pub enum PgJobError { QueryError { command: String, error: sqlx::Error }, #[error("transaction {command} failed with: {error}")] TransactionError { command: String, error: sqlx::Error }, + #[error("transaction was already closed")] + TransactionAlreadyClosedError, } /// Enumeration of possible statuses for a Job. @@ -222,6 +229,11 @@ pub struct PgJob { pub pool: PgPool, } +// Container struct for a batch of PgJobs. +pub struct PgBatch { + pub jobs: Vec>, +} + impl PgJob { async fn acquire_conn( &mut self, @@ -304,7 +316,72 @@ impl PgQueueJob for PgJob { #[derive(Debug)] pub struct PgTransactionJob<'c, J, M> { pub job: Job, - pub transaction: sqlx::Transaction<'c, sqlx::postgres::Postgres>, + + /// The open transaction this job came from. If multiple jobs were queried at once, then this + /// transaction will be shared between them (across async tasks and threads as necessary). See + /// below for more information. + shared_txn: Arc>>, +} + +// Container struct for a batch of PgTransactionJob. Includes a reference to the shared transaction +// for committing the work when all of the jobs are finished. +pub struct PgTransactionBatch<'c, J, M> { + pub jobs: Vec>, + + /// The open transaction the jobs in the Vec came from. This should be used to commit or + /// rollback when all of the work is finished. + shared_txn: Arc>>, +} + +impl<'c, J, M> PgTransactionBatch<'_, J, M> { + pub async fn commit(&mut self) -> PgQueueResult<()> { + let mut txn_guard = self.shared_txn.lock().await; + txn_guard.commit().await + } +} + +/// A shared transaction. Exists as a named type so that it can have some helper methods. +#[derive(Debug)] +struct SharedTxn<'c> { + /// The actual transaction object. If a transaction error occurs (e.g. the connection drops) + /// then the transaction should be dropped so that other jobs don't each try to do DB work that + /// is bound to fail. + raw_txn: Option>, +} + +impl SharedTxn<'_> { + /// Commits and then drops the transaction. + pub async fn commit(&mut self) -> PgQueueResult<()> { + if self.raw_txn.is_none() { + return Err(PgQueueError::TransactionAlreadyClosedError); + } + + let txn = self.raw_txn.take().unwrap(); + txn.commit().await.map_err(|e| PgQueueError::QueryError { + command: "COMMIT".to_owned(), + error: e, + })?; + + Ok(()) + } + + /// See `raw_txn` above, this should be called when a transaction error occurs (e.g. the + /// connection drops) so that other jobs don't each try to do DB work that is bound to fail. + fn drop_txn(&mut self) { + self.raw_txn = None; + } +} + +impl SharedTxn<'_> { + /// Helper to get the transaction reference if it exists, for use in sqlx queries. + /// If it doesn't exist, that means a previous error occurred and the transaction has been + /// dropped. + fn get_txn_ref(&mut self) -> Option<&mut sqlx::PgConnection> { + match self.raw_txn.as_mut() { + Some(txn) => Some(txn.deref_mut()), + None => None, + } + } } #[async_trait] @@ -312,22 +389,20 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio async fn complete( mut self, ) -> Result>>> { - let completed_job = self - .job - .complete(&mut *self.transaction) - .await - .map_err(|error| PgJobError::QueryError { - command: "UPDATE".to_owned(), - error, - })?; + let mut txn_guard = self.shared_txn.lock().await; - self.transaction - .commit() - .await - .map_err(|error| PgJobError::TransactionError { - command: "COMMIT".to_owned(), + let txn_ref = txn_guard + .get_txn_ref() + .ok_or(PgJobError::TransactionAlreadyClosedError)?; + + let completed_job = self.job.complete(txn_ref).await.map_err(|error| { + txn_guard.drop_txn(); + + PgJobError::QueryError { + command: "UPDATE".to_owned(), error, - })?; + } + })?; Ok(completed_job) } @@ -336,22 +411,20 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio mut self, error: S, ) -> Result, PgJobError>>> { - let failed_job = self - .job - .fail(error, &mut *self.transaction) - .await - .map_err(|error| PgJobError::QueryError { - command: "UPDATE".to_owned(), - error, - })?; + let mut txn_guard = self.shared_txn.lock().await; - self.transaction - .commit() - .await - .map_err(|error| PgJobError::TransactionError { - command: "COMMIT".to_owned(), + let txn_ref = txn_guard + .get_txn_ref() + .ok_or(PgJobError::TransactionAlreadyClosedError)?; + + let failed_job = self.job.fail(error, txn_ref).await.map_err(|error| { + txn_guard.drop_txn(); + + PgJobError::QueryError { + command: "UPDATE".to_owned(), error, - })?; + } + })?; Ok(failed_job) } @@ -371,23 +444,25 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio }); } + let mut txn_guard = self.shared_txn.lock().await; + + let txn_ref = txn_guard + .get_txn_ref() + .ok_or(PgJobError::TransactionAlreadyClosedError)?; + let retried_job = self .job .retryable() .queue(queue) - .retry(error, retry_interval, &mut *self.transaction) + .retry(error, retry_interval, txn_ref) .await - .map_err(|error| PgJobError::QueryError { - command: "UPDATE".to_owned(), - error, - })?; + .map_err(|error| { + txn_guard.drop_txn(); - self.transaction - .commit() - .await - .map_err(|error| PgJobError::TransactionError { - command: "COMMIT".to_owned(), - error, + PgJobError::QueryError { + command: "UPDATE".to_owned(), + error, + } })?; Ok(retried_job) @@ -565,15 +640,16 @@ impl PgQueue { Ok(Self { name, pool }) } - /// Dequeue a `Job` from this `PgQueue`. - /// The `Job` will be updated to `'running'` status, so any other `dequeue` calls will skip it. + /// Dequeue up to `limit` `Job`s from this `PgQueue`. + /// The `Job`s will be updated to `'running'` status, so any other `dequeue` calls will skip it. pub async fn dequeue< J: for<'d> serde::Deserialize<'d> + std::marker::Send + std::marker::Unpin + 'static, M: for<'d> serde::Deserialize<'d> + std::marker::Send + std::marker::Unpin + 'static, >( &self, attempted_by: &str, - ) -> PgQueueResult>> { + limit: u32, + ) -> PgQueueResult>> { let mut connection = self .pool .acquire() @@ -595,7 +671,7 @@ WITH available_in_queue AS ( ORDER BY attempt, scheduled_at - LIMIT 1 + LIMIT $2 FOR UPDATE SKIP LOCKED ) UPDATE @@ -604,7 +680,7 @@ SET attempted_at = NOW(), status = 'running'::job_status, attempt = attempt + 1, - attempted_by = array_append(attempted_by, $2::text) + attempted_by = array_append(attempted_by, $3::text) FROM available_in_queue WHERE @@ -613,17 +689,29 @@ RETURNING job_queue.* "#; - let query_result: Result, sqlx::Error> = sqlx::query_as(base_query) + let query_result: Result>, sqlx::Error> = sqlx::query_as(base_query) .bind(&self.name) + .bind(limit as i64) .bind(attempted_by) - .fetch_one(&mut *connection) + .fetch_all(&mut *connection) .await; match query_result { - Ok(job) => Ok(Some(PgJob { - job, - pool: self.pool.clone(), - })), + Ok(jobs) => { + if jobs.is_empty() { + return Ok(None); + } + + let pg_jobs: Vec> = jobs + .into_iter() + .map(|job| PgJob { + job, + pool: self.pool.clone(), + }) + .collect(); + + Ok(Some(PgBatch { jobs: pg_jobs })) + } // Although connection would be closed once it goes out of scope, sqlx recommends explicitly calling close(). // See: https://docs.rs/sqlx/latest/sqlx/postgres/any/trait.AnyConnectionBackend.html#tymethod.close. @@ -631,6 +719,7 @@ RETURNING let _ = connection.close().await; Ok(None) } + Err(e) => { let _ = connection.close().await; Err(PgQueueError::QueryError { @@ -641,9 +730,10 @@ RETURNING } } - /// Dequeue a `Job` from this `PgQueue` and hold the transaction. - /// Any other `dequeue_tx` calls will skip rows locked, so by holding a transaction we ensure only one worker can dequeue a job. - /// Holding a transaction open can have performance implications, but it means no `'running'` state is required. + /// Dequeue up to `limit` `Job`s from this `PgQueue` and hold the transaction. Any other + /// `dequeue_tx` calls will skip rows locked, so by holding a transaction we ensure only one + /// worker can dequeue a job. Holding a transaction open can have performance implications, but + /// it means no `'running'` state is required. pub async fn dequeue_tx< 'a, J: for<'d> serde::Deserialize<'d> + std::marker::Send + std::marker::Unpin + 'static, @@ -651,7 +741,8 @@ RETURNING >( &self, attempted_by: &str, - ) -> PgQueueResult>> { + limit: u32, + ) -> PgQueueResult>> { let mut tx = self .pool .begin() @@ -673,7 +764,7 @@ WITH available_in_queue AS ( ORDER BY attempt, scheduled_at - LIMIT 1 + LIMIT $2 FOR UPDATE SKIP LOCKED ) UPDATE @@ -682,7 +773,7 @@ SET attempted_at = NOW(), status = 'running'::job_status, attempt = attempt + 1, - attempted_by = array_append(attempted_by, $2::text) + attempted_by = array_append(attempted_by, $3::text) FROM available_in_queue WHERE @@ -691,20 +782,38 @@ RETURNING job_queue.* "#; - let query_result: Result, sqlx::Error> = sqlx::query_as(base_query) + let query_result: Result>, sqlx::Error> = sqlx::query_as(base_query) .bind(&self.name) + .bind(limit as i64) .bind(attempted_by) - .fetch_one(&mut *tx) + .fetch_all(&mut *tx) .await; match query_result { - Ok(job) => Ok(Some(PgTransactionJob { - job, - transaction: tx, - })), + Ok(jobs) => { + if jobs.is_empty() { + return Ok(None); + } + + let shared_txn = Arc::new(Mutex::new(SharedTxn { raw_txn: Some(tx) })); + + let pg_jobs: Vec> = jobs + .into_iter() + .map(|job| PgTransactionJob { + job, + shared_txn: shared_txn.clone(), + }) + .collect(); + + Ok(Some(PgTransactionBatch { + jobs: pg_jobs, + shared_txn: shared_txn.clone(), + })) + } - // Transaction is rolledback on drop. + // Transaction is rolled back on drop. Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(PgQueueError::QueryError { command: "UPDATE".to_owned(), error: e, @@ -751,7 +860,7 @@ mod tests { use super::*; use crate::retry::RetryPolicy; - #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)] + #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug, Clone)] struct JobMetadata { team_id: u32, plugin_config_id: i32, @@ -768,7 +877,7 @@ mod tests { } } - #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)] + #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug, Clone)] struct JobParameters { method: String, body: String, @@ -810,10 +919,13 @@ mod tests { queue.enqueue(new_job).await.expect("failed to enqueue job"); let pg_job: PgJob = queue - .dequeue(&worker_id) + .dequeue(&worker_id, 1) .await - .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("failed to dequeue jobs") + .expect("didn't find any jobs to dequeue") + .jobs + .pop() + .unwrap(); assert_eq!(pg_job.job.attempt, 1); assert!(pg_job.job.attempted_by.contains(&worker_id)); @@ -831,12 +943,62 @@ mod tests { .await .expect("failed to connect to local test postgresql database"); - let pg_job: Option> = queue - .dequeue(&worker_id) + let pg_jobs: Option> = queue + .dequeue(&worker_id, 1) .await - .expect("failed to dequeue job"); + .expect("failed to dequeue jobs"); + + assert!(pg_jobs.is_none()); + } + + #[sqlx::test(migrations = "../migrations")] + async fn test_can_dequeue_multiple_jobs(db: PgPool) { + let job_target = job_target(); + let job_metadata = JobMetadata::default(); + let job_parameters = JobParameters::default(); + let worker_id = worker_id(); + + let queue = PgQueue::new_from_pool("test_can_dequeue_multiple_jobs", db) + .await + .expect("failed to connect to local test postgresql database"); + + for _ in 0..5 { + queue + .enqueue(NewJob::new( + 1, + job_metadata.clone(), + job_parameters.clone(), + &job_target, + )) + .await + .expect("failed to enqueue job"); + } + + // Only get 4 jobs, leaving one in the queue. + let limit = 4; + let batch: PgBatch = queue + .dequeue(&worker_id, limit) + .await + .expect("failed to dequeue job") + .expect("didn't find a job to dequeue"); + + // Complete those 4. + assert_eq!(batch.jobs.len(), limit as usize); + for job in batch.jobs { + job.complete().await.expect("failed to complete job"); + } + + // Try to get up to 4 jobs, but only 1 remains. + let batch: PgBatch = queue + .dequeue(&worker_id, limit) + .await + .expect("failed to dequeue job") + .expect("didn't find a job to dequeue"); - assert!(pg_job.is_none()); + assert_eq!(batch.jobs.len(), 1); // Only one job should have been left in the queue. + for job in batch.jobs { + job.complete().await.expect("failed to complete job"); + } } #[sqlx::test(migrations = "../migrations")] @@ -845,19 +1007,21 @@ mod tests { let job_metadata = JobMetadata::default(); let job_parameters = JobParameters::default(); let worker_id = worker_id(); - let new_job = NewJob::new(1, job_metadata, job_parameters, &job_target); let queue = PgQueue::new_from_pool("test_can_dequeue_tx_job", db) .await .expect("failed to connect to local test postgresql database"); + let new_job = NewJob::new(1, job_metadata, job_parameters, &job_target); queue.enqueue(new_job).await.expect("failed to enqueue job"); - let tx_job: PgTransactionJob<'_, JobParameters, JobMetadata> = queue - .dequeue_tx(&worker_id) + let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue + .dequeue_tx(&worker_id, 1) .await - .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("failed to dequeue jobs") + .expect("didn't find any jobs to dequeue"); + + let tx_job = batch.jobs.pop().unwrap(); assert_eq!(tx_job.job.attempt, 1); assert!(tx_job.job.attempted_by.contains(&worker_id)); @@ -867,6 +1031,65 @@ mod tests { assert_eq!(*tx_job.job.parameters.as_ref(), JobParameters::default()); assert_eq!(tx_job.job.status, JobStatus::Running); assert_eq!(tx_job.job.target, job_target); + + // Transactional jobs must be completed, failed or retried before being dropped. This is + // to prevent logic bugs when using the shared txn. + tx_job.complete().await.expect("failed to complete job"); + + batch.commit().await.expect("failed to commit transaction"); + } + + #[sqlx::test(migrations = "../migrations")] + async fn test_can_dequeue_multiple_tx_jobs(db: PgPool) { + let job_target = job_target(); + let job_metadata = JobMetadata::default(); + let job_parameters = JobParameters::default(); + let worker_id = worker_id(); + + let queue = PgQueue::new_from_pool("test_can_dequeue_multiple_tx_jobs", db) + .await + .expect("failed to connect to local test postgresql database"); + + for _ in 0..5 { + queue + .enqueue(NewJob::new( + 1, + job_metadata.clone(), + job_parameters.clone(), + &job_target, + )) + .await + .expect("failed to enqueue job"); + } + + // Only get 4 jobs, leaving one in the queue. + let limit = 4; + let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue + .dequeue_tx(&worker_id, limit) + .await + .expect("failed to dequeue job") + .expect("didn't find a job to dequeue"); + + assert_eq!(batch.jobs.len(), limit as usize); + + // Complete those 4 and commit. + for job in std::mem::take(&mut batch.jobs) { + job.complete().await.expect("failed to complete job"); + } + batch.commit().await.expect("failed to commit transaction"); + + // Try to get up to 4 jobs, but only 1 remains. + let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue + .dequeue_tx(&worker_id, limit) + .await + .expect("failed to dequeue job") + .expect("didn't find a job to dequeue"); + assert_eq!(batch.jobs.len(), 1); // Only one job should have been left in the queue. + + for job in std::mem::take(&mut batch.jobs) { + job.complete().await.expect("failed to complete job"); + } + batch.commit().await.expect("failed to commit transaction"); } #[sqlx::test(migrations = "../migrations")] @@ -876,12 +1099,12 @@ mod tests { .await .expect("failed to connect to local test postgresql database"); - let tx_job: Option> = queue - .dequeue_tx(&worker_id) + let batch: Option> = queue + .dequeue_tx(&worker_id, 1) .await .expect("failed to dequeue job"); - assert!(tx_job.is_none()); + assert!(batch.is_none()); } #[sqlx::test(migrations = "../migrations")] @@ -903,10 +1126,13 @@ mod tests { queue.enqueue(new_job).await.expect("failed to enqueue job"); let job: PgJob = queue - .dequeue(&worker_id) + .dequeue(&worker_id, 1) .await .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("didn't find a job to dequeue") + .jobs + .pop() + .unwrap(); let retry_interval = retry_policy.retry_interval(job.job.attempt as u32, None); let retry_queue = retry_policy.retry_queue(&job.job.queue).to_owned(); @@ -920,10 +1146,13 @@ mod tests { .expect("failed to retry job"); let retried_job: PgJob = queue - .dequeue(&worker_id) + .dequeue(&worker_id, 1) .await .expect("failed to dequeue job") - .expect("didn't find retried job to dequeue"); + .expect("didn't find retried job to dequeue") + .jobs + .pop() + .unwrap(); assert_eq!(retried_job.job.attempt, 2); assert!(retried_job.job.attempted_by.contains(&worker_id)); @@ -957,10 +1186,13 @@ mod tests { queue.enqueue(new_job).await.expect("failed to enqueue job"); let job: PgJob = queue - .dequeue(&worker_id) + .dequeue(&worker_id, 1) .await .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("didn't find a job to dequeue") + .jobs + .pop() + .unwrap(); let retry_interval = retry_policy.retry_interval(job.job.attempt as u32, None); let retry_queue = retry_policy.retry_queue(&job.job.queue).to_owned(); @@ -973,8 +1205,8 @@ mod tests { .await .expect("failed to retry job"); - let retried_job_not_found: Option> = queue - .dequeue(&worker_id) + let retried_job_not_found: Option> = queue + .dequeue(&worker_id, 1) .await .expect("failed to dequeue job"); @@ -985,10 +1217,13 @@ mod tests { .expect("failed to connect to retry queue in local test postgresql database"); let retried_job: PgJob = queue - .dequeue(&worker_id) + .dequeue(&worker_id, 1) .await .expect("failed to dequeue job") - .expect("job not found in retry queue"); + .expect("job not found in retry queue") + .jobs + .pop() + .unwrap(); assert_eq!(retried_job.job.attempt, 2); assert!(retried_job.job.attempted_by.contains(&worker_id)); @@ -1019,10 +1254,13 @@ mod tests { queue.enqueue(new_job).await.expect("failed to enqueue job"); let job: PgJob = queue - .dequeue(&worker_id) + .dequeue(&worker_id, 1) .await .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("didn't find a job to dequeue") + .jobs + .pop() + .unwrap(); let retry_interval = retry_policy.retry_interval(job.job.attempt as u32, None); diff --git a/hook-janitor/src/webhooks.rs b/hook-janitor/src/webhooks.rs index ee8ff43..b8283e0 100644 --- a/hook-janitor/src/webhooks.rs +++ b/hook-janitor/src/webhooks.rs @@ -867,10 +867,13 @@ mod tests { { // The fixtures include an available job, so let's complete it while the txn is open. let webhook_job: PgJob = queue - .dequeue(&"worker_id") + .dequeue(&"worker_id", 1) .await .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("didn't find a job to dequeue") + .jobs + .pop() + .unwrap(); webhook_job .complete() .await @@ -893,10 +896,13 @@ mod tests { let new_job = NewJob::new(1, job_metadata, job_parameters, &"target"); queue.enqueue(new_job).await.expect("failed to enqueue job"); let webhook_job: PgJob = queue - .dequeue(&"worker_id") + .dequeue(&"worker_id", 1) .await .expect("failed to dequeue job") - .expect("didn't find a job to dequeue"); + .expect("didn't find a job to dequeue") + .jobs + .pop() + .unwrap(); webhook_job .complete() .await diff --git a/hook-worker/src/config.rs b/hook-worker/src/config.rs index 74342f7..a0ea213 100644 --- a/hook-worker/src/config.rs +++ b/hook-worker/src/config.rs @@ -34,6 +34,9 @@ pub struct Config { #[envconfig(default = "true")] pub transactional: bool, + + #[envconfig(default = "10")] + pub dequeue_count: u32, } impl Config { diff --git a/hook-worker/src/main.rs b/hook-worker/src/main.rs index 345fa3d..775f815 100644 --- a/hook-worker/src/main.rs +++ b/hook-worker/src/main.rs @@ -42,6 +42,7 @@ async fn main() -> Result<(), WorkerError> { let worker = WebhookWorker::new( &config.worker_name, &queue, + config.dequeue_count, config.poll_interval.0, config.request_timeout.0, config.max_concurrent_jobs, diff --git a/hook-worker/src/worker.rs b/hook-worker/src/worker.rs index c526c3f..835bb9f 100644 --- a/hook-worker/src/worker.rs +++ b/hook-worker/src/worker.rs @@ -2,7 +2,9 @@ use std::collections; use std::sync::Arc; use std::time; +use futures::future::join_all; use hook_common::health::HealthHandle; +use hook_common::pgqueue::{PgBatch, PgTransactionBatch}; use hook_common::{ pgqueue::{Job, PgJob, PgJobError, PgQueue, PgQueueError, PgQueueJob, PgTransactionJob}, retry::RetryPolicy, @@ -68,6 +70,8 @@ pub struct WebhookWorker<'p> { name: String, /// The queue we will be dequeuing jobs from. queue: &'p PgQueue, + /// The maximum number of jobs to dequeue in one query. + dequeue_count: u32, /// The interval for polling the queue. poll_interval: time::Duration, /// The client used for HTTP requests. @@ -81,9 +85,11 @@ pub struct WebhookWorker<'p> { } impl<'p> WebhookWorker<'p> { + #[allow(clippy::too_many_arguments)] pub fn new( name: &str, queue: &'p PgQueue, + dequeue_count: u32, poll_interval: time::Duration, request_timeout: time::Duration, max_concurrent_jobs: usize, @@ -106,6 +112,7 @@ impl<'p> WebhookWorker<'p> { Self { name: name.to_owned(), queue, + dequeue_count, poll_interval, client, max_concurrent_jobs, @@ -114,16 +121,16 @@ impl<'p> WebhookWorker<'p> { } } - /// Wait until a job becomes available in our queue. - async fn wait_for_job<'a>(&self) -> PgJob { + /// Wait until at least one job becomes available in our queue. + async fn wait_for_jobs<'a>(&self) -> PgBatch { let mut interval = tokio::time::interval(self.poll_interval); loop { interval.tick().await; self.liveness.report_healthy().await; - match self.queue.dequeue(&self.name).await { - Ok(Some(job)) => return job, + match self.queue.dequeue(&self.name, self.dequeue_count).await { + Ok(Some(batch)) => return batch, Ok(None) => continue, Err(error) => { error!("error while trying to dequeue job: {}", error); @@ -133,18 +140,18 @@ impl<'p> WebhookWorker<'p> { } } - /// Wait until a job becomes available in our queue in transactional mode. - async fn wait_for_job_tx<'a>( + /// Wait until at least one job becomes available in our queue in transactional mode. + async fn wait_for_jobs_tx<'a>( &self, - ) -> PgTransactionJob<'a, WebhookJobParameters, WebhookJobMetadata> { + ) -> PgTransactionBatch<'a, WebhookJobParameters, WebhookJobMetadata> { let mut interval = tokio::time::interval(self.poll_interval); loop { interval.tick().await; self.liveness.report_healthy().await; - match self.queue.dequeue_tx(&self.name).await { - Ok(Some(job)) => return job, + match self.queue.dequeue_tx(&self.name, self.dequeue_count).await { + Ok(Some(batch)) => return batch, Ok(None) => continue, Err(error) => { error!("error while trying to dequeue_tx job: {}", error); @@ -165,65 +172,87 @@ impl<'p> WebhookWorker<'p> { if transactional { loop { report_semaphore_utilization(); - let webhook_job = self.wait_for_job_tx().await; - spawn_webhook_job_processing_task( - self.client.clone(), - semaphore.clone(), - self.retry_policy.clone(), - webhook_job, - ) - .await; + let mut batch = self.wait_for_jobs_tx().await; + + // Get enough permits for the jobs before spawning a task. + let permits = semaphore + .clone() + .acquire_many_owned(batch.jobs.len() as u32) + .await + .expect("semaphore has been closed"); + + let client = self.client.clone(); + let retry_policy = self.retry_policy.clone(); + + tokio::spawn(async move { + let mut futures = Vec::new(); + + // We have to `take` the Vec of jobs from the batch to avoid a borrow checker + // error below when we commit. + for job in std::mem::take(&mut batch.jobs) { + let client = client.clone(); + let retry_policy = retry_policy.clone(); + + let future = + async move { process_webhook_job(client, job, &retry_policy).await }; + + futures.push(future); + } + + let results = join_all(futures).await; + for result in results { + if let Err(e) = result { + error!("error processing webhook job: {}", e); + } + } + + let _ = batch.commit().await.map_err(|e| { + error!("error committing transactional batch: {}", e); + }); + + drop(permits); + }); } } else { loop { report_semaphore_utilization(); - let webhook_job = self.wait_for_job().await; - spawn_webhook_job_processing_task( - self.client.clone(), - semaphore.clone(), - self.retry_policy.clone(), - webhook_job, - ) - .await; - } - } - } -} + let batch = self.wait_for_jobs().await; -/// Spawn a Tokio task to process a Webhook Job once we successfully acquire a permit. -/// -/// # Arguments -/// -/// * `client`: An HTTP client to execute the webhook job request. -/// * `semaphore`: A semaphore used for rate limiting purposes. This function will panic if this semaphore is closed. -/// * `retry_policy`: The retry policy used to set retry parameters if a job fails and has remaining attempts. -/// * `webhook_job`: The webhook job to process as dequeued from `hook_common::pgqueue::PgQueue`. -async fn spawn_webhook_job_processing_task( - client: reqwest::Client, - semaphore: Arc, - retry_policy: RetryPolicy, - webhook_job: W, -) -> tokio::task::JoinHandle> { - let permit = semaphore - .acquire_owned() - .await - .expect("semaphore has been closed"); + // Get enough permits for the jobs before spawning a task. + let permits = semaphore + .clone() + .acquire_many_owned(batch.jobs.len() as u32) + .await + .expect("semaphore has been closed"); - let labels = [("queue", webhook_job.queue())]; + let client = self.client.clone(); + let retry_policy = self.retry_policy.clone(); - metrics::counter!("webhook_jobs_total", &labels).increment(1); + tokio::spawn(async move { + let mut futures = Vec::new(); + + for job in batch.jobs { + let client = client.clone(); + let retry_policy = retry_policy.clone(); + + let future = + async move { process_webhook_job(client, job, &retry_policy).await }; - tokio::spawn(async move { - let result = process_webhook_job(client, webhook_job, &retry_policy).await; - drop(permit); - match result { - Ok(_) => Ok(()), - Err(error) => { - error!("failed to process webhook job: {}", error); - Err(error) + futures.push(future); + } + + let results = join_all(futures).await; + for result in results { + if let Err(e) = result { + error!("error processing webhook job: {}", e); + } + } + + drop(permits); + }); } } - }) + } } /// Process a webhook job by transitioning it to its appropriate state after its request is sent. @@ -248,6 +277,7 @@ async fn process_webhook_job( let parameters = webhook_job.parameters(); let labels = [("queue", webhook_job.queue())]; + metrics::counter!("webhook_jobs_total", &labels).increment(1); let now = tokio::time::Instant::now(); @@ -543,6 +573,7 @@ mod tests { let worker = WebhookWorker::new( &worker_id, &queue, + 1, time::Duration::from_millis(100), time::Duration::from_millis(5000), 10, @@ -550,7 +581,7 @@ mod tests { liveness, ); - let consumed_job = worker.wait_for_job().await; + let consumed_job = worker.wait_for_jobs().await.jobs.pop().unwrap(); assert_eq!(consumed_job.job.attempt, 1); assert!(consumed_job.job.attempted_by.contains(&worker_id));