Skip to content

Commit

Permalink
implement custom Drop for AsyncPoolClient so we can drop it in sync c…
Browse files Browse the repository at this point in the history
…ontext
  • Loading branch information
syphar committed Apr 4, 2024
1 parent 7485f33 commit 42f2c84
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/bin/cratesfyi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ impl Context for BinContext {
.get_or_try_init::<_, Error>(|| {
Ok(Pool::new(
&*self.config()?,
&*self.runtime()?,
self.runtime()?,
self.instance_metrics()?,
)?)
})?
Expand Down
50 changes: 44 additions & 6 deletions src/db/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ use futures_util::{future::BoxFuture, stream::BoxStream};
use postgres::{Client, NoTls};
use r2d2_postgres::PostgresConnectionManager;
use sqlx::{postgres::PgPoolOptions, Executor};
use std::{sync::Arc, time::Duration};
use std::{
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
};
use tokio::runtime::Runtime;
use tracing::debug;

pub type PoolClient = r2d2::PooledConnection<PostgresConnectionManager<NoTls>>;
pub type AsyncPoolClient = sqlx::pool::PoolConnection<sqlx::postgres::Postgres>;

const DEFAULT_SCHEMA: &str = "public";

Expand All @@ -20,14 +23,15 @@ pub struct Pool {
#[cfg(not(test))]
pool: r2d2::Pool<PostgresConnectionManager<NoTls>>,
async_pool: sqlx::PgPool,
runtime: Arc<Runtime>,
metrics: Arc<InstanceMetrics>,
max_size: u32,
}

impl Pool {
pub fn new(
config: &Config,
runtime: &Runtime,
runtime: Arc<Runtime>,
metrics: Arc<InstanceMetrics>,
) -> Result<Pool, PoolError> {
debug!(
Expand All @@ -39,7 +43,7 @@ impl Pool {
#[cfg(test)]
pub(crate) fn new_with_schema(
config: &Config,
runtime: &Runtime,
runtime: Arc<Runtime>,
metrics: Arc<InstanceMetrics>,
schema: &str,
) -> Result<Pool, PoolError> {
Expand All @@ -48,7 +52,7 @@ impl Pool {

fn new_inner(
config: &Config,
runtime: &Runtime,
runtime: Arc<Runtime>,
metrics: Arc<InstanceMetrics>,
schema: &str,
) -> Result<Pool, PoolError> {
Expand Down Expand Up @@ -109,6 +113,7 @@ impl Pool {
pool,
async_pool,
metrics,
runtime,
max_size: config.max_legacy_pool_size + config.max_pool_size,
})
}
Expand Down Expand Up @@ -139,7 +144,10 @@ impl Pool {

pub async fn get_async(&self) -> Result<AsyncPoolClient, PoolError> {
match self.async_pool.acquire().await {
Ok(conn) => Ok(conn),
Ok(conn) => Ok(AsyncPoolClient {
inner: Some(conn),
runtime: self.runtime.clone(),
}),
Err(err) => {
self.metrics.failed_db_connections.inc();
Err(PoolError::AsyncClientError(err))
Expand Down Expand Up @@ -222,6 +230,36 @@ where
}
}

/// we wrap `sqlx::PoolConnection` so we can drop it in a sync context
/// and enter the runtime.
/// Otherwise dropping the PoolConnection will panic because it can't spawn a task.
#[derive(Debug)]
pub struct AsyncPoolClient {
inner: Option<sqlx::pool::PoolConnection<sqlx::postgres::Postgres>>,
runtime: Arc<Runtime>,
}

impl Deref for AsyncPoolClient {
type Target = sqlx::PgConnection;

fn deref(&self) -> &Self::Target {
self.inner.as_ref().unwrap()
}
}

impl DerefMut for AsyncPoolClient {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.as_mut().unwrap()
}
}

impl Drop for AsyncPoolClient {
fn drop(&mut self) {
let _guard = self.runtime.enter();
drop(self.inner.take())
}
}

#[derive(Debug)]
struct SetSchema {
schema: String,
Expand Down
2 changes: 1 addition & 1 deletion src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ impl TestDatabase {
// test to create a fresh instance of the database to run within.
let schema = format!("docs_rs_test_schema_{}", rand::random::<u64>());

let pool = Pool::new_with_schema(config, &runtime, metrics, &schema)?;
let pool = Pool::new_with_schema(config, runtime.clone(), metrics, &schema)?;

runtime.block_on({
let schema = schema.clone();
Expand Down

0 comments on commit 42f2c84

Please sign in to comment.