Skip to content
This repository has been archived by the owner on Feb 8, 2024. It is now read-only.

Commit

Permalink
drop SharedTxn struct, rename dequeue_count, add histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
bretthoerner committed Feb 2, 2024
1 parent daa0e07 commit 8cb66e0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 75 deletions.
94 changes: 26 additions & 68 deletions hook-common/src/pgqueue.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! # PgQueue
//!
//! A job queue implementation backed by a PostgreSQL table.
use std::ops::DerefMut;
use std::time;
use std::{str::FromStr, sync::Arc};

Expand Down Expand Up @@ -320,7 +319,7 @@ pub struct PgTransactionJob<'c, J, M> {
/// The open transaction this job came from. If multiple jobs were queried at once, then this
/// transaction will be shared between them (across async tasks and threads as necessary). See
/// below for more information.
shared_txn: Arc<Mutex<SharedTxn<'c>>>,
shared_txn: Arc<Mutex<Option<sqlx::Transaction<'c, sqlx::postgres::Postgres>>>>,
}

// Container struct for a batch of PgTransactionJob. Includes a reference to the shared transaction
Expand All @@ -330,58 +329,21 @@ pub struct PgTransactionBatch<'c, J, M> {

/// The open transaction the jobs in the Vec came from. This should be used to commit or
/// rollback when all of the work is finished.
shared_txn: Arc<Mutex<SharedTxn<'c>>>,
shared_txn: Arc<Mutex<Option<sqlx::Transaction<'c, sqlx::postgres::Postgres>>>>,
}

impl<'c, J, M> PgTransactionBatch<'_, J, M> {
pub async fn commit(&mut self) -> PgQueueResult<()> {
let mut txn_guard = self.shared_txn.lock().await;
txn_guard.commit().await
}
}

/// A shared transaction. Exists as a named type so that it can have some helper methods.
#[derive(Debug)]
struct SharedTxn<'c> {
/// The actual transaction object. If a transaction error occurs (e.g. the connection drops)
/// then the transaction should be dropped so that other jobs don't each try to do DB work that
/// is bound to fail.
raw_txn: Option<sqlx::Transaction<'c, sqlx::postgres::Postgres>>,
}

impl SharedTxn<'_> {
/// Commits and then drops the transaction.
pub async fn commit(&mut self) -> PgQueueResult<()> {
if self.raw_txn.is_none() {
return Err(PgQueueError::TransactionAlreadyClosedError);
}

let txn = self.raw_txn.take().unwrap();
let txn = txn_guard.take().unwrap();
txn.commit().await.map_err(|e| PgQueueError::QueryError {
command: "COMMIT".to_owned(),
error: e,
})?;

Ok(())
}

/// See `raw_txn` above, this should be called when a transaction error occurs (e.g. the
/// connection drops) so that other jobs don't each try to do DB work that is bound to fail.
fn drop_txn(&mut self) {
self.raw_txn = None;
}
}

impl SharedTxn<'_> {
/// Helper to get the transaction reference if it exists, for use in sqlx queries.
/// If it doesn't exist, that means a previous error occurred and the transaction has been
/// dropped.
fn get_txn_ref(&mut self) -> Option<&mut sqlx::PgConnection> {
match self.raw_txn.as_mut() {
Some(txn) => Some(txn.deref_mut()),
None => None,
}
}
}

#[async_trait]
Expand All @@ -392,17 +354,17 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio
let mut txn_guard = self.shared_txn.lock().await;

let txn_ref = txn_guard
.get_txn_ref()
.as_deref_mut()
.ok_or(PgJobError::TransactionAlreadyClosedError)?;

let completed_job = self.job.complete(txn_ref).await.map_err(|error| {
txn_guard.drop_txn();

PgJobError::QueryError {
command: "UPDATE".to_owned(),
error,
}
})?;
let completed_job =
self.job
.complete(txn_ref)
.await
.map_err(|error| PgJobError::QueryError {
command: "UPDATE".to_owned(),
error,
})?;

Ok(completed_job)
}
Expand All @@ -414,17 +376,17 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio
let mut txn_guard = self.shared_txn.lock().await;

let txn_ref = txn_guard
.get_txn_ref()
.as_deref_mut()
.ok_or(PgJobError::TransactionAlreadyClosedError)?;

let failed_job = self.job.fail(error, txn_ref).await.map_err(|error| {
txn_guard.drop_txn();

PgJobError::QueryError {
command: "UPDATE".to_owned(),
error,
}
})?;
let failed_job =
self.job
.fail(error, txn_ref)
.await
.map_err(|error| PgJobError::QueryError {
command: "UPDATE".to_owned(),
error,
})?;

Ok(failed_job)
}
Expand All @@ -447,7 +409,7 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio
let mut txn_guard = self.shared_txn.lock().await;

let txn_ref = txn_guard
.get_txn_ref()
.as_deref_mut()
.ok_or(PgJobError::TransactionAlreadyClosedError)?;

let retried_job = self
Expand All @@ -456,13 +418,9 @@ impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactio
.queue(queue)
.retry(error, retry_interval, txn_ref)
.await
.map_err(|error| {
txn_guard.drop_txn();

PgJobError::QueryError {
command: "UPDATE".to_owned(),
error,
}
.map_err(|error| PgJobError::QueryError {
command: "UPDATE".to_owned(),
error,
})?;

Ok(retried_job)
Expand Down Expand Up @@ -795,7 +753,7 @@ RETURNING
return Ok(None);
}

let shared_txn = Arc::new(Mutex::new(SharedTxn { raw_txn: Some(tx) }));
let shared_txn = Arc::new(Mutex::new(Some(tx)));

let pg_jobs: Vec<PgTransactionJob<J, M>> = jobs
.into_iter()
Expand Down
2 changes: 1 addition & 1 deletion hook-worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct Config {
pub transactional: bool,

#[envconfig(default = "10")]
pub dequeue_count: u32,
pub dequeue_batch_size: u32,
}

impl Config {
Expand Down
2 changes: 1 addition & 1 deletion hook-worker/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async fn main() -> Result<(), WorkerError> {
let worker = WebhookWorker::new(
&config.worker_name,
&queue,
config.dequeue_count,
config.dequeue_batch_size,
config.poll_interval.0,
config.request_timeout.0,
config.max_concurrent_jobs,
Expand Down
22 changes: 17 additions & 5 deletions hook-worker/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct WebhookWorker<'p> {
/// The queue we will be dequeuing jobs from.
queue: &'p PgQueue,
/// The maximum number of jobs to dequeue in one query.
dequeue_count: u32,
dequeue_batch_size: u32,
/// The interval for polling the queue.
poll_interval: time::Duration,
/// The client used for HTTP requests.
Expand All @@ -89,7 +89,7 @@ impl<'p> WebhookWorker<'p> {
pub fn new(
name: &str,
queue: &'p PgQueue,
dequeue_count: u32,
dequeue_batch_size: u32,
poll_interval: time::Duration,
request_timeout: time::Duration,
max_concurrent_jobs: usize,
Expand All @@ -112,7 +112,7 @@ impl<'p> WebhookWorker<'p> {
Self {
name: name.to_owned(),
queue,
dequeue_count,
dequeue_batch_size,
poll_interval,
client,
max_concurrent_jobs,
Expand All @@ -129,7 +129,11 @@ impl<'p> WebhookWorker<'p> {
interval.tick().await;
self.liveness.report_healthy().await;

match self.queue.dequeue(&self.name, self.dequeue_count).await {
match self
.queue
.dequeue(&self.name, self.dequeue_batch_size)
.await
{
Ok(Some(batch)) => return batch,
Ok(None) => continue,
Err(error) => {
Expand All @@ -150,7 +154,11 @@ impl<'p> WebhookWorker<'p> {
interval.tick().await;
self.liveness.report_healthy().await;

match self.queue.dequeue_tx(&self.name, self.dequeue_count).await {
match self
.queue
.dequeue_tx(&self.name, self.dequeue_batch_size)
.await
{
Ok(Some(batch)) => return batch,
Ok(None) => continue,
Err(error) => {
Expand All @@ -169,10 +177,13 @@ impl<'p> WebhookWorker<'p> {
.set(1f64 - semaphore.available_permits() as f64 / self.max_concurrent_jobs as f64);
};

let dequeue_batch_size_histogram = metrics::histogram!("webhook_dequeue_batch_size");

if transactional {
loop {
report_semaphore_utilization();
let mut batch = self.wait_for_jobs_tx().await;
dequeue_batch_size_histogram.record(batch.jobs.len() as f64);

// Get enough permits for the jobs before spawning a task.
let permits = semaphore
Expand Down Expand Up @@ -217,6 +228,7 @@ impl<'p> WebhookWorker<'p> {
loop {
report_semaphore_utilization();
let batch = self.wait_for_jobs().await;
dequeue_batch_size_histogram.record(batch.jobs.len() as f64);

// Get enough permits for the jobs before spawning a task.
let permits = semaphore
Expand Down

0 comments on commit 8cb66e0

Please sign in to comment.