Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #198 #201

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
matrix:
rust: ["stable", "beta", "nightly"]
backend: ["postgres", "mysql", "sqlite"]
os: [ubuntu-latest, macos-13, macos-14, windows-2019]
os: [ubuntu-latest, macos-13, macos-15, windows-2019]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout sources
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
echo "DATABASE_URL=postgres://postgres@localhost/" >> $GITHUB_ENV

- name: Install postgres (MacOS M1)
if: matrix.os == 'macos-14' && matrix.backend == 'postgres'
if: matrix.os == 'macos-15' && matrix.backend == 'postgres'
run: |
brew install postgresql@14
brew services start postgresql@14
Expand All @@ -138,24 +138,24 @@ jobs:
- name: Install mysql (MacOS)
if: matrix.os == 'macos-13' && matrix.backend == 'mysql'
run: |
brew install mariadb@11.2
/usr/local/opt/mariadb@11.2/bin/mysql_install_db
/usr/local/opt/mariadb@11.2/bin/mysql.server start
brew install mariadb@11.4
/usr/local/opt/mariadb@11.4/bin/mysql_install_db
/usr/local/opt/mariadb@11.4/bin/mysql.server start
sleep 3
/usr/local/opt/mariadb@11.2/bin/mysqladmin -u runner password diesel
/usr/local/opt/mariadb@11.2/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
/usr/local/opt/mariadb@11.4/bin/mysqladmin -u runner password diesel
/usr/local/opt/mariadb@11.4/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
echo "DATABASE_URL=mysql://runner:diesel@localhost/diesel_test" >> $GITHUB_ENV

- name: Install mysql (MacOS M1)
if: matrix.os == 'macos-14' && matrix.backend == 'mysql'
if: matrix.os == 'macos-15' && matrix.backend == 'mysql'
run: |
brew install mariadb@11.2
ls /opt/homebrew/opt/mariadb@11.2
/opt/homebrew/opt/mariadb@11.2/bin/mysql_install_db
/opt/homebrew/opt/mariadb@11.2/bin/mysql.server start
brew install mariadb@11.4
ls /opt/homebrew/opt/mariadb@11.4
/opt/homebrew/opt/mariadb@11.4/bin/mysql_install_db
/opt/homebrew/opt/mariadb@11.4/bin/mysql.server start
sleep 3
/opt/homebrew/opt/mariadb@11.2/bin/mysqladmin -u runner password diesel
/opt/homebrew/opt/mariadb@11.2/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
/opt/homebrew/opt/mariadb@11.4/bin/mysqladmin -u runner password diesel
/opt/homebrew/opt/mariadb@11.4/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
echo "DATABASE_URL=mysql://runner:diesel@localhost/diesel_test" >> $GITHUB_ENV

- name: Install postgres (Windows)
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/

## [Unreleased]

## [0.5.2] - 2024-11-26

* Fixed an issue around transaction cancellation that could lead to connection pools containing connections with dangling transactions

## [0.5.1] - 2024-11-01

* Add crate feature `pool` for extending connection pool implements through external crate
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "diesel-async"
version = "0.5.1"
version = "0.5.2"
authors = ["Georg Semmler <[email protected]>"]
edition = "2021"
autotests = false
Expand Down
249 changes: 87 additions & 162 deletions src/transaction_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use diesel::QueryResult;
use scoped_futures::ScopedBoxFuture;
use std::borrow::Cow;
use std::num::NonZeroU32;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crate::AsyncConnection;
// TODO: refactor this to share more code with diesel
Expand Down Expand Up @@ -88,24 +90,31 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
/// in an error state.
#[doc(hidden)]
fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
match Self::transaction_manager_status_mut(conn).transaction_state() {
// all transactions are closed
// so we don't consider this connection broken
Ok(ValidTransactionManagerStatus {
in_transaction: None,
..
}) => false,
// The transaction manager is in an error state
// Therefore we consider this connection broken
Err(_) => true,
// The transaction manager contains a open transaction
// we do consider this connection broken
// if that transaction was not opened by `begin_test_transaction`
Ok(ValidTransactionManagerStatus {
in_transaction: Some(s),
..
}) => !s.test_transaction,
}
check_broken_transaction_state(conn)
}
}

fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
where
Conn: AsyncConnection,
{
match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
// all transactions are closed
// so we don't consider this connection broken
Ok(ValidTransactionManagerStatus {
in_transaction: None,
..
}) => false,
// The transaction manager is in an error state
// Therefore we consider this connection broken
Err(_) => true,
// The transaction manager contains a open transaction
// we do consider this connection broken
// if that transaction was not opened by `begin_test_transaction`
Ok(ValidTransactionManagerStatus {
in_transaction: Some(s),
..
}) => !s.test_transaction,
}
}

Expand All @@ -114,147 +123,23 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
#[derive(Default, Debug)]
pub struct AnsiTransactionManager {
pub(crate) status: TransactionManagerStatus,
// this boolean flag tracks whether we are currently in the process
// of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
// if we ever encounter a situation where this flag is set
// while the connection is returned to a pool
// that means the connection is broken as someone dropped the
// transaction future while these commands where executed
// and we cannot know the connection state anymore
//
// We ensure this by wrapping all calls to `.await`
// into `AnsiTransactionManager::critical_transaction_block`
// below
//
// See https://github.com/weiznich/diesel_async/issues/198 for
// details
pub(crate) is_broken: Arc<AtomicBool>,
}

// /// Status of the transaction manager
// #[derive(Debug)]
// pub enum TransactionManagerStatus {
// /// Valid status, the manager can run operations
// Valid(ValidTransactionManagerStatus),
// /// Error status, probably following a broken connection. The manager will no longer run operations
// InError,
// }

// impl Default for TransactionManagerStatus {
// fn default() -> Self {
// TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
// }
// }

// impl TransactionManagerStatus {
// /// Returns the transaction depth if the transaction manager's status is valid, or returns
// /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
// pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
// match self {
// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
// }
// }

// /// If in transaction and transaction manager is not broken, registers that the
// /// connection can not be used anymore until top-level transaction is rolled back
// pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
// in_transaction:
// Some(InTransactionStatus {
// top_level_transaction_requires_rollback,
// ..
// }),
// }) = self
// {
// *top_level_transaction_requires_rollback = true;
// }
// }

// /// Sets the transaction manager status to InError
// ///
// /// Subsequent attempts to use transaction-related features will result in a
// /// [`Error::BrokenTransactionManager`] error
// pub fn set_in_error(&mut self) {
// *self = TransactionManagerStatus::InError
// }

// fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
// match self {
// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
// }
// }

// pub(crate) fn set_test_transaction_flag(&mut self) {
// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
// in_transaction: Some(s),
// }) = self
// {
// s.test_transaction = true;
// }
// }
// }

// /// Valid transaction status for the manager. Can return the current transaction depth
// #[allow(missing_copy_implementations)]
// #[derive(Debug, Default)]
// pub struct ValidTransactionManagerStatus {
// in_transaction: Option<InTransactionStatus>,
// }

// #[allow(missing_copy_implementations)]
// #[derive(Debug)]
// struct InTransactionStatus {
// transaction_depth: NonZeroU32,
// top_level_transaction_requires_rollback: bool,
// test_transaction: bool,
// }

// impl ValidTransactionManagerStatus {
// /// Return the current transaction depth
// ///
// /// This value is `None` if no current transaction is running
// /// otherwise the number of nested transactions is returned.
// pub fn transaction_depth(&self) -> Option<NonZeroU32> {
// self.in_transaction.as_ref().map(|it| it.transaction_depth)
// }

// /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
// /// `Ok(())`
// pub fn change_transaction_depth(
// &mut self,
// transaction_depth_change: TransactionDepthChange,
// ) -> QueryResult<()> {
// match (&mut self.in_transaction, transaction_depth_change) {
// (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
// // Can be replaced with saturating_add directly on NonZeroU32 once
// // <https://github.com/rust-lang/rust/issues/84186> is stable
// in_transaction.transaction_depth =
// NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
// .expect("nz + nz is always non-zero");
// Ok(())
// }
// (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
// // This sets `transaction_depth` to `None` as soon as we reach zero
// match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
// Some(depth) => in_transaction.transaction_depth = depth,
// None => self.in_transaction = None,
// }
// Ok(())
// }
// (None, TransactionDepthChange::IncreaseDepth) => {
// self.in_transaction = Some(InTransactionStatus {
// transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
// top_level_transaction_requires_rollback: false,
// test_transaction: false,
// });
// Ok(())
// }
// (None, TransactionDepthChange::DecreaseDepth) => {
// // We screwed up something somewhere
// // we cannot decrease the transaction count if
// // we are not inside a transaction
// Err(Error::NotInTransaction)
// }
// }
// }
// }

// /// Represents a change to apply to the depth of a transaction
// #[derive(Debug, Clone, Copy)]
// pub enum TransactionDepthChange {
// /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
// IncreaseDepth,
// /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
// DecreaseDepth,
// }

impl AnsiTransactionManager {
fn get_transaction_state<Conn>(
conn: &mut Conn,
Expand All @@ -274,17 +159,38 @@ impl AnsiTransactionManager {
where
Conn: AsyncConnection<TransactionManager = Self>,
{
let is_broken = conn.transaction_state().is_broken.clone();
let state = Self::get_transaction_state(conn)?;
match state.transaction_depth() {
None => {
conn.batch_execute(sql).await?;
Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
Ok(())
}
Some(_depth) => Err(Error::AlreadyInTransaction),
}
}

// This function should be used to await any connection
// related future in our transaction manager implementation
//
// It takes care of tracking entering and exiting executing the future
// which in turn is used to determine if it's safe to still use
// the connection in the event of a canceled transaction execution
async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
where
F: std::future::Future,
{
let was_broken = is_broken.swap(true, Ordering::Relaxed);
debug_assert!(
!was_broken,
"Tried to execute a transaction SQL on transaction manager that was previously cancled"
);
let res = f.await;
is_broken.store(false, Ordering::Relaxed);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_broken.store(false, Ordering::Relaxed);
let was_broken = is_broken.swap(false, Ordering::Relaxed);
debug_assert!(was_broken);

res
}
}

#[async_trait::async_trait]
Expand All @@ -308,7 +214,11 @@ where
.unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
conn.instrumentation()
.on_connection_event(InstrumentationEvent::begin_transaction(depth));
conn.batch_execute(&start_transaction_sql).await?;
Self::critical_transaction_block(
&conn.transaction_state().is_broken.clone(),
conn.batch_execute(&start_transaction_sql),
)
.await?;
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;

Expand Down Expand Up @@ -344,7 +254,10 @@ where
conn.instrumentation()
.on_connection_event(InstrumentationEvent::rollback_transaction(depth));

match conn.batch_execute(&rollback_sql).await {
let is_broken = conn.transaction_state().is_broken.clone();

match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
{
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
Expand Down Expand Up @@ -429,7 +342,9 @@ where
conn.instrumentation()
.on_connection_event(InstrumentationEvent::commit_transaction(depth));

match conn.batch_execute(&commit_sql).await {
let is_broken = conn.transaction_state().is_broken.clone();

match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
Expand All @@ -453,7 +368,12 @@ where
..
}) = conn.transaction_state().status
{
match Self::rollback_transaction(conn).await {
match Self::critical_transaction_block(
&is_broken,
Self::rollback_transaction(conn),
)
.await
{
Ok(()) => {}
Err(rollback_error) => {
conn.transaction_state().status.set_in_error();
Expand All @@ -472,4 +392,9 @@ where
fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
&mut conn.transaction_state().status
}

fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
conn.transaction_state().is_broken.load(Ordering::Relaxed)
|| check_broken_transaction_state(conn)
}
}
Loading