diff --git a/Data/SBV/Control/Utils.hs b/Data/SBV/Control/Utils.hs index 3b9ef5a99..5dec2eb85 100644 --- a/Data/SBV/Control/Utils.hs +++ b/Data/SBV/Control/Utils.hs @@ -1898,7 +1898,17 @@ executeQuery queryContext (QueryT userQuery) = do liftIO $ writeIORef (runMode st) $ SMTMode qc IRun isSAT cfg - lift $ join $ liftIO $ backend cfg' st (show pgm) $ extractIO . runReaderT userQuery + let terminateSolver maybeForwardedException = do + qs <- readIORef $ rQueryState st + case qs of + Nothing -> return () + Just QueryState{queryTerminate} -> queryTerminate maybeForwardedException + + lift $ join $ liftIO $ C.mask $ \restore -> do + r <- restore (extractIO $ join $ liftIO $ backend cfg' st (show pgm) $ extractIO . runReaderT userQuery) `C.catch` \e -> + terminateSolver (Just e) >> C.throwIO (e :: C.SomeException) + terminateSolver Nothing + return r -- Already in a query, in theory we can just continue, but that causes use-case issues -- so we reject it. TODO: Review if we should actually support this. The issue arises with diff --git a/Data/SBV/Core/Symbolic.hs b/Data/SBV/Core/Symbolic.hs index e40e077b4..bef51e493 100644 --- a/Data/SBV/Core/Symbolic.hs +++ b/Data/SBV/Core/Symbolic.hs @@ -89,6 +89,7 @@ import GHC.Stack import GHC.Stack.Types import GHC.Generics (Generic) +import qualified Control.Exception as C import qualified Control.Monad.State.Lazy as LS import qualified Control.Monad.State.Strict as SS import qualified Control.Monad.Writer.Lazy as LW @@ -778,7 +779,7 @@ data QueryState = QueryState { queryAsk :: Maybe Int -> String - , querySend :: Maybe Int -> String -> IO () , queryRetrieveResponse :: Maybe Int -> IO String , queryConfig :: SMTConfig - , queryTerminate :: IO () + , queryTerminate :: Maybe C.SomeException -> IO () , queryTimeOutValue :: Maybe Int , queryAssertionStackDepth :: Int } @@ -1968,12 +1969,6 @@ runSymbolicInState st (SymbolicT c) = do mapM_ check $ nub $ G.universeBi res - -- Clean-up after ourselves - qs <- liftIO $ readIORef $ rQueryState st - case qs of - Nothing -> return () - Just QueryState{queryTerminate} -> liftIO queryTerminate - return (r, res) -- | Grab the program from a running symbolic simulation state. diff --git a/Data/SBV/SMT/SMT.hs b/Data/SBV/SMT/SMT.hs index cbf9c3bf9..8638a3853 100644 --- a/Data/SBV/SMT/SMT.hs +++ b/Data/SBV/SMT/SMT.hs @@ -862,7 +862,7 @@ runSolver cfg ctx execPath opts pgm continuation ex <- waitForProcess pid `C.catch` (\(e :: C.SomeException) -> handleAsync e (return (ExitFailure (-999)))) return (out, err, ex) - cleanUp + cleanUp maybeForwardedException = do (out, err, ex) <- terminateSolver msg $ [ "Solver : " ++ nm @@ -874,8 +874,9 @@ runSolver cfg ctx execPath opts pgm continuation finalizeTranscript (transcript cfg) ex recordEndTime cfg ctx - case ex of - ExitSuccess -> return () + case (ex, maybeForwardedException) of + (_, Just forwardedException) -> C.throwIO forwardedException + (ExitSuccess, _) -> return () _ -> if ignoreExitCode cfg then msg ["Ignoring non-zero exit code of " ++ show ex ++ " per user request!"] else C.throwIO SBVException { sbvExceptionDescription = "Failed to complete the call to " ++ nm diff --git a/Data/SBV/Utils/ExtractIO.hs b/Data/SBV/Utils/ExtractIO.hs index 7afafb838..c23f3a2bc 100644 --- a/Data/SBV/Utils/ExtractIO.hs +++ b/Data/SBV/Utils/ExtractIO.hs @@ -32,7 +32,7 @@ class MonadIO m => ExtractIO m where -- | Trivial IO extraction for 'IO'. instance ExtractIO IO where - extractIO = pure + extractIO = fmap pure -- | IO extraction for 'MaybeT'. instance ExtractIO m => ExtractIO (MaybeT m) where