Skip to content

Commit

Permalink
Improves transaction rollback handling
Browse files Browse the repository at this point in the history
I added a libpq transaction status check to the `BracketSuccess` branch
of `withTransaction` that will ensure that the transaction has the
`TransInTrans` state before issuing the success callback. If the status
is `TransInError`, the transaction will be rolled back. If it is
anything else, the transaction will be rolled back and an
`UnexpectedTransactionStatusError` will be thrown.

Additionally, I added exception handling around the success SQL
execution that will run the rollback callback if executing the SQL fails
with a `SqlExecutionError`, which could happen if a transaction cannot
be committed due to something like a serialization error.
  • Loading branch information
jlavelle committed Dec 5, 2024
1 parent a998108 commit e0e9e3a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
53 changes: 49 additions & 4 deletions orville-postgresql/src/Orville/PostgreSQL/Execution/Transaction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ module Orville.PostgreSQL.Execution.Transaction
( withTransaction
, inWithTransaction
, InWithTransaction (InOutermostTransaction, InSavepointTransaction)
, UnexpectedTransactionStatusError (..)
)
where

import Control.Exception (Exception, throwIO, try)
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Database.PostgreSQL.LibPQ as LibPQ
import Numeric.Natural (Natural)

import qualified Orville.PostgreSQL.Execution.Execute as Execute
Expand All @@ -25,6 +28,7 @@ import qualified Orville.PostgreSQL.Internal.Bracket as Bracket
import qualified Orville.PostgreSQL.Internal.MonadOrville as MonadOrville
import qualified Orville.PostgreSQL.Internal.OrvilleState as OrvilleState
import qualified Orville.PostgreSQL.Monad as Monad
import qualified Orville.PostgreSQL.Raw.Connection as Connection
import qualified Orville.PostgreSQL.Raw.RawSql as RawSql

{- | Performs an action in an Orville monad within a database transaction. The transaction
Expand Down Expand Up @@ -80,10 +84,29 @@ withTransaction action =
liftIO $
case result of
Bracket.BracketSuccess -> do
let
successEvent = OrvilleState.transactionSuccessEvent transaction
executeTransactionSql (transactionEventSql state successEvent)
callback successEvent
mbTransactionStatus <- Connection.transactionStatus conn
case mbTransactionStatus of
Nothing -> do
callback OrvilleState.RollbackTransaction
throwIO $ UnexpectedTransactionStatusError Nothing
Just transactionStatus -> case transactionStatus of
LibPQ.TransInTrans -> do
let
successEvent = OrvilleState.transactionSuccessEvent transaction
eSuccess <- try $ executeTransactionSql (transactionEventSql state successEvent)
case eSuccess of
Right () ->
callback successEvent
Left ex -> do
callback OrvilleState.RollbackTransaction
throwIO (ex :: Connection.SqlExecutionError)
LibPQ.TransInError -> do
executeTransactionSql (transactionEventSql state OrvilleState.RollbackTransaction)
callback OrvilleState.RollbackTransaction
_ -> do
executeTransactionSql (transactionEventSql state OrvilleState.RollbackTransaction)
callback OrvilleState.RollbackTransaction
throwIO $ UnexpectedTransactionStatusError (Just transactionStatus)
Bracket.BracketError -> do
let
rollbackEvent = OrvilleState.rollbackTransactionEvent transaction
Expand All @@ -92,6 +115,28 @@ withTransaction action =

Bracket.bracketWithResult beginTransaction finishTransaction doAction

{- |
'withTransaction' will throw this exception if libpq reports a transaction status other
than 'LibPQ.TransInTrans' or 'LibPQ.TransInError', or if there is no connection to the
database. The latter case should be impossible, and indicates a bug in Orville if observed.
@since 1.1.0.0
-}
newtype UnexpectedTransactionStatusError = UnexpectedTransactionStatusError
{ unexpectedTransactionStatusErrorTransactionStatus :: Maybe LibPQ.TransactionStatus
}

instance Show UnexpectedTransactionStatusError where
show =
maybe
"UnexpectedTransactionStatusError: No database connection."
( \status ->
"UnexpectedTransactionStatusError: " <> show status
)
. unexpectedTransactionStatusErrorTransactionStatus

instance Exception UnexpectedTransactionStatusError

transactionEventSql ::
OrvilleState.OrvilleState ->
OrvilleState.TransactionEvent ->
Expand Down
18 changes: 18 additions & 0 deletions orville-postgresql/test/Test/Transaction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Test.Transaction
)
where

import Control.Exception (SomeException (..), catch)
import qualified Control.Monad as Monad
import qualified Data.ByteString as BS
import qualified Data.IORef as IORef
Expand All @@ -12,6 +13,7 @@ import qualified Hedgehog as HH
import qualified Hedgehog.Gen as Gen

import qualified Orville.PostgreSQL as Orville
import qualified Orville.PostgreSQL.Execution as Execution
import qualified Orville.PostgreSQL.Expr as Expr
import qualified Orville.PostgreSQL.OrvilleState as OrvilleState
import qualified Orville.PostgreSQL.Raw.Connection as Conn
Expand All @@ -31,6 +33,7 @@ transactionTests pool =
, prop_callbacksMadeForTransactionRollback pool
, prop_usesCustomBeginTransactionSql pool
, prop_inWithTransaction pool
, prop_rollbackCallbackInInvalidTransaction pool
]

prop_transactionsWithoutExceptionsCommit :: Property.NamedDBProperty
Expand Down Expand Up @@ -177,6 +180,21 @@ prop_inWithTransaction =
outsideBefore === Nothing
outsideAfter === Nothing

prop_rollbackCallbackInInvalidTransaction :: Property.NamedDBProperty
prop_rollbackCallbackInInvalidTransaction =
Property.singletonNamedDBProperty "withTransaction triggers the rollback callback if the LibPQ transaction status is TransInError" $ \pool -> do
let
badQuery = RawSql.fromString "bad"

allEvents <- captureTransactionCallbackEvents pool $
Orville.withTransaction $ do
Orville.liftCatch
catch
(Execution.executeVoid Execution.OtherQuery badQuery)
(\(SomeException _) -> pure ())

allEvents === [Orville.BeginTransaction, Orville.RollbackTransaction]

captureTransactionCallbackEvents ::
Orville.ConnectionPool ->
Orville.Orville () ->
Expand Down

0 comments on commit e0e9e3a

Please sign in to comment.