Skip to content

Commit

Permalink
Integrate middleware changes with internalization of SqlBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
iand675 committed Apr 8, 2021
1 parent 79725c1 commit c6f90a2
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 75 deletions.
4 changes: 2 additions & 2 deletions persistent-mysql/Database/Persist/MySQL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ open' :: MySQL.ConnectInfo -> LogFunc -> IO SqlBackend
open' ci logFunc = do
conn <- MySQL.connect ci
MySQLBase.autocommit conn False -- disable autocommit!
smap <- makeSimpleStatementCache
smap <- mkSimpleStatementCache
return $
setConnPutManySql putManySql $
setConnRepsertManySql repsertManySql $
Expand Down Expand Up @@ -1242,7 +1242,7 @@ mockMigrate _connectInfo allDefs _getter val = do
-- the actual database isn't already present in the system.
mockMigration :: Migration -> IO ()
mockMigration mig = do
smap <- makeSimpleStatementCache
smap <- mkSimpleStatementCache
let sqlbackend =
mkSqlBackend MkSqlBackendArgs
{ connPrepare = \_ -> do
Expand Down
15 changes: 7 additions & 8 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ import System.Environment (getEnvironment)

import Database.Persist.Sql
import Database.Persist.SqlBackend
import Database.Persist.SqlBackend.StatementCache
import qualified Database.Persist.Sql.Util as Util

-- | A @libpq@ connection string. A simple example of connection
Expand Down Expand Up @@ -327,22 +328,22 @@ openSimpleConn = openSimpleConnWithVersion getServerVersion
-- @since 2.9.1
openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend
openSimpleConnWithVersion getVerDouble logFunc conn = do
smap <- makeSimpleStatementCache
smap <- mkSimpleStatementCache
serverVersion <- oldGetVersionToNew getVerDouble conn
return $ createBackend logFunc serverVersion smap conn

-- | Create the backend given a logging function, server version, mutable statement cell,
-- and connection.
createBackend :: LogFunc -> NonEmpty Word
-> StatementCache -> PG.Connection -> SqlBackend
-> MkStatementCache -> PG.Connection -> SqlBackend
createBackend logFunc serverVersion smap conn =
maybe id setConnPutManySql (upsertFunction putManySql serverVersion) $
maybe id setConnUpsertSql (upsertFunction upsertSql' serverVersion) $
setConnInsertManySql insertManySql' $
maybe id setConnRepsertManySql (upsertFunction repsertManySql serverVersion) $
mkSqlBackend MkSqlBackendArgs
{ connPrepare = prepare' conn
, connStmtMap = smap
, connStmtMap = mkStatementCache smap
, connInsertSql = insertSql'
, connClose = PG.close conn
, connMigrateSql = migrate'
Expand All @@ -362,7 +363,6 @@ createBackend logFunc serverVersion smap conn =
, connRDBMS = "postgresql"
, connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL"
, connLogFunc = logFunc
, connStatementMiddleware = const pure
}

prepare' :: PG.Connection -> Text -> IO Statement
Expand Down Expand Up @@ -1603,7 +1603,7 @@ data PostgresConfHooks = PostgresConfHooks
-- The default implementation does nothing.
--
-- @since 2.11.0
, pgConfHooksCreateStatementCache :: IO StatementCache
, pgConfHooksCreateStatementCache :: IO MkStatementCache
}

-- | Default settings for 'PostgresConfHooks'. See the individual fields of 'PostgresConfHooks' for the default values.
Expand All @@ -1613,7 +1613,7 @@ defaultPostgresConfHooks :: PostgresConfHooks
defaultPostgresConfHooks = PostgresConfHooks
{ pgConfHooksGetServerVersion = getServerVersionNonEmpty
, pgConfHooksAfterCreate = const $ pure ()
, pgConfHooksCreateStatementCache = makeSimpleStatementCache
, pgConfHooksCreateStatementCache = mkSimpleStatementCache
}


Expand Down Expand Up @@ -1695,7 +1695,7 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do
-- with the difference that an actual database is not needed.
mockMigration :: Migration -> IO ()
mockMigration mig = do
smap <- makeSimpleStatementCache
smap <- mkStatementCache <$> mkSimpleStatementCache
let sqlbackend =
mkSqlBackend MkSqlBackendArgs
{ connPrepare = \_ -> do
Expand All @@ -1719,7 +1719,6 @@ mockMigration mig = do
, connRDBMS = undefined
, connLimitOffset = undefined
, connLogFunc = undefined
, connStatementMiddleware = const pure
}
result = runReaderT $ runWriterT $ runWriterT mig
resp <- result sqlbackend
Expand Down
9 changes: 3 additions & 6 deletions persistent-sqlite/Database/Persist/Sqlite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ import qualified Data.Conduit.Combinators as C
import qualified Data.Conduit.List as CL
import qualified Data.HashMap.Lazy as HashMap
import Data.Int (Int64)
import Data.IORef
import qualified Data.Map as Map
import Data.Monoid ((<>))
import Data.Pool (Pool)
import Data.Text (Text)
Expand All @@ -91,6 +89,7 @@ import Database.Persist.Compatible
#endif
import Database.Persist.Sql
import Database.Persist.SqlBackend
import Database.Persist.SqlBackend.StatementCache
import qualified Database.Persist.Sql.Util as Util
import qualified Database.Sqlite as Sqlite

Expand Down Expand Up @@ -267,7 +266,7 @@ wrapConnectionInfo connInfo conn logFunc = do
Sqlite.reset conn stmt
Sqlite.finalize stmt

smap <- makeSimpleStatementCache
smap <- mkStatementCache <$> mkSimpleStatementCache
return $
setConnMaxParams 999 $
setConnPutManySql putManySql $
Expand All @@ -288,7 +287,6 @@ wrapConnectionInfo connInfo conn logFunc = do
, connRDBMS = "sqlite"
, connLimitOffset = decorateSQLWithLimitOffset "LIMIT -1"
, connLogFunc = logFunc
, connStatementMiddleware = const pure
}
where
helper t getter = do
Expand Down Expand Up @@ -455,7 +453,7 @@ migrate' allDefs getter val = do
-- with the difference that an actual database isn't needed for it.
mockMigration :: Migration -> IO ()
mockMigration mig = do
smap <- makeSimpleStatementCache
smap <- mkStatementCache <$> mkSimpleStatementCache
let sqlbackend =
setConnMaxParams 999 $
mkSqlBackend MkSqlBackendArgs
Expand All @@ -480,7 +478,6 @@ mockMigration mig = do
, connRDBMS = "sqlite"
, connLimitOffset = decorateSQLWithLimitOffset "LIMIT -1"
, connLogFunc = undefined
, connStatementMiddleware = const pure
}
result = runReaderT . runWriterT . runWriterT $ mig
resp <- result sqlbackend
Expand Down
8 changes: 4 additions & 4 deletions persistent/Database/Persist/Sql/Raw.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ import qualified Data.Text as T
import Database.Persist
import Database.Persist.Sql.Types
import Database.Persist.Sql.Types.Internal
import Database.Persist.SqlBackend.Internal
import Database.Persist.Sql.Class
import Database.Persist.Sql.Types.Internal (statementCacheLookup, StatementCache (statementCacheInsert))
import Database.Persist.SqlBackend.Internal.StatementCache

rawQuery :: (MonadResource m, MonadReader env m, BackendCompatible SqlBackend env)
=> Text
Expand Down Expand Up @@ -76,7 +75,8 @@ getStmt sql = do

getStmtConn :: SqlBackend -> Text -> IO Statement
getStmtConn conn sql = do
smap <- liftIO $ statementCacheLookup (connStmtMap conn) sql
let cacheKey = mkCacheKeyFromQuery sql
smap <- liftIO $ statementCacheLookup (connStmtMap conn) cacheKey
case smap of
Just stmt -> connStatementMiddleware conn sql stmt
Nothing -> do
Expand All @@ -101,7 +101,7 @@ getStmtConn conn sql = do
then stmtQuery stmt' x
else liftIO $ throwIO $ StatementAlreadyFinalized sql
}
liftIO $ statementCacheInsert (connStmtMap conn) sql stmt
liftIO $ statementCacheInsert (connStmtMap conn) cacheKey stmt
connStatementMiddleware conn sql stmt

-- | Execute a raw SQL statement and return its results as a
Expand Down
1 change: 1 addition & 0 deletions persistent/Database/Persist/Sql/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Database.Persist.Class.PersistStore
import Database.Persist.Sql.Types
import Database.Persist.Sql.Types.Internal
import Database.Persist.Sql.Raw
import Database.Persist.SqlBackend.Internal.StatementCache

-- | Get a connection from the pool, run the given action, and then return the
-- connection to the pool.
Expand Down
2 changes: 0 additions & 2 deletions persistent/Database/Persist/Sql/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ module Database.Persist.Sql.Types
, SqlBackendCanRead, SqlBackendCanWrite, SqlReadT, SqlWriteT, IsSqlBackend
, OverflowNatural(..)
, ConnectionPoolConfig(..)
, StatementCache(..)
, makeSimpleStatementCache
) where

import Database.Persist.Types.Base (FieldCascade)
Expand Down
15 changes: 15 additions & 0 deletions persistent/Database/Persist/SqlBackend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module Database.Persist.SqlBackend
, setConnInsertManySql
, setConnUpsertSql
, setConnPutManySql
, setConnStatementMiddleware
) where

import Control.Monad.Reader
Expand All @@ -29,6 +30,7 @@ import qualified Database.Persist.SqlBackend.Internal as SqlBackend
import Database.Persist.SqlBackend.Internal.MkSqlBackend as Mk (MkSqlBackendArgs(..))
import Database.Persist.Types.Base
import Database.Persist.SqlBackend.Internal.InsertSqlResult
import Database.Persist.SqlBackend.Internal.Statement
import Data.List.NonEmpty (NonEmpty)

-- $utilities
Expand Down Expand Up @@ -158,3 +160,16 @@ setConnPutManySql
-> SqlBackend
setConnPutManySql mkQuery sb =
sb { connPutManySql = Just mkQuery }

-- | Set the 'connPutManySql field on the 'SqlBackend'. This can be used to
-- locally alter the statement prior to the statement being queried or executed.
-- If this is not set, it will have no effect.
--
-- @since 2.13.0.0
setConnStatementMiddleware
:: (Text -> Statement -> IO Statement)
-> SqlBackend
-> SqlBackend
setConnStatementMiddleware middleware sb =
sb { connStatementMiddleware = middleware }

7 changes: 6 additions & 1 deletion persistent/Database/Persist/SqlBackend/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Database.Persist.Types.Base
import Data.Int
import Data.IORef
import Control.Monad.Reader
import Database.Persist.SqlBackend.StatementCache
import Database.Persist.SqlBackend.Internal.MkSqlBackend
import Database.Persist.SqlBackend.Internal.Statement
import Database.Persist.SqlBackend.Internal.InsertSqlResult
Expand Down Expand Up @@ -74,7 +75,7 @@ data SqlBackend = SqlBackend
-- When left as 'Nothing', we default to using 'defaultPutMany'.
--
-- @since 2.8.1
, connStmtMap :: IORef (Map Text Statement)
, connStmtMap :: StatementCache
-- ^ A reference to the cache of statements. 'Statement's are keyed by
-- the 'Text' queries that generated them.
, connClose :: IO ()
Expand Down Expand Up @@ -137,6 +138,9 @@ data SqlBackend = SqlBackend
-- When left as 'Nothing', we default to using 'defaultRepsertMany'.
--
-- @since 2.9.0
, connStatementMiddleware :: Text -> Statement -> IO Statement
-- ^ Provide facilities for injecting middleware into statements
-- to allow for instrumenting queries.
}

-- | A function for creating a value of the 'SqlBackend' type. You should prefer
Expand All @@ -153,6 +157,7 @@ mkSqlBackend MkSqlBackendArgs {..} =
, connPutManySql = Nothing
, connUpsertSql = Nothing
, connInsertManySql = Nothing
, connStatementMiddleware = const pure
, ..
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import Data.Map (Map)
import Data.String
import Data.Text (Text)
import Database.Persist.Class.PersistStore
import Database.Persist.SqlBackend.StatementCache
import Database.Persist.SqlBackend.Internal.Statement
import Database.Persist.SqlBackend.Internal.StatementCache
import Database.Persist.SqlBackend.Internal.InsertSqlResult
import Database.Persist.SqlBackend.Internal.IsolationLevel
import Database.Persist.Types.Base
Expand All @@ -35,7 +35,7 @@ data MkSqlBackendArgs = MkSqlBackendArgs
, connInsertSql :: EntityDef -> [PersistValue] -> InsertSqlResult
-- ^ This function generates the SQL and values necessary for
-- performing an insert against the database.
, connStmtMap :: InternalStatementCache
, connStmtMap :: StatementCache
-- ^ A reference to the cache of statements. 'Statement's are keyed by
-- the 'Text' queries that generated them.
, connClose :: IO ()
Expand Down Expand Up @@ -81,9 +81,6 @@ data MkSqlBackendArgs = MkSqlBackendArgs
-- queries are the superior way to offer pagination.
, connLogFunc :: LogFunc
-- ^ A log function for the 'SqlBackend' to use.
, connStatementMiddleware :: Text -> Statement -> IO Statement
-- ^ Provide facilities for injecting middleware into statements
-- to allow for instrumenting queries.
}

type LogFunc = Loc -> LogSource -> LogLevel -> LogStr -> IO ()
63 changes: 16 additions & 47 deletions persistent/Database/Persist/SqlBackend/Internal/StatementCache.hs
Original file line number Diff line number Diff line change
@@ -1,54 +1,23 @@
module Database.Persist.SqlBackend.Internal.StatementCache
( StatementCache(..)
, InternalStatementCache
, makeSimpleStatementCache
, internalizeStatementCache
) where
module Database.Persist.SqlBackend.Internal.StatementCache where

import Data.Foldable
import Data.IORef
import qualified Data.Map as Map
import Data.Text (Text)
import Database.Persist.SqlBackend.Internal.Statement

class StatementCache c where
statementCacheLookup :: c -> Text -> IO (Maybe Statement)
statementCacheInsert :: c -> Text -> Statement -> IO ()
statementCacheClear :: c -> IO ()
statementCacheSize :: c -> IO Int

data InternalStatementCache = InternalStatementCache
{ _statementCacheLookup :: Text -> IO (Maybe Statement)
, _statementCacheInsert :: Text -> Statement -> IO ()
, _statementCacheClear :: IO ()
, _statementCacheSize :: IO Int
-- | A statement cache used to lookup statements that have already been prepared
-- for a given query.
--
-- @since 2.13.0
data StatementCache = StatementCache
{ statementCacheLookup :: StatementCacheKey -> IO (Maybe Statement)
, statementCacheInsert :: StatementCacheKey -> Statement -> IO ()
, statementCacheClear :: IO ()
, statementCacheSize :: IO Int
}

instance StatementCache InternalStatementCache where
statementCacheLookup = _statementCacheLookup
statementCacheInsert = _statementCacheInsert
statementCacheClear = _statementCacheClear
statementCacheSize = _statementCacheSize


internalizeStatementCache :: StatementCache c => c -> InternalStatementCache
internalizeStatementCache c = InternalStatementCache
{ _statementCacheLookup = statementCacheLookup c
, _statementCacheInsert = statementCacheInsert c
, _statementCacheClear = statementCacheClear c
, _statementCacheSize = statementCacheSize c
}

makeSimpleStatementCache :: IO InternalStatementCache
makeSimpleStatementCache = do
stmtMap <- newIORef Map.empty
pure $ InternalStatementCache
{ _statementCacheLookup = \sql -> Map.lookup sql <$> readIORef stmtMap
, _statementCacheInsert = \sql stmt ->
modifyIORef' stmtMap (Map.insert sql stmt)
, _statementCacheClear = do
oldStatements <- atomicModifyIORef' stmtMap (\oldStatements -> (Map.empty, oldStatements))
traverse_ stmtFinalize oldStatements
, _statementCacheSize = Map.size <$> readIORef stmtMap
}
newtype StatementCacheKey = StatementCacheKey { cacheKey :: Text }
-- Wrapping around this to allow for more efficient keying mechanisms
-- in the future, perhaps.

-- | Construct a `StatementCacheKey` from a raw SQL query.
mkCacheKeyFromQuery :: Text -> StatementCacheKey
mkCacheKeyFromQuery = StatementCacheKey
Loading

0 comments on commit c6f90a2

Please sign in to comment.