diff --git a/src/bin/cratesfyi.rs b/src/bin/cratesfyi.rs index a2e6920cf..0d8b46ca6 100644 --- a/src/bin/cratesfyi.rs +++ b/src/bin/cratesfyi.rs @@ -882,7 +882,7 @@ impl Context for BinContext { .get_or_try_init::<_, Error>(|| { Ok(Pool::new( &*self.config()?, - &*self.runtime()?, + self.runtime()?, self.instance_metrics()?, )?) })? diff --git a/src/db/pool.rs b/src/db/pool.rs index 20f6850fd..2a1a43f54 100644 --- a/src/db/pool.rs +++ b/src/db/pool.rs @@ -4,12 +4,15 @@ use futures_util::{future::BoxFuture, stream::BoxStream}; use postgres::{Client, NoTls}; use r2d2_postgres::PostgresConnectionManager; use sqlx::{postgres::PgPoolOptions, Executor}; -use std::{sync::Arc, time::Duration}; +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, + time::Duration, +}; use tokio::runtime::Runtime; use tracing::debug; pub type PoolClient = r2d2::PooledConnection>; -pub type AsyncPoolClient = sqlx::pool::PoolConnection; const DEFAULT_SCHEMA: &str = "public"; @@ -20,6 +23,7 @@ pub struct Pool { #[cfg(not(test))] pool: r2d2::Pool>, async_pool: sqlx::PgPool, + runtime: Arc, metrics: Arc, max_size: u32, } @@ -27,7 +31,7 @@ pub struct Pool { impl Pool { pub fn new( config: &Config, - runtime: &Runtime, + runtime: Arc, metrics: Arc, ) -> Result { debug!( @@ -39,7 +43,7 @@ impl Pool { #[cfg(test)] pub(crate) fn new_with_schema( config: &Config, - runtime: &Runtime, + runtime: Arc, metrics: Arc, schema: &str, ) -> Result { @@ -48,7 +52,7 @@ impl Pool { fn new_inner( config: &Config, - runtime: &Runtime, + runtime: Arc, metrics: Arc, schema: &str, ) -> Result { @@ -109,6 +113,7 @@ impl Pool { pool, async_pool, metrics, + runtime, max_size: config.max_legacy_pool_size + config.max_pool_size, }) } @@ -139,7 +144,10 @@ impl Pool { pub async fn get_async(&self) -> Result { match self.async_pool.acquire().await { - Ok(conn) => Ok(conn), + Ok(conn) => Ok(AsyncPoolClient { + inner: Some(conn), + runtime: self.runtime.clone(), + }), Err(err) => { self.metrics.failed_db_connections.inc(); Err(PoolError::AsyncClientError(err)) @@ -222,6 +230,36 @@ where } } +/// we wrap `sqlx::PoolConnection` so we can drop it in a sync context +/// and enter the runtime. +/// Otherwise dropping the PoolConnection will panic because it can't spawn a task. +#[derive(Debug)] +pub struct AsyncPoolClient { + inner: Option>, + runtime: Arc, +} + +impl Deref for AsyncPoolClient { + type Target = sqlx::PgConnection; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref().unwrap() + } +} + +impl DerefMut for AsyncPoolClient { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.as_mut().unwrap() + } +} + +impl Drop for AsyncPoolClient { + fn drop(&mut self) { + let _guard = self.runtime.enter(); + drop(self.inner.take()) + } +} + #[derive(Debug)] struct SetSchema { schema: String, diff --git a/src/test/mod.rs b/src/test/mod.rs index ae1409d95..55447c69e 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -590,7 +590,7 @@ impl TestDatabase { // test to create a fresh instance of the database to run within. let schema = format!("docs_rs_test_schema_{}", rand::random::()); - let pool = Pool::new_with_schema(config, &runtime, metrics, &schema)?; + let pool = Pool::new_with_schema(config, runtime.clone(), metrics, &schema)?; runtime.block_on({ let schema = schema.clone();