Skip to content

Commit

Permalink
request id middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
akshaymankar authored and mdimjasevic committed May 22, 2024
1 parent b8113da commit 60abca1
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 69 deletions.
29 changes: 8 additions & 21 deletions libs/wai-utilities/src/Network/Wai/Utilities/Request.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,15 @@ import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as Lazy
import Data.Id
import Data.Text.Lazy qualified as Text
import Data.UUID qualified as UUID
import Data.UUID.V4 as UUID
import Imports
import Network.HTTP.Types.Status (status400)
import Network.HTTP.Types
import Network.Wai
import Network.Wai.Predicate
import Network.Wai.Predicate.Request
import Network.Wai.Utilities.Error qualified as Wai
import Network.Wai.Utilities.ZAuth ((.&>))
import Pipes
import Pipes.Prelude qualified as P
import System.Logger ((.=), (~~))
import System.Logger qualified as Log

readBody :: (MonadIO m, HasRequest r) => r -> m LByteString
readBody r = liftIO $ Lazy.fromChunks <$> P.toListM chunks
Expand Down Expand Up @@ -73,22 +69,13 @@ parseOptionalBody r =
nonEmptyBody "" = Nothing
nonEmptyBody ne = Just ne

lookupRequestId :: HasRequest r => r -> Maybe ByteString
lookupRequestId = lookup "Request-Id" . requestHeaders . getRequest

-- | Like 'lookupRequestId' it looks up the request ID in the request's headers.
-- In case there is no such header, a fresh ID is returned.
getRequestId :: (HasRequest r, Show r) => Log.Logger -> r -> IO RequestId
getRequestId logger req = case lookupRequestId req of
Just rid -> pure (RequestId rid)
Nothing -> do
localRid <- RequestId . UUID.toASCIIBytes <$> UUID.nextRandom
unless (rawPathInfo (getRequest req) `elem` ["/i/status", "/i/metrics"]) $
Log.info logger $
"request-id" .= localRid
~~ "request" .= (show req)
~~ Log.msg (Log.val "generated a new request id for local request")
pure localRid
lookupRequestId :: HeaderName -> Request -> Maybe ByteString
lookupRequestId reqIdHeaderName =
lookup reqIdHeaderName . requestHeaders

getRequestId :: HeaderName -> Request -> RequestId
getRequestId reqIdHeaderName req =
RequestId $ fromMaybe "N/A" $ lookupRequestId reqIdHeaderName req

----------------------------------------------------------------------------
-- Typed JSON 'Request'
Expand Down
51 changes: 36 additions & 15 deletions libs/wai-utilities/src/Network/Wai/Utilities/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module Network.Wai.Utilities.Server
route,

-- * Middlewares
requestIdMiddleware,
catchErrors,
catchErrorsWithRequestId,
OnErrorMetrics,
Expand All @@ -42,10 +43,13 @@ module Network.Wai.Utilities.Server
logError,
logError',
logErrorMsg,
logIO,
runHandlers,
restrict,
flushRequestBody,

-- * Constants
defaultRequestIdHeaderName,
federationRequestIdHeaderName,
)
where

Expand All @@ -62,18 +66,20 @@ import Data.Domain (domainText)
import Data.Metrics.GC (spawnGCMetricsCollector)
import Data.Metrics.Middleware
import Data.Streaming.Zlib (ZlibException (..))
import Data.Text.Encoding qualified as Text
import Data.Text.Encoding.Error (lenientDecode)
import Data.Text.Lazy qualified as LT
import Data.Text.Lazy.Encoding qualified as LT
import Data.UUID qualified as UUID
import Data.UUID.V4 qualified as UUID
import Imports
import Network.HTTP.Types.Status
import Network.HTTP.Types
import Network.Wai
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Internal (TimeoutThread)
import Network.Wai.Internal qualified as WaiInt
import Network.Wai.Predicate hiding (Error, err, status)
import Network.Wai.Predicate qualified as P
import Network.Wai.Predicate.Request (HasRequest)
import Network.Wai.Routing.Route (App, Continue, Routes, Tree)
import Network.Wai.Routing.Route qualified as Route
import Network.Wai.Utilities.Error qualified as Error
Expand Down Expand Up @@ -185,8 +191,22 @@ route rt rq k = Route.routeWith (Route.Config $ errorRs' noEndpoint) rt rq (lift
--------------------------------------------------------------------------------
-- Middlewares

catchErrors :: Logger -> OnErrorMetrics -> Middleware
catchErrors l m = catchErrorsWithRequestId lookupRequestId l m
requestIdMiddleware :: Logger -> HeaderName -> Middleware
requestIdMiddleware logger reqIdHeaderName origApp req responder =
case lookup reqIdHeaderName req.requestHeaders of
Just _ -> origApp req responder
Nothing -> do
reqId <- Text.encodeUtf8 . UUID.toText <$> UUID.nextRandom
Log.info logger $
msg ("generated a new request id for local request" :: ByteString)
. field "request" reqId
. field "method" (requestMethod req)
. field "path" (rawPathInfo req)
let reqWithId = req {requestHeaders = (reqIdHeaderName, reqId) : req.requestHeaders}
origApp reqWithId responder

catchErrors :: Logger -> HeaderName -> OnErrorMetrics -> Middleware
catchErrors l reqIdHeaderName m = catchErrorsWithRequestId (lookupRequestId reqIdHeaderName) l m

-- | Create a middleware that catches exceptions and turns
-- them into appropriate 'Error' responses, thereby logging
Expand Down Expand Up @@ -258,8 +278,9 @@ heavyDebugLogging ::
((Request, LByteString) -> Maybe (Request, LByteString)) ->
Level ->
Logger ->
HeaderName ->
Middleware
heavyDebugLogging sanitizeReq lvl lgr app = \req cont -> do
heavyDebugLogging sanitizeReq lvl lgr reqIdHeaderName app = \req cont -> do
(bdy, req') <-
if lvl `elem` [Trace, Debug]
then cloneBody req
Expand All @@ -278,7 +299,7 @@ heavyDebugLogging sanitizeReq lvl lgr app = \req cont -> do
logMostlyEverything req bdy resp = Log.debug lgr logMsg
where
logMsg =
field "request" (fromMaybe "N/A" $ lookupRequestId req)
field "request" (fromMaybe "N/A" $ lookupRequestId reqIdHeaderName req)
. field "request_details" (show req)
. field "request_body" bdy
. field "response_status" (show $ responseStatus resp)
Expand Down Expand Up @@ -377,12 +398,18 @@ onError g mReqId m r k e = liftIO $ do
flushRequestBody r
k (jsonResponseToWai resp)

defaultRequestIdHeaderName :: HeaderName
defaultRequestIdHeaderName = "Request-Id"

federationRequestIdHeaderName :: HeaderName
federationRequestIdHeaderName = "Wire-Origin-Request-Id"

-- | Log an 'Error' response for debugging purposes.
--
-- It would be nice to have access to the request body here, but that's already streamed away
-- by the handler in all likelyhood. See 'heavyDebugLogging'.
logError :: (MonadIO m, HasRequest r) => Logger -> Maybe r -> Wai.Error -> m ()
logError g mr = logError' g (lookupRequestId =<< mr)
logError :: (MonadIO m) => Logger -> Maybe Request -> Wai.Error -> m ()
logError g mr = logError' g (lookupRequestId defaultRequestIdHeaderName =<< mr)

logError' :: (MonadIO m) => Logger -> Maybe ByteString -> Wai.Error -> m ()
logError' g mr e = liftIO $ doLog g (logErrorMsgWithRequest mr e)
Expand Down Expand Up @@ -421,12 +448,6 @@ logErrorMsgWithRequest :: Maybe ByteString -> Wai.Error -> Msg -> Msg
logErrorMsgWithRequest mr e =
field "request" (fromMaybe "N/A" mr) . logErrorMsg e

logIO :: (ToBytes msg, HasRequest r) => Logger -> Level -> Maybe r -> msg -> IO ()
logIO lg lv r a =
let reqId = field "request" . fromMaybe "N/A" . lookupRequestId <$> r
mesg = fromMaybe id reqId . msg a
in Log.log lg lv mesg

runHandlers :: SomeException -> [Handler m a] -> m a
runHandlers e [] = throw e
runHandlers e (Handler h : hs) = maybe (runHandlers e hs) h (fromException e)
Expand Down
5 changes: 3 additions & 2 deletions services/brig/src/Brig/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,13 @@ mkApp o = do
. Metrics.servantPrometheusMiddleware (Proxy @ServantCombinedAPI)
. GZip.gunzip
. GZip.gzip GZip.def
. catchErrors (e ^. applog) [Right $ e ^. metrics]
. catchErrors (e ^. applog) defaultRequestIdHeaderName [Right $ e ^. metrics]
. requestIdMiddleware (e ^. applog) defaultRequestIdHeaderName

-- the servant API wraps the one defined using wai-routing
servantApp :: Env -> Wai.Application
servantApp e0 req cont = do
rid <- getRequestId (e0 ^. applog) req
let rid = getRequestId defaultRequestIdHeaderName req
let e = requestId .~ rid $ e0
let localDomain = view (settings . federationDomain) e
Servant.serveWithContext
Expand Down
3 changes: 2 additions & 1 deletion services/cannon/src/Cannon/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ run o = do
versionMiddleware (foldMap expandVersionExp (o ^. disabledAPIVersions))
. servantPrometheusMiddleware (Proxy @CombinedAPI)
. Gzip.gzip Gzip.def
. catchErrors g [Right m]
. catchErrors g defaultRequestIdHeaderName [Right m]
. requestIdMiddleware g defaultRequestIdHeaderName
app :: Application
app = middleware (serve (Proxy @CombinedAPI) server)
server :: Servant.Server CombinedAPI
Expand Down
5 changes: 3 additions & 2 deletions services/cannon/src/Cannon/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import Data.Text.Encoding
import Imports
import Network.Wai
import Network.Wai.Utilities.Request qualified as Wai
import Network.Wai.Utilities.Server
import Servant qualified
import System.Logger qualified as Logger
import System.Logger.Class hiding (info)
Expand Down Expand Up @@ -113,8 +114,8 @@ mkEnv m external o l d p g t =

runCannon :: Env -> Cannon a -> Request -> IO a
runCannon e c r = do
rid <- Wai.getRequestId e.applog r
let e' = e {reqId = rid}
let rid = Wai.getRequestId defaultRequestIdHeaderName r
e' = e {reqId = rid}
runCannon' e' c

runCannon' :: Env -> Cannon a -> IO a
Expand Down
7 changes: 4 additions & 3 deletions services/cargohold/src/CargoHold/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ mkApp o = Codensity $ \k ->
versionMiddleware (foldMap expandVersionExp (o ^. settings . disabledAPIVersions))
. servantPrometheusMiddleware (Proxy @CombinedAPI)
. GZip.gzip GZip.def
. catchErrors (e ^. appLogger) [Right $ e ^. metrics]
. catchErrors (e ^. appLogger) defaultRequestIdHeaderName [Right $ e ^. metrics]
. requestIdMiddleware (e ^. appLogger) defaultRequestIdHeaderName
servantApp :: Env -> Application
servantApp e0 r cont = do
rid <- getRequestId (e0 ^. appLogger) r
let e = requestId .~ rid $ e0
let rid = getRequestId defaultRequestIdHeaderName r
e = requestId .~ rid $ e0
Servant.serveWithContext
(Proxy @CombinedAPI)
((o ^. settings . federationDomain) :. Servant.EmptyContext)
Expand Down
23 changes: 18 additions & 5 deletions services/federator/src/Federator/ExternalServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ data API mode = API
:- "federation"
:> Capture "component" Component
:> Capture "rpc" RPC
:> Header "Wire-Origin-Request-Id" RequestId
:> Header' '[Required, Strict] "Wire-Origin-Request-Id" RequestId
:> Header' '[Required, Strict] OriginDomainHeaderName Domain
:> Header' '[Required, Strict] "X-SSL-Certificate" CertHeader
:> Endpath
Expand Down Expand Up @@ -119,10 +119,10 @@ server ::
server mgr intPort interpreter =
API
{ status = Health.status mgr "internal server" intPort,
externalRequest = \component rpc mReqId remoteDomain remoteCert ->
externalRequest = \component rpc rid remoteDomain remoteCert ->
Tagged $ \req respond -> do
-- TODO: Log generated request ID
rid <- maybe (RequestId . Text.encodeUtf8 . UUID.toText <$> UUID.nextRandom) pure mReqId
-- rid <- maybe (RequestId . Text.encodeUtf8 . UUID.toText <$> UUID.nextRandom) pure mReqId
runCodensity (interpreter rid (callInward component rpc rid remoteDomain remoteCert req)) respond
}

Expand Down Expand Up @@ -193,8 +193,21 @@ callInward component (RPC rpc) rid originDomain (CertHeader cert) wreq = do
}

serveInward :: Env -> Int -> IO ()
serveInward env =
serveInward env = do
let middleware =
requestIdMiddleware
. Metrics.servantPrometheusMiddleware (Proxy :: Proxy (ToServantApi API))
serveServant
(Metrics.servantPrometheusMiddleware $ Proxy @(ToServantApi API))
middleware
(server env._httpManager env._internalPort $ runFederator env)
env

requestIdMiddleware :: Wai.Middleware
requestIdMiddleware origApp req responder =
-- TODO: extract constant
case lookup "Wire-Origin-Request-Id" req.requestHeaders of
Just _ -> origApp req responder
Nothing -> do
reqId <- Text.encodeUtf8 . UUID.toText <$> UUID.nextRandom
let reqWithId = req {Wai.requestHeaders = ("Wire-Origin-Request-Id", reqId) : req.requestHeaders}
origApp reqWithId responder
18 changes: 10 additions & 8 deletions services/federator/src/Federator/InternalServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ import Data.Domain
import Data.Id
import Data.Metrics.Servant qualified as Metrics
import Data.Proxy
import Data.Text.Encoding qualified as T
import Data.UUID as UUID
import Data.UUID.V4 as UUID
import Federator.Env
import Federator.Error.ServerError
import Federator.Health qualified as Health
Expand All @@ -42,6 +39,7 @@ import Imports
import Network.HTTP.Client (Manager)
import Network.HTTP.Types qualified as HTTP
import Network.Wai qualified as Wai
import Network.Wai.Utilities.Server
import Polysemy
import Polysemy.Error
import Polysemy.Input
Expand All @@ -67,7 +65,7 @@ data API mode = API
internalRequest ::
mode
:- "rpc"
:> Header "Wire-Origin-Request-Id" RequestId
:> Header' '[Required, Strict] "Wire-Origin-Request-Id" RequestId
:> Capture "domain" Domain
:> Capture "component" Component
:> Capture "rpc" RPC
Expand Down Expand Up @@ -95,10 +93,10 @@ server ::
server mgr extPort interpreter =
API
{ status = Health.status mgr "external server" extPort,
internalRequest = \mReqId remoteDomain component rpc ->
internalRequest = \rid remoteDomain component rpc ->
Tagged $ \req respond -> do
-- TODO: Log generated request ID
rid <- maybe (RequestId . T.encodeUtf8 . UUID.toText <$> UUID.nextRandom) pure mReqId
-- rid <- maybe (RequestId . T.encodeUtf8 . UUID.toText <$> UUID.nextRandom) pure mReqId
-- rid <- case
-- rid <- case mReqId of
-- Just r -> pure r
Expand Down Expand Up @@ -156,8 +154,12 @@ callOutward rid targetDomain component (RPC path) req = do
pure $ streamingResponseToWai resp

serveOutward :: Env -> Int -> IO ()
serveOutward env =
serveOutward env = do
let middleware =
-- TODO: extract constant
requestIdMiddleware env._applog federationRequestIdHeaderName
. Metrics.servantPrometheusMiddleware (Proxy :: Proxy (ToServantApi API))
serveServant
(Metrics.servantPrometheusMiddleware $ Proxy @(ToServantApi API))
middleware
(server env._httpManager env._externalPort $ runFederator env)
env
7 changes: 4 additions & 3 deletions services/galley/src/Galley/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ mkApp opts =
. servantPrometheusMiddleware (Proxy @CombinedAPI)
. GZip.gunzip
. GZip.gzip GZip.def
. catchErrors logger [Right metrics]
. catchErrors logger defaultRequestIdHeaderName [Right metrics]
. requestIdMiddleware logger defaultRequestIdHeaderName
Codensity $ \k -> finally (k ()) $ do
Log.info logger $ Log.msg @Text "Galley application finished."
Log.flush logger
Expand Down Expand Up @@ -129,8 +130,8 @@ mkApp opts =

servantApp :: Env -> Application
servantApp e0 r cont = do
rid <- getRequestId (e0 ^. applog) r
let e = reqId .~ rid $ e0
let rid = getRequestId defaultRequestIdHeaderName r
e = reqId .~ rid $ e0
Servant.serveWithContext
(Proxy @CombinedAPI)
( view (options . settings . federationDomain) e
Expand Down
7 changes: 4 additions & 3 deletions services/gundeck/src/Gundeck/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ run o = do
. waiPrometheusMiddleware sitemap
. GZip.gunzip
. GZip.gzip GZip.def
. catchErrors (e ^. applog) [Right $ e ^. monitor]
. catchErrors (e ^. applog) defaultRequestIdHeaderName [Right $ e ^. monitor]
. requestIdMiddleware (e ^. applog) defaultRequestIdHeaderName

type CombinedAPI = GundeckAPI :<|> Servant.Raw

mkApp :: Env -> Wai.Application
mkApp env0 req cont = do
rid <- getRequestId (env0 ^. applog) req
let env = reqId .~ rid $ env0
let rid = getRequestId defaultRequestIdHeaderName req
env = reqId .~ rid $ env0
Servant.serve
(Proxy @CombinedAPI)
(servantSitemap' env :<|> Servant.Tagged (runGundeckWithRoutes env))
Expand Down
3 changes: 2 additions & 1 deletion services/proxy/src/Proxy/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ run o = do
versionMiddleware (foldMap expandVersionExp (o ^. disabledAPIVersions))
. waiPrometheusMiddleware (sitemap e)
. GZip.gunzip
. catchErrors (e ^. applog) [Right m]
. catchErrors (e ^. applog) defaultRequestIdHeaderName [Right m]
. requestIdMiddleware (e ^. applog) defaultRequestIdHeaderName
runSettingsWithShutdown s (middleware app) Nothing `finally` destroyEnv e
Loading

0 comments on commit 60abca1

Please sign in to comment.