From 17fc8db2091f683d46346b23fb71dd748fd621a5 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 25 Jul 2024 11:27:00 +0200 Subject: [PATCH 01/10] Remove XIO leftover --- src/Network/GRPC/Internal/XIO.hs | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 src/Network/GRPC/Internal/XIO.hs diff --git a/src/Network/GRPC/Internal/XIO.hs b/src/Network/GRPC/Internal/XIO.hs deleted file mode 100644 index 498a9acb..00000000 --- a/src/Network/GRPC/Internal/XIO.hs +++ /dev/null @@ -1,8 +0,0 @@ --- | Temporary re-export of @Control.Monad.XIO@ --- --- TODO: Eventually @Control.Monad.XIO@ should move to its own package. -module Network.GRPC.Internal.XIO ( - module Control.Monad.XIO - ) where - -import Control.Monad.XIO From 8122cef5060b396433263f0db5da51cbddb36689 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 25 Jul 2024 08:28:18 +0200 Subject: [PATCH 02/10] Generalize the HKD infra slightly --- src/Network/GRPC/Spec/Headers/Request.hs | 25 ++++++++------- src/Network/GRPC/Spec/Headers/Response.hs | 35 ++++++++++---------- util/Network/GRPC/Util/HKD.hs | 39 +++++++++++++++++++++-- 3 files changed, 69 insertions(+), 30 deletions(-) diff --git a/src/Network/GRPC/Spec/Headers/Request.hs b/src/Network/GRPC/Spec/Headers/Request.hs index fe9505fe..bf370a0a 100644 --- a/src/Network/GRPC/Spec/Headers/Request.hs +++ b/src/Network/GRPC/Spec/Headers/Request.hs @@ -118,6 +118,7 @@ data RequestHeaders_ f = RequestHeaders { -- | Unrecognized headers , requestUnrecognized :: HKD f () } + deriving anyclass (HKD.Coerce) -- | Request headers (without allowing for invalid headers) -- @@ -153,19 +154,19 @@ deriving stock instance Eq RequestHeaders' -- about the instance for RequestHeaders for some reason. instance HKD.Traversable RequestHeaders_ where - sequence x = + traverse f x = RequestHeaders - <$> requestTimeout x - <*> requestCompression x - <*> requestAcceptCompression x - <*> requestContentType x - <*> requestMessageType x - <*> requestUserAgent x - <*> requestIncludeTE x - <*> requestTraceContext x - <*> requestPreviousRpcAttempts x - <*> pure (requestMetadata x) - <*> requestUnrecognized x + <$> (f $ requestTimeout x) + <*> (f $ requestCompression x) + <*> (f $ requestAcceptCompression x) + <*> (f $ requestContentType x) + <*> (f $ requestMessageType x) + <*> (f $ requestUserAgent x) + <*> (f $ requestIncludeTE x) + <*> (f $ requestTraceContext x) + <*> (f $ requestPreviousRpcAttempts x) + <*> (pure $ requestMetadata x) + <*> (f $ requestUnrecognized x) {------------------------------------------------------------------------------- Construction diff --git a/src/Network/GRPC/Spec/Headers/Response.hs b/src/Network/GRPC/Spec/Headers/Response.hs index e24ccba1..baa725bd 100644 --- a/src/Network/GRPC/Spec/Headers/Response.hs +++ b/src/Network/GRPC/Spec/Headers/Response.hs @@ -108,6 +108,7 @@ data ResponseHeaders_ f = ResponseHeaders { -- | Unrecognized headers , responseUnrecognized :: HKD f () } + deriving anyclass (HKD.Coerce) -- | Response headers (without allowing for invalid headers) -- @@ -127,13 +128,13 @@ deriving stock instance Show ResponseHeaders' deriving stock instance Eq ResponseHeaders' instance HKD.Traversable ResponseHeaders_ where - sequence x = + traverse f x = ResponseHeaders - <$> responseCompression x - <*> responseAcceptCompression x - <*> responseContentType x - <*> pure (responseMetadata x) - <*> responseUnrecognized x + <$> (f $ responseCompression x) + <*> (f $ responseAcceptCompression x) + <*> (f $ responseContentType x) + <*> (pure $ responseMetadata x) + <*> (f $ responseUnrecognized x) -- | Information sent by the peer after the final output -- @@ -170,6 +171,7 @@ data ProperTrailers_ f = ProperTrailers { -- | Unrecognized trailers , properTrailersUnrecognized :: HKD f () } + deriving anyclass (HKD.Coerce) -- | Default constructor for 'ProperTrailers' simpleProperTrailers :: forall f. @@ -201,14 +203,14 @@ deriving stock instance Show ProperTrailers' deriving stock instance Eq ProperTrailers' instance HKD.Traversable ProperTrailers_ where - sequence x = + traverse f x = ProperTrailers - <$> properTrailersGrpcStatus x - <*> properTrailersGrpcMessage x - <*> properTrailersPushback x - <*> properTrailersOrcaLoadReport x - <*> pure (properTrailersMetadata x) - <*> properTrailersUnrecognized x + <$> (f $ properTrailersGrpcStatus x) + <*> (f $ properTrailersGrpcMessage x) + <*> (f $ properTrailersPushback x) + <*> (f $ properTrailersOrcaLoadReport x) + <*> (pure $ properTrailersMetadata x) + <*> (f $ properTrailersUnrecognized x) -- | Trailers sent in the gRPC Trailers-Only case -- @@ -222,6 +224,7 @@ data TrailersOnly_ f = TrailersOnly { -- | All regular trailers can also appear in the Trailers-Only case , trailersOnlyProper :: ProperTrailers_ f } + deriving anyclass (HKD.Coerce) -- | Trailers for the Trailers-Only case (without allowing for invalid trailers) type TrailersOnly = TrailersOnly_ Undecorated @@ -237,10 +240,10 @@ deriving stock instance Show TrailersOnly' deriving stock instance Eq TrailersOnly' instance HKD.Traversable TrailersOnly_ where - sequence x = + traverse f x = TrailersOnly - <$> trailersOnlyContentType x - <*> HKD.sequence (trailersOnlyProper x) + <$> (f $ trailersOnlyContentType x) + <*> (HKD.traverse f $ trailersOnlyProper x) -- | 'ProperTrailers' is a subset of 'TrailersOnly' properTrailersToTrailersOnly :: diff --git a/util/Network/GRPC/Util/HKD.hs b/util/Network/GRPC/Util/HKD.hs index 54bd6c3e..3ea903a1 100644 --- a/util/Network/GRPC/Util/HKD.hs +++ b/util/Network/GRPC/Util/HKD.hs @@ -10,7 +10,9 @@ module Network.GRPC.Util.HKD ( , DecoratedWith , Undecorated -- * Dealing with HKD records + , Coerce(..) , Traversable(..) + , sequence , sequenceThrow -- * Dealing with HKD fields , ValidDecoration @@ -21,8 +23,10 @@ import Prelude hiding (Traversable(..), pure) import Prelude qualified import Control.Monad.Except (MonadError, throwError) +import Data.Functor.Identity import Data.Kind import Data.Proxy +import Unsafe.Coerce (unsafeCoerce) {------------------------------------------------------------------------------- Definition @@ -40,8 +44,39 @@ type instance HKD (DecoratedWith f) x = f x Dealing with HKD records -------------------------------------------------------------------------------} -class Traversable t where - sequence :: Applicative f => t (DecoratedWith f) -> f (t Undecorated) +class Coerce t where + -- | Drop decoration + -- + -- /NOTE/: The default instance is valid only for datatypes that are morally + -- have a "higher order representative role"; that is, the type of every field + -- of @t (DecoratedWith Identity)@ must be representationally equal to the + -- corresponding type of @t Undecorated@. In the typical case of + -- + -- > data SomeRecord f = MkSomeRecord { + -- > field1 :: HKD f a1 + -- > , field2 :: HKD f a2 + -- > .. + -- > , fieldN :: aN + -- > , .. + -- > , fieldM :: HKD f aM + -- > } + -- + -- where every field either has type @HKD f a@ or @a@ (not mentioning @f@ at + -- all), this will automatically be the case. + undecorate :: t (DecoratedWith Identity) -> t Undecorated + undecorate = unsafeCoerce + +class Coerce t => Traversable t where + traverse :: + Applicative m + => (forall a. f a -> m (g a)) + -> t (DecoratedWith f) + -> m (t (DecoratedWith g)) + +sequence :: + (Traversable t, Applicative m) + => t (DecoratedWith m) -> m (t Undecorated) +sequence = fmap undecorate . traverse (fmap Identity) sequenceThrow :: (MonadError e m, Traversable t) From f0479a5e5d01a0f3b614b88530aaf24ec3c2ef75 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Wed, 24 Jul 2024 13:13:15 +0200 Subject: [PATCH 03/10] Refactor `processResponseHeaders` slightly --- src/Network/GRPC/Client/Session.hs | 40 +++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/src/Network/GRPC/Client/Session.hs b/src/Network/GRPC/Client/Session.hs index 4e0e2a25..8ba7e24b 100644 --- a/src/Network/GRPC/Client/Session.hs +++ b/src/Network/GRPC/Client/Session.hs @@ -117,6 +117,29 @@ instance SupportsClientRpc rpc => InitiateSession (ClientSession rpc) where instance NoTrailers (ClientSession rpc) where noTrailers _ = NoMetadata +{------------------------------------------------------------------------------- + Process response headers +-------------------------------------------------------------------------------} + +data RequiredHeaders = RequiredHeaders { + requiredCompression :: Maybe CompressionId + } + +-- | Validate /all/ headers, and then extract the required +validateAll :: ResponseHeaders' -> Either InvalidHeaders RequiredHeaders +validateAll = fmap go . HKD.sequence + where + go :: ResponseHeaders -> RequiredHeaders + go responseHeaders = RequiredHeaders { + requiredCompression = responseCompression responseHeaders + } + +-- | Validate only the required headers +validateRequired :: ResponseHeaders' -> Either InvalidHeaders RequiredHeaders +validateRequired responseHeaders' = + RequiredHeaders + <$> responseCompression responseHeaders' + -- | Process response headers -- -- This is the client equivalent of @@ -127,15 +150,11 @@ processResponseHeaders :: -> IO Compression processResponseHeaders (ClientSession conn) responseHeaders' = do Connection.updateConnectionMeta conn responseHeaders' - - if connVerifyHeaders connParams then - case HKD.sequence responseHeaders' of - Left err -> throwIO $ CallSetupInvalidResponseHeaders err - Right hdrs -> getCompression $ responseCompression hdrs - else - case responseCompression responseHeaders' of - Left err -> throwIO $ CallSetupInvalidResponseHeaders err - Right mcid -> getCompression mcid + required <- either invalid return $ + if connVerifyHeaders connParams + then validateAll responseHeaders' + else validateRequired responseHeaders' + getCompression (requiredCompression required) where connParams :: ConnParams connParams = Connection.connParams conn @@ -147,6 +166,9 @@ processResponseHeaders (ClientSession conn) responseHeaders' = do Just compr -> return compr Nothing -> throwIO $ CallSetupUnsupportedCompression cid + invalid :: forall x. InvalidHeaders -> IO x + invalid = throwIO . CallSetupInvalidResponseHeaders + {------------------------------------------------------------------------------- Exceptions -------------------------------------------------------------------------------} From f483c0e0f397c4ef45c42932c105591e4d4f8c5f Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Wed, 24 Jul 2024 13:45:17 +0200 Subject: [PATCH 04/10] Move `GrpcException` to `.Status` --- src/Network/GRPC/Spec.hs | 11 +++++---- src/Network/GRPC/Spec/Headers/Response.hs | 25 ++------------------ src/Network/GRPC/Spec/Status.hs | 28 +++++++++++++++++++++++ 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/Network/GRPC/Spec.hs b/src/Network/GRPC/Spec.hs index 2d997486..b860f5d5 100644 --- a/src/Network/GRPC/Spec.hs +++ b/src/Network/GRPC/Spec.hs @@ -122,6 +122,10 @@ module Network.GRPC.Spec ( , TrailersOnly' , Pushback(..) , simpleProperTrailers + -- ** Termination + , GrpcNormalTermination(..) + , grpcExceptionToTrailers + , grpcClassifyTermination -- ** Serialization , parseProperTrailers , parseProperTrailers' @@ -133,17 +137,14 @@ module Network.GRPC.Spec ( , buildPushback , properTrailersToTrailersOnly , trailersOnlyToProperTrailers - -- *** Status + -- * Status , GrpcStatus(..) , GrpcError(..) , fromGrpcStatus , toGrpcStatus - -- *** gRPC termination + -- ** Exceptions , GrpcException(..) , throwGrpcError - , GrpcNormalTermination(..) - , grpcExceptionToTrailers - , grpcClassifyTermination -- * Metadata , CustomMetadata(CustomMetadata) , customMetadataName diff --git a/src/Network/GRPC/Spec/Headers/Response.hs b/src/Network/GRPC/Spec/Headers/Response.hs index baa725bd..b806bb8c 100644 --- a/src/Network/GRPC/Spec/Headers/Response.hs +++ b/src/Network/GRPC/Spec/Headers/Response.hs @@ -22,9 +22,7 @@ module Network.GRPC.Spec.Headers.Response ( , trailersOnlyToProperTrailers , properTrailersToTrailersOnly , classifyServerResponse - -- * gRPC termination - , GrpcException(..) - , throwGrpcError + -- * Termination , GrpcNormalTermination(..) , grpcClassifyTermination , grpcExceptionToTrailers @@ -427,28 +425,9 @@ parsePushback bs = return DoNotRetry {------------------------------------------------------------------------------- - gRPC exceptions + Termination -------------------------------------------------------------------------------} --- | Server indicated a gRPC error --- --- For the common case where you just want to set 'grpcError', you can use --- 'throwGrpcError'. -data GrpcException = GrpcException { - grpcError :: GrpcError - , grpcErrorMessage :: Maybe Text - , grpcErrorMetadata :: [CustomMetadata] - } - deriving stock (Show) - deriving anyclass (Exception) - -throwGrpcError :: GrpcError -> IO a -throwGrpcError grpcError = throwIO $ GrpcException { - grpcError - , grpcErrorMessage = Nothing - , grpcErrorMetadata = [] - } - -- | Server indicated normal termination -- -- This is only an exception if the client tries to send any further messages. diff --git a/src/Network/GRPC/Spec/Status.hs b/src/Network/GRPC/Spec/Status.hs index 41e59a41..22f7cbe0 100644 --- a/src/Network/GRPC/Spec/Status.hs +++ b/src/Network/GRPC/Spec/Status.hs @@ -4,11 +4,17 @@ module Network.GRPC.Spec.Status ( , GrpcError(..) , fromGrpcStatus , toGrpcStatus + -- * Exceptions + , GrpcException(..) + , throwGrpcError ) where import Control.Exception +import Data.Text (Text) import GHC.Generics (Generic) +import Network.GRPC.Spec.CustomMetadata.Raw (CustomMetadata) + {------------------------------------------------------------------------------- gRPC status -------------------------------------------------------------------------------} @@ -209,3 +215,25 @@ toGrpcStatus 15 = Just $ GrpcError $ GrpcDataLoss toGrpcStatus 16 = Just $ GrpcError $ GrpcUnauthenticated toGrpcStatus _ = Nothing +{------------------------------------------------------------------------------- + gRPC exceptions +-------------------------------------------------------------------------------} + +-- | Server indicated a gRPC error +-- +-- For the common case where you just want to set 'grpcError', you can use +-- 'throwGrpcError'. +data GrpcException = GrpcException { + grpcError :: GrpcError + , grpcErrorMessage :: Maybe Text + , grpcErrorMetadata :: [CustomMetadata] + } + deriving stock (Show) + deriving anyclass (Exception) + +throwGrpcError :: GrpcError -> IO a +throwGrpcError grpcError = throwIO $ GrpcException { + grpcError + , grpcErrorMessage = Nothing + , grpcErrorMetadata = [] + } From e7796dd8ff39ad06b2b829168c04b45584a1a011 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Wed, 24 Jul 2024 14:03:37 +0200 Subject: [PATCH 05/10] Move all serialization code to separate modules This avoids some cyclic module dependencies. --- grapesy.cabal | 14 +- interop/Interop/Client/Common.hs | 7 +- .../TestCase/ClientCompressedStreaming.hs | 13 +- .../Client/TestCase/ClientCompressedUnary.hs | 14 +- interop/Interop/Client/TestCase/EmptyUnary.hs | 9 +- .../TestCase/ServerCompressedStreaming.hs | 5 +- .../Client/TestCase/ServerCompressedUnary.hs | 17 +- .../Client/TestCase/ServerStreaming.hs | 2 +- .../Client/TestCase/SpecialStatusMessage.hs | 4 +- .../Client/TestCase/StatusCodeAndMessage.hs | 4 +- interop/Interop/Server/Common.hs | 10 +- .../Server/TestService/StreamingInputCall.hs | 9 +- .../Server/TestService/StreamingOutputCall.hs | 7 +- .../Interop/Server/TestService/UnaryCall.hs | 13 +- src/Network/GRPC/Client.hs | 13 +- src/Network/GRPC/Client/Call.hs | 36 +- src/Network/GRPC/Client/Session.hs | 5 +- src/Network/GRPC/Common.hs | 7 +- src/Network/GRPC/Server.hs | 6 +- src/Network/GRPC/Server/Call.hs | 33 +- src/Network/GRPC/Server/RequestHandler.hs | 1 + src/Network/GRPC/Server/Session.hs | 5 +- src/Network/GRPC/Server/StreamType.hs | 2 +- src/Network/GRPC/Spec.hs | 56 +- src/Network/GRPC/Spec/CustomMetadata/Raw.hs | 110 +--- src/Network/GRPC/Spec/Headers/Common.hs | 204 +------ .../GRPC/Spec/Headers/PseudoHeaders.hs | 59 +- src/Network/GRPC/Spec/Headers/Request.hs | 280 +--------- src/Network/GRPC/Spec/Headers/Response.hs | 493 ----------------- src/Network/GRPC/Spec/MessageMeta.hs | 40 ++ src/Network/GRPC/Spec/Serialization.hs | 61 +++ .../GRPC/Spec/{ => Serialization}/Base64.hs | 2 +- .../GRPC/Spec/Serialization/CustomMetadata.hs | 134 +++++ .../GRPC/Spec/Serialization/Headers/Common.hs | 209 +++++++ .../Serialization/Headers/PseudoHeaders.hs | 61 +++ .../Spec/Serialization/Headers/Request.hs | 275 ++++++++++ .../Spec/Serialization/Headers/Response.hs | 513 ++++++++++++++++++ .../{ => Serialization}/LengthPrefixed.hs | 63 +-- src/Network/GRPC/Spec/Serialization/Status.hs | 50 ++ .../GRPC/Spec/Serialization/Timeout.hs | 76 +++ .../GRPC/Spec/Serialization/TraceContext.hs | 128 +++++ src/Network/GRPC/Spec/Status.hs | 41 -- src/Network/GRPC/Spec/Timeout.hs | 72 --- src/Network/GRPC/Spec/TraceContext.hs | 119 ---- test-grapesy/Test/Driver/ClientServer.hs | 5 +- test-grapesy/Test/Prop/Serialization.hs | 1 + test-grapesy/Test/Sanity/Interop.hs | 37 +- .../Test/Sanity/StreamingType/NonStreaming.hs | 2 +- 48 files changed, 1740 insertions(+), 1587 deletions(-) create mode 100644 src/Network/GRPC/Spec/MessageMeta.hs create mode 100644 src/Network/GRPC/Spec/Serialization.hs rename src/Network/GRPC/Spec/{ => Serialization}/Base64.hs (98%) create mode 100644 src/Network/GRPC/Spec/Serialization/CustomMetadata.hs create mode 100644 src/Network/GRPC/Spec/Serialization/Headers/Common.hs create mode 100644 src/Network/GRPC/Spec/Serialization/Headers/PseudoHeaders.hs create mode 100644 src/Network/GRPC/Spec/Serialization/Headers/Request.hs create mode 100644 src/Network/GRPC/Spec/Serialization/Headers/Response.hs rename src/Network/GRPC/Spec/{ => Serialization}/LengthPrefixed.hs (76%) create mode 100644 src/Network/GRPC/Spec/Serialization/Status.hs create mode 100644 src/Network/GRPC/Spec/Serialization/Timeout.hs create mode 100644 src/Network/GRPC/Spec/Serialization/TraceContext.hs diff --git a/grapesy.cabal b/grapesy.cabal index 8600a8d3..2ba2b4bf 100644 --- a/grapesy.cabal +++ b/grapesy.cabal @@ -119,6 +119,7 @@ library Network.GRPC.Server.StreamType Network.GRPC.Server.StreamType.Binary Network.GRPC.Spec + Network.GRPC.Spec.Serialization other-modules: Network.GRPC.Client.Call Network.GRPC.Client.Connection @@ -132,7 +133,6 @@ library Network.GRPC.Server.RequestHandler Network.GRPC.Server.RequestHandler.API Network.GRPC.Server.Session - Network.GRPC.Spec.Base64 Network.GRPC.Spec.Call Network.GRPC.Spec.Compression Network.GRPC.Spec.CustomMetadata.Map @@ -144,7 +144,7 @@ library Network.GRPC.Spec.Headers.PseudoHeaders Network.GRPC.Spec.Headers.Request Network.GRPC.Spec.Headers.Response - Network.GRPC.Spec.LengthPrefixed + Network.GRPC.Spec.MessageMeta Network.GRPC.Spec.OrcaLoadReport Network.GRPC.Spec.PercentEncoding Network.GRPC.Spec.RPC @@ -153,6 +153,16 @@ library Network.GRPC.Spec.RPC.Raw Network.GRPC.Spec.RPC.StreamType Network.GRPC.Spec.RPC.Unknown + Network.GRPC.Spec.Serialization.Base64 + Network.GRPC.Spec.Serialization.CustomMetadata + Network.GRPC.Spec.Serialization.Headers.Common + Network.GRPC.Spec.Serialization.Headers.PseudoHeaders + Network.GRPC.Spec.Serialization.Headers.Request + Network.GRPC.Spec.Serialization.Headers.Response + Network.GRPC.Spec.Serialization.LengthPrefixed + Network.GRPC.Spec.Serialization.Status + Network.GRPC.Spec.Serialization.Timeout + Network.GRPC.Spec.Serialization.TraceContext Network.GRPC.Spec.Status Network.GRPC.Spec.Timeout Network.GRPC.Spec.TraceContext diff --git a/interop/Interop/Client/Common.hs b/interop/Interop/Client/Common.hs index a6b4d856..94af891f 100644 --- a/interop/Interop/Client/Common.hs +++ b/interop/Interop/Client/Common.hs @@ -20,7 +20,6 @@ import Network.GRPC.Client import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Common.Protobuf -import Network.GRPC.Spec import Interop.Util.Exceptions import Interop.Util.Messages @@ -110,13 +109,13 @@ verifyStreamingOutputs :: forall rpc. HasCallStack => Call rpc -> (ProperTrailers' -> IO ()) -- ^ Verify trailers - -> [(InboundEnvelope, Output rpc) -> IO ()] -- ^ Verifier per expected output + -> [(InboundMeta, Output rpc) -> IO ()] -- ^ Verifier per expected output -> IO () verifyStreamingOutputs call verifyTrailers = go where - go :: [(InboundEnvelope, Output rpc) -> IO ()] -> IO () + go :: [(InboundMeta, Output rpc) -> IO ()] -> IO () go verifiers = do - mResp <- recvOutputWithEnvelope call + mResp <- recvOutputWithMeta call case (mResp, verifiers) of (NoMoreElems trailers, []) -> verifyTrailers trailers (StreamElem{}, []) -> assertFailure "Too many outputs" diff --git a/interop/Interop/Client/TestCase/ClientCompressedStreaming.hs b/interop/Interop/Client/TestCase/ClientCompressedStreaming.hs index fec43731..9a10ff89 100644 --- a/interop/Interop/Client/TestCase/ClientCompressedStreaming.hs +++ b/interop/Interop/Client/TestCase/ClientCompressedStreaming.hs @@ -3,7 +3,6 @@ module Interop.Client.TestCase.ClientCompressedStreaming (runTest) where import Network.GRPC.Client import Network.GRPC.Common import Network.GRPC.Common.Protobuf -import Network.GRPC.Spec import Interop.Client.Common import Interop.Client.Connect @@ -20,21 +19,21 @@ runTest cmdline = checkServerSupportsCompressedRequest conn withRPC conn def (Proxy @StreamingInputCall) $ \call -> do - sendInputWithEnvelope call $ StreamElem compressed - sendInputWithEnvelope call $ FinalElem uncompressed NoMetadata + sendInputWithMeta call $ StreamElem compressed + sendInputWithMeta call $ FinalElem uncompressed NoMetadata (resp, _metadata) <- recvFinalOutput call assertEqual 73086 $ resp ^. #aggregatedPayloadSize where -- Expect compressed, and /is/ compressed - compressed :: (OutboundEnvelope, Proto StreamingInputCallRequest) + compressed :: (OutboundMeta, Proto StreamingInputCallRequest) compressed = ( def { outboundEnableCompression = True } , mkStreamingInputCallRequest True 27182 ) -- Expect uncompressed, and /is/ uncompressed - uncompressed :: (OutboundEnvelope, Proto StreamingInputCallRequest) + uncompressed :: (OutboundMeta, Proto StreamingInputCallRequest) uncompressed = ( def { outboundEnableCompression = False } , mkStreamingInputCallRequest False 45904 @@ -51,11 +50,11 @@ runTest cmdline = checkServerSupportsCompressedRequest :: Connection -> IO () checkServerSupportsCompressedRequest conn = withRPC conn def (Proxy @StreamingInputCall) $ \call -> do - sendInputWithEnvelope call $ FinalElem featureProbe NoMetadata + sendInputWithMeta call $ FinalElem featureProbe NoMetadata expectInvalidArgument $ recvFinalOutput call where -- Expect compressed, but is /not/ actually compressed - featureProbe :: (OutboundEnvelope, Proto StreamingInputCallRequest) + featureProbe :: (OutboundMeta, Proto StreamingInputCallRequest) featureProbe = ( def { outboundEnableCompression = False } , mkStreamingInputCallRequest True 27182 diff --git a/interop/Interop/Client/TestCase/ClientCompressedUnary.hs b/interop/Interop/Client/TestCase/ClientCompressedUnary.hs index 26b394bb..d85f711d 100644 --- a/interop/Interop/Client/TestCase/ClientCompressedUnary.hs +++ b/interop/Interop/Client/TestCase/ClientCompressedUnary.hs @@ -4,7 +4,7 @@ module Interop.Client.TestCase.ClientCompressedUnary (runTest) where import Network.GRPC.Client import Network.GRPC.Common -import Network.GRPC.Spec +import Network.GRPC.Common.Protobuf import Interop.Client.Common import Interop.Client.Connect @@ -21,25 +21,25 @@ runTest cmdline = -- 2. Call UnaryCall with the compressed message withRPC conn def (Proxy @UnaryCall) $ \call -> do - sendInputWithEnvelope call $ FinalElem compressed NoMetadata + sendInputWithMeta call $ FinalElem compressed NoMetadata (resp, _metadata) <- recvFinalOutput call verifySimpleResponse resp -- 3. Call UnaryCall with the uncompressed message withRPC conn def (Proxy @UnaryCall) $ \call -> do - sendInputWithEnvelope call $ FinalElem uncompressed NoMetadata + sendInputWithMeta call $ FinalElem uncompressed NoMetadata (resp, _metadata) <- recvFinalOutput call verifySimpleResponse resp where -- Expect compressed, and /is/ compressed - compressed :: (OutboundEnvelope, Proto SimpleRequest) + compressed :: (OutboundMeta, Proto SimpleRequest) compressed = ( def { outboundEnableCompression = True } , mkSimpleRequest True ) -- Expect uncompressed, and /is/ uncompressed - uncompressed :: (OutboundEnvelope, Proto SimpleRequest) + uncompressed :: (OutboundMeta, Proto SimpleRequest) uncompressed = ( def { outboundEnableCompression = False } , mkSimpleRequest False @@ -51,11 +51,11 @@ runTest cmdline = checkServerSupportsCompressedRequest :: Connection -> IO () checkServerSupportsCompressedRequest conn = withRPC conn def (Proxy @UnaryCall) $ \call -> do - sendInputWithEnvelope call $ FinalElem featureProbe NoMetadata + sendInputWithMeta call $ FinalElem featureProbe NoMetadata expectInvalidArgument $ recvFinalOutput call where -- Expect compressed, but is /not/ actually compressed - featureProbe :: (OutboundEnvelope, Proto SimpleRequest) + featureProbe :: (OutboundMeta, Proto SimpleRequest) featureProbe = ( def { outboundEnableCompression = False } , mkSimpleRequest True diff --git a/interop/Interop/Client/TestCase/EmptyUnary.hs b/interop/Interop/Client/TestCase/EmptyUnary.hs index 724f48b1..6a8aa7ce 100644 --- a/interop/Interop/Client/TestCase/EmptyUnary.hs +++ b/interop/Interop/Client/TestCase/EmptyUnary.hs @@ -6,7 +6,6 @@ import Network.GRPC.Client import Network.GRPC.Common import Network.GRPC.Common.Protobuf import Network.GRPC.Common.StreamElem qualified as StreamElem -import Network.GRPC.Spec import Interop.Client.Connect import Interop.Cmdline @@ -20,8 +19,8 @@ runTest cmdline = withConnection def (testServer cmdline) $ \conn -> withRPC conn def (Proxy @EmptyCall) $ \call -> do sendFinalInput call empty - streamElem :: StreamElem ProperTrailers' (InboundEnvelope, Proto Empty) - <- recvOutputWithEnvelope call + streamElem :: StreamElem ProperTrailers' (InboundMeta, Proto Empty) + <- recvOutputWithMeta call -- The test description asks us to also verify the size of the /outgoing/ -- message if possible. This information is not readily available in @@ -29,9 +28,9 @@ runTest cmdline = -- interop client against the @grapesy@ interop server. case StreamElem.value streamElem of - Just (envelope, resp) -> do + Just (meta, resp) -> do assertEqual empty $ resp - assertEqual 0 $ inboundUncompressedSize envelope + assertEqual 0 $ inboundUncompressedSize meta Nothing -> assertFailure "Expected response" where diff --git a/interop/Interop/Client/TestCase/ServerCompressedStreaming.hs b/interop/Interop/Client/TestCase/ServerCompressedStreaming.hs index 6e0f5512..29c9f7ea 100644 --- a/interop/Interop/Client/TestCase/ServerCompressedStreaming.hs +++ b/interop/Interop/Client/TestCase/ServerCompressedStreaming.hs @@ -4,7 +4,6 @@ import Data.Maybe (isJust) import Network.GRPC.Client import Network.GRPC.Common -import Network.GRPC.Spec import Interop.Client.Common import Interop.Client.Connect @@ -20,8 +19,8 @@ runTest cmdline = do withRPC conn def (Proxy @StreamingOutputCall) $ \call -> do sendFinalInput call $ mkStreamingOutputCallRequest expected Nothing verifyStreamingOutputs call (\_ -> return ()) $ [ - \(envelope, resp) -> do - assertEqual compressed $ isJust (inboundCompressedSize envelope) + \(meta, resp) -> do + assertEqual compressed $ isJust (inboundCompressedSize meta) verifyStreamingOutputCallResponse sz resp | (compressed, sz) <- expected ] diff --git a/interop/Interop/Client/TestCase/ServerCompressedUnary.hs b/interop/Interop/Client/TestCase/ServerCompressedUnary.hs index 0fee743e..c901dc24 100644 --- a/interop/Interop/Client/TestCase/ServerCompressedUnary.hs +++ b/interop/Interop/Client/TestCase/ServerCompressedUnary.hs @@ -6,7 +6,6 @@ import Network.GRPC.Client import Network.GRPC.Common import Network.GRPC.Common.Protobuf import Network.GRPC.Common.StreamElem qualified as StreamElem -import Network.GRPC.Spec import Interop.Client.Common import Interop.Client.Connect @@ -21,18 +20,18 @@ runTest :: Cmdline -> IO () runTest cmdline = withConnection def (testServer cmdline) $ \conn -> do withRPC conn def (Proxy @UnaryCall) $ \call -> do - sendInputWithEnvelope call $ FinalElem (request True) NoMetadata - resp <- recvOutputWithEnvelope call + sendInputWithMeta call $ FinalElem (request True) NoMetadata + resp <- recvOutputWithMeta call verifyResponse True (StreamElem.value resp) withRPC conn def (Proxy @UnaryCall) $ \call -> do - sendInputWithEnvelope call $ FinalElem (request False) NoMetadata - resp <- recvOutputWithEnvelope call + sendInputWithMeta call $ FinalElem (request False) NoMetadata + resp <- recvOutputWithMeta call verifyResponse False (StreamElem.value resp) where -- To keep the test simple, we disable /outbound/ compression -- (this test is testing /inbound/ compression) - request :: Bool -> (OutboundEnvelope, Proto SimpleRequest) + request :: Bool -> (OutboundMeta, Proto SimpleRequest) request expectCompressed = ( def { outboundEnableCompression = False } , mkSimpleRequest False @@ -41,10 +40,10 @@ runTest cmdline = verifyResponse :: HasCallStack - => Bool -> Maybe (InboundEnvelope, Proto SimpleResponse) -> IO () + => Bool -> Maybe (InboundMeta, Proto SimpleResponse) -> IO () verifyResponse _expectCompressed Nothing = assertFailure "Expected response" -verifyResponse expectCompressed (Just (envelope, resp)) = do - assertEqual expectCompressed $ isJust (inboundCompressedSize envelope) +verifyResponse expectCompressed (Just (meta, resp)) = do + assertEqual expectCompressed $ isJust (inboundCompressedSize meta) verifySimpleResponse resp diff --git a/interop/Interop/Client/TestCase/ServerStreaming.hs b/interop/Interop/Client/TestCase/ServerStreaming.hs index b37c6c3d..53321e31 100644 --- a/interop/Interop/Client/TestCase/ServerStreaming.hs +++ b/interop/Interop/Client/TestCase/ServerStreaming.hs @@ -16,7 +16,7 @@ runTest cmdline = withRPC conn def (Proxy @StreamingOutputCall) $ \call -> do sendFinalInput call $ mkStreamingOutputCallRequest expected Nothing verifyStreamingOutputs call (\_ -> return ()) $ [ - \(_envelope, resp) -> verifyStreamingOutputCallResponse sz resp + \(_meta, resp) -> verifyStreamingOutputCallResponse sz resp | (_compressed, sz) <- expected ] where diff --git a/interop/Interop/Client/TestCase/SpecialStatusMessage.hs b/interop/Interop/Client/TestCase/SpecialStatusMessage.hs index c9e03824..e54f76c4 100644 --- a/interop/Interop/Client/TestCase/SpecialStatusMessage.hs +++ b/interop/Interop/Client/TestCase/SpecialStatusMessage.hs @@ -7,7 +7,7 @@ import Data.Text (Text) import Network.GRPC.Client import Network.GRPC.Common import Network.GRPC.Common.Protobuf -import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization (buildGrpcStatus) import Interop.Client.Connect import Interop.Cmdline @@ -30,7 +30,7 @@ runTest cmdline = do echoStatus :: Proto EchoStatus echoStatus = defMessage - & #code .~ fromIntegral (fromGrpcStatus $ GrpcError GrpcUnknown) + & #code .~ fromIntegral (buildGrpcStatus $ GrpcError GrpcUnknown) & #message .~ statusMessage statusMessage :: Text diff --git a/interop/Interop/Client/TestCase/StatusCodeAndMessage.hs b/interop/Interop/Client/TestCase/StatusCodeAndMessage.hs index ae9e05e3..77e3daac 100644 --- a/interop/Interop/Client/TestCase/StatusCodeAndMessage.hs +++ b/interop/Interop/Client/TestCase/StatusCodeAndMessage.hs @@ -7,7 +7,7 @@ import Data.Text (Text) import Network.GRPC.Client import Network.GRPC.Common import Network.GRPC.Common.Protobuf -import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization (buildGrpcStatus) import Interop.Client.Connect import Interop.Cmdline @@ -39,7 +39,7 @@ runTest cmdline = do echoStatus :: Proto EchoStatus echoStatus = defMessage - & #code .~ fromIntegral (fromGrpcStatus $ GrpcError GrpcUnknown) + & #code .~ fromIntegral (buildGrpcStatus $ GrpcError GrpcUnknown) & #message .~ statusMessage statusMessage :: Text diff --git a/interop/Interop/Server/Common.hs b/interop/Interop/Server/Common.hs index aa89027e..9aa73127 100644 --- a/interop/Interop/Server/Common.hs +++ b/interop/Interop/Server/Common.hs @@ -12,7 +12,7 @@ import Control.Exception import Network.GRPC.Common import Network.GRPC.Common.Protobuf import Network.GRPC.Server -import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization (parseGrpcStatus) import Interop.Util.Exceptions @@ -54,7 +54,7 @@ constructResponseMetadata call = do -- See echoStatus :: Proto EchoStatus -> IO () echoStatus status = - case toGrpcStatus code of + case parseGrpcStatus code of Just GrpcOk -> return () Just (GrpcError err) -> @@ -69,9 +69,9 @@ echoStatus status = code :: Word code = fromIntegral $ status ^. #code -checkInboundCompression :: Bool -> InboundEnvelope -> IO () -checkInboundCompression expectCompressed envelope = - case (expectCompressed, inboundCompressedSize envelope) of +checkInboundCompression :: Bool -> InboundMeta -> IO () +checkInboundCompression expectCompressed meta = + case (expectCompressed, inboundCompressedSize meta) of (True, Just _) -> return () (False, Nothing) -> diff --git a/interop/Interop/Server/TestService/StreamingInputCall.hs b/interop/Interop/Server/TestService/StreamingInputCall.hs index d501774a..30997469 100644 --- a/interop/Interop/Server/TestService/StreamingInputCall.hs +++ b/interop/Interop/Server/TestService/StreamingInputCall.hs @@ -7,7 +7,6 @@ import Data.ByteString qualified as BS.Strict import Network.GRPC.Common import Network.GRPC.Common.Protobuf import Network.GRPC.Server -import Network.GRPC.Spec import Interop.Server.Common @@ -29,15 +28,15 @@ handle call = do -- Returns the sum of all request payload bodies received. loop :: Int -> IO Int loop !acc = do - streamElem <- recvInputWithEnvelope call + streamElem <- recvInputWithMeta call case streamElem of StreamElem r -> handleRequest r >>= \sz -> loop (acc + sz) FinalElem r _ -> handleRequest r >>= \sz -> return $ acc + sz NoMoreElems _ -> return acc - handleRequest :: (InboundEnvelope, Proto StreamingInputCallRequest) -> IO Int - handleRequest (envelope, request) = do - checkInboundCompression expectCompressed envelope + handleRequest :: (InboundMeta, Proto StreamingInputCallRequest) -> IO Int + handleRequest (meta, request) = do + checkInboundCompression expectCompressed meta return $ BS.Strict.length (request ^. #payload ^. #body) where expectCompressed :: Bool diff --git a/interop/Interop/Server/TestService/StreamingOutputCall.hs b/interop/Interop/Server/TestService/StreamingOutputCall.hs index b468d9f2..f624b838 100644 --- a/interop/Interop/Server/TestService/StreamingOutputCall.hs +++ b/interop/Interop/Server/TestService/StreamingOutputCall.hs @@ -9,7 +9,6 @@ import Control.Monad import Network.GRPC.Common import Network.GRPC.Common.Protobuf import Network.GRPC.Server -import Network.GRPC.Spec import Interop.Util.Messages @@ -43,12 +42,12 @@ handleRequest call request = payload <- payloadOfType (Proto COMPRESSABLE) size - let envelope :: OutboundEnvelope - envelope = def { outboundEnableCompression = shouldCompress } + let meta :: OutboundMeta + meta = def { outboundEnableCompression = shouldCompress } response :: Proto StreamingOutputCallResponse response = defMessage & #payload .~ payload - sendOutputWithEnvelope call $ StreamElem (envelope, response) + sendOutputWithMeta call $ StreamElem (meta, response) threadDelay intervalUs diff --git a/interop/Interop/Server/TestService/UnaryCall.hs b/interop/Interop/Server/TestService/UnaryCall.hs index 4e7aff1e..d3f9b3d9 100644 --- a/interop/Interop/Server/TestService/UnaryCall.hs +++ b/interop/Interop/Server/TestService/UnaryCall.hs @@ -6,7 +6,6 @@ import Network.GRPC.Common import Network.GRPC.Common.Protobuf import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Server -import Network.GRPC.Spec import Interop.Server.Common import Interop.Util.Messages @@ -25,8 +24,8 @@ handle call = do trailers <- constructResponseMetadata call -- Wait for the request - (inboundEnvelope, request) <- do - streamElem <- recvInputWithEnvelope call + (inboundMeta, request) <- do + streamElem <- recvInputWithMeta call case StreamElem.value streamElem of Nothing -> fail "Expected element" Just x -> return x @@ -35,13 +34,13 @@ handle call = do -- let expectCompressed :: Bool expectCompressed = request ^. #expectCompressed . #value - checkInboundCompression expectCompressed inboundEnvelope + checkInboundCompression expectCompressed inboundMeta -- Send response payload <- payloadOfType (request ^. #responseType) (request ^. #responseSize) - let outboundEnvelope :: OutboundEnvelope - outboundEnvelope = def { + let outboundMeta :: OutboundMeta + outboundMeta = def { outboundEnableCompression = request ^. #responseCompressed . #value } @@ -49,7 +48,7 @@ handle call = do response :: Proto SimpleResponse response = defMessage & #payload .~ payload - sendOutputWithEnvelope call $ StreamElem (outboundEnvelope, response) + sendOutputWithMeta call $ StreamElem (outboundMeta, response) -- Send status and trailers echoStatus (request ^. #responseStatus) diff --git a/src/Network/GRPC/Client.hs b/src/Network/GRPC/Client.hs index 01c31f9f..271733fc 100644 --- a/src/Network/GRPC/Client.hs +++ b/src/Network/GRPC/Client.hs @@ -57,10 +57,19 @@ module Network.GRPC.Client ( , recvAllOutputs -- ** Low-level\/specialized API - , sendInputWithEnvelope + , ResponseHeaders_(..) + , ResponseHeaders + , ResponseHeaders' + , ProperTrailers_(..) + , ProperTrailers + , ProperTrailers' + , TrailersOnly_(..) + , TrailersOnly + , TrailersOnly' , recvNextOutputElem - , recvOutputWithEnvelope , recvInitialResponse + , recvOutputWithMeta + , sendInputWithMeta -- * Communication patterns , rpc diff --git a/src/Network/GRPC/Client/Call.hs b/src/Network/GRPC/Client/Call.hs index 32ea5d6d..02fd3ef0 100644 --- a/src/Network/GRPC/Client/Call.hs +++ b/src/Network/GRPC/Client/Call.hs @@ -26,9 +26,9 @@ module Network.GRPC.Client.Call ( , recvAllOutputs -- ** Low-level\/specialized API - , sendInputWithEnvelope + , sendInputWithMeta , recvNextOutputElem - , recvOutputWithEnvelope + , recvOutputWithMeta , recvInitialResponse ) where @@ -289,19 +289,19 @@ sendInput :: => Call rpc -> StreamElem NoMetadata (Input rpc) -> m () -sendInput call = sendInputWithEnvelope call . fmap (def,) +sendInput call = sendInputWithMeta call . fmap (def,) -- | Generalization of 'sendInput', providing additional control -- --- See also 'Network.GRPC.Server.sendOutputWithEnvelope'. +-- See also 'Network.GRPC.Server.sendOutputWithMeta'. -- -- Most applications will never need to use this function. -sendInputWithEnvelope :: +sendInputWithMeta :: (HasCallStack, MonadIO m) => Call rpc - -> StreamElem NoMetadata (OutboundEnvelope, Input rpc) + -> StreamElem NoMetadata (OutboundMeta, Input rpc) -> m () -sendInputWithEnvelope Call{callChannel} msg = liftIO $ do +sendInputWithMeta Call{callChannel} msg = liftIO $ do Session.send callChannel msg -- This should be called before exiting the scope of 'withRPC'. @@ -320,7 +320,7 @@ recvOutput :: forall m rpc. => Call rpc -> m (StreamElem (ResponseTrailingMetadata rpc) (Output rpc)) recvOutput call@Call{} = liftIO $ do - streamElem <- recvOutputWithEnvelope call + streamElem <- recvOutputWithMeta call bitraverse (responseTrailingMetadata call) (return . snd) streamElem -- | Receive an output from the peer, if one exists @@ -343,12 +343,12 @@ recvNextOutputElem = -- -- Most applications will never need to use this function. -- --- See also 'Network.GRPC.Server.recvInputWithEnvelope'. -recvOutputWithEnvelope :: forall rpc m. +-- See also 'Network.GRPC.Server.recvInputWithMeta'. +recvOutputWithMeta :: forall rpc m. (MonadIO m, HasCallStack) => Call rpc - -> m (StreamElem ProperTrailers' (InboundEnvelope, Output rpc)) -recvOutputWithEnvelope = recvBoth + -> m (StreamElem ProperTrailers' (InboundMeta, Output rpc)) +recvOutputWithMeta = recvBoth -- | The initial metadata that was included in the response headers -- @@ -540,7 +540,7 @@ recvAllOutputs call processOutput = loop recvBoth :: forall m rpc. (HasCallStack, MonadIO m) => Call rpc - -> m (StreamElem ProperTrailers' (InboundEnvelope, Output rpc)) + -> m (StreamElem ProperTrailers' (InboundMeta, Output rpc)) recvBoth Call{callChannel} = liftIO $ flatten <$> Session.recvBoth callChannel where @@ -548,8 +548,8 @@ recvBoth Call{callChannel} = liftIO $ flatten :: Either TrailersOnly' - (StreamElem ProperTrailers' (InboundEnvelope, Output rpc)) - -> StreamElem ProperTrailers' (InboundEnvelope, Output rpc) + (StreamElem ProperTrailers' (InboundMeta, Output rpc)) + -> StreamElem ProperTrailers' (InboundMeta, Output rpc) flatten (Left trailersOnly) = NoMoreElems $ fst $ trailersOnlyToProperTrailers trailersOnly flatten (Right streamElem) = @@ -558,15 +558,15 @@ recvBoth Call{callChannel} = liftIO $ recvEither :: forall m rpc. (HasCallStack, MonadIO m) => Call rpc - -> m (Either ProperTrailers' (InboundEnvelope, Output rpc)) + -> m (Either ProperTrailers' (InboundMeta, Output rpc)) recvEither Call{callChannel} = liftIO $ flatten <$> Session.recvEither callChannel where flatten :: Either TrailersOnly' - (Either ProperTrailers' (InboundEnvelope, Output rpc)) - -> Either ProperTrailers' (InboundEnvelope, Output rpc) + (Either ProperTrailers' (InboundMeta, Output rpc)) + -> Either ProperTrailers' (InboundMeta, Output rpc) flatten (Left trailersOnly) = Left $ fst $ trailersOnlyToProperTrailers trailersOnly flatten (Right (Left properTrailers)) = diff --git a/src/Network/GRPC/Client/Session.hs b/src/Network/GRPC/Client/Session.hs index 8ba7e24b..707fac44 100644 --- a/src/Network/GRPC/Client/Session.hs +++ b/src/Network/GRPC/Client/Session.hs @@ -20,6 +20,7 @@ import Network.GRPC.Client.Connection qualified as Connection import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization import Network.GRPC.Util.Session import Network.GRPC.Util.HKD qualified as HKD @@ -45,7 +46,7 @@ instance IsRPC rpc => DataFlow (ClientInbound rpc) where } deriving (Show) - type Message (ClientInbound rpc) = (InboundEnvelope, Output rpc) + type Message (ClientInbound rpc) = (InboundMeta, Output rpc) type Trailers (ClientInbound rpc) = ProperTrailers' type NoMessages (ClientInbound rpc) = TrailersOnly' @@ -56,7 +57,7 @@ instance IsRPC rpc => DataFlow (ClientOutbound rpc) where } deriving (Show) - type Message (ClientOutbound rpc) = (OutboundEnvelope, Input rpc) + type Message (ClientOutbound rpc) = (OutboundMeta, Input rpc) type Trailers (ClientOutbound rpc) = NoMetadata -- gRPC does not support a Trailers-Only case for requests diff --git a/src/Network/GRPC/Common.hs b/src/Network/GRPC/Common.hs index 6402d385..380a79a5 100644 --- a/src/Network/GRPC/Common.hs +++ b/src/Network/GRPC/Common.hs @@ -49,9 +49,14 @@ module Network.GRPC.Common ( , defaultSecurePort , defaultHTTP2Settings + -- * Message metadata + , OutboundMeta(..) + , InboundMeta(..) + -- * Exceptions - -- ** gRPC exceptions + -- ** gRPC status and exceptions + , GrpcStatus(..) , GrpcError(..) , GrpcException(..) , throwGrpcError diff --git a/src/Network/GRPC/Server.hs b/src/Network/GRPC/Server.hs index 80fe0aad..585ef694 100644 --- a/src/Network/GRPC/Server.hs +++ b/src/Network/GRPC/Server.hs @@ -5,6 +5,7 @@ module Network.GRPC.Server ( -- ** Configuration , ServerParams(..) , RequestHandler + , ContentType(..) -- * Handlers , Call -- opaque @@ -37,8 +38,8 @@ module Network.GRPC.Server ( -- ** Low-level\/specialized API , initiateResponse , sendTrailersOnly - , recvInputWithEnvelope - , sendOutputWithEnvelope + , recvInputWithMeta + , sendOutputWithMeta , getRequestHeaders -- * Exceptions @@ -57,6 +58,7 @@ import Network.GRPC.Server.HandlerMap (HandlerMap) import Network.GRPC.Server.HandlerMap qualified as HandlerMap import Network.GRPC.Server.RequestHandler import Network.GRPC.Server.Session (CallSetupFailure(..)) +import Network.GRPC.Spec import Network.GRPC.Util.HTTP2.Stream (ClientDisconnected(..)) {------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/Server/Call.hs b/src/Network/GRPC/Server/Call.hs index 7e3449c0..48cbd813 100644 --- a/src/Network/GRPC/Server/Call.hs +++ b/src/Network/GRPC/Server/Call.hs @@ -28,8 +28,8 @@ module Network.GRPC.Server.Call ( , initiateResponse , sendTrailersOnly , recvNextInputElem - , recvInputWithEnvelope - , sendOutputWithEnvelope + , recvInputWithMeta + , sendOutputWithMeta , getRequestHeaders -- ** Internal API @@ -59,6 +59,7 @@ import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Server.Context import Network.GRPC.Server.Session import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization import Network.GRPC.Util.HTTP2 (fromHeaderTable) import Network.GRPC.Util.Session qualified as Session import Network.GRPC.Util.Session.Server qualified as Server @@ -312,7 +313,7 @@ serverExceptionToClientError params err -- We do not return trailers, since gRPC does not support sending trailers from -- the client to the server (only from the server to the client). recvInput :: HasCallStack => Call rpc -> IO (StreamElem NoMetadata (Input rpc)) -recvInput = fmap (fmap snd) . recvInputWithEnvelope +recvInput = fmap (fmap snd) . recvInputWithMeta -- | Receive RPC input from the client, if one exists -- @@ -355,11 +356,11 @@ recvNextInputElem = fmap (fmap snd) . recvEither -- as its compressed and uncompressed size. -- -- Most applications will never need to use this function. -recvInputWithEnvelope :: forall rpc. +recvInputWithMeta :: forall rpc. HasCallStack => Call rpc - -> IO (StreamElem NoMetadata (InboundEnvelope, Input rpc)) -recvInputWithEnvelope = recvBoth + -> IO (StreamElem NoMetadata (InboundMeta, Input rpc)) +recvInputWithMeta = recvBoth -- | Send RPC output to the client -- @@ -373,7 +374,7 @@ sendOutput :: HasCallStack => Call rpc -> StreamElem (ResponseTrailingMetadata rpc) (Output rpc) -> IO () -sendOutput call = sendOutputWithEnvelope call . fmap (def,) +sendOutput call = sendOutputWithMeta call . fmap (def,) -- | Generalization of 'sendOutput' with additional control -- @@ -381,12 +382,12 @@ sendOutput call = sendOutputWithEnvelope call . fmap (def,) -- messages. -- -- Most applications will never need to use this function. -sendOutputWithEnvelope :: forall rpc. +sendOutputWithMeta :: forall rpc. HasCallStack => Call rpc - -> StreamElem (ResponseTrailingMetadata rpc) (OutboundEnvelope, Output rpc) + -> StreamElem (ResponseTrailingMetadata rpc) (OutboundMeta, Output rpc) -> IO () -sendOutputWithEnvelope call@Call{callChannel} msg = do +sendOutputWithMeta call@Call{callChannel} msg = do _updated <- initiateResponse call msg' <- bitraverse mkTrailers return msg Session.send callChannel msg' @@ -649,7 +650,7 @@ sendProperTrailers Call{callContext, callResponseKickoff, callChannel} recvBoth :: forall rpc. HasCallStack => Call rpc - -> IO (StreamElem NoMetadata (InboundEnvelope, Input rpc)) + -> IO (StreamElem NoMetadata (InboundMeta, Input rpc)) recvBoth Call{callChannel} = flatten <$> Session.recvBoth callChannel where @@ -657,15 +658,15 @@ recvBoth Call{callChannel} = flatten :: Either Void - (StreamElem NoMetadata (InboundEnvelope, Input rpc)) - -> StreamElem NoMetadata (InboundEnvelope, Input rpc) + (StreamElem NoMetadata (InboundMeta, Input rpc)) + -> StreamElem NoMetadata (InboundMeta, Input rpc) flatten (Left impossible) = absurd impossible flatten (Right streamElem) = streamElem recvEither :: forall rpc. HasCallStack => Call rpc - -> IO (NextElem (InboundEnvelope, Input rpc)) + -> IO (NextElem (InboundMeta, Input rpc)) recvEither Call{callChannel} = flatten <$> Session.recvEither callChannel where @@ -674,8 +675,8 @@ recvEither Call{callChannel} = flatten :: Either Void - (Either NoMetadata (InboundEnvelope, Input rpc)) - -> NextElem (InboundEnvelope, Input rpc) + (Either NoMetadata (InboundMeta, Input rpc)) + -> NextElem (InboundMeta, Input rpc) flatten (Left impossible) = absurd impossible flatten (Right (Left NoMetadata)) = NoNextElem flatten (Right (Right msg)) = NextElem msg diff --git a/src/Network/GRPC/Server/RequestHandler.hs b/src/Network/GRPC/Server/RequestHandler.hs index 2019b5e0..17131c0e 100644 --- a/src/Network/GRPC/Server/RequestHandler.hs +++ b/src/Network/GRPC/Server/RequestHandler.hs @@ -34,6 +34,7 @@ import Network.GRPC.Server.HandlerMap qualified as HandlerMap import Network.GRPC.Server.RequestHandler.API import Network.GRPC.Server.Session (CallSetupFailure(..)) import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization import Network.GRPC.Util.GHC import Network.GRPC.Util.HKD qualified as HKD import Network.GRPC.Util.Session.Server diff --git a/src/Network/GRPC/Server/Session.hs b/src/Network/GRPC/Server/Session.hs index 984574b7..b1d386d5 100644 --- a/src/Network/GRPC/Server/Session.hs +++ b/src/Network/GRPC/Server/Session.hs @@ -13,6 +13,7 @@ import Data.Void import Network.GRPC.Server.Context import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization import Network.GRPC.Util.Session {------------------------------------------------------------------------------- @@ -37,7 +38,7 @@ instance IsRPC rpc => DataFlow (ServerInbound rpc) where } deriving (Show) - type Message (ServerInbound rpc) = (InboundEnvelope, Input rpc) + type Message (ServerInbound rpc) = (InboundMeta, Input rpc) type Trailers (ServerInbound rpc) = NoMetadata -- gRPC does not support request trailers @@ -50,7 +51,7 @@ instance IsRPC rpc => DataFlow (ServerOutbound rpc) where } deriving (Show) - type Message (ServerOutbound rpc) = (OutboundEnvelope, Output rpc) + type Message (ServerOutbound rpc) = (OutboundMeta, Output rpc) type Trailers (ServerOutbound rpc) = ProperTrailers type NoMessages (ServerOutbound rpc) = TrailersOnly diff --git a/src/Network/GRPC/Server/StreamType.hs b/src/Network/GRPC/Server/StreamType.hs index f16c6239..fa4eff7d 100644 --- a/src/Network/GRPC/Server/StreamType.hs +++ b/src/Network/GRPC/Server/StreamType.hs @@ -3,7 +3,7 @@ -- | Server handlers module Network.GRPC.Server.StreamType ( -- * Handler type - ServerHandler' -- opaque + ServerHandler'(..) , ServerHandler -- * Construct server handler , mkNonStreaming diff --git a/src/Network/GRPC/Spec.hs b/src/Network/GRPC/Spec.hs index b860f5d5..0975c0b2 100644 --- a/src/Network/GRPC/Spec.hs +++ b/src/Network/GRPC/Spec.hs @@ -27,15 +27,6 @@ module Network.GRPC.Spec ( , RawRpc -- *** Unknown , UnknownRpc - -- ** Messages - -- *** Parsing - , InboundEnvelope(..) - , parseInput - , parseOutput - -- *** Construction - , OutboundEnvelope(..) - , buildInput - , buildOutput -- * Streaming types , StreamingType(..) , SStreamingType(..) @@ -70,6 +61,9 @@ module Network.GRPC.Spec ( , snappy #endif , allSupportedCompression + -- * Message metadata + , OutboundMeta(..) + , InboundMeta(..) -- * Requests , RequestHeaders_(..) , RequestHeaders @@ -86,33 +80,17 @@ module Network.GRPC.Spec ( , Scheme(..) , Method(..) , rpcPath - -- ** Serialization - , RawResourceHeaders(..) - , InvalidResourceHeaders(..) - , buildResourceHeaders - , parseResourceHeaders - -- ** Headers - , buildRequestHeaders - , parseRequestHeaders - , parseRequestHeaders' -- ** Timeouts , Timeout(..) , TimeoutValue(..) , TimeoutUnit(..) , timeoutToMicro , isValidTimeoutValue - -- ** Serialization - , buildTimeout - , parseTimeout -- * Responses -- ** Headers , ResponseHeaders_(..) , ResponseHeaders , ResponseHeaders' - , buildResponseHeaders - , parseResponseHeaders - , parseResponseHeaders' - , classifyServerResponse -- ** Trailers , ProperTrailers_(..) , ProperTrailers @@ -126,22 +104,11 @@ module Network.GRPC.Spec ( , GrpcNormalTermination(..) , grpcExceptionToTrailers , grpcClassifyTermination - -- ** Serialization - , parseProperTrailers - , parseProperTrailers' - , parseTrailersOnly - , parseTrailersOnly' - , parsePushback - , buildProperTrailers - , buildTrailersOnly - , buildPushback , properTrailersToTrailersOnly , trailersOnlyToProperTrailers -- * Status , GrpcStatus(..) , GrpcError(..) - , fromGrpcStatus - , toGrpcStatus -- ** Exceptions , GrpcException(..) , throwGrpcError @@ -152,6 +119,7 @@ module Network.GRPC.Spec ( , safeCustomMetadata , HeaderName(BinaryHeader, AsciiHeader) , safeHeaderName + , isValidAsciiValue , NoMetadata(..) , UnexpectedMetadata(..) -- ** Handling of duplicate metadata entries @@ -159,11 +127,6 @@ module Network.GRPC.Spec ( , customMetadataMapFromList , customMetadataMapToList , customMetadataMapInsert - -- ** Serialization - , buildBinaryValue - , parseBinaryValue - , parseCustomMetadata - , buildCustomMetadata -- ** Typed , RequestMetadata , ResponseInitialMetadata @@ -175,10 +138,15 @@ module Network.GRPC.Spec ( , ParseMetadata(..) , StaticMetadata(..) , buildMetadataIO - -- * Common infrastructure to all headers + -- * Invalid headers , InvalidHeaders(..) , InvalidHeader(..) , prettyInvalidHeaders + , invalidHeader + , missingHeader + , unexpectedHeader + , throwInvalidHeader + -- * Common infrastructure to all headers , ContentType(..) , MessageType(..) -- * OpenTelemetry @@ -186,8 +154,6 @@ module Network.GRPC.Spec ( , TraceId(..) , SpanId(..) , TraceOptions(..) - , buildTraceContext - , parseTraceContext -- * ORCA , OrcaLoadReport ) where @@ -203,7 +169,7 @@ import Network.GRPC.Spec.Headers.Invalid import Network.GRPC.Spec.Headers.PseudoHeaders import Network.GRPC.Spec.Headers.Request import Network.GRPC.Spec.Headers.Response -import Network.GRPC.Spec.LengthPrefixed +import Network.GRPC.Spec.MessageMeta import Network.GRPC.Spec.OrcaLoadReport import Network.GRPC.Spec.RPC import Network.GRPC.Spec.RPC.JSON diff --git a/src/Network/GRPC/Spec/CustomMetadata/Raw.hs b/src/Network/GRPC/Spec/CustomMetadata/Raw.hs index 120c95e8..e8d3e1db 100644 --- a/src/Network/GRPC/Spec/CustomMetadata/Raw.hs +++ b/src/Network/GRPC/Spec/CustomMetadata/Raw.hs @@ -13,25 +13,13 @@ module Network.GRPC.Spec.CustomMetadata.Raw ( , safeCustomMetadata , HeaderName(BinaryHeader, AsciiHeader) , safeHeaderName - -- * Serialization - , buildHeaderName - , buildAsciiValue - , buildBinaryValue - , buildCustomMetadata - , parseHeaderName - , parseAsciiValue - , parseBinaryValue - , parseCustomMetadata + , isValidAsciiValue ) where import Control.DeepSeq (NFData) import Control.Monad -import Control.Monad.Except (MonadError(throwError)) import Data.ByteString qualified as BS.Strict import Data.ByteString qualified as Strict (ByteString) -import Data.CaseInsensitive (CI) -import Data.CaseInsensitive qualified as CI -import Data.List (intersperse) import Data.List qualified as List import Data.Maybe (fromMaybe) import Data.Set (Set) @@ -41,10 +29,7 @@ import Data.Word import GHC.Generics (Generic) import GHC.Show import GHC.Stack -import Network.HTTP.Types qualified as HTTP -import Network.GRPC.Spec.Base64 -import Network.GRPC.Spec.Headers.Invalid import Network.GRPC.Util.ByteString (strip, ascii) {------------------------------------------------------------------------------- @@ -247,99 +232,6 @@ instance Show HeaderName where show (UnsafeBinaryHeader name) = show name show (UnsafeAsciiHeader name) = show name -{------------------------------------------------------------------------------- - Serialization --------------------------------------------------------------------------------} - -buildHeaderName :: HeaderName -> CI Strict.ByteString -buildHeaderName name = - case name of - UnsafeBinaryHeader name' -> CI.mk name' - UnsafeAsciiHeader name' -> CI.mk name' - -parseHeaderName :: MonadError String m => CI Strict.ByteString -> m HeaderName -parseHeaderName name = - case safeHeaderName (CI.foldedCase name) of - Nothing -> throwError $ "Invalid header name: " ++ show name - Just name' -> return name' - -buildAsciiValue :: Strict.ByteString -> Strict.ByteString -buildAsciiValue = id - -parseAsciiValue :: - MonadError String m - => Strict.ByteString -> m Strict.ByteString -parseAsciiValue bs = do - unless (isValidAsciiValue bs) $ - throwError $ "Invalid ASCII header: " ++ show bs - return bs - -buildBinaryValue :: Strict.ByteString -> Strict.ByteString -buildBinaryValue = encodeBase64 - --- | Parse binary value --- --- The presence of duplicate headers makes this a bit subtle. Let's consider an --- example. Suppose we have two duplicate headers --- --- > foo-bin: YWJj -- encoding of "abc" --- > foo-bin: ZGVm -- encoding of "def" --- --- The spec says --- --- > Custom-Metadata header order is not guaranteed to be preserved except for --- > values with duplicate header names. Duplicate header names may have their --- > values joined with "," as the delimiter and be considered semantically --- > equivalent. --- --- In @grapesy@ we will do the decoding of both headers /prior/ to joining --- duplicate headers, and so the value we will reconstruct for @foo-bin@ is --- \"abc,def\". --- --- However, suppose we deal with a (non-compliant) peer which is unaware of --- binary headers and has applied the joining rule /without/ decoding: --- --- > foo-bin: YWJj,ZGVm --- --- The spec is a bit vague about this case, saying only: --- --- > Implementations must split Binary-Headers on "," before decoding the --- > Base64-encoded values. --- --- Here we assume that this case must be treated the same way as if the headers --- /had/ been decoded prior to joining. Therefore, we split the input on commas, --- decode each result separately, and join the results with commas again. -parseBinaryValue :: forall m. - MonadError String m - => Strict.ByteString -> m Strict.ByteString -parseBinaryValue bs = do - let chunks = BS.Strict.split (ascii ',') bs - decoded <- mapM decode chunks - return $ mconcat $ intersperse "," decoded - where - decode :: Strict.ByteString -> m Strict.ByteString - decode chunk = - case decodeBase64 chunk of - Left err -> throwError err - Right val -> return val - -buildCustomMetadata :: CustomMetadata -> HTTP.Header -buildCustomMetadata (CustomMetadata name value) = - case name of - UnsafeBinaryHeader _ -> (buildHeaderName name, buildBinaryValue value) - UnsafeAsciiHeader _ -> (buildHeaderName name, buildAsciiValue value) - -parseCustomMetadata :: - MonadError InvalidHeaders m - => HTTP.Header -> m CustomMetadata -parseCustomMetadata hdr@(name, value) = throwInvalidHeader hdr $ do - name' <- parseHeaderName name - value' <- case name' of - UnsafeAsciiHeader _ -> parseAsciiValue value - UnsafeBinaryHeader _ -> parseBinaryValue value - -- If parsing succeeds, that justifies the use of 'UnsafeCustomMetadata' - return $ UnsafeCustomMetadata name' value' - {------------------------------------------------------------------------------- Internal auxiliary -------------------------------------------------------------------------------} diff --git a/src/Network/GRPC/Spec/Headers/Common.hs b/src/Network/GRPC/Spec/Headers/Common.hs index ce29f385..54906f5e 100644 --- a/src/Network/GRPC/Spec/Headers/Common.hs +++ b/src/Network/GRPC/Spec/Headers/Common.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE OverloadedStrings #-} - -- | Functionality shared between requests and responses -- -- The following headers are used both in requests and in responses: @@ -15,48 +13,17 @@ -- -- Intended for unqualified import. module Network.GRPC.Spec.Headers.Common ( - -- * Content type + -- * Definition ContentType(..) - , buildContentType - , parseContentType - -- * Message type , MessageType(..) - , buildMessageType - , parseMessageType - -- * Message encoding - , buildMessageEncoding - , buildMessageAcceptEncoding - , parseMessageEncoding - , parseMessageAcceptEncoding - -- * Utilities - , trim ) where -import Control.Monad -import Control.Monad.Except -import Data.ByteString qualified as BS.Strict import Data.ByteString qualified as Strict (ByteString) -import Data.ByteString.Char8 qualified as BS.Strict.C8 import Data.Default -import Data.Foldable (toList) -import Data.List (intersperse) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.Proxy -import Data.Word import GHC.Generics (Generic) -import Network.HTTP.Types qualified as HTTP - -import Network.GRPC.Spec.Compression -import Network.GRPC.Spec.Headers.Invalid -import Network.GRPC.Spec.RPC -import Network.GRPC.Spec.RPC.Unknown -import Network.GRPC.Util.ByteString {------------------------------------------------------------------------------- - > Content-Type → - > "content-type" - > "application/grpc" - > [("+proto" / "+json" / {custom})] + ContentType -------------------------------------------------------------------------------} -- | Content type @@ -80,74 +47,8 @@ data ContentType = instance Default ContentType where def = ContentTypeDefault -buildContentType :: - IsRPC rpc - => Proxy rpc - -> ContentType - -> HTTP.Header -buildContentType proxy contentType = ( - "content-type" - , case contentType of - ContentTypeDefault -> defaultContentType - ContentTypeOverride x -> x - ) - where - defaultContentType :: Strict.ByteString - defaultContentType = rpcContentType proxy - -parseContentType :: forall m rpc. - (MonadError InvalidHeaders m, IsRPC rpc) - => Proxy rpc - -> HTTP.Header - -> m ContentType -parseContentType proxy hdr@(_name, value) = do - if value == rpcContentType proxy then - return ContentTypeDefault - else do - -- Headers must be ASCII, justifying the use of BS.Strict.C8. - -- See . - -- The gRPC spec does not allow for quoted strings. - withoutPrefix <- - case BS.Strict.C8.stripPrefix "application/grpc" value of - Nothing -> err "Missing \"application/grpc\" prefix." - Just remainder -> return remainder - - -- The gRPC spec does not allow for any parameters. - when (';' `BS.Strict.C8.elem` withoutPrefix) $ - err "Unexpected parameter." - - -- Check format - -- - -- The only @format@ we should allow is @serializationFormat proxy@. - -- However, some non-conforming proxies use formats such as - -- @application/grpc+octet-stream@. We therefore ignore @format@ here. - if BS.Strict.C8.null withoutPrefix then - -- Accept "application/grpc" - return $ ContentTypeOverride value - else - case BS.Strict.C8.stripPrefix "+" withoutPrefix of - Just _format -> - -- Accept "application/grpc+" - return $ ContentTypeOverride value - Nothing -> - err "Invalid subtype." - where - err :: String -> m a - err reason = throwError $ invalidHeader hdr $ concat [ - reason - , " Expected \"" - , BS.Strict.C8.unpack $ - rpcContentType (Proxy @(UnknownRpc Nothing Nothing)) - , "\" or \"" - , BS.Strict.C8.unpack $ - rpcContentType proxy - , "\", with \"" - , "application/grpc+{other_format}" - , "\" also accepted." - ] - {------------------------------------------------------------------------------- - > Message-Type → "grpc-message-type" {type name for message schema} + MessageType -------------------------------------------------------------------------------} -- | Message type @@ -166,102 +67,3 @@ data MessageType = instance Default MessageType where def = MessageTypeDefault -buildMessageType :: - IsRPC rpc - => Proxy rpc - -> MessageType - -> Maybe HTTP.Header -buildMessageType proxy messageType = - case messageType of - MessageTypeDefault -> mkHeader <$> defaultMessageType - MessageTypeOverride x -> Just $ mkHeader x - where - defaultMessageType :: Maybe Strict.ByteString - defaultMessageType = rpcMessageType proxy - - mkHeader :: Strict.ByteString -> HTTP.Header - mkHeader = ("grpc-message-type",) - --- | Parse message type --- --- We do not need the @grpc-message-type@ header in order to know the message --- type, because the /path/ determines the service and method, and that in turn --- determines the message type. Therefore, if the value is not what we expect, --- we merely record this fact ('MessageTypeOverride') but don't otherwise do --- anything differently. -parseMessageType :: forall rpc. - IsRPC rpc - => Proxy rpc - -> HTTP.Header - -> MessageType -parseMessageType proxy (_name, given) = - case rpcMessageType proxy of - Nothing -> - -- We expected no message type at all, but did get one - MessageTypeOverride given - Just expected -> - if expected == given - then MessageTypeDefault - else MessageTypeOverride given - -{------------------------------------------------------------------------------- - > Message-Encoding → "grpc-encoding" Content-Coding - > Content-Coding → "identity" / "gzip" / "deflate" / "snappy" / {custom} --------------------------------------------------------------------------------} - -buildMessageEncoding :: CompressionId -> HTTP.Header -buildMessageEncoding compr = ( - "grpc-encoding" - , serializeCompressionId compr - ) - -parseMessageEncoding :: - MonadError InvalidHeaders m - => HTTP.Header - -> m CompressionId -parseMessageEncoding (_name, value) = - return $ deserializeCompressionId value - -{------------------------------------------------------------------------------- - > Message-Accept-Encoding → - > "grpc-accept-encoding" Content-Coding *("," Content-Coding) --------------------------------------------------------------------------------} - -buildMessageAcceptEncoding :: NonEmpty CompressionId -> HTTP.Header -buildMessageAcceptEncoding compr = ( - "grpc-accept-encoding" - , mconcat . intersperse "," . map serializeCompressionId $ toList compr - ) - -parseMessageAcceptEncoding :: forall m. - MonadError InvalidHeaders m - => HTTP.Header - -> m (NonEmpty CompressionId) -parseMessageAcceptEncoding hdr@(_name, value) = - atLeastOne - . map (deserializeCompressionId . strip) - . BS.Strict.splitWith (== ascii ',') - $ value - where - atLeastOne :: forall a. [a] -> m (NonEmpty a) - atLeastOne (x : xs) = return (x :| xs) - atLeastOne [] = throwError $ invalidHeader hdr $ - "Expected at least one compresion ID" - -{------------------------------------------------------------------------------- - Utilities --------------------------------------------------------------------------------} - --- | Trim leading or trailing whitespace --- --- We only allow for space and tab, based on --- . -trim :: Strict.ByteString -> Strict.ByteString -trim = ltrim . rtrim - where - ltrim, rtrim :: Strict.ByteString -> Strict.ByteString - ltrim = BS.Strict.dropWhile isSpace - rtrim = BS.Strict.dropWhileEnd isSpace - - isSpace :: Word8 -> Bool - isSpace x = x == 32 || x == 9 diff --git a/src/Network/GRPC/Spec/Headers/PseudoHeaders.hs b/src/Network/GRPC/Spec/Headers/PseudoHeaders.hs index 997bf2c6..79af2b18 100644 --- a/src/Network/GRPC/Spec/Headers/PseudoHeaders.hs +++ b/src/Network/GRPC/Spec/Headers/PseudoHeaders.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE OverloadedStrings #-} - -- | Part of the gRPC spec that maps to HTTP2 pseudo-headers -- -- Intended for unqualified import. @@ -13,25 +11,15 @@ module Network.GRPC.Spec.Headers.PseudoHeaders ( , Scheme(..) , Address(..) , Path(..) - -- * Building and parsing resource headers - , RawResourceHeaders(..) - -- ** Building - , buildResourceHeaders , rpcPath - -- ** Parsing - , InvalidResourceHeaders(..) - , parseResourceHeaders ) where -import Control.Monad.Except -import Data.ByteString qualified as BS.Strict import Data.ByteString qualified as Strict (ByteString) import Data.Hashable import Data.Proxy import Network.Socket (HostName, PortNumber) import Network.GRPC.Spec.RPC -import Network.GRPC.Util.ByteString {------------------------------------------------------------------------------- Definition @@ -137,51 +125,6 @@ instance Hashable Path where hashWithSalt salt Path{pathService, pathMethod} = hashWithSalt salt (pathService, pathMethod) -{------------------------------------------------------------------------------- - Building and parsing resource headers --------------------------------------------------------------------------------} - -data RawResourceHeaders = RawResourceHeaders { - rawPath :: Strict.ByteString - , rawMethod :: Strict.ByteString - } - deriving (Show) - -buildResourceHeaders :: ResourceHeaders -> RawResourceHeaders -buildResourceHeaders ResourceHeaders{resourcePath, resourceMethod} = - RawResourceHeaders { - rawMethod = case resourceMethod of Post -> "POST" - , rawPath = mconcat [ - "/" - , pathService resourcePath - , "/" - , pathMethod resourcePath - ] - } - +-- | Construct path rpcPath :: IsRPC rpc => Proxy rpc -> Path rpcPath proxy = Path (rpcServiceName proxy) (rpcMethodName proxy) - -data InvalidResourceHeaders = - InvalidMethod Strict.ByteString - | InvalidPath Strict.ByteString - deriving stock (Show) - --- | Parse pseudo headers -parseResourceHeaders :: - RawResourceHeaders - -> Either InvalidResourceHeaders ResourceHeaders -parseResourceHeaders RawResourceHeaders{rawMethod, rawPath} = do - resourceMethod <- - case rawMethod of - "POST" -> return Post - _otherwise -> throwError $ InvalidMethod rawMethod - - resourcePath <- - case BS.Strict.split (ascii '/') rawPath of - ["", service, method] -> - return $ Path service method - _otherwise -> - throwError $ InvalidPath rawPath - - return ResourceHeaders{resourceMethod, resourcePath} diff --git a/src/Network/GRPC/Spec/Headers/Request.hs b/src/Network/GRPC/Spec/Headers/Request.hs index bf370a0a..e68c147b 100644 --- a/src/Network/GRPC/Spec/Headers/Request.hs +++ b/src/Network/GRPC/Spec/Headers/Request.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE OverloadedStrings #-} - -- | Construct HTTP2 requests -- -- Intended for qualified import. @@ -11,33 +9,17 @@ module Network.GRPC.Spec.Headers.Request ( , RequestHeaders , RequestHeaders' , InvalidRequestHeaders - -- * Serialization - , buildRequestHeaders - , parseRequestHeaders - , parseRequestHeaders' ) where -import Control.Monad -import Control.Monad.Except (MonadError(throwError)) -import Control.Monad.State (State, execState, modify) -import Data.Bifunctor import Data.ByteString qualified as Strict (ByteString) -import Data.ByteString.Char8 qualified as BS.Strict.C8 -import Data.Functor (($>)) -import Data.List (intercalate) import Data.List.NonEmpty (NonEmpty) -import Data.Maybe (catMaybes) -import Data.Proxy import GHC.Generics (Generic) import Network.HTTP.Types qualified as HTTP -import Text.Read (readMaybe) import Network.GRPC.Spec.Compression (CompressionId) import Network.GRPC.Spec.CustomMetadata.Map -import Network.GRPC.Spec.CustomMetadata.Raw import Network.GRPC.Spec.Headers.Common import Network.GRPC.Spec.Headers.Invalid -import Network.GRPC.Spec.RPC import Network.GRPC.Spec.Timeout import Network.GRPC.Spec.TraceContext import Network.GRPC.Util.HKD (HKD, Undecorated, DecoratedWith) @@ -144,6 +126,11 @@ type RequestHeaders = RequestHeaders_ Undecorated type RequestHeaders' = RequestHeaders_ (DecoratedWith (Either InvalidRequestHeaders)) +-- | Invalid request headers +-- +-- For certain types of failures the gRPC spec mandates a specific HTTP status. +type InvalidRequestHeaders = (HTTP.Status, InvalidHeaders) + deriving stock instance Show RequestHeaders deriving stock instance Eq RequestHeaders deriving stock instance Generic RequestHeaders @@ -167,260 +154,3 @@ instance HKD.Traversable RequestHeaders_ where <*> (f $ requestPreviousRpcAttempts x) <*> (pure $ requestMetadata x) <*> (f $ requestUnrecognized x) - -{------------------------------------------------------------------------------- - Construction --------------------------------------------------------------------------------} - --- | Request headers --- --- > Request-Headers → --- > Call-Definition --- > *Custom-Metadata -buildRequestHeaders :: - IsRPC rpc - => Proxy rpc -> RequestHeaders -> [HTTP.Header] -buildRequestHeaders proxy callParams@RequestHeaders{requestMetadata} = concat [ - callDefinition proxy callParams - , map buildCustomMetadata $ customMetadataMapToList requestMetadata - ] - --- | Call definition --- --- > Call-Definition → --- > Method --- > Scheme --- > Path --- > TE --- > [Authority] --- > [Timeout] --- > Content-Type --- > [Message-Type] --- > [Message-Encoding] --- > [Message-Accept-Encoding] --- > [User-Agent] --- --- However, the spec additionally mandates that --- --- HTTP2 requires that reserved headers, ones starting with ":" appear --- before all other headers. Additionally implementations should send --- Timeout immediately after the reserved headers and they should send the --- Call-Definition headers before sending Custom-Metadata. --- --- (Relevant part of the HTTP2 spec: --- .) This means --- @TE@ should come /after/ @Authority@ (if using). However, we will not include --- the reserved headers here /at all/, as they are automatically added by --- `http2`. -callDefinition :: forall rpc. - IsRPC rpc - => Proxy rpc -> RequestHeaders -> [HTTP.Header] -callDefinition proxy = \hdrs -> catMaybes [ - hdrTimeout <$> requestTimeout hdrs - , guard (requestIncludeTE hdrs) $> buildTe - , buildContentType proxy <$> requestContentType hdrs - , join $ buildMessageType proxy <$> requestMessageType hdrs - , buildMessageEncoding <$> requestCompression hdrs - , buildMessageAcceptEncoding <$> requestAcceptCompression hdrs - , buildUserAgent <$> requestUserAgent hdrs - , buildGrpcTraceBin <$> requestTraceContext hdrs - , buildPreviousRpcAttempts <$> requestPreviousRpcAttempts hdrs - ] - where - hdrTimeout :: Timeout -> HTTP.Header - hdrTimeout t = ("grpc-timeout", buildTimeout t) - - -- > TE → "te" "trailers" # Used to detect incompatible proxies - buildTe :: HTTP.Header - buildTe = ("te", "trailers") - - -- > User-Agent → "user-agent" {structured user-agent string} - -- - -- The spec says: - -- - -- While the protocol does not require a user-agent to function it is - -- recommended that clients provide a structured user-agent string that - -- provides a basic description of the calling library, version & platform - -- to facilitate issue diagnosis in heterogeneous environments. The - -- following structure is recommended to library developers - -- - -- > User-Agent → - -- > "grpc-" - -- > Language - -- > ?("-" Variant) - -- > "/" - -- > Version - -- > ?( " (" *(AdditionalProperty ";") ")" ) - buildUserAgent :: Strict.ByteString -> HTTP.Header - buildUserAgent userAgent = ( - "user-agent" - , userAgent - ) - - buildGrpcTraceBin :: TraceContext -> HTTP.Header - buildGrpcTraceBin ctxt = ( - "grpc-trace-bin" - , buildBinaryValue $ buildTraceContext ctxt - ) - - buildPreviousRpcAttempts :: Int -> HTTP.Header - buildPreviousRpcAttempts n = ( - "grpc-previous-rpc-attempts" - , BS.Strict.C8.pack $ show n - ) - -{------------------------------------------------------------------------------- - Invalid headers --------------------------------------------------------------------------------} - --- | Invalid request headers --- --- For certain types of failures the gRPC spec mandates a specific HTTP status. -type InvalidRequestHeaders = (HTTP.Status, InvalidHeaders) - -{------------------------------------------------------------------------------- - Parsing --------------------------------------------------------------------------------} - -parseRequestHeaders :: forall rpc m. - (IsRPC rpc, MonadError InvalidRequestHeaders m) - => Proxy rpc - -> [HTTP.Header] -> m RequestHeaders -parseRequestHeaders proxy = HKD.sequenceThrow . parseRequestHeaders' proxy - --- | Parse request headers --- --- This can report invalid headers on a per-header basis; see also --- 'parseRequestHeaders'. -parseRequestHeaders' :: forall rpc. - IsRPC rpc - => Proxy rpc - -> [HTTP.Header] -> RequestHeaders' -parseRequestHeaders' proxy = - flip execState uninitRequestHeaders - . mapM_ (parseHeader . second trim) - where - parseHeader :: HTTP.Header -> State RequestHeaders' () - parseHeader hdr@(name, value) - | name == "user-agent" - = modify $ \x -> x { - requestUserAgent = return (Just value) - } - - | name == "grpc-timeout" - = modify $ \x -> x { - requestTimeout = fmap Just $ - httpError hdr HTTP.badRequest400 $ - parseTimeout value - } - - | name == "grpc-encoding" - = modify $ \x -> x { - requestCompression = fmap Just $ - first (HTTP.badRequest400,) $ - parseMessageEncoding hdr - } - - | name == "grpc-accept-encoding" - = modify $ \x -> x { - requestAcceptCompression = fmap Just $ - first (HTTP.badRequest400,) $ - parseMessageAcceptEncoding hdr - } - - | name == "grpc-trace-bin" - = modify $ \x -> x { - requestTraceContext = fmap Just $ - httpError hdr HTTP.badRequest400 $ - parseBinaryValue value >>= parseTraceContext - } - - | name == "content-type" - = modify $ \x -> x { - requestContentType = fmap Just $ - first (HTTP.unsupportedMediaType415,) $ - parseContentType proxy hdr - } - - | name == "grpc-message-type" - = modify $ \x -> x { - requestMessageType = return . Just $ - parseMessageType proxy hdr - } - - | name == "te" - = modify $ \x -> x { - requestIncludeTE = do - first (HTTP.badRequest400,) $ - expectHeaderValue hdr ["trailers"] - return True - } - - | name == "grpc-previous-rpc-attempts" - = modify $ \x -> x { - requestPreviousRpcAttempts = do - httpError hdr HTTP.badRequest400 $ - maybe - (Left $ "grpc-previous-rpc-attempts: invalid " ++ show value) - (Right . Just) - (readMaybe $ BS.Strict.C8.unpack value) - } - - | otherwise - = modify $ \x -> - case parseCustomMetadata hdr of - Left invalid -> x { - requestUnrecognized = Left $ - case requestUnrecognized x of - Left (status, invalid') -> (status, invalid <> invalid') - Right () -> (HTTP.badRequest400, invalid) - } - Right md -> x { - requestMetadata = customMetadataMapInsert md $ requestMetadata x - } - - uninitRequestHeaders :: RequestHeaders' - uninitRequestHeaders = RequestHeaders { - requestTimeout = return Nothing - , requestCompression = return Nothing - , requestAcceptCompression = return Nothing - , requestContentType = return Nothing - , requestIncludeTE = return False - , requestUserAgent = return Nothing - , requestTraceContext = return Nothing - , requestPreviousRpcAttempts = return Nothing - , requestMessageType = - -- If the default is that this header should be absent, then /start/ - -- with 'MessageTypeDefault'; if it happens to present, parse it as - -- an override. - case rpcMessageType proxy of - Nothing -> return $ Just MessageTypeDefault - Just _ -> return $ Nothing - , requestMetadata = mempty - , requestUnrecognized = return () - } - - httpError :: - MonadError InvalidRequestHeaders m' - => HTTP.Header -> HTTP.Status -> Either String a -> m' a - httpError _ _ (Right a) = return a - httpError hdr status (Left err) = throwError (status, invalidHeader hdr err) - -{------------------------------------------------------------------------------- - Internal auxiliary --------------------------------------------------------------------------------} - -expectHeaderValue :: - MonadError InvalidHeaders m - => HTTP.Header -> [Strict.ByteString] -> m () -expectHeaderValue hdr@(_name, actual) expected = - unless (actual `elem` expected) $ - throwError $ invalidHeader hdr err - where - err :: String - err = concat [ - "Expected " - , intercalate " or " $ - map (\e -> "\"" ++ BS.Strict.C8.unpack e ++ "\"") expected - , "." - ] diff --git a/src/Network/GRPC/Spec/Headers/Response.hs b/src/Network/GRPC/Spec/Headers/Response.hs index b806bb8c..bc0020b1 100644 --- a/src/Network/GRPC/Spec/Headers/Response.hs +++ b/src/Network/GRPC/Spec/Headers/Response.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -- | Deal with HTTP2 responses @@ -21,64 +20,27 @@ module Network.GRPC.Spec.Headers.Response ( , simpleProperTrailers , trailersOnlyToProperTrailers , properTrailersToTrailersOnly - , classifyServerResponse -- * Termination , GrpcNormalTermination(..) , grpcClassifyTermination , grpcExceptionToTrailers - -- * Serialization - -- ** Construction - , buildResponseHeaders - , buildProperTrailers - , buildTrailersOnly - , buildPushback - -- ** Parsing - , parseResponseHeaders - , parseResponseHeaders' - , parseProperTrailers - , parseProperTrailers' - , parseTrailersOnly - , parseTrailersOnly' - , parsePushback ) where import Control.Exception -import Control.Monad.Except -import Control.Monad.State -import Data.Bifunctor -import Data.ByteString qualified as BS.Strict -import Data.ByteString qualified as Strict (ByteString) -import Data.ByteString.Char8 qualified as BS.Strict.C8 -import Data.ByteString.Lazy qualified as BS.Lazy -import Data.ByteString.Lazy qualified as Lazy (ByteString) -import Data.CaseInsensitive qualified as CI import Data.List.NonEmpty (NonEmpty) -import Data.Maybe (isJust) import Data.Proxy import Data.Text (Text) -import Data.Text qualified as Text -import Data.Text.Encoding qualified as Text import GHC.Generics (Generic) -import Network.HTTP.Types qualified as HTTP -import Text.Read (readMaybe) - -#if !MIN_VERSION_text(2,0,0) -import Data.Text.Encoding.Error qualified as Text -#endif import Network.GRPC.Spec.Compression (CompressionId) import Network.GRPC.Spec.CustomMetadata.Map import Network.GRPC.Spec.CustomMetadata.Raw -import Network.GRPC.Spec.CustomMetadata.Typed import Network.GRPC.Spec.Headers.Common import Network.GRPC.Spec.Headers.Invalid import Network.GRPC.Spec.OrcaLoadReport -import Network.GRPC.Spec.PercentEncoding qualified as PercentEncoding -import Network.GRPC.Spec.RPC import Network.GRPC.Spec.Status import Network.GRPC.Util.HKD (HKD, Undecorated, DecoratedWith) import Network.GRPC.Util.HKD qualified as HKD -import Network.GRPC.Util.Protobuf qualified as Protobuf {------------------------------------------------------------------------------- Outputs (messages received from the peer) @@ -264,125 +226,6 @@ trailersOnlyToProperTrailers TrailersOnly{ , trailersOnlyContentType ) -{------------------------------------------------------------------------------- - Classify server response --------------------------------------------------------------------------------} - --- | Classify server response --- --- gRPC servers are supposed to respond with HTTP status @200 OK@ no matter --- whether the call was successful or not; if not successful, the information --- about the failure should be reported using @grpc-status@ and related headers --- (@grpc-message@, @grpc-status-details-bin@). --- --- The gRPC spec mandates that if we get a non-200 status from a broken --- deployment, we synthesize a gRPC exception with an appropriate status and --- status message. The spec itself does not provide any guidance on what such an --- appropriate status would look like, but the official gRPC repo does provide a --- partial mapping between HTTP status codes and gRPC status codes at --- . --- This is the mapping we implement here. -classifyServerResponse :: forall rpc. - IsRPC rpc - => Proxy rpc - -> HTTP.Status -- ^ HTTP status - -> [HTTP.Header] -- ^ Headers - -> Maybe Lazy.ByteString -- ^ Response body, if known (used for errors only) - -> Either TrailersOnly' ResponseHeaders' -classifyServerResponse rpc status headers mBody - -- The "HTTP to gRPC Status Code Mapping" is explicit: - -- - -- > (..) to be used only for clients that received a response that did not - -- > include grpc-status. If grpc-status was provided, it must be used. - -- - -- Therefore if @grpc-status@ is present, we ignore the HTTP status. - | hasGrpcStatus headers - = Left $ parseTrailersOnly' rpc headers - - | 200 <- statusCode - = Right $ parseResponseHeaders' rpc headers - - | otherwise - = Left $ - case statusCode of - 400 -> synthesize GrpcInternal -- Bad request - 401 -> synthesize GrpcUnauthenticated -- Unauthorized - 403 -> synthesize GrpcPermissionDenied -- Forbidden - 404 -> synthesize GrpcUnimplemented -- Not found - 429 -> synthesize GrpcUnavailable -- Too many requests - 502 -> synthesize GrpcUnavailable -- Bad gateway - 503 -> synthesize GrpcUnavailable -- Service unavailable - 504 -> synthesize GrpcUnavailable -- Gateway timeout - _ -> synthesize GrpcUnknown - where - HTTP.Status{statusCode, statusMessage} = status - - -- The @grpc-status@ header not present, and HTTP status not @200 OK@. - -- We classify the response as an error response (hence 'TrailersOnly''): - -- - -- * We set 'properTrailersGrpcStatus' based on the HTTP status. - -- * We leave 'properTrailersGrpcMessage' alone if @grpc-message@ present - -- and valid, and replace it with a default message otherwise. - -- - -- The resulting 'TrailersOnly'' cannot contain any parse errors - -- (only @grpc-status@ is required, and only @grpc-message@ can fail). - synthesize :: GrpcError -> TrailersOnly' - synthesize err = parsed { - trailersOnlyProper = parsedTrailers { - properTrailersGrpcStatus = Right $ - GrpcError err - , properTrailersGrpcMessage = Right $ - case properTrailersGrpcMessage parsedTrailers of - Right (Just msg) -> Just msg - _otherwise -> Just defaultMsg - } - } - - where - parsed :: TrailersOnly' - parsed = parseTrailersOnly' rpc headers - - parsedTrailers :: ProperTrailers' - parsedTrailers = trailersOnlyProper parsed - - defaultMsg :: Text - defaultMsg = mconcat [ - "Unexpected HTTP status code " - , Text.pack (show statusCode) - , if not (BS.Strict.null statusMessage) - then " (" <> decodeUtf8Lenient statusMessage <> ")" - else mempty - , case mBody of - Just body | not (BS.Lazy.null body) -> mconcat [ - "\nResponse body:\n" - , decodeUtf8Lenient (BS.Lazy.toStrict body) - ] - _otherwise -> - mempty - ] - --- | Is the @grpc-status@ header set? --- --- We use this as a proxy to determine if we are in the Trailers-Only case. --- --- It might be tempting to use the HTTP @Content-Length@ header instead, but --- this is doubly wrong: --- --- * There might be servers who use the Trailers-Only case but do not set the --- @Content-Length@ header (although such a server would not conform to the --- HTTP spec: "An origin server SHOULD send a @Content-Length@ header field --- when the content size is known prior to sending the complete header --- section"; see --- ). --- * Conversely, there might be servers or proxies who /do/ set @Content-Length@ --- header even when it's /not/ the Trailers-Only case (e.g., see --- or --- ). --- --- We therefore check for the presence of the @grpc-status@ header instead. -hasGrpcStatus :: [HTTP.Header] -> Bool -hasGrpcStatus = isJust . lookup "grpc-status" - {------------------------------------------------------------------------------- Pushback -------------------------------------------------------------------------------} @@ -400,30 +243,6 @@ data Pushback = | DoNotRetry deriving (Show, Eq, Generic) -buildPushback :: Pushback -> Strict.ByteString -buildPushback (RetryAfter n) = BS.Strict.C8.pack $ show n -buildPushback DoNotRetry = "-1" - --- | Parse 'Pushback' --- --- Parsing a pushback cannot fail; the spec mandates: --- --- > If the value for pushback is negative or unparseble, then it will be seen --- > as the server asking the client not to retry at all. --- --- We therefore only require @Monad m@, not @MonadError m@ (having the @Monad@ --- constraint at all keeps the type signature consistent with other parsing --- functions). -parsePushback :: Monad m => Strict.ByteString -> m Pushback -parsePushback bs = - case readMaybe (BS.Strict.C8.unpack bs) of - Just (n :: Int) -> - -- The @Read@ instance for @Word@ /does/ allow for signs - -- - return $ if n < 0 then DoNotRetry else RetryAfter (fromIntegral n) - Nothing -> - return DoNotRetry - {------------------------------------------------------------------------------- Termination -------------------------------------------------------------------------------} @@ -482,315 +301,3 @@ grpcExceptionToTrailers GrpcException{ (GrpcError grpcError) grpcErrorMessage (customMetadataMapFromList grpcErrorMetadata) - -{------------------------------------------------------------------------------- - > Response-Headers → - > HTTP-Status - > [Message-Encoding] - > [Message-Accept-Encoding] - > Content-Type - > *Custom-Metadata - - We do not deal with @HTTP-Status@ here; @http2@ deals this separately. --------------------------------------------------------------------------------} - --- | Build response headers -buildResponseHeaders :: forall rpc. - SupportsServerRpc rpc - => Proxy rpc -> ResponseHeaders -> [HTTP.Header] -buildResponseHeaders proxy - ResponseHeaders{ responseCompression - , responseAcceptCompression - , responseMetadata - , responseContentType - } = concat [ - [ buildContentType proxy x - | Just x <- [responseContentType] - ] - , [ buildMessageEncoding x - | Just x <- [responseCompression] - ] - , [ buildMessageAcceptEncoding x - | Just x <- [responseAcceptCompression] - ] - , [ buildTrailer proxy ] - , [ buildCustomMetadata x - | x <- customMetadataMapToList responseMetadata - ] - ] - --- | Parse response headers -parseResponseHeaders :: forall rpc m. - (IsRPC rpc, MonadError InvalidHeaders m) - => Proxy rpc -> [HTTP.Header] -> m ResponseHeaders -parseResponseHeaders proxy = HKD.sequenceThrow . parseResponseHeaders' proxy - -parseResponseHeaders' :: forall rpc. - IsRPC rpc - => Proxy rpc -> [HTTP.Header] -> ResponseHeaders' -parseResponseHeaders' proxy = - flip execState uninitResponseHeaders - . mapM_ (parseHeader . second trim) - where - -- HTTP2 header names are always lowercase, and must be ASCII. - -- - parseHeader :: HTTP.Header -> State ResponseHeaders' () - parseHeader hdr@(name, _value) - | name == "content-type" - = modify $ \x -> x { - responseContentType = Just <$> parseContentType proxy hdr - } - - | name == "grpc-encoding" - = modify $ \x -> x { - responseCompression = Just <$> parseMessageEncoding hdr - } - - | name == "grpc-accept-encoding" - = modify $ \x -> x { - responseAcceptCompression = Just <$> parseMessageAcceptEncoding hdr - } - - | name == "trailer" - = return () -- ignore the HTTP trailer header - - | otherwise - = modify $ \x -> - case parseCustomMetadata hdr of - Left invalid -> x{ - responseUnrecognized = Left $ mconcat [ - invalid - , otherInvalid $ responseUnrecognized x - ] - } - Right md -> x{ - responseMetadata = - customMetadataMapInsert md $ responseMetadata x - } - - uninitResponseHeaders :: ResponseHeaders' - uninitResponseHeaders = ResponseHeaders { - responseCompression = return Nothing - , responseAcceptCompression = return Nothing - , responseContentType = return Nothing - , responseMetadata = mempty - , responseUnrecognized = return () - } - -{------------------------------------------------------------------------------- - > Trailers → Status [Status-Message] *Custom-Metadata --------------------------------------------------------------------------------} - --- | Construct the HTTP 'Trailer' header --- --- This lists all headers that /might/ be present in the trailers. --- --- See --- --- * --- * -buildTrailer :: forall rpc. SupportsServerRpc rpc => Proxy rpc -> HTTP.Header -buildTrailer _ = ( - "Trailer" - , BS.Strict.intercalate ", " allPotentialTrailers - ) - where - allPotentialTrailers :: [Strict.ByteString] - allPotentialTrailers = concat [ - reservedTrailers - , map (CI.original . buildHeaderName) $ - metadataHeaderNames (Proxy @(ResponseTrailingMetadata rpc)) - ] - - -- These cannot be 'HeaderName' (which disallow reserved names) - -- - -- This list must match the names used by 'buildProperTrailers' - -- and recognized by 'parseProperTrailers'. - reservedTrailers :: [Strict.ByteString] - reservedTrailers = [ - "grpc-status" - , "grpc-message" - , "grpc-retry-pushback-ms" - , "endpoint-load-metrics-bin" - ] - --- | Build trailers (see 'buildTrailersOnly' for the Trailers-Only case) --- --- NOTE: If we add additional (reserved) headers here, we also need to add them --- to 'buildTrailer'. -buildProperTrailers :: ProperTrailers -> [HTTP.Header] -buildProperTrailers ProperTrailers{ - properTrailersGrpcStatus - , properTrailersGrpcMessage - , properTrailersMetadata - , properTrailersPushback - , properTrailersOrcaLoadReport - } = concat [ - [ ( "grpc-status" - , BS.Strict.C8.pack $ show $ fromGrpcStatus properTrailersGrpcStatus - ) - ] - , [ ("grpc-message", PercentEncoding.encode x) - | Just x <- [properTrailersGrpcMessage] - ] - , [ ( "grpc-retry-pushback-ms" - , buildPushback x - ) - | Just x <- [properTrailersPushback] - ] - , [ ( "endpoint-load-metrics-bin" - , buildBinaryValue $ Protobuf.buildStrict x - ) - | Just x <- [properTrailersOrcaLoadReport] - ] - , [ buildCustomMetadata x - | x <- customMetadataMapToList properTrailersMetadata - ] - ] - --- | Build trailers for the Trailers-Only case -buildTrailersOnly :: IsRPC rpc => Proxy rpc -> TrailersOnly -> [HTTP.Header] -buildTrailersOnly proxy TrailersOnly{ - trailersOnlyContentType - , trailersOnlyProper - } = concat [ - [ buildContentType proxy x - | Just x <- [trailersOnlyContentType] - ] - , buildProperTrailers trailersOnlyProper - ] - --- | Parse response trailers --- --- The gRPC spec defines: --- --- > Trailers → Status [Status-Message] *Custom-Metadata --- > Trailers-Only → HTTP-Status Content-Type Trailers --- --- This means that Trailers-Only is a superset of the Trailers; we make use of --- this here, and error out if we get an unexpected @Content-Type@ override. -parseProperTrailers :: forall rpc m. - (IsRPC rpc, MonadError InvalidHeaders m) - => Proxy rpc -> [HTTP.Header] -> m ProperTrailers -parseProperTrailers proxy = HKD.sequenceThrow . parseProperTrailers' proxy - -parseProperTrailers' :: forall rpc. - IsRPC rpc - => Proxy rpc -> [HTTP.Header] -> ProperTrailers' -parseProperTrailers' proxy hdrs = - case trailersOnlyToProperTrailers trailersOnly of - (properTrailers, Right Nothing) -> - properTrailers - (properTrailers, Right (Just _ct)) -> - properTrailers { - properTrailersUnrecognized = Left $ mconcat [ - unexpectedHeader "content-type" - , otherInvalid $ properTrailersUnrecognized properTrailers - ] - } - (properTrailers, Left invalid) -> - -- The @content-type@ header is present, /and/ invalid! - properTrailers { - properTrailersUnrecognized = Left $ mconcat [ - unexpectedHeader "content-type" - , invalid - , otherInvalid $ properTrailersUnrecognized properTrailers - ] - } - where - trailersOnly :: TrailersOnly' - trailersOnly = parseTrailersOnly' proxy hdrs - -parseTrailersOnly :: forall m rpc. - (IsRPC rpc, MonadError InvalidHeaders m) - => Proxy rpc -> [HTTP.Header] -> m TrailersOnly -parseTrailersOnly proxy = HKD.sequenceThrow . parseTrailersOnly' proxy - -parseTrailersOnly' :: forall rpc. - IsRPC rpc - => Proxy rpc -> [HTTP.Header] -> TrailersOnly' -parseTrailersOnly' proxy = - flip execState uninitTrailersOnly - . mapM_ (parseHeader . second trim) - where - parseHeader :: HTTP.Header -> State TrailersOnly' () - parseHeader hdr@(name, value) - | name == "content-type" - = modify $ \x -> x { - trailersOnlyContentType = Just <$> parseContentType proxy hdr - } - - | name == "grpc-status" - = modify $ liftProperTrailers $ \x -> x{ - properTrailersGrpcStatus = throwInvalidHeader hdr $ - case toGrpcStatus =<< readMaybe (BS.Strict.C8.unpack value) of - Nothing -> throwError $ "Invalid status: " ++ show value - Just v -> return v - } - - | name == "grpc-message" - = modify $ liftProperTrailers $ \x -> x{ - properTrailersGrpcMessage = throwInvalidHeader hdr $ - case PercentEncoding.decode value of - Left err -> throwError $ show err - Right msg -> return (Just msg) - } - - | name == "grpc-retry-pushback-ms" - = modify $ liftProperTrailers $ \x -> x{ - properTrailersPushback = - Just <$> parsePushback value - } - - | name == "endpoint-load-metrics-bin" - = modify $ liftProperTrailers $ \x -> x{ - properTrailersOrcaLoadReport = throwInvalidHeader hdr $ do - value' <- parseBinaryValue value - case Protobuf.parseStrict value' of - Left err -> throwError err - Right report -> return $ Just report - } - - | otherwise - = modify $ liftProperTrailers $ \x -> - case parseCustomMetadata hdr of - Left invalid -> x{ - properTrailersUnrecognized = Left $ mconcat [ - invalid - , otherInvalid $ properTrailersUnrecognized x - ] - } - Right md -> x{ - properTrailersMetadata = - customMetadataMapInsert md $ properTrailersMetadata x - } - - uninitTrailersOnly :: TrailersOnly' - uninitTrailersOnly = TrailersOnly { - trailersOnlyContentType = return Nothing - , trailersOnlyProper = simpleProperTrailers - (throwError $ missingHeader "grpc-status") - (return Nothing) - mempty - } - - liftProperTrailers :: - (ProperTrailers_ f -> ProperTrailers_ f) - -> TrailersOnly_ f -> TrailersOnly_ f - liftProperTrailers f trailersOnly = trailersOnly{ - trailersOnlyProper = f (trailersOnlyProper trailersOnly) - } - -{------------------------------------------------------------------------------- - Internal auxiliary --------------------------------------------------------------------------------} - -otherInvalid :: Either InvalidHeaders () -> InvalidHeaders -otherInvalid = either id (\() -> mempty) - -decodeUtf8Lenient :: BS.Strict.C8.ByteString -> Text -#if MIN_VERSION_text(2,0,0) -decodeUtf8Lenient = Text.decodeUtf8Lenient -#else -decodeUtf8Lenient = Text.decodeUtf8With Text.lenientDecode -#endif diff --git a/src/Network/GRPC/Spec/MessageMeta.hs b/src/Network/GRPC/Spec/MessageMeta.hs new file mode 100644 index 00000000..98fe27ad --- /dev/null +++ b/src/Network/GRPC/Spec/MessageMeta.hs @@ -0,0 +1,40 @@ +-- | Information about messages +module Network.GRPC.Spec.MessageMeta ( + OutboundMeta(..) + , InboundMeta(..) + ) where + +import Data.Default +import Data.Word + +{------------------------------------------------------------------------------- + Outbound messages +-------------------------------------------------------------------------------} + +data OutboundMeta = OutboundMeta { + -- | Enable compression for this message + -- + -- Even if enabled, compression will only be used if this results in a + -- smaller message. + outboundEnableCompression :: Bool + } + deriving stock (Show) + +instance Default OutboundMeta where + def = OutboundMeta { + outboundEnableCompression = True + } + +{------------------------------------------------------------------------------- + Inbound messages +-------------------------------------------------------------------------------} + +data InboundMeta = InboundMeta { + -- | Size of the message in compressed form, /if/ it was compressed + inboundCompressedSize :: Maybe Word32 + + -- | Size of the message in uncompressed (but still serialized) form + , inboundUncompressedSize :: Word32 + } + deriving stock (Show) + diff --git a/src/Network/GRPC/Spec/Serialization.hs b/src/Network/GRPC/Spec/Serialization.hs new file mode 100644 index 00000000..3e0095cd --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization.hs @@ -0,0 +1,61 @@ +module Network.GRPC.Spec.Serialization ( + -- * Messages + -- ** Inputs + buildInput + , parseInput + -- ** Outputs + , buildOutput + , parseOutput + -- ** Inbound + -- * Headers + -- ** Status + , buildGrpcStatus + , parseGrpcStatus + -- ** Pseudoheaders + , RawResourceHeaders(..) + , InvalidResourceHeaders(..) + , buildResourceHeaders + , parseResourceHeaders + -- ** RequestHeaders + , buildRequestHeaders + , parseRequestHeaders + , parseRequestHeaders' + -- *** Timeouts + , buildTimeout + , parseTimeout + -- *** OpenTelemetry + , buildTraceContext + , parseTraceContext + -- ** ResponseHeaders + , buildResponseHeaders + , parseResponseHeaders + , parseResponseHeaders' + -- *** Pushback + , buildPushback + , parsePushback + -- ** ProperTrailers + , buildProperTrailers + , parseProperTrailers + , parseProperTrailers' + -- ** TrailersOnly + , buildTrailersOnly + , parseTrailersOnly + , parseTrailersOnly' + -- ** Classify server response + , classifyServerResponse + -- ** Custom metadata + , parseCustomMetadata + , buildCustomMetadata + -- *** Binary values + , buildBinaryValue + , parseBinaryValue + ) where + +import Network.GRPC.Spec.Serialization.CustomMetadata +import Network.GRPC.Spec.Serialization.Headers.PseudoHeaders +import Network.GRPC.Spec.Serialization.Headers.Request +import Network.GRPC.Spec.Serialization.Headers.Response +import Network.GRPC.Spec.Serialization.LengthPrefixed +import Network.GRPC.Spec.Serialization.Status +import Network.GRPC.Spec.Serialization.Timeout +import Network.GRPC.Spec.Serialization.TraceContext diff --git a/src/Network/GRPC/Spec/Base64.hs b/src/Network/GRPC/Spec/Serialization/Base64.hs similarity index 98% rename from src/Network/GRPC/Spec/Base64.hs rename to src/Network/GRPC/Spec/Serialization/Base64.hs index 50e0f0a6..f7845742 100644 --- a/src/Network/GRPC/Spec/Base64.hs +++ b/src/Network/GRPC/Spec/Serialization/Base64.hs @@ -3,7 +3,7 @@ -- The gRPC specification mandates standard Base64-encoding for binary headers -- , /but/ without -- padding. -module Network.GRPC.Spec.Base64 ( +module Network.GRPC.Spec.Serialization.Base64 ( encodeBase64 , decodeBase64 ) where diff --git a/src/Network/GRPC/Spec/Serialization/CustomMetadata.hs b/src/Network/GRPC/Spec/Serialization/CustomMetadata.hs new file mode 100644 index 00000000..5e643340 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/CustomMetadata.hs @@ -0,0 +1,134 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Network.GRPC.Spec.Serialization.CustomMetadata ( + -- * HeaderName + buildHeaderName + , parseHeaderName + -- * AsciiValue + , buildAsciiValue + , parseAsciiValue + -- * BinaryValue + , buildBinaryValue + , parseBinaryValue + -- * CustomMetadata + , buildCustomMetadata + , parseCustomMetadata + ) where + +import Control.Monad +import Control.Monad.Except (MonadError(throwError)) +import Data.ByteString qualified as BS.Strict +import Data.ByteString qualified as Strict (ByteString) +import Data.CaseInsensitive (CI) +import Data.CaseInsensitive qualified as CI +import Data.List (intersperse) +import Network.HTTP.Types qualified as HTTP + +import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization.Base64 +import Network.GRPC.Util.ByteString (ascii) + +{------------------------------------------------------------------------------- + HeaderName +-------------------------------------------------------------------------------} + +buildHeaderName :: HeaderName -> CI Strict.ByteString +buildHeaderName name = + case name of + BinaryHeader name' -> CI.mk name' + AsciiHeader name' -> CI.mk name' + +parseHeaderName :: MonadError String m => CI Strict.ByteString -> m HeaderName +parseHeaderName name = + case safeHeaderName (CI.foldedCase name) of + Nothing -> throwError $ "Invalid header name: " ++ show name + Just name' -> return name' + +{------------------------------------------------------------------------------- + AsciiValue +-------------------------------------------------------------------------------} + +buildAsciiValue :: Strict.ByteString -> Strict.ByteString +buildAsciiValue = id + +parseAsciiValue :: + MonadError String m + => Strict.ByteString -> m Strict.ByteString +parseAsciiValue bs = do + unless (isValidAsciiValue bs) $ + throwError $ "Invalid ASCII header: " ++ show bs + return bs + +{------------------------------------------------------------------------------- + BinaryValue +-------------------------------------------------------------------------------} + +buildBinaryValue :: Strict.ByteString -> Strict.ByteString +buildBinaryValue = encodeBase64 + +-- | Parse binary value +-- +-- The presence of duplicate headers makes this a bit subtle. Let's consider an +-- example. Suppose we have two duplicate headers +-- +-- > foo-bin: YWJj -- encoding of "abc" +-- > foo-bin: ZGVm -- encoding of "def" +-- +-- The spec says +-- +-- > Custom-Metadata header order is not guaranteed to be preserved except for +-- > values with duplicate header names. Duplicate header names may have their +-- > values joined with "," as the delimiter and be considered semantically +-- > equivalent. +-- +-- In @grapesy@ we will do the decoding of both headers /prior/ to joining +-- duplicate headers, and so the value we will reconstruct for @foo-bin@ is +-- \"abc,def\". +-- +-- However, suppose we deal with a (non-compliant) peer which is unaware of +-- binary headers and has applied the joining rule /without/ decoding: +-- +-- > foo-bin: YWJj,ZGVm +-- +-- The spec is a bit vague about this case, saying only: +-- +-- > Implementations must split Binary-Headers on "," before decoding the +-- > Base64-encoded values. +-- +-- Here we assume that this case must be treated the same way as if the headers +-- /had/ been decoded prior to joining. Therefore, we split the input on commas, +-- decode each result separately, and join the results with commas again. +parseBinaryValue :: forall m. + MonadError String m + => Strict.ByteString -> m Strict.ByteString +parseBinaryValue bs = do + let chunks = BS.Strict.split (ascii ',') bs + decoded <- mapM decode chunks + return $ mconcat $ intersperse "," decoded + where + decode :: Strict.ByteString -> m Strict.ByteString + decode chunk = + case decodeBase64 chunk of + Left err -> throwError err + Right val -> return val + +{------------------------------------------------------------------------------- + CustomMetadata +-------------------------------------------------------------------------------} + +buildCustomMetadata :: CustomMetadata -> HTTP.Header +buildCustomMetadata (CustomMetadata name value) = + case name of + BinaryHeader _ -> (buildHeaderName name, buildBinaryValue value) + AsciiHeader _ -> (buildHeaderName name, buildAsciiValue value) + +parseCustomMetadata :: + MonadError InvalidHeaders m + => HTTP.Header -> m CustomMetadata +parseCustomMetadata hdr@(name, value) = throwInvalidHeader hdr $ do + name' <- parseHeaderName name + value' <- case name' of + AsciiHeader _ -> parseAsciiValue value + BinaryHeader _ -> parseBinaryValue value + -- If parsing succeeds, that justifies the use of 'UnsafeCustomMetadata' + return $ CustomMetadata name' value' diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Common.hs b/src/Network/GRPC/Spec/Serialization/Headers/Common.hs new file mode 100644 index 00000000..8ee96c67 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/Headers/Common.hs @@ -0,0 +1,209 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Network.GRPC.Spec.Serialization.Headers.Common ( + -- * Content type + buildContentType + , parseContentType + -- * Message type + , buildMessageType + , parseMessageType + -- * Message encoding + , buildMessageEncoding + , buildMessageAcceptEncoding + , parseMessageEncoding + , parseMessageAcceptEncoding + -- * Utilities + , trim + ) where + +import Control.Monad +import Control.Monad.Except +import Data.ByteString qualified as BS.Strict +import Data.ByteString qualified as Strict (ByteString) +import Data.ByteString.Char8 qualified as BS.Strict.C8 +import Data.Foldable (toList) +import Data.List (intersperse) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.Proxy +import Data.Word +import Network.HTTP.Types qualified as HTTP + +import Network.GRPC.Spec +import Network.GRPC.Util.ByteString + +{------------------------------------------------------------------------------- + > Content-Type → + > "content-type" + > "application/grpc" + > [("+proto" / "+json" / {custom})] +-------------------------------------------------------------------------------} + +buildContentType :: + IsRPC rpc + => Proxy rpc + -> ContentType + -> HTTP.Header +buildContentType proxy contentType = ( + "content-type" + , case contentType of + ContentTypeDefault -> defaultContentType + ContentTypeOverride x -> x + ) + where + defaultContentType :: Strict.ByteString + defaultContentType = rpcContentType proxy + +parseContentType :: forall m rpc. + (MonadError InvalidHeaders m, IsRPC rpc) + => Proxy rpc + -> HTTP.Header + -> m ContentType +parseContentType proxy hdr@(_name, value) = do + if value == rpcContentType proxy then + return ContentTypeDefault + else do + -- Headers must be ASCII, justifying the use of BS.Strict.C8. + -- See . + -- The gRPC spec does not allow for quoted strings. + withoutPrefix <- + case BS.Strict.C8.stripPrefix "application/grpc" value of + Nothing -> err "Missing \"application/grpc\" prefix." + Just remainder -> return remainder + + -- The gRPC spec does not allow for any parameters. + when (';' `BS.Strict.C8.elem` withoutPrefix) $ + err "Unexpected parameter." + + -- Check format + -- + -- The only @format@ we should allow is @serializationFormat proxy@. + -- However, some non-conforming proxies use formats such as + -- @application/grpc+octet-stream@. We therefore ignore @format@ here. + if BS.Strict.C8.null withoutPrefix then + -- Accept "application/grpc" + return $ ContentTypeOverride value + else + case BS.Strict.C8.stripPrefix "+" withoutPrefix of + Just _format -> + -- Accept "application/grpc+" + return $ ContentTypeOverride value + Nothing -> + err "Invalid subtype." + where + err :: String -> m a + err reason = throwError $ invalidHeader hdr $ concat [ + reason + , " Expected \"" + , BS.Strict.C8.unpack $ + rpcContentType (Proxy @(UnknownRpc Nothing Nothing)) + , "\" or \"" + , BS.Strict.C8.unpack $ + rpcContentType proxy + , "\", with \"" + , "application/grpc+{other_format}" + , "\" also accepted." + ] + +{------------------------------------------------------------------------------- + > Message-Type → "grpc-message-type" {type name for message schema} +-------------------------------------------------------------------------------} + +buildMessageType :: + IsRPC rpc + => Proxy rpc + -> MessageType + -> Maybe HTTP.Header +buildMessageType proxy messageType = + case messageType of + MessageTypeDefault -> mkHeader <$> defaultMessageType + MessageTypeOverride x -> Just $ mkHeader x + where + defaultMessageType :: Maybe Strict.ByteString + defaultMessageType = rpcMessageType proxy + + mkHeader :: Strict.ByteString -> HTTP.Header + mkHeader = ("grpc-message-type",) + +-- | Parse message type +-- +-- We do not need the @grpc-message-type@ header in order to know the message +-- type, because the /path/ determines the service and method, and that in turn +-- determines the message type. Therefore, if the value is not what we expect, +-- we merely record this fact ('MessageTypeOverride') but don't otherwise do +-- anything differently. +parseMessageType :: forall rpc. + IsRPC rpc + => Proxy rpc + -> HTTP.Header + -> MessageType +parseMessageType proxy (_name, given) = + case rpcMessageType proxy of + Nothing -> + -- We expected no message type at all, but did get one + MessageTypeOverride given + Just expected -> + if expected == given + then MessageTypeDefault + else MessageTypeOverride given + +{------------------------------------------------------------------------------- + > Message-Encoding → "grpc-encoding" Content-Coding + > Content-Coding → "identity" / "gzip" / "deflate" / "snappy" / {custom} +-------------------------------------------------------------------------------} + +buildMessageEncoding :: CompressionId -> HTTP.Header +buildMessageEncoding compr = ( + "grpc-encoding" + , serializeCompressionId compr + ) + +parseMessageEncoding :: + MonadError InvalidHeaders m + => HTTP.Header + -> m CompressionId +parseMessageEncoding (_name, value) = + return $ deserializeCompressionId value + +{------------------------------------------------------------------------------- + > Message-Accept-Encoding → + > "grpc-accept-encoding" Content-Coding *("," Content-Coding) +-------------------------------------------------------------------------------} + +buildMessageAcceptEncoding :: NonEmpty CompressionId -> HTTP.Header +buildMessageAcceptEncoding compr = ( + "grpc-accept-encoding" + , mconcat . intersperse "," . map serializeCompressionId $ toList compr + ) + +parseMessageAcceptEncoding :: forall m. + MonadError InvalidHeaders m + => HTTP.Header + -> m (NonEmpty CompressionId) +parseMessageAcceptEncoding hdr@(_name, value) = + atLeastOne + . map (deserializeCompressionId . strip) + . BS.Strict.splitWith (== ascii ',') + $ value + where + atLeastOne :: forall a. [a] -> m (NonEmpty a) + atLeastOne (x : xs) = return (x :| xs) + atLeastOne [] = throwError $ invalidHeader hdr $ + "Expected at least one compresion ID" + +{------------------------------------------------------------------------------- + Utilities +-------------------------------------------------------------------------------} + +-- | Trim leading or trailing whitespace +-- +-- We only allow for space and tab, based on +-- . +trim :: Strict.ByteString -> Strict.ByteString +trim = ltrim . rtrim + where + ltrim, rtrim :: Strict.ByteString -> Strict.ByteString + ltrim = BS.Strict.dropWhile isSpace + rtrim = BS.Strict.dropWhileEnd isSpace + + isSpace :: Word8 -> Bool + isSpace x = x == 32 || x == 9 diff --git a/src/Network/GRPC/Spec/Serialization/Headers/PseudoHeaders.hs b/src/Network/GRPC/Spec/Serialization/Headers/PseudoHeaders.hs new file mode 100644 index 00000000..2fa555b7 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/Headers/PseudoHeaders.hs @@ -0,0 +1,61 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Network.GRPC.Spec.Serialization.Headers.PseudoHeaders ( + RawResourceHeaders(..) + , InvalidResourceHeaders(..) + , buildResourceHeaders + , parseResourceHeaders + ) where + +import Control.Monad.Except +import Data.ByteString qualified as BS.Strict +import Data.ByteString qualified as Strict (ByteString) + +import Network.GRPC.Spec +import Network.GRPC.Util.ByteString + +{------------------------------------------------------------------------------- + Serialization +-------------------------------------------------------------------------------} + +data RawResourceHeaders = RawResourceHeaders { + rawPath :: Strict.ByteString + , rawMethod :: Strict.ByteString + } + deriving (Show) + +data InvalidResourceHeaders = + InvalidMethod Strict.ByteString + | InvalidPath Strict.ByteString + deriving stock (Show) + +buildResourceHeaders :: ResourceHeaders -> RawResourceHeaders +buildResourceHeaders ResourceHeaders{resourcePath, resourceMethod} = + RawResourceHeaders { + rawMethod = case resourceMethod of Post -> "POST" + , rawPath = mconcat [ + "/" + , pathService resourcePath + , "/" + , pathMethod resourcePath + ] + } + +-- | Parse pseudo headers +parseResourceHeaders :: + RawResourceHeaders + -> Either InvalidResourceHeaders ResourceHeaders +parseResourceHeaders RawResourceHeaders{rawMethod, rawPath} = do + resourceMethod <- + case rawMethod of + "POST" -> return Post + _otherwise -> throwError $ InvalidMethod rawMethod + + resourcePath <- + case BS.Strict.split (ascii '/') rawPath of + ["", service, method] -> + return $ Path service method + _otherwise -> + throwError $ InvalidPath rawPath + + return ResourceHeaders{resourceMethod, resourcePath} diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Request.hs b/src/Network/GRPC/Spec/Serialization/Headers/Request.hs new file mode 100644 index 00000000..f1a537f6 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/Headers/Request.hs @@ -0,0 +1,275 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Network.GRPC.Spec.Serialization.Headers.Request ( + buildRequestHeaders + , parseRequestHeaders + , parseRequestHeaders' + ) where + +import Control.Monad +import Control.Monad.Except (MonadError(throwError)) +import Control.Monad.State (State, execState, modify) +import Data.Bifunctor +import Data.ByteString qualified as Strict (ByteString) +import Data.ByteString.Char8 qualified as BS.Strict.C8 +import Data.Functor (($>)) +import Data.List (intercalate) +import Data.Maybe (catMaybes) +import Data.Proxy +import Network.HTTP.Types qualified as HTTP +import Text.Read (readMaybe) + +import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization.CustomMetadata +import Network.GRPC.Spec.Serialization.Headers.Common +import Network.GRPC.Spec.Serialization.Timeout +import Network.GRPC.Spec.Serialization.TraceContext +import Network.GRPC.Util.HKD qualified as HKD + +{------------------------------------------------------------------------------- + Construction +-------------------------------------------------------------------------------} + +-- | Request headers +-- +-- > Request-Headers → +-- > Call-Definition +-- > *Custom-Metadata +buildRequestHeaders :: + IsRPC rpc + => Proxy rpc -> RequestHeaders -> [HTTP.Header] +buildRequestHeaders proxy callParams@RequestHeaders{requestMetadata} = concat [ + callDefinition proxy callParams + , map buildCustomMetadata $ customMetadataMapToList requestMetadata + ] + +-- | Call definition +-- +-- > Call-Definition → +-- > Method +-- > Scheme +-- > Path +-- > TE +-- > [Authority] +-- > [Timeout] +-- > Content-Type +-- > [Message-Type] +-- > [Message-Encoding] +-- > [Message-Accept-Encoding] +-- > [User-Agent] +-- +-- However, the spec additionally mandates that +-- +-- HTTP2 requires that reserved headers, ones starting with ":" appear +-- before all other headers. Additionally implementations should send +-- Timeout immediately after the reserved headers and they should send the +-- Call-Definition headers before sending Custom-Metadata. +-- +-- (Relevant part of the HTTP2 spec: +-- .) This means +-- @TE@ should come /after/ @Authority@ (if using). However, we will not include +-- the reserved headers here /at all/, as they are automatically added by +-- `http2`. +callDefinition :: forall rpc. + IsRPC rpc + => Proxy rpc -> RequestHeaders -> [HTTP.Header] +callDefinition proxy = \hdrs -> catMaybes [ + hdrTimeout <$> requestTimeout hdrs + , guard (requestIncludeTE hdrs) $> buildTe + , buildContentType proxy <$> requestContentType hdrs + , join $ buildMessageType proxy <$> requestMessageType hdrs + , buildMessageEncoding <$> requestCompression hdrs + , buildMessageAcceptEncoding <$> requestAcceptCompression hdrs + , buildUserAgent <$> requestUserAgent hdrs + , buildGrpcTraceBin <$> requestTraceContext hdrs + , buildPreviousRpcAttempts <$> requestPreviousRpcAttempts hdrs + ] + where + hdrTimeout :: Timeout -> HTTP.Header + hdrTimeout t = ("grpc-timeout", buildTimeout t) + + -- > TE → "te" "trailers" # Used to detect incompatible proxies + buildTe :: HTTP.Header + buildTe = ("te", "trailers") + + -- > User-Agent → "user-agent" {structured user-agent string} + -- + -- The spec says: + -- + -- While the protocol does not require a user-agent to function it is + -- recommended that clients provide a structured user-agent string that + -- provides a basic description of the calling library, version & platform + -- to facilitate issue diagnosis in heterogeneous environments. The + -- following structure is recommended to library developers + -- + -- > User-Agent → + -- > "grpc-" + -- > Language + -- > ?("-" Variant) + -- > "/" + -- > Version + -- > ?( " (" *(AdditionalProperty ";") ")" ) + buildUserAgent :: Strict.ByteString -> HTTP.Header + buildUserAgent userAgent = ( + "user-agent" + , userAgent + ) + + buildGrpcTraceBin :: TraceContext -> HTTP.Header + buildGrpcTraceBin ctxt = ( + "grpc-trace-bin" + , buildBinaryValue $ buildTraceContext ctxt + ) + + buildPreviousRpcAttempts :: Int -> HTTP.Header + buildPreviousRpcAttempts n = ( + "grpc-previous-rpc-attempts" + , BS.Strict.C8.pack $ show n + ) + +{------------------------------------------------------------------------------- + Parsing +-------------------------------------------------------------------------------} + +parseRequestHeaders :: forall rpc m. + (IsRPC rpc, MonadError InvalidRequestHeaders m) + => Proxy rpc + -> [HTTP.Header] -> m RequestHeaders +parseRequestHeaders proxy = HKD.sequenceThrow . parseRequestHeaders' proxy + +-- | Parse request headers +-- +-- This can report invalid headers on a per-header basis; see also +-- 'parseRequestHeaders'. +parseRequestHeaders' :: forall rpc. + IsRPC rpc + => Proxy rpc + -> [HTTP.Header] -> RequestHeaders' +parseRequestHeaders' proxy = + flip execState uninitRequestHeaders + . mapM_ (parseHeader . second trim) + where + parseHeader :: HTTP.Header -> State RequestHeaders' () + parseHeader hdr@(name, value) + | name == "user-agent" + = modify $ \x -> x { + requestUserAgent = return (Just value) + } + + | name == "grpc-timeout" + = modify $ \x -> x { + requestTimeout = fmap Just $ + httpError hdr HTTP.badRequest400 $ + parseTimeout value + } + + | name == "grpc-encoding" + = modify $ \x -> x { + requestCompression = fmap Just $ + first (HTTP.badRequest400,) $ + parseMessageEncoding hdr + } + + | name == "grpc-accept-encoding" + = modify $ \x -> x { + requestAcceptCompression = fmap Just $ + first (HTTP.badRequest400,) $ + parseMessageAcceptEncoding hdr + } + + | name == "grpc-trace-bin" + = modify $ \x -> x { + requestTraceContext = fmap Just $ + httpError hdr HTTP.badRequest400 $ + parseBinaryValue value >>= parseTraceContext + } + + | name == "content-type" + = modify $ \x -> x { + requestContentType = fmap Just $ + first (HTTP.unsupportedMediaType415,) $ + parseContentType proxy hdr + } + + | name == "grpc-message-type" + = modify $ \x -> x { + requestMessageType = return . Just $ + parseMessageType proxy hdr + } + + | name == "te" + = modify $ \x -> x { + requestIncludeTE = do + first (HTTP.badRequest400,) $ + expectHeaderValue hdr ["trailers"] + return True + } + + | name == "grpc-previous-rpc-attempts" + = modify $ \x -> x { + requestPreviousRpcAttempts = do + httpError hdr HTTP.badRequest400 $ + maybe + (Left $ "grpc-previous-rpc-attempts: invalid " ++ show value) + (Right . Just) + (readMaybe $ BS.Strict.C8.unpack value) + } + + | otherwise + = modify $ \x -> + case parseCustomMetadata hdr of + Left invalid -> x { + requestUnrecognized = Left $ + case requestUnrecognized x of + Left (status, invalid') -> (status, invalid <> invalid') + Right () -> (HTTP.badRequest400, invalid) + } + Right md -> x { + requestMetadata = customMetadataMapInsert md $ requestMetadata x + } + + uninitRequestHeaders :: RequestHeaders' + uninitRequestHeaders = RequestHeaders { + requestTimeout = return Nothing + , requestCompression = return Nothing + , requestAcceptCompression = return Nothing + , requestContentType = return Nothing + , requestIncludeTE = return False + , requestUserAgent = return Nothing + , requestTraceContext = return Nothing + , requestPreviousRpcAttempts = return Nothing + , requestMessageType = + -- If the default is that this header should be absent, then /start/ + -- with 'MessageTypeDefault'; if it happens to present, parse it as + -- an override. + case rpcMessageType proxy of + Nothing -> return $ Just MessageTypeDefault + Just _ -> return $ Nothing + , requestMetadata = mempty + , requestUnrecognized = return () + } + + httpError :: + MonadError InvalidRequestHeaders m' + => HTTP.Header -> HTTP.Status -> Either String a -> m' a + httpError _ _ (Right a) = return a + httpError hdr status (Left err) = throwError (status, invalidHeader hdr err) + +{------------------------------------------------------------------------------- + Internal auxiliary +-------------------------------------------------------------------------------} + +expectHeaderValue :: + MonadError InvalidHeaders m + => HTTP.Header -> [Strict.ByteString] -> m () +expectHeaderValue hdr@(_name, actual) expected = + unless (actual `elem` expected) $ + throwError $ invalidHeader hdr err + where + err :: String + err = concat [ + "Expected " + , intercalate " or " $ + map (\e -> "\"" ++ BS.Strict.C8.unpack e ++ "\"") expected + , "." + ] diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Response.hs b/src/Network/GRPC/Spec/Serialization/Headers/Response.hs new file mode 100644 index 00000000..a220bca9 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/Headers/Response.hs @@ -0,0 +1,513 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE OverloadedStrings #-} + +-- | Deal with HTTP2 responses +-- +-- Intended for unqualified import. +module Network.GRPC.Spec.Serialization.Headers.Response ( + -- * ResponseHeaders + buildResponseHeaders + , parseResponseHeaders + , parseResponseHeaders' + -- * ProperTrailers + , buildProperTrailers + , parseProperTrailers + , parseProperTrailers' + -- * TrailersOnly + , buildTrailersOnly + , parseTrailersOnly + , parseTrailersOnly' + -- * Classify server response + , classifyServerResponse + -- * Pushback + , buildPushback + , parsePushback + ) where + +import Control.Monad.Except +import Control.Monad.State +import Data.Bifunctor +import Data.ByteString qualified as BS.Strict +import Data.ByteString qualified as Strict (ByteString) +import Data.ByteString.Char8 qualified as BS.Strict.C8 +import Data.ByteString.Lazy qualified as BS.Lazy +import Data.ByteString.Lazy qualified as Lazy (ByteString) +import Data.CaseInsensitive qualified as CI +import Data.Maybe (isJust) +import Data.Proxy +import Data.Text (Text) +import Data.Text qualified as Text +import Data.Text.Encoding qualified as Text +import Network.HTTP.Types qualified as HTTP +import Text.Read (readMaybe) + +#if !MIN_VERSION_text(2,0,0) +import Data.Text.Encoding.Error qualified as Text +#endif + +import Network.GRPC.Spec +import Network.GRPC.Spec.PercentEncoding qualified as PercentEncoding +import Network.GRPC.Spec.Serialization.CustomMetadata +import Network.GRPC.Spec.Serialization.Headers.Common +import Network.GRPC.Spec.Serialization.Status +import Network.GRPC.Util.HKD qualified as HKD +import Network.GRPC.Util.Protobuf qualified as Protobuf + +{------------------------------------------------------------------------------- + Classify server response +-------------------------------------------------------------------------------} + +-- | Classify server response +-- +-- gRPC servers are supposed to respond with HTTP status @200 OK@ no matter +-- whether the call was successful or not; if not successful, the information +-- about the failure should be reported using @grpc-status@ and related headers +-- (@grpc-message@, @grpc-status-details-bin@). +-- +-- The gRPC spec mandates that if we get a non-200 status from a broken +-- deployment, we synthesize a gRPC exception with an appropriate status and +-- status message. The spec itself does not provide any guidance on what such an +-- appropriate status would look like, but the official gRPC repo does provide a +-- partial mapping between HTTP status codes and gRPC status codes at +-- . +-- This is the mapping we implement here. +classifyServerResponse :: forall rpc. + IsRPC rpc + => Proxy rpc + -> HTTP.Status -- ^ HTTP status + -> [HTTP.Header] -- ^ Headers + -> Maybe Lazy.ByteString -- ^ Response body, if known (used for errors only) + -> Either TrailersOnly' ResponseHeaders' +classifyServerResponse rpc status headers mBody + -- The "HTTP to gRPC Status Code Mapping" is explicit: + -- + -- > (..) to be used only for clients that received a response that did not + -- > include grpc-status. If grpc-status was provided, it must be used. + -- + -- Therefore if @grpc-status@ is present, we ignore the HTTP status. + | hasGrpcStatus headers + = Left $ parseTrailersOnly' rpc headers + + | 200 <- statusCode + = Right $ parseResponseHeaders' rpc headers + + | otherwise + = Left $ + case statusCode of + 400 -> synthesize GrpcInternal -- Bad request + 401 -> synthesize GrpcUnauthenticated -- Unauthorized + 403 -> synthesize GrpcPermissionDenied -- Forbidden + 404 -> synthesize GrpcUnimplemented -- Not found + 429 -> synthesize GrpcUnavailable -- Too many requests + 502 -> synthesize GrpcUnavailable -- Bad gateway + 503 -> synthesize GrpcUnavailable -- Service unavailable + 504 -> synthesize GrpcUnavailable -- Gateway timeout + _ -> synthesize GrpcUnknown + where + HTTP.Status{statusCode, statusMessage} = status + + -- The @grpc-status@ header not present, and HTTP status not @200 OK@. + -- We classify the response as an error response (hence 'TrailersOnly''): + -- + -- * We set 'properTrailersGrpcStatus' based on the HTTP status. + -- * We leave 'properTrailersGrpcMessage' alone if @grpc-message@ present + -- and valid, and replace it with a default message otherwise. + -- + -- The resulting 'TrailersOnly'' cannot contain any parse errors + -- (only @grpc-status@ is required, and only @grpc-message@ can fail). + synthesize :: GrpcError -> TrailersOnly' + synthesize err = parsed { + trailersOnlyProper = parsedTrailers { + properTrailersGrpcStatus = Right $ + GrpcError err + , properTrailersGrpcMessage = Right $ + case properTrailersGrpcMessage parsedTrailers of + Right (Just msg) -> Just msg + _otherwise -> Just defaultMsg + } + } + + where + parsed :: TrailersOnly' + parsed = parseTrailersOnly' rpc headers + + parsedTrailers :: ProperTrailers' + parsedTrailers = trailersOnlyProper parsed + + defaultMsg :: Text + defaultMsg = mconcat [ + "Unexpected HTTP status code " + , Text.pack (show statusCode) + , if not (BS.Strict.null statusMessage) + then " (" <> decodeUtf8Lenient statusMessage <> ")" + else mempty + , case mBody of + Just body | not (BS.Lazy.null body) -> mconcat [ + "\nResponse body:\n" + , decodeUtf8Lenient (BS.Lazy.toStrict body) + ] + _otherwise -> + mempty + ] + +-- | Is the @grpc-status@ header set? +-- +-- We use this as a proxy to determine if we are in the Trailers-Only case. +-- +-- It might be tempting to use the HTTP @Content-Length@ header instead, but +-- this is doubly wrong: +-- +-- * There might be servers who use the Trailers-Only case but do not set the +-- @Content-Length@ header (although such a server would not conform to the +-- HTTP spec: "An origin server SHOULD send a @Content-Length@ header field +-- when the content size is known prior to sending the complete header +-- section"; see +-- ). +-- * Conversely, there might be servers or proxies who /do/ set @Content-Length@ +-- header even when it's /not/ the Trailers-Only case (e.g., see +-- or +-- ). +-- +-- We therefore check for the presence of the @grpc-status@ header instead. +hasGrpcStatus :: [HTTP.Header] -> Bool +hasGrpcStatus = isJust . lookup "grpc-status" + +{------------------------------------------------------------------------------- + > Response-Headers → + > HTTP-Status + > [Message-Encoding] + > [Message-Accept-Encoding] + > Content-Type + > *Custom-Metadata + + We do not deal with @HTTP-Status@ here; @http2@ deals this separately. +-------------------------------------------------------------------------------} + +-- | Build response headers +buildResponseHeaders :: forall rpc. + SupportsServerRpc rpc + => Proxy rpc -> ResponseHeaders -> [HTTP.Header] +buildResponseHeaders proxy + ResponseHeaders{ responseCompression + , responseAcceptCompression + , responseMetadata + , responseContentType + } = concat [ + [ buildContentType proxy x + | Just x <- [responseContentType] + ] + , [ buildMessageEncoding x + | Just x <- [responseCompression] + ] + , [ buildMessageAcceptEncoding x + | Just x <- [responseAcceptCompression] + ] + , [ buildTrailer proxy ] + , [ buildCustomMetadata x + | x <- customMetadataMapToList responseMetadata + ] + ] + +-- | Parse response headers +parseResponseHeaders :: forall rpc m. + (IsRPC rpc, MonadError InvalidHeaders m) + => Proxy rpc -> [HTTP.Header] -> m ResponseHeaders +parseResponseHeaders proxy = HKD.sequenceThrow . parseResponseHeaders' proxy + +parseResponseHeaders' :: forall rpc. + IsRPC rpc + => Proxy rpc -> [HTTP.Header] -> ResponseHeaders' +parseResponseHeaders' proxy = + flip execState uninitResponseHeaders + . mapM_ (parseHeader . second trim) + where + -- HTTP2 header names are always lowercase, and must be ASCII. + -- + parseHeader :: HTTP.Header -> State ResponseHeaders' () + parseHeader hdr@(name, _value) + | name == "content-type" + = modify $ \x -> x { + responseContentType = Just <$> parseContentType proxy hdr + } + + | name == "grpc-encoding" + = modify $ \x -> x { + responseCompression = Just <$> parseMessageEncoding hdr + } + + | name == "grpc-accept-encoding" + = modify $ \x -> x { + responseAcceptCompression = Just <$> parseMessageAcceptEncoding hdr + } + + | name == "trailer" + = return () -- ignore the HTTP trailer header + + | otherwise + = modify $ \x -> + case parseCustomMetadata hdr of + Left invalid -> x{ + responseUnrecognized = Left $ mconcat [ + invalid + , otherInvalid $ responseUnrecognized x + ] + } + Right md -> x{ + responseMetadata = + customMetadataMapInsert md $ responseMetadata x + } + + uninitResponseHeaders :: ResponseHeaders' + uninitResponseHeaders = ResponseHeaders { + responseCompression = return Nothing + , responseAcceptCompression = return Nothing + , responseContentType = return Nothing + , responseMetadata = mempty + , responseUnrecognized = return () + } + +{------------------------------------------------------------------------------- + > Trailers → Status [Status-Message] *Custom-Metadata +-------------------------------------------------------------------------------} + +-- | Construct the HTTP 'Trailer' header +-- +-- This lists all headers that /might/ be present in the trailers. +-- +-- See +-- +-- * +-- * +buildTrailer :: forall rpc. SupportsServerRpc rpc => Proxy rpc -> HTTP.Header +buildTrailer _ = ( + "Trailer" + , BS.Strict.intercalate ", " allPotentialTrailers + ) + where + allPotentialTrailers :: [Strict.ByteString] + allPotentialTrailers = concat [ + reservedTrailers + , map (CI.original . buildHeaderName) $ + metadataHeaderNames (Proxy @(ResponseTrailingMetadata rpc)) + ] + + -- These cannot be 'HeaderName' (which disallow reserved names) + -- + -- This list must match the names used by 'buildProperTrailers' + -- and recognized by 'parseProperTrailers'. + reservedTrailers :: [Strict.ByteString] + reservedTrailers = [ + "grpc-status" + , "grpc-message" + , "grpc-retry-pushback-ms" + , "endpoint-load-metrics-bin" + ] + +-- | Build trailers (see 'buildTrailersOnly' for the Trailers-Only case) +-- +-- NOTE: If we add additional (reserved) headers here, we also need to add them +-- to 'buildTrailer'. +buildProperTrailers :: ProperTrailers -> [HTTP.Header] +buildProperTrailers ProperTrailers{ + properTrailersGrpcStatus + , properTrailersGrpcMessage + , properTrailersMetadata + , properTrailersPushback + , properTrailersOrcaLoadReport + } = concat [ + [ ( "grpc-status" + , BS.Strict.C8.pack $ show $ buildGrpcStatus properTrailersGrpcStatus + ) + ] + , [ ("grpc-message", PercentEncoding.encode x) + | Just x <- [properTrailersGrpcMessage] + ] + , [ ( "grpc-retry-pushback-ms" + , buildPushback x + ) + | Just x <- [properTrailersPushback] + ] + , [ ( "endpoint-load-metrics-bin" + , buildBinaryValue $ Protobuf.buildStrict x + ) + | Just x <- [properTrailersOrcaLoadReport] + ] + , [ buildCustomMetadata x + | x <- customMetadataMapToList properTrailersMetadata + ] + ] + +-- | Build trailers for the Trailers-Only case +buildTrailersOnly :: IsRPC rpc => Proxy rpc -> TrailersOnly -> [HTTP.Header] +buildTrailersOnly proxy TrailersOnly{ + trailersOnlyContentType + , trailersOnlyProper + } = concat [ + [ buildContentType proxy x + | Just x <- [trailersOnlyContentType] + ] + , buildProperTrailers trailersOnlyProper + ] + +-- | Parse response trailers +-- +-- The gRPC spec defines: +-- +-- > Trailers → Status [Status-Message] *Custom-Metadata +-- > Trailers-Only → HTTP-Status Content-Type Trailers +-- +-- This means that Trailers-Only is a superset of the Trailers; we make use of +-- this here, and error out if we get an unexpected @Content-Type@ override. +parseProperTrailers :: forall rpc m. + (IsRPC rpc, MonadError InvalidHeaders m) + => Proxy rpc -> [HTTP.Header] -> m ProperTrailers +parseProperTrailers proxy = HKD.sequenceThrow . parseProperTrailers' proxy + +parseProperTrailers' :: forall rpc. + IsRPC rpc + => Proxy rpc -> [HTTP.Header] -> ProperTrailers' +parseProperTrailers' proxy hdrs = + case trailersOnlyToProperTrailers trailersOnly of + (properTrailers, Right Nothing) -> + properTrailers + (properTrailers, Right (Just _ct)) -> + properTrailers { + properTrailersUnrecognized = Left $ mconcat [ + unexpectedHeader "content-type" + , otherInvalid $ properTrailersUnrecognized properTrailers + ] + } + (properTrailers, Left invalid) -> + -- The @content-type@ header is present, /and/ invalid! + properTrailers { + properTrailersUnrecognized = Left $ mconcat [ + unexpectedHeader "content-type" + , invalid + , otherInvalid $ properTrailersUnrecognized properTrailers + ] + } + where + trailersOnly :: TrailersOnly' + trailersOnly = parseTrailersOnly' proxy hdrs + +parseTrailersOnly :: forall m rpc. + (IsRPC rpc, MonadError InvalidHeaders m) + => Proxy rpc -> [HTTP.Header] -> m TrailersOnly +parseTrailersOnly proxy = HKD.sequenceThrow . parseTrailersOnly' proxy + +parseTrailersOnly' :: forall rpc. + IsRPC rpc + => Proxy rpc -> [HTTP.Header] -> TrailersOnly' +parseTrailersOnly' proxy = + flip execState uninitTrailersOnly + . mapM_ (parseHeader . second trim) + where + parseHeader :: HTTP.Header -> State TrailersOnly' () + parseHeader hdr@(name, value) + | name == "content-type" + = modify $ \x -> x { + trailersOnlyContentType = Just <$> parseContentType proxy hdr + } + + | name == "grpc-status" + = modify $ liftProperTrailers $ \x -> x{ + properTrailersGrpcStatus = throwInvalidHeader hdr $ + case parseGrpcStatus =<< readMaybe (BS.Strict.C8.unpack value) of + Nothing -> throwError $ "Invalid status: " ++ show value + Just v -> return v + } + + | name == "grpc-message" + = modify $ liftProperTrailers $ \x -> x{ + properTrailersGrpcMessage = throwInvalidHeader hdr $ + case PercentEncoding.decode value of + Left err -> throwError $ show err + Right msg -> return (Just msg) + } + + | name == "grpc-retry-pushback-ms" + = modify $ liftProperTrailers $ \x -> x{ + properTrailersPushback = + Just <$> parsePushback value + } + + | name == "endpoint-load-metrics-bin" + = modify $ liftProperTrailers $ \x -> x{ + properTrailersOrcaLoadReport = throwInvalidHeader hdr $ do + value' <- parseBinaryValue value + case Protobuf.parseStrict value' of + Left err -> throwError err + Right report -> return $ Just report + } + + | otherwise + = modify $ liftProperTrailers $ \x -> + case parseCustomMetadata hdr of + Left invalid -> x{ + properTrailersUnrecognized = Left $ mconcat [ + invalid + , otherInvalid $ properTrailersUnrecognized x + ] + } + Right md -> x{ + properTrailersMetadata = + customMetadataMapInsert md $ properTrailersMetadata x + } + + uninitTrailersOnly :: TrailersOnly' + uninitTrailersOnly = TrailersOnly { + trailersOnlyContentType = return Nothing + , trailersOnlyProper = simpleProperTrailers + (throwError $ missingHeader "grpc-status") + (return Nothing) + mempty + } + + liftProperTrailers :: + (ProperTrailers_ f -> ProperTrailers_ f) + -> TrailersOnly_ f -> TrailersOnly_ f + liftProperTrailers f trailersOnly = trailersOnly{ + trailersOnlyProper = f (trailersOnlyProper trailersOnly) + } + +{------------------------------------------------------------------------------- + Pushback +-------------------------------------------------------------------------------} + +buildPushback :: Pushback -> Strict.ByteString +buildPushback (RetryAfter n) = BS.Strict.C8.pack $ show n +buildPushback DoNotRetry = "-1" + +-- | Parse 'Pushback' +-- +-- Parsing a pushback cannot fail; the spec mandates: +-- +-- > If the value for pushback is negative or unparseble, then it will be seen +-- > as the server asking the client not to retry at all. +-- +-- We therefore only require @Monad m@, not @MonadError m@ (having the @Monad@ +-- constraint at all keeps the type signature consistent with other parsing +-- functions). +parsePushback :: Monad m => Strict.ByteString -> m Pushback +parsePushback bs = + case readMaybe (BS.Strict.C8.unpack bs) of + Just (n :: Int) -> + -- The @Read@ instance for @Word@ /does/ allow for signs + -- + return $ if n < 0 then DoNotRetry else RetryAfter (fromIntegral n) + Nothing -> + return DoNotRetry + +{------------------------------------------------------------------------------- + Internal auxiliary +-------------------------------------------------------------------------------} + +otherInvalid :: Either InvalidHeaders () -> InvalidHeaders +otherInvalid = either id (\() -> mempty) + +decodeUtf8Lenient :: BS.Strict.C8.ByteString -> Text +#if MIN_VERSION_text(2,0,0) +decodeUtf8Lenient = Text.decodeUtf8Lenient +#else +decodeUtf8Lenient = Text.decodeUtf8With Text.lenientDecode +#endif diff --git a/src/Network/GRPC/Spec/LengthPrefixed.hs b/src/Network/GRPC/Spec/Serialization/LengthPrefixed.hs similarity index 76% rename from src/Network/GRPC/Spec/LengthPrefixed.hs rename to src/Network/GRPC/Spec/Serialization/LengthPrefixed.hs index 4e82b707..aea3e838 100644 --- a/src/Network/GRPC/Spec/LengthPrefixed.hs +++ b/src/Network/GRPC/Spec/Serialization/LengthPrefixed.hs @@ -1,16 +1,16 @@ -- | Length-prefixed messages -- -- These are used both for inputs and outputs. -module Network.GRPC.Spec.LengthPrefixed ( +module Network.GRPC.Spec.Serialization.LengthPrefixed ( -- * Message prefix MessagePrefix(..) -- * Length-prefixex messages -- ** Construction - , OutboundEnvelope(..) + , OutboundMeta(..) , buildInput , buildOutput -- ** Parsing - , InboundEnvelope(..) + , InboundMeta(..) , parseInput , parseOutput ) where @@ -21,12 +21,10 @@ import Data.ByteString.Builder (Builder) import Data.ByteString.Builder qualified as Builder import Data.ByteString.Lazy qualified as BS.Lazy import Data.ByteString.Lazy qualified as Lazy (ByteString) -import Data.Default import Data.Proxy import Data.Word -import Network.GRPC.Spec.Compression -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec import Network.GRPC.Util.Parser (Parser) import Network.GRPC.Util.Parser qualified as Parser @@ -59,20 +57,6 @@ getMessagePrefix = do Construction -------------------------------------------------------------------------------} -data OutboundEnvelope = OutboundEnvelope { - -- | Enable compression for this message - -- - -- Even if enabled, compression will only be used if this results in a - -- smaller message. - outboundEnableCompression :: Bool - } - deriving stock (Show) - -instance Default OutboundEnvelope where - def = OutboundEnvelope { - outboundEnableCompression = True - } - -- | Serialize RPC input -- -- > Length-Prefixed-Message → Compressed-Flag Message-Length Message @@ -86,7 +70,7 @@ buildInput :: SupportsClientRpc rpc => Proxy rpc -> Compression - -> (OutboundEnvelope, Input rpc) + -> (OutboundMeta, Input rpc) -> Builder buildInput = buildMsg . rpcSerializeInput @@ -95,7 +79,7 @@ buildOutput :: SupportsServerRpc rpc => Proxy rpc -> Compression - -> (OutboundEnvelope, Output rpc) + -> (OutboundMeta, Output rpc) -> Builder buildOutput = buildMsg . rpcSerializeOutput @@ -103,9 +87,9 @@ buildOutput = buildMsg . rpcSerializeOutput buildMsg :: (x -> Lazy.ByteString) -> Compression - -> (OutboundEnvelope, x) + -> (OutboundMeta, x) -> Builder -buildMsg build compr (envelope, x) = mconcat [ +buildMsg build compr (meta, x) = mconcat [ buildMessagePrefix prefix , Builder.lazyByteString $ if shouldCompress @@ -120,7 +104,7 @@ buildMsg build compr (envelope, x) = mconcat [ shouldCompress :: Bool shouldCompress = and [ uncompressedSizeThreshold compr uncompressedLength - , outboundEnableCompression envelope + , outboundEnableCompression meta , compressedLength < uncompressedLength ] where @@ -145,55 +129,46 @@ buildMsg build compr (envelope, x) = mconcat [ Parsing -------------------------------------------------------------------------------} -data InboundEnvelope = InboundEnvelope { - -- | Size of the message in compressed form, /if/ it was compressed - inboundCompressedSize :: Maybe Word32 - - -- | Size of the message in uncompressed (but still serialized) form - , inboundUncompressedSize :: Word32 - } - deriving stock (Show) - parseInput :: SupportsServerRpc rpc => Proxy rpc -> Compression - -> Parser String (InboundEnvelope, Input rpc) + -> Parser String (InboundMeta, Input rpc) parseInput = parseMsg . rpcDeserializeInput parseOutput :: SupportsClientRpc rpc => Proxy rpc -> Compression - -> Parser String (InboundEnvelope, Output rpc) + -> Parser String (InboundMeta, Output rpc) parseOutput = parseMsg . rpcDeserializeOutput parseMsg :: forall x. (Lazy.ByteString -> Either String x) -> Compression - -> Parser String (InboundEnvelope, x) + -> Parser String (InboundMeta, x) parseMsg parse compr = do prefix <- Parser.getExactly 5 getMessagePrefix Parser.consumeExactly (fromIntegral $ msgLength prefix) $ parseBody (msgIsCompressed prefix) where - parseBody :: Bool -> Lazy.ByteString -> Either String (InboundEnvelope, x) + parseBody :: Bool -> Lazy.ByteString -> Either String (InboundMeta, x) parseBody False body = - (envelope,) <$> parse body + (meta,) <$> parse body where - envelope :: InboundEnvelope - envelope = InboundEnvelope { + meta :: InboundMeta + meta = InboundMeta { inboundCompressedSize = Nothing , inboundUncompressedSize = lengthOf body } parseBody True compressed = - (envelope,) <$> parse uncompressed + (meta,) <$> parse uncompressed where uncompressed :: Lazy.ByteString uncompressed = decompress compr compressed - envelope :: InboundEnvelope - envelope = InboundEnvelope { + meta :: InboundMeta + meta = InboundMeta { inboundCompressedSize = Just (lengthOf compressed) , inboundUncompressedSize = lengthOf uncompressed } diff --git a/src/Network/GRPC/Spec/Serialization/Status.hs b/src/Network/GRPC/Spec/Serialization/Status.hs new file mode 100644 index 00000000..8a2a09d5 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/Status.hs @@ -0,0 +1,50 @@ +module Network.GRPC.Spec.Serialization.Status ( + buildGrpcStatus + , parseGrpcStatus + ) where + +import Network.GRPC.Spec + +{------------------------------------------------------------------------------- + Serialization +-------------------------------------------------------------------------------} + +buildGrpcStatus :: GrpcStatus -> Word +buildGrpcStatus GrpcOk = 0 +buildGrpcStatus (GrpcError GrpcCancelled) = 1 +buildGrpcStatus (GrpcError GrpcUnknown) = 2 +buildGrpcStatus (GrpcError GrpcInvalidArgument) = 3 +buildGrpcStatus (GrpcError GrpcDeadlineExceeded) = 4 +buildGrpcStatus (GrpcError GrpcNotFound) = 5 +buildGrpcStatus (GrpcError GrpcAlreadyExists) = 6 +buildGrpcStatus (GrpcError GrpcPermissionDenied) = 7 +buildGrpcStatus (GrpcError GrpcResourceExhausted) = 8 +buildGrpcStatus (GrpcError GrpcFailedPrecondition) = 9 +buildGrpcStatus (GrpcError GrpcAborted) = 10 +buildGrpcStatus (GrpcError GrpcOutOfRange) = 11 +buildGrpcStatus (GrpcError GrpcUnimplemented) = 12 +buildGrpcStatus (GrpcError GrpcInternal) = 13 +buildGrpcStatus (GrpcError GrpcUnavailable) = 14 +buildGrpcStatus (GrpcError GrpcDataLoss) = 15 +buildGrpcStatus (GrpcError GrpcUnauthenticated) = 16 + +parseGrpcStatus :: Word -> Maybe GrpcStatus +parseGrpcStatus 0 = Just $ GrpcOk +parseGrpcStatus 1 = Just $ GrpcError $ GrpcCancelled +parseGrpcStatus 2 = Just $ GrpcError $ GrpcUnknown +parseGrpcStatus 3 = Just $ GrpcError $ GrpcInvalidArgument +parseGrpcStatus 4 = Just $ GrpcError $ GrpcDeadlineExceeded +parseGrpcStatus 5 = Just $ GrpcError $ GrpcNotFound +parseGrpcStatus 6 = Just $ GrpcError $ GrpcAlreadyExists +parseGrpcStatus 7 = Just $ GrpcError $ GrpcPermissionDenied +parseGrpcStatus 8 = Just $ GrpcError $ GrpcResourceExhausted +parseGrpcStatus 9 = Just $ GrpcError $ GrpcFailedPrecondition +parseGrpcStatus 10 = Just $ GrpcError $ GrpcAborted +parseGrpcStatus 11 = Just $ GrpcError $ GrpcOutOfRange +parseGrpcStatus 12 = Just $ GrpcError $ GrpcUnimplemented +parseGrpcStatus 13 = Just $ GrpcError $ GrpcInternal +parseGrpcStatus 14 = Just $ GrpcError $ GrpcUnavailable +parseGrpcStatus 15 = Just $ GrpcError $ GrpcDataLoss +parseGrpcStatus 16 = Just $ GrpcError $ GrpcUnauthenticated +parseGrpcStatus _ = Nothing + diff --git a/src/Network/GRPC/Spec/Serialization/Timeout.hs b/src/Network/GRPC/Spec/Serialization/Timeout.hs new file mode 100644 index 00000000..26108b28 --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/Timeout.hs @@ -0,0 +1,76 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Network.GRPC.Spec.Serialization.Timeout ( + buildTimeout + , parseTimeout + ) where + +import Control.Monad.Except +import Data.ByteString qualified as BS.Strict +import Data.ByteString qualified as Strict (ByteString) +import Data.ByteString.Char8 qualified as BS.Strict.C8 +import Data.Char (isDigit) + +import Network.GRPC.Spec + +{------------------------------------------------------------------------------- + Serialization + + > Timeout → "grpc-timeout" TimeoutValue TimeoutUnit + > TimeoutValue → {positive integer as ASCII string of at most 8 digits} + > TimeoutUnit → Hour / Minute / Second / Millisecond / Microsecond / Nanosecond + > Hour → "H" + > Minute → "M" + > Second → "S" + > Millisecond → "m" + > Microsecond → "u" + > Nanosecond → "n" +-------------------------------------------------------------------------------} + +buildTimeout :: Timeout -> Strict.ByteString +buildTimeout (Timeout unit val) = mconcat [ + BS.Strict.C8.pack $ show $ getTimeoutValue val + , case unit of + Hour -> "H" + Minute -> "M" + Second -> "S" + Millisecond -> "m" + Microsecond -> "u" + Nanosecond -> "n" + ] + +parseTimeout :: forall m. + MonadError String m + => Strict.ByteString -> m Timeout +parseTimeout bs = do + let (bsVal, bsUnit) = BS.Strict.C8.span isDigit bs + + val <- + if BS.Strict.length bsVal < 1 || BS.Strict.length bsVal > 8 + then invalid + else return . TimeoutValue $ read (BS.Strict.C8.unpack bsVal) + + charUnit <- + case BS.Strict.C8.uncons bsUnit of + Nothing -> + invalid + Just (u, remainder) -> + if BS.Strict.null remainder + then return u + else invalid + + unit <- + case charUnit of + 'H' -> return Hour + 'M' -> return Minute + 'S' -> return Second + 'm' -> return Millisecond + 'u' -> return Microsecond + 'n' -> return Nanosecond + _ -> invalid + + return $ Timeout unit val + where + invalid :: m a + invalid = throwError $ "Could not parse timeout " ++ show bs + diff --git a/src/Network/GRPC/Spec/Serialization/TraceContext.hs b/src/Network/GRPC/Spec/Serialization/TraceContext.hs new file mode 100644 index 00000000..9872e8dc --- /dev/null +++ b/src/Network/GRPC/Spec/Serialization/TraceContext.hs @@ -0,0 +1,128 @@ +{-# OPTIONS_GHC -Wno-orphans #-} + +module Network.GRPC.Spec.Serialization.TraceContext ( + buildTraceContext + , parseTraceContext + ) where + +import Control.Applicative (many) +import Control.Monad.Except +import Data.Binary (Binary(..)) +import Data.Binary qualified as Binary +import Data.Binary.Get qualified as Get +import Data.Binary.Put qualified as Put +import Data.ByteString qualified as Strict (ByteString) +import Data.ByteString.Lazy qualified as BS.Lazy +import Data.Default +import Data.Maybe (maybeToList) +import Data.Word + +import Network.GRPC.Spec + +{------------------------------------------------------------------------------- + Serialization +-------------------------------------------------------------------------------} + +buildTraceContext :: TraceContext -> Strict.ByteString +buildTraceContext = BS.Lazy.toStrict . Binary.encode + +parseTraceContext :: MonadError String m => Strict.ByteString -> m TraceContext +parseTraceContext bs = + case Binary.decodeOrFail (BS.Lazy.fromStrict bs) of + Right (_, _, ctxt) -> return ctxt + Left (_, _, err) -> throwError err + +{------------------------------------------------------------------------------- + Internal auxiliary: parsing + + +-------------------------------------------------------------------------------} + +instance Binary TraceId where + put = Put.putByteString . getTraceId + get = TraceId <$> Get.getByteString 16 + +instance Binary SpanId where + put = Put.putByteString . getSpanId + get = SpanId <$> Get.getByteString 8 + +instance Binary TraceOptions where + put = Put.putWord8 . traceOptionsToWord8 + get = traceOptionsFromWord8 =<< Get.getWord8 + +instance Binary Field where + put (FieldTraceId tid) = Put.putWord8 0 <> put tid + put (FieldSpanId sid) = Put.putWord8 1 <> put sid + put (FieldOptions opts) = Put.putWord8 2 <> put opts + + get = do + fieldId <- Get.getWord8 + case fieldId of + 0 -> FieldTraceId <$> get + 1 -> FieldSpanId <$> get + 2 -> FieldOptions <$> get + _ -> fail $ "Invalid fieldId " ++ show fieldId + +instance Binary TraceContext where + put ctxt = mconcat [ + Put.putWord8 0 -- Version 0 + , foldMap put (traceContextToFields ctxt) + ] + + get = do + version <- Get.getWord8 + case version of + 0 -> traceContextFromFields =<< many get + _ -> fail $ "Invalid version " ++ show version + +{------------------------------------------------------------------------------- + Internal: fields +-------------------------------------------------------------------------------} + +data Field = + FieldTraceId TraceId + | FieldSpanId SpanId + | FieldOptions TraceOptions + +traceContextToFields :: TraceContext -> [Field] +traceContextToFields (TraceContext tid sid opts) = concat [ + FieldTraceId <$> maybeToList tid + , FieldSpanId <$> maybeToList sid + , FieldOptions <$> maybeToList opts + ] + +traceContextFromFields :: forall m. MonadFail m => [Field] -> m TraceContext +traceContextFromFields = flip go def + where + go :: [Field] -> TraceContext -> m TraceContext + go [] acc = return acc + go (f:fs) acc = + case f of + FieldTraceId tid -> + case traceContextTraceId acc of + Nothing -> go fs $ acc{traceContextTraceId = Just tid} + Just _ -> fail "Multiple TraceId fields" + FieldSpanId sid -> + case traceContextSpanId acc of + Nothing -> go fs $ acc{traceContextSpanId = Just sid} + Just _ -> fail "Multiple SpanId fields" + FieldOptions opts -> + case traceContextOptions acc of + Nothing -> go fs $ acc{traceContextOptions = Just opts} + Just _ -> fail "Multiple TraceOptions fields" + +{------------------------------------------------------------------------------- + Internal: dealing with 'TraceOptions' + + We take advantage of the fact that currently only a single option is defined. + Once we have more than one, this code will be a bit more complicated. +-------------------------------------------------------------------------------} + +traceOptionsToWord8 :: TraceOptions -> Word8 +traceOptionsToWord8 (TraceOptions False) = 0 +traceOptionsToWord8 (TraceOptions True) = 1 + +traceOptionsFromWord8 :: MonadFail m => Word8 -> m TraceOptions +traceOptionsFromWord8 0 = return $ TraceOptions False +traceOptionsFromWord8 1 = return $ TraceOptions True +traceOptionsFromWord8 n = fail $ "Invalid TraceOptions " ++ show n diff --git a/src/Network/GRPC/Spec/Status.hs b/src/Network/GRPC/Spec/Status.hs index 22f7cbe0..3e1ef882 100644 --- a/src/Network/GRPC/Spec/Status.hs +++ b/src/Network/GRPC/Spec/Status.hs @@ -2,8 +2,6 @@ module Network.GRPC.Spec.Status ( -- * GRPC status GrpcStatus(..) , GrpcError(..) - , fromGrpcStatus - , toGrpcStatus -- * Exceptions , GrpcException(..) , throwGrpcError @@ -176,45 +174,6 @@ data GrpcError = deriving stock (Show, Eq, Generic) deriving anyclass (Exception) -fromGrpcStatus :: GrpcStatus -> Word -fromGrpcStatus GrpcOk = 0 -fromGrpcStatus (GrpcError GrpcCancelled) = 1 -fromGrpcStatus (GrpcError GrpcUnknown) = 2 -fromGrpcStatus (GrpcError GrpcInvalidArgument) = 3 -fromGrpcStatus (GrpcError GrpcDeadlineExceeded) = 4 -fromGrpcStatus (GrpcError GrpcNotFound) = 5 -fromGrpcStatus (GrpcError GrpcAlreadyExists) = 6 -fromGrpcStatus (GrpcError GrpcPermissionDenied) = 7 -fromGrpcStatus (GrpcError GrpcResourceExhausted) = 8 -fromGrpcStatus (GrpcError GrpcFailedPrecondition) = 9 -fromGrpcStatus (GrpcError GrpcAborted) = 10 -fromGrpcStatus (GrpcError GrpcOutOfRange) = 11 -fromGrpcStatus (GrpcError GrpcUnimplemented) = 12 -fromGrpcStatus (GrpcError GrpcInternal) = 13 -fromGrpcStatus (GrpcError GrpcUnavailable) = 14 -fromGrpcStatus (GrpcError GrpcDataLoss) = 15 -fromGrpcStatus (GrpcError GrpcUnauthenticated) = 16 - -toGrpcStatus :: Word -> Maybe GrpcStatus -toGrpcStatus 0 = Just $ GrpcOk -toGrpcStatus 1 = Just $ GrpcError $ GrpcCancelled -toGrpcStatus 2 = Just $ GrpcError $ GrpcUnknown -toGrpcStatus 3 = Just $ GrpcError $ GrpcInvalidArgument -toGrpcStatus 4 = Just $ GrpcError $ GrpcDeadlineExceeded -toGrpcStatus 5 = Just $ GrpcError $ GrpcNotFound -toGrpcStatus 6 = Just $ GrpcError $ GrpcAlreadyExists -toGrpcStatus 7 = Just $ GrpcError $ GrpcPermissionDenied -toGrpcStatus 8 = Just $ GrpcError $ GrpcResourceExhausted -toGrpcStatus 9 = Just $ GrpcError $ GrpcFailedPrecondition -toGrpcStatus 10 = Just $ GrpcError $ GrpcAborted -toGrpcStatus 11 = Just $ GrpcError $ GrpcOutOfRange -toGrpcStatus 12 = Just $ GrpcError $ GrpcUnimplemented -toGrpcStatus 13 = Just $ GrpcError $ GrpcInternal -toGrpcStatus 14 = Just $ GrpcError $ GrpcUnavailable -toGrpcStatus 15 = Just $ GrpcError $ GrpcDataLoss -toGrpcStatus 16 = Just $ GrpcError $ GrpcUnauthenticated -toGrpcStatus _ = Nothing - {------------------------------------------------------------------------------- gRPC exceptions -------------------------------------------------------------------------------} diff --git a/src/Network/GRPC/Spec/Timeout.hs b/src/Network/GRPC/Spec/Timeout.hs index acd8c229..5d1248e6 100644 --- a/src/Network/GRPC/Spec/Timeout.hs +++ b/src/Network/GRPC/Spec/Timeout.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE OverloadedStrings #-} - module Network.GRPC.Spec.Timeout ( -- * Timeouts Timeout(..) @@ -8,16 +6,8 @@ module Network.GRPC.Spec.Timeout ( , isValidTimeoutValue -- * Translation , timeoutToMicro - -- * Serialization - , buildTimeout - , parseTimeout ) where -import Control.Monad.Except -import Data.ByteString qualified as BS.Strict -import Data.ByteString qualified as Strict (ByteString) -import Data.ByteString.Char8 qualified as BS.Strict.C8 -import Data.Char (isDigit) import GHC.Generics (Generic) import GHC.Show @@ -92,65 +82,3 @@ timeoutToMicro = \case mu + if n' == 0 then 0 else 1 where (mu, n') = divMod n 1_000 - -{------------------------------------------------------------------------------- - Serialization - - > Timeout → "grpc-timeout" TimeoutValue TimeoutUnit - > TimeoutValue → {positive integer as ASCII string of at most 8 digits} - > TimeoutUnit → Hour / Minute / Second / Millisecond / Microsecond / Nanosecond - > Hour → "H" - > Minute → "M" - > Second → "S" - > Millisecond → "m" - > Microsecond → "u" - > Nanosecond → "n" --------------------------------------------------------------------------------} - -buildTimeout :: Timeout -> Strict.ByteString -buildTimeout (Timeout unit val) = mconcat [ - BS.Strict.C8.pack $ show $ getTimeoutValue val - , case unit of - Hour -> "H" - Minute -> "M" - Second -> "S" - Millisecond -> "m" - Microsecond -> "u" - Nanosecond -> "n" - ] - -parseTimeout :: forall m. - MonadError String m - => Strict.ByteString -> m Timeout -parseTimeout bs = do - let (bsVal, bsUnit) = BS.Strict.C8.span isDigit bs - - val <- - if BS.Strict.length bsVal < 1 || BS.Strict.length bsVal > 8 - then invalid - else return . TimeoutValue $ read (BS.Strict.C8.unpack bsVal) - - charUnit <- - case BS.Strict.C8.uncons bsUnit of - Nothing -> - invalid - Just (u, remainder) -> - if BS.Strict.null remainder - then return u - else invalid - - unit <- - case charUnit of - 'H' -> return Hour - 'M' -> return Minute - 'S' -> return Second - 'm' -> return Millisecond - 'u' -> return Microsecond - 'n' -> return Nanosecond - _ -> invalid - - return $ Timeout unit val - where - invalid :: m a - invalid = throwError $ "Could not parse timeout " ++ show bs - diff --git a/src/Network/GRPC/Spec/TraceContext.hs b/src/Network/GRPC/Spec/TraceContext.hs index e3978fe8..8ffae7fe 100644 --- a/src/Network/GRPC/Spec/TraceContext.hs +++ b/src/Network/GRPC/Spec/TraceContext.hs @@ -7,25 +7,13 @@ module Network.GRPC.Spec.TraceContext ( , TraceId(..) , SpanId(..) , TraceOptions(..) - -- ** Serialization - , buildTraceContext - , parseTraceContext ) where -import Control.Applicative (many) -import Control.Monad.Except -import Data.Binary (Binary(..)) -import Data.Binary qualified as Binary -import Data.Binary.Get qualified as Get -import Data.Binary.Put qualified as Put import Data.ByteString qualified as Strict (ByteString) import Data.ByteString.Base16 qualified as BS.Strict.Base16 import Data.ByteString.Char8 qualified as BS.Strict.Char8 -import Data.ByteString.Lazy qualified as BS.Lazy import Data.Default -import Data.Maybe (maybeToList) import Data.String -import Data.Word import GHC.Generics (Generic) {------------------------------------------------------------------------------- @@ -139,110 +127,3 @@ instance IsString SpanId where Left err -> error err Right tid -> SpanId tid -{------------------------------------------------------------------------------- - Parsing --------------------------------------------------------------------------------} - -buildTraceContext :: TraceContext -> Strict.ByteString -buildTraceContext = BS.Lazy.toStrict . Binary.encode - -parseTraceContext :: MonadError String m => Strict.ByteString -> m TraceContext -parseTraceContext bs = - case Binary.decodeOrFail (BS.Lazy.fromStrict bs) of - Right (_, _, ctxt) -> return ctxt - Left (_, _, err) -> throwError err - -{------------------------------------------------------------------------------- - Internal auxiliary: parsing - - --------------------------------------------------------------------------------} - -instance Binary TraceId where - put = Put.putByteString . getTraceId - get = TraceId <$> Get.getByteString 16 - -instance Binary SpanId where - put = Put.putByteString . getSpanId - get = SpanId <$> Get.getByteString 8 - -instance Binary TraceOptions where - put = Put.putWord8 . traceOptionsToWord8 - get = traceOptionsFromWord8 =<< Get.getWord8 - -instance Binary Field where - put (FieldTraceId tid) = Put.putWord8 0 <> put tid - put (FieldSpanId sid) = Put.putWord8 1 <> put sid - put (FieldOptions opts) = Put.putWord8 2 <> put opts - - get = do - fieldId <- Get.getWord8 - case fieldId of - 0 -> FieldTraceId <$> get - 1 -> FieldSpanId <$> get - 2 -> FieldOptions <$> get - _ -> fail $ "Invalid fieldId " ++ show fieldId - -instance Binary TraceContext where - put ctxt = mconcat [ - Put.putWord8 0 -- Version 0 - , foldMap put (traceContextToFields ctxt) - ] - - get = do - version <- Get.getWord8 - case version of - 0 -> traceContextFromFields =<< many get - _ -> fail $ "Invalid version " ++ show version - -{------------------------------------------------------------------------------- - Internal: fields --------------------------------------------------------------------------------} - -data Field = - FieldTraceId TraceId - | FieldSpanId SpanId - | FieldOptions TraceOptions - -traceContextToFields :: TraceContext -> [Field] -traceContextToFields (TraceContext tid sid opts) = concat [ - FieldTraceId <$> maybeToList tid - , FieldSpanId <$> maybeToList sid - , FieldOptions <$> maybeToList opts - ] - -traceContextFromFields :: forall m. MonadFail m => [Field] -> m TraceContext -traceContextFromFields = flip go def - where - go :: [Field] -> TraceContext -> m TraceContext - go [] acc = return acc - go (f:fs) acc = - case f of - FieldTraceId tid -> - case traceContextTraceId acc of - Nothing -> go fs $ acc{traceContextTraceId = Just tid} - Just _ -> fail "Multiple TraceId fields" - FieldSpanId sid -> - case traceContextSpanId acc of - Nothing -> go fs $ acc{traceContextSpanId = Just sid} - Just _ -> fail "Multiple SpanId fields" - FieldOptions opts -> - case traceContextOptions acc of - Nothing -> go fs $ acc{traceContextOptions = Just opts} - Just _ -> fail "Multiple TraceOptions fields" - -{------------------------------------------------------------------------------- - Internal: dealing with 'TraceOptions' - - We take advantage of the fact that currently only a single option is defined. - Once we have more than one, this code will be a bit more complicated. --------------------------------------------------------------------------------} - -traceOptionsToWord8 :: TraceOptions -> Word8 -traceOptionsToWord8 (TraceOptions False) = 0 -traceOptionsToWord8 (TraceOptions True) = 1 - -traceOptionsFromWord8 :: MonadFail m => Word8 -> m TraceOptions -traceOptionsFromWord8 0 = return $ TraceOptions False -traceOptionsFromWord8 1 = return $ TraceOptions True -traceOptionsFromWord8 n = fail $ "Invalid TraceOptions " ++ show n diff --git a/test-grapesy/Test/Driver/ClientServer.hs b/test-grapesy/Test/Driver/ClientServer.hs index ffa1cd87..319192be 100644 --- a/test-grapesy/Test/Driver/ClientServer.hs +++ b/test-grapesy/Test/Driver/ClientServer.hs @@ -40,7 +40,6 @@ import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Server qualified as Server import Network.GRPC.Server.Run qualified as Server -import Network.GRPC.Spec import Paths_grapesy @@ -109,10 +108,10 @@ data ContentTypeOverride = -- -- It is the responsibility of the test to make sure that this content-type -- is in fact valid. - | ValidOverride ContentType + | ValidOverride Server.ContentType -- | Override with an invalid content-type - | InvalidOverride ContentType + | InvalidOverride Server.ContentType instance Default ClientServerConfig where def = ClientServerConfig { diff --git a/test-grapesy/Test/Prop/Serialization.hs b/test-grapesy/Test/Prop/Serialization.hs index d57c1a80..aad61cf8 100644 --- a/test-grapesy/Test/Prop/Serialization.hs +++ b/test-grapesy/Test/Prop/Serialization.hs @@ -33,6 +33,7 @@ import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Common.Protobuf import Network.GRPC.Spec +import Network.GRPC.Spec.Serialization import Test.Util.Awkward import Test.Util.Orphans () diff --git a/test-grapesy/Test/Sanity/Interop.hs b/test-grapesy/Test/Sanity/Interop.hs index c94fe9d2..17949426 100644 --- a/test-grapesy/Test/Sanity/Interop.hs +++ b/test-grapesy/Test/Sanity/Interop.hs @@ -17,12 +17,13 @@ import Network.GRPC.Client qualified as Client import Network.GRPC.Client.Binary qualified as Client.Binary import Network.GRPC.Client.StreamType.IO qualified as Client import Network.GRPC.Common +import Network.GRPC.Common.Binary (RawRpc) import Network.GRPC.Common.Protobuf import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Server qualified as Server import Network.GRPC.Server.Binary qualified as Server.Binary +import Network.GRPC.Server.StreamType (ServerHandler'(..)) import Network.GRPC.Server.StreamType qualified as Server -import Network.GRPC.Spec import Proto.API.Interop import Proto.API.Ping @@ -108,10 +109,10 @@ test_emptyUnary = , client = simpleTestClient $ \conn -> Client.withRPC conn def (Proxy @EmptyCall) $ \call -> do Client.sendFinalInput call defMessage - streamElem <- Client.recvOutputWithEnvelope call + streamElem <- Client.recvOutputWithMeta call case StreamElem.value streamElem of - Nothing -> fail "Expected answer" - Just (envelope, _x) -> verifyEnvelope envelope + Nothing -> fail "Expected answer" + Just (meta, _x) -> verifyMeta meta , server = [ Server.fromMethod @EmptyCall $ ServerHandler $ \_empty -> return defMessage @@ -121,10 +122,10 @@ test_emptyUnary = -- We don't /expect/ the empty message to be compressed, due to the overhead -- mentioned above. However, /if/ it is compressed, perhaps using a custom -- zero-overhead compression algorithm, it's size should be zero. - verifyEnvelope :: InboundEnvelope -> IO () - verifyEnvelope envelope = do - assertEqual "uncompressed size" (inboundUncompressedSize envelope) 0 - case inboundCompressedSize envelope of + verifyMeta :: InboundMeta -> IO () + verifyMeta meta = do + assertEqual "uncompressed size" (inboundUncompressedSize meta) 0 + case inboundCompressedSize meta of Nothing -> return () Just size -> assertEqual "compressed size" size 0 @@ -149,8 +150,8 @@ test_serverCompressedStreaming = & #compressed .~ (defMessage & #value .~ False) & #size .~ 92653 ] - output1 <- Client.recvOutputWithEnvelope call - output2 <- Client.recvOutputWithEnvelope call + output1 <- Client.recvOutputWithMeta call + output2 <- Client.recvOutputWithMeta call verifyOutputs (StreamElem.value output1, StreamElem.value output2) , server = [ Server.someRpcHandler $ @@ -173,8 +174,8 @@ test_serverCompressedStreaming = size :: Int size = fromIntegral $ responseParams ^. #size - envelope :: OutboundEnvelope - envelope = def { outboundEnableCompression = shouldCompress } + meta :: OutboundMeta + meta = def { outboundEnableCompression = shouldCompress } -- Payload matters for the test, because for messages that are too -- small no compression is used even when enabled. @@ -184,22 +185,22 @@ test_serverCompressedStreaming = response :: Proto StreamingOutputCallResponse response = defMessage & #payload .~ payload - Server.sendOutputWithEnvelope call $ StreamElem (envelope, response) + Server.sendOutputWithMeta call $ StreamElem (meta, response) -- No further output Server.sendTrailers call def verifyOutputs :: - ( Maybe (InboundEnvelope, Proto StreamingOutputCallResponse) - , Maybe (InboundEnvelope, Proto StreamingOutputCallResponse) + ( Maybe (InboundMeta, Proto StreamingOutputCallResponse) + , Maybe (InboundMeta, Proto StreamingOutputCallResponse) ) -> IO () verifyOutputs = \case - (Just (envelope1, _), Just (envelope2, _)) -> do - case inboundCompressedSize envelope1 of + (Just (meta1, _), Just (meta2, _)) -> do + case inboundCompressedSize meta1 of Nothing -> assertFailure "First output should be compressed" Just _ -> return () - case inboundCompressedSize envelope2 of + case inboundCompressedSize meta2 of Nothing -> return () Just _ -> assertFailure "First output should not be compressed" _otherwise -> diff --git a/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs b/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs index 89f826f1..049f4aed 100644 --- a/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs +++ b/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs @@ -13,9 +13,9 @@ import Network.GRPC.Client.Binary qualified as Binary import Network.GRPC.Common import Network.GRPC.Common.Binary (RawRpc) import Network.GRPC.Common.Compression qualified as Compr +import Network.GRPC.Server (ContentType(..)) import Network.GRPC.Server.StreamType qualified as Server import Network.GRPC.Server.StreamType.Binary qualified as Binary -import Network.GRPC.Spec (ContentType(ContentTypeOverride)) import Test.Driver.ClientServer From 7c6fb3416b038c62dbe5277037cbc3b0fa3b2713 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Wed, 24 Jul 2024 17:36:37 +0200 Subject: [PATCH 06/10] Remove duplication We were processing the request headers _twice_. --- src/Network/GRPC/Client/Session.hs | 8 +- src/Network/GRPC/Server/Call.hs | 117 +++++++++++++++------- src/Network/GRPC/Server/RequestHandler.hs | 31 +----- 3 files changed, 91 insertions(+), 65 deletions(-) diff --git a/src/Network/GRPC/Client/Session.hs b/src/Network/GRPC/Client/Session.hs index 707fac44..fda1eae8 100644 --- a/src/Network/GRPC/Client/Session.hs +++ b/src/Network/GRPC/Client/Session.hs @@ -122,6 +122,10 @@ instance NoTrailers (ClientSession rpc) where Process response headers -------------------------------------------------------------------------------} +-- | Required response headers +-- +-- If any of these headers are missing or invalid, we throw an exception, +-- independent of 'connVerifyHeaders' data RequiredHeaders = RequiredHeaders { requiredCompression :: Maybe CompressionId } @@ -132,8 +136,8 @@ validateAll = fmap go . HKD.sequence where go :: ResponseHeaders -> RequiredHeaders go responseHeaders = RequiredHeaders { - requiredCompression = responseCompression responseHeaders - } + requiredCompression = responseCompression responseHeaders + } -- | Validate only the required headers validateRequired :: ResponseHeaders' -> Either InvalidHeaders RequiredHeaders diff --git a/src/Network/GRPC/Server/Call.hs b/src/Network/GRPC/Server/Call.hs index 48cbd813..082173d9 100644 --- a/src/Network/GRPC/Server/Call.hs +++ b/src/Network/GRPC/Server/Call.hs @@ -60,6 +60,7 @@ import Network.GRPC.Server.Context import Network.GRPC.Server.Session import Network.GRPC.Spec import Network.GRPC.Spec.Serialization +import Network.GRPC.Util.HKD qualified as HKD import Network.GRPC.Util.HTTP2 (fromHeaderTable) import Network.GRPC.Util.Session qualified as Session import Network.GRPC.Util.Session.Server qualified as Server @@ -134,12 +135,12 @@ setupCall :: forall rpc. SupportsServerRpc rpc => Server.ConnectionToClient -> ServerContext - -> IO (Call rpc) + -> IO (Call rpc, Maybe Timeout) setupCall conn callContext@ServerContext{serverParams} = do callResponseMetadata <- newTVarIO Nothing callResponseKickoff <- newEmptyTMVarIO - inboundHeaders <- determineInbound callSession req + (inboundHeaders, timeout) <- determineInbound callSession req let callRequestHeaders = inbHeaders inboundHeaders -- Technically compression is only relevant in the 'KickoffRegular' case @@ -162,14 +163,17 @@ setupCall conn callContext@ServerContext{serverParams} = do cOut ) - return Call{ - callContext - , callSession - , callRequestHeaders - , callResponseMetadata - , callResponseKickoff - , callChannel - } + return ( + Call{ + callContext + , callSession + , callRequestHeaders + , callResponseMetadata + , callResponseKickoff + , callChannel + } + , timeout + ) where callSession :: ServerSession rpc callSession = ServerSession { @@ -184,13 +188,16 @@ determineInbound :: forall rpc. SupportsServerRpc rpc => ServerSession rpc -> HTTP2.Request - -> IO (Headers (ServerInbound rpc)) + -> IO (Headers (ServerInbound rpc), Maybe Timeout) determineInbound session req = do - cIn <- getInboundCompression session $ requestCompression requestHeaders' - return InboundHeaders { - inbHeaders = requestHeaders' - , inbCompression = cIn - } + (cIn, timeout) <- processRequestHeaders session requestHeaders' + return ( + InboundHeaders { + inbHeaders = requestHeaders' + , inbCompression = cIn + } + , timeout + ) where requestHeaders' :: RequestHeaders' requestHeaders' = parseRequestHeaders' (Proxy @rpc) $ @@ -260,24 +267,6 @@ startOutbound serverParams metadataVar kickoffVar cOut = do , responseBody = Nothing } --- | Determine compression used for messages from the peer -getInboundCompression :: - ServerSession rpc - -> Either InvalidRequestHeaders (Maybe CompressionId) - -> IO Compression -getInboundCompression session = \case - Left err -> throwIO $ CallSetupInvalidRequestHeaders err - Right Nothing -> return noCompression - Right (Just cid) -> - case Compr.getSupported serverCompression cid of - Just compr -> return compr - Nothing -> throwIO $ CallSetupUnsupportedCompression cid - where - ServerSession{serverSessionContext} = session - ServerContext{serverParams} = serverSessionContext - ServerParams{serverCompression} = serverParams - - -- | Determine compression to be used for messages to the peer -- -- In the case that we fail to parse the @grpc-accept-encoding@ header, we @@ -304,6 +293,66 @@ serverExceptionToClientError params err mMsg <- serverExceptionToClient params err return $ simpleProperTrailers (GrpcError GrpcUnknown) mMsg mempty +{------------------------------------------------------------------------------- + Process request headers +-------------------------------------------------------------------------------} + +-- | Required request headers +-- +-- If any of these headers are missing or invalid, we throw an exception, +-- independent of 'serverVerifyHeaders' +data RequiredHeaders = RequiredHeaders { + requiredCompression :: Maybe CompressionId + , requiredTimeout :: Maybe Timeout + } + +-- | Validate /all/ headers, and then extract the required +validateAll :: + RequestHeaders' + -> Either InvalidRequestHeaders RequiredHeaders +validateAll = fmap go . HKD.sequence + where + go :: RequestHeaders -> RequiredHeaders + go requestHeaders = RequiredHeaders { + requiredCompression = requestCompression requestHeaders + , requiredTimeout = requestTimeout requestHeaders + } + +-- | Validate only the required headers +validateRequired :: + RequestHeaders' + -> Either InvalidRequestHeaders RequiredHeaders +validateRequired requestHeaders' = + RequiredHeaders + <$> requestCompression requestHeaders' + <*> requestTimeout requestHeaders' + +processRequestHeaders :: + ServerSession rpc + -> RequestHeaders' + -> IO (Compression, Maybe Timeout) +processRequestHeaders session requestHeaders' = do + required <- either invalid return $ + if serverVerifyHeaders serverParams + then validateAll requestHeaders' + else validateRequired requestHeaders' + cIn <- getCompression (requiredCompression required) + return (cIn, requiredTimeout required) + where + ServerSession{serverSessionContext} = session + ServerContext{serverParams} = serverSessionContext + + -- this replaces getInboundCompression + getCompression :: Maybe CompressionId -> IO Compression + getCompression Nothing = return noCompression + getCompression (Just cid) = + case Compr.getSupported (serverCompression serverParams) cid of + Just compr -> return compr + Nothing -> throwIO $ CallSetupUnsupportedCompression cid + + invalid :: forall x. InvalidRequestHeaders -> IO x + invalid = throwIO . CallSetupInvalidRequestHeaders + {------------------------------------------------------------------------------- Open (ongoing) call -------------------------------------------------------------------------------} diff --git a/src/Network/GRPC/Server/RequestHandler.hs b/src/Network/GRPC/Server/RequestHandler.hs index 17131c0e..3dc007c0 100644 --- a/src/Network/GRPC/Server/RequestHandler.hs +++ b/src/Network/GRPC/Server/RequestHandler.hs @@ -27,7 +27,7 @@ import Network.HTTP.Types qualified as HTTP import Network.HTTP2.Server qualified as HTTP2 import Network.GRPC.Server.Call -import Network.GRPC.Server.Context (ServerContext (..), ServerParams (..)) +import Network.GRPC.Server.Context (ServerContext (..)) import Network.GRPC.Server.Handler import Network.GRPC.Server.HandlerMap (HandlerMap) import Network.GRPC.Server.HandlerMap qualified as HandlerMap @@ -36,7 +36,6 @@ import Network.GRPC.Server.Session (CallSetupFailure(..)) import Network.GRPC.Spec import Network.GRPC.Spec.Serialization import Network.GRPC.Util.GHC -import Network.GRPC.Util.HKD qualified as HKD import Network.GRPC.Util.Session.Server {------------------------------------------------------------------------------- @@ -50,10 +49,8 @@ requestHandler handlers ctxt unmask request respond = do SomeRpcHandler (_ :: Proxy rpc) handler <- findHandler handlers request `catch` setupFailure respond - call :: Call rpc <- + (call :: Call rpc, mTimeout :: Maybe Timeout) <- setupCall connectionToClient ctxt `catch` setupFailure respond - mTimeout :: Maybe Timeout <- - processRequestHeaders ctxt call `catch` setupFailure respond imposeTimeout mTimeout $ runHandler unmask call handler @@ -103,30 +100,6 @@ findHandler handlers req = do , rawMethod = fromMaybe "" $ HTTP2.requestMethod req } --- | Process request headers --- --- In strict mode we verify /all/ headers; otherwise, we only verify those --- headers we need to setup the call. --- --- Throws 'CallSetupFailure' if any (validated) headers were invalid. -processRequestHeaders :: - ServerContext - -> Call rpc - -> IO (Maybe Timeout) -processRequestHeaders ctxt call = do - requestHeaders' <- getRequestHeaders call - if serverVerifyHeaders then - case HKD.sequence requestHeaders' of - Left err -> throwM $ CallSetupInvalidRequestHeaders err - Right requestHeaders -> return $ requestTimeout requestHeaders - else - case requestTimeout requestHeaders' of - Left err -> throwM $ CallSetupInvalidRequestHeaders err - Right mTimeout -> return mTimeout - where - ServerContext{serverParams} = ctxt - ServerParams{serverVerifyHeaders} = serverParams - -- | Call setup failure -- -- Something went wrong during call setup. No response has been sent to the From fb05518fca1b4166cd3267b38ed0c9e02cbcf221 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 25 Jul 2024 12:17:10 +0200 Subject: [PATCH 07/10] Make request headers more uniform with the rest This paves the way for a next refactoring, where we provide some general infra for dealing with errors in headers. --- src/Network/GRPC/Server/Call.hs | 8 ++-- src/Network/GRPC/Server/RequestHandler.hs | 5 ++- src/Network/GRPC/Server/Session.hs | 2 +- src/Network/GRPC/Spec.hs | 7 +++- src/Network/GRPC/Spec/Headers/Invalid.hs | 41 ++++++++++++++++--- src/Network/GRPC/Spec/Headers/Request.hs | 9 +--- .../GRPC/Spec/Serialization/Headers/Common.hs | 26 ++++++------ .../Spec/Serialization/Headers/Request.hs | 32 +++++++-------- 8 files changed, 78 insertions(+), 52 deletions(-) diff --git a/src/Network/GRPC/Server/Call.hs b/src/Network/GRPC/Server/Call.hs index 082173d9..a99975c0 100644 --- a/src/Network/GRPC/Server/Call.hs +++ b/src/Network/GRPC/Server/Call.hs @@ -273,7 +273,7 @@ startOutbound serverParams metadataVar kickoffVar cOut = do -- simply use no compression. getOutboundCompression :: ServerSession rpc - -> Either InvalidRequestHeaders (Maybe (NonEmpty CompressionId)) + -> Either InvalidHeaders (Maybe (NonEmpty CompressionId)) -> Compression getOutboundCompression session = \case Left _invalidHeader -> noCompression @@ -309,7 +309,7 @@ data RequiredHeaders = RequiredHeaders { -- | Validate /all/ headers, and then extract the required validateAll :: RequestHeaders' - -> Either InvalidRequestHeaders RequiredHeaders + -> Either InvalidHeaders RequiredHeaders validateAll = fmap go . HKD.sequence where go :: RequestHeaders -> RequiredHeaders @@ -321,7 +321,7 @@ validateAll = fmap go . HKD.sequence -- | Validate only the required headers validateRequired :: RequestHeaders' - -> Either InvalidRequestHeaders RequiredHeaders + -> Either InvalidHeaders RequiredHeaders validateRequired requestHeaders' = RequiredHeaders <$> requestCompression requestHeaders' @@ -350,7 +350,7 @@ processRequestHeaders session requestHeaders' = do Just compr -> return compr Nothing -> throwIO $ CallSetupUnsupportedCompression cid - invalid :: forall x. InvalidRequestHeaders -> IO x + invalid :: forall x. InvalidHeaders -> IO x invalid = throwIO . CallSetupInvalidRequestHeaders {------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/Server/RequestHandler.hs b/src/Network/GRPC/Server/RequestHandler.hs index 3dc007c0..01355666 100644 --- a/src/Network/GRPC/Server/RequestHandler.hs +++ b/src/Network/GRPC/Server/RequestHandler.hs @@ -147,8 +147,9 @@ failureResponse (CallSetupInvalidResourceHeaders (InvalidMethod method)) = failureResponse (CallSetupInvalidResourceHeaders (InvalidPath path)) = HTTP2.responseBuilder HTTP.badRequest400 [] . Builder.byteString $ "Invalid path " <> path -failureResponse (CallSetupInvalidRequestHeaders (status, invalid)) = - HTTP2.responseBuilder status [] $ prettyInvalidHeaders invalid +failureResponse (CallSetupInvalidRequestHeaders invalid) = + HTTP2.responseBuilder (statusInvalidHeaders invalid) [] $ + prettyInvalidHeaders invalid failureResponse (CallSetupUnsupportedCompression cid) = HTTP2.responseBuilder HTTP.badRequest400 [] . Builder.byteString $ "Unsupported compression: " <> BS.UTF8.fromString (show cid) diff --git a/src/Network/GRPC/Server/Session.hs b/src/Network/GRPC/Server/Session.hs index b1d386d5..a14be744 100644 --- a/src/Network/GRPC/Server/Session.hs +++ b/src/Network/GRPC/Server/Session.hs @@ -79,7 +79,7 @@ data CallSetupFailure = -- 'CallSetupInvalidResourceHeaders' refers to an invalid method (anything -- other than POST) or an invalid path; 'CallSetupInvalidRequestHeaders' -- means we could not parse the HTTP headers according to the gRPC spec. - | CallSetupInvalidRequestHeaders InvalidRequestHeaders + | CallSetupInvalidRequestHeaders InvalidHeaders -- | Client chose unsupported compression algorithm -- diff --git a/src/Network/GRPC/Spec.hs b/src/Network/GRPC/Spec.hs index 0975c0b2..26f34678 100644 --- a/src/Network/GRPC/Spec.hs +++ b/src/Network/GRPC/Spec.hs @@ -68,7 +68,6 @@ module Network.GRPC.Spec ( , RequestHeaders_(..) , RequestHeaders , RequestHeaders' - , InvalidRequestHeaders -- ** Parameters , CallParams(..) -- ** Pseudo-headers @@ -141,11 +140,15 @@ module Network.GRPC.Spec ( -- * Invalid headers , InvalidHeaders(..) , InvalidHeader(..) - , prettyInvalidHeaders + -- ** Construction , invalidHeader + , invalidHeaderWith , missingHeader , unexpectedHeader , throwInvalidHeader + -- ** Use + , prettyInvalidHeaders + , statusInvalidHeaders -- * Common infrastructure to all headers , ContentType(..) , MessageType(..) diff --git a/src/Network/GRPC/Spec/Headers/Invalid.hs b/src/Network/GRPC/Spec/Headers/Invalid.hs index 61f44208..d7760212 100644 --- a/src/Network/GRPC/Spec/Headers/Invalid.hs +++ b/src/Network/GRPC/Spec/Headers/Invalid.hs @@ -8,11 +8,13 @@ module Network.GRPC.Spec.Headers.Invalid ( , InvalidHeader(..) -- * Construction , invalidHeader + , invalidHeaderWith , missingHeader , unexpectedHeader , throwInvalidHeader -- * Utility , prettyInvalidHeaders + , statusInvalidHeaders ) where import Control.Monad.Except @@ -20,7 +22,9 @@ import Data.ByteString.Builder qualified as Builder import Data.ByteString.Builder qualified as ByteString (Builder) import Data.ByteString.UTF8 qualified as BS.UTF8 import Data.CaseInsensitive qualified as CI +import Data.Maybe (fromMaybe) import Network.HTTP.Types qualified as HTTP +import Control.Applicative {------------------------------------------------------------------------------- Definition @@ -50,7 +54,10 @@ data InvalidHeader = -- | We failed to parse this header -- -- We record the original header and the reason parsing failed. - InvalidHeader HTTP.Header String + -- + -- For some invalid headers the gRPC spec mandates a specific HTTP status; + -- if this status is not specified, then we use 400 Bad Request. + InvalidHeader (Maybe HTTP.Status) HTTP.Header String -- | Missing header (header that should have been present but was not) | MissingHeader HTTP.HeaderName @@ -64,13 +71,16 @@ data InvalidHeader = -------------------------------------------------------------------------------} invalidHeader :: HTTP.Header -> String -> InvalidHeaders -invalidHeader hdr err = InvalidHeaders [InvalidHeader hdr err] +invalidHeader hdr err = wrapOne $ InvalidHeader Nothing hdr err + +invalidHeaderWith :: HTTP.Status -> HTTP.Header -> String -> InvalidHeaders +invalidHeaderWith status hdr err = wrapOne $ InvalidHeader (Just status) hdr err missingHeader :: HTTP.HeaderName -> InvalidHeaders -missingHeader name = InvalidHeaders [MissingHeader name] +missingHeader name = wrapOne $ MissingHeader name unexpectedHeader :: HTTP.HeaderName -> InvalidHeaders -unexpectedHeader name = InvalidHeaders [UnexpectedHeader name] +unexpectedHeader name = wrapOne $ UnexpectedHeader name throwInvalidHeader :: MonadError InvalidHeaders m @@ -88,7 +98,7 @@ prettyInvalidHeaders :: InvalidHeaders -> ByteString.Builder prettyInvalidHeaders = mconcat . map go . getInvalidHeaders where go :: InvalidHeader -> ByteString.Builder - go (InvalidHeader (name, value) err) = mconcat [ + go (InvalidHeader _status (name, value) err) = mconcat [ "Invalid header '" , Builder.byteString (CI.original name) , "' with value '" @@ -107,3 +117,24 @@ prettyInvalidHeaders = mconcat . map go . getInvalidHeaders , Builder.byteString (CI.original name) , "'\n" ] + +-- | HTTP status to report +-- +-- If there are multiple headers, each of which with a mandated status, we +-- just use the first; the spec is essentially ambiguous in this case. +statusInvalidHeaders :: InvalidHeaders -> HTTP.Status +statusInvalidHeaders (InvalidHeaders hs) = + fromMaybe HTTP.badRequest400 $ asum $ map getStatus hs + where + getStatus :: InvalidHeader -> Maybe HTTP.Status + getStatus (InvalidHeader status _ _) = status + getStatus MissingHeader{} = Nothing + getStatus UnexpectedHeader{} = Nothing + +{------------------------------------------------------------------------------- + Internal auxiliary +-------------------------------------------------------------------------------} + +wrapOne :: InvalidHeader -> InvalidHeaders +wrapOne = InvalidHeaders . (:[]) + diff --git a/src/Network/GRPC/Spec/Headers/Request.hs b/src/Network/GRPC/Spec/Headers/Request.hs index e68c147b..12ea6391 100644 --- a/src/Network/GRPC/Spec/Headers/Request.hs +++ b/src/Network/GRPC/Spec/Headers/Request.hs @@ -8,13 +8,11 @@ module Network.GRPC.Spec.Headers.Request ( RequestHeaders_(..) , RequestHeaders , RequestHeaders' - , InvalidRequestHeaders ) where import Data.ByteString qualified as Strict (ByteString) import Data.List.NonEmpty (NonEmpty) import GHC.Generics (Generic) -import Network.HTTP.Types qualified as HTTP import Network.GRPC.Spec.Compression (CompressionId) import Network.GRPC.Spec.CustomMetadata.Map @@ -124,12 +122,7 @@ type RequestHeaders = RequestHeaders_ Undecorated -- -- (i.e., either valid or invalid). type RequestHeaders' = - RequestHeaders_ (DecoratedWith (Either InvalidRequestHeaders)) - --- | Invalid request headers --- --- For certain types of failures the gRPC spec mandates a specific HTTP status. -type InvalidRequestHeaders = (HTTP.Status, InvalidHeaders) + RequestHeaders_ (DecoratedWith (Either InvalidHeaders)) deriving stock instance Show RequestHeaders deriving stock instance Eq RequestHeaders diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Common.hs b/src/Network/GRPC/Spec/Serialization/Headers/Common.hs index 8ee96c67..2a9b2a46 100644 --- a/src/Network/GRPC/Spec/Serialization/Headers/Common.hs +++ b/src/Network/GRPC/Spec/Serialization/Headers/Common.hs @@ -91,18 +91,20 @@ parseContentType proxy hdr@(_name, value) = do err "Invalid subtype." where err :: String -> m a - err reason = throwError $ invalidHeader hdr $ concat [ - reason - , " Expected \"" - , BS.Strict.C8.unpack $ - rpcContentType (Proxy @(UnknownRpc Nothing Nothing)) - , "\" or \"" - , BS.Strict.C8.unpack $ - rpcContentType proxy - , "\", with \"" - , "application/grpc+{other_format}" - , "\" also accepted." - ] + err reason = + throwError . invalidHeaderWith HTTP.unsupportedMediaType415 hdr $ + concat [ + reason + , " Expected \"" + , BS.Strict.C8.unpack $ + rpcContentType (Proxy @(UnknownRpc Nothing Nothing)) + , "\" or \"" + , BS.Strict.C8.unpack $ + rpcContentType proxy + , "\", with \"" + , "application/grpc+{other_format}" + , "\" also accepted." + ] {------------------------------------------------------------------------------- > Message-Type → "grpc-message-type" {type name for message schema} diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Request.hs b/src/Network/GRPC/Spec/Serialization/Headers/Request.hs index f1a537f6..8e9b59fb 100644 --- a/src/Network/GRPC/Spec/Serialization/Headers/Request.hs +++ b/src/Network/GRPC/Spec/Serialization/Headers/Request.hs @@ -132,7 +132,7 @@ callDefinition proxy = \hdrs -> catMaybes [ -------------------------------------------------------------------------------} parseRequestHeaders :: forall rpc m. - (IsRPC rpc, MonadError InvalidRequestHeaders m) + (IsRPC rpc, MonadError InvalidHeaders m) => Proxy rpc -> [HTTP.Header] -> m RequestHeaders parseRequestHeaders proxy = HKD.sequenceThrow . parseRequestHeaders' proxy @@ -159,36 +159,33 @@ parseRequestHeaders' proxy = | name == "grpc-timeout" = modify $ \x -> x { requestTimeout = fmap Just $ - httpError hdr HTTP.badRequest400 $ + httpError hdr $ parseTimeout value } | name == "grpc-encoding" = modify $ \x -> x { requestCompression = fmap Just $ - first (HTTP.badRequest400,) $ - parseMessageEncoding hdr + parseMessageEncoding hdr } | name == "grpc-accept-encoding" = modify $ \x -> x { requestAcceptCompression = fmap Just $ - first (HTTP.badRequest400,) $ - parseMessageAcceptEncoding hdr + parseMessageAcceptEncoding hdr } | name == "grpc-trace-bin" = modify $ \x -> x { requestTraceContext = fmap Just $ - httpError hdr HTTP.badRequest400 $ + httpError hdr $ parseBinaryValue value >>= parseTraceContext } | name == "content-type" = modify $ \x -> x { requestContentType = fmap Just $ - first (HTTP.unsupportedMediaType415,) $ - parseContentType proxy hdr + parseContentType proxy hdr } | name == "grpc-message-type" @@ -200,15 +197,14 @@ parseRequestHeaders' proxy = | name == "te" = modify $ \x -> x { requestIncludeTE = do - first (HTTP.badRequest400,) $ - expectHeaderValue hdr ["trailers"] + expectHeaderValue hdr ["trailers"] return True } | name == "grpc-previous-rpc-attempts" = modify $ \x -> x { requestPreviousRpcAttempts = do - httpError hdr HTTP.badRequest400 $ + httpError hdr $ maybe (Left $ "grpc-previous-rpc-attempts: invalid " ++ show value) (Right . Just) @@ -221,8 +217,8 @@ parseRequestHeaders' proxy = Left invalid -> x { requestUnrecognized = Left $ case requestUnrecognized x of - Left (status, invalid') -> (status, invalid <> invalid') - Right () -> (HTTP.badRequest400, invalid) + Left invalid' -> invalid <> invalid' + Right () -> invalid } Right md -> x { requestMetadata = customMetadataMapInsert md $ requestMetadata x @@ -250,10 +246,10 @@ parseRequestHeaders' proxy = } httpError :: - MonadError InvalidRequestHeaders m' - => HTTP.Header -> HTTP.Status -> Either String a -> m' a - httpError _ _ (Right a) = return a - httpError hdr status (Left err) = throwError (status, invalidHeader hdr err) + MonadError InvalidHeaders m' + => HTTP.Header -> Either String a -> m' a + httpError _ (Right a) = return a + httpError hdr (Left err) = throwError $ invalidHeader hdr err {------------------------------------------------------------------------------- Internal auxiliary From ff5002b8f44b741315d22e42d52f4007a83d8713 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 25 Jul 2024 12:29:30 +0200 Subject: [PATCH 08/10] Minor cleanup of `HKD` --- src/Network/GRPC/Spec/Headers/Request.hs | 9 ++++---- src/Network/GRPC/Spec/Headers/Response.hs | 10 ++++----- .../Spec/Serialization/Headers/Request.hs | 2 +- .../Spec/Serialization/Headers/Response.hs | 6 ++--- util/Network/GRPC/Util/HKD.hs | 22 +++++++++++++------ 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/Network/GRPC/Spec/Headers/Request.hs b/src/Network/GRPC/Spec/Headers/Request.hs index 12ea6391..f823de5c 100644 --- a/src/Network/GRPC/Spec/Headers/Request.hs +++ b/src/Network/GRPC/Spec/Headers/Request.hs @@ -20,7 +20,7 @@ import Network.GRPC.Spec.Headers.Common import Network.GRPC.Spec.Headers.Invalid import Network.GRPC.Spec.Timeout import Network.GRPC.Spec.TraceContext -import Network.GRPC.Util.HKD (HKD, Undecorated, DecoratedWith) +import Network.GRPC.Util.HKD (HKD, Undecorated, Checked) import Network.GRPC.Util.HKD qualified as HKD {------------------------------------------------------------------------------- @@ -114,15 +114,14 @@ type RequestHeaders = RequestHeaders_ Undecorated -- -- NOTE: The HKD type -- --- > RequestHeaders_ (DecoratedWith (Either InvalidRequestHeaders)) +-- > RequestHeaders_ (Checked InvalidHeaders) -- -- means that each field of type @HKD f a@ is of type -- --- > Either InvalidRequestHeaders a +-- > Either InvalidHeaders a -- -- (i.e., either valid or invalid). -type RequestHeaders' = - RequestHeaders_ (DecoratedWith (Either InvalidHeaders)) +type RequestHeaders' = RequestHeaders_ (Checked InvalidHeaders) deriving stock instance Show RequestHeaders deriving stock instance Eq RequestHeaders diff --git a/src/Network/GRPC/Spec/Headers/Response.hs b/src/Network/GRPC/Spec/Headers/Response.hs index bc0020b1..0c147d97 100644 --- a/src/Network/GRPC/Spec/Headers/Response.hs +++ b/src/Network/GRPC/Spec/Headers/Response.hs @@ -39,7 +39,7 @@ import Network.GRPC.Spec.Headers.Common import Network.GRPC.Spec.Headers.Invalid import Network.GRPC.Spec.OrcaLoadReport import Network.GRPC.Spec.Status -import Network.GRPC.Util.HKD (HKD, Undecorated, DecoratedWith) +import Network.GRPC.Util.HKD (HKD, Undecorated, Checked) import Network.GRPC.Util.HKD qualified as HKD {------------------------------------------------------------------------------- @@ -77,8 +77,8 @@ type ResponseHeaders = ResponseHeaders_ Undecorated -- | Response headers allowing for invalid headers -- --- See 'RequestHeaders'' for an explanation of @DecoratedWith@. -type ResponseHeaders' = ResponseHeaders_ (DecoratedWith (Either InvalidHeaders)) +-- See 'RequestHeaders'' for an explanation of @Checked@. +type ResponseHeaders' = ResponseHeaders_ (Checked InvalidHeaders) deriving stock instance Show ResponseHeaders deriving stock instance Eq ResponseHeaders @@ -153,7 +153,7 @@ simpleProperTrailers status msg metadata = ProperTrailers { type ProperTrailers = ProperTrailers_ Undecorated -- | Trailers sent after the response, allowing for invalid trailers -type ProperTrailers' = ProperTrailers_ (DecoratedWith (Either InvalidHeaders)) +type ProperTrailers' = ProperTrailers_ (Checked InvalidHeaders) deriving stock instance Show ProperTrailers deriving stock instance Eq ProperTrailers @@ -190,7 +190,7 @@ data TrailersOnly_ f = TrailersOnly { type TrailersOnly = TrailersOnly_ Undecorated -- | Trailers for the Trailers-Only case, allowing for invalid headers -type TrailersOnly' = TrailersOnly_ (DecoratedWith (Either InvalidHeaders)) +type TrailersOnly' = TrailersOnly_ (Checked InvalidHeaders) deriving stock instance Show TrailersOnly deriving stock instance Eq TrailersOnly diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Request.hs b/src/Network/GRPC/Spec/Serialization/Headers/Request.hs index 8e9b59fb..06e0c990 100644 --- a/src/Network/GRPC/Spec/Serialization/Headers/Request.hs +++ b/src/Network/GRPC/Spec/Serialization/Headers/Request.hs @@ -135,7 +135,7 @@ parseRequestHeaders :: forall rpc m. (IsRPC rpc, MonadError InvalidHeaders m) => Proxy rpc -> [HTTP.Header] -> m RequestHeaders -parseRequestHeaders proxy = HKD.sequenceThrow . parseRequestHeaders' proxy +parseRequestHeaders proxy = HKD.sequenceChecked . parseRequestHeaders' proxy -- | Parse request headers -- diff --git a/src/Network/GRPC/Spec/Serialization/Headers/Response.hs b/src/Network/GRPC/Spec/Serialization/Headers/Response.hs index a220bca9..960b7f5a 100644 --- a/src/Network/GRPC/Spec/Serialization/Headers/Response.hs +++ b/src/Network/GRPC/Spec/Serialization/Headers/Response.hs @@ -212,7 +212,7 @@ buildResponseHeaders proxy parseResponseHeaders :: forall rpc m. (IsRPC rpc, MonadError InvalidHeaders m) => Proxy rpc -> [HTTP.Header] -> m ResponseHeaders -parseResponseHeaders proxy = HKD.sequenceThrow . parseResponseHeaders' proxy +parseResponseHeaders proxy = HKD.sequenceChecked . parseResponseHeaders' proxy parseResponseHeaders' :: forall rpc. IsRPC rpc @@ -361,7 +361,7 @@ buildTrailersOnly proxy TrailersOnly{ parseProperTrailers :: forall rpc m. (IsRPC rpc, MonadError InvalidHeaders m) => Proxy rpc -> [HTTP.Header] -> m ProperTrailers -parseProperTrailers proxy = HKD.sequenceThrow . parseProperTrailers' proxy +parseProperTrailers proxy = HKD.sequenceChecked . parseProperTrailers' proxy parseProperTrailers' :: forall rpc. IsRPC rpc @@ -393,7 +393,7 @@ parseProperTrailers' proxy hdrs = parseTrailersOnly :: forall m rpc. (IsRPC rpc, MonadError InvalidHeaders m) => Proxy rpc -> [HTTP.Header] -> m TrailersOnly -parseTrailersOnly proxy = HKD.sequenceThrow . parseTrailersOnly' proxy +parseTrailersOnly proxy = HKD.sequenceChecked . parseTrailersOnly' proxy parseTrailersOnly' :: forall rpc. IsRPC rpc diff --git a/util/Network/GRPC/Util/HKD.hs b/util/Network/GRPC/Util/HKD.hs index 3ea903a1..bbfadfe9 100644 --- a/util/Network/GRPC/Util/HKD.hs +++ b/util/Network/GRPC/Util/HKD.hs @@ -13,10 +13,12 @@ module Network.GRPC.Util.HKD ( , Coerce(..) , Traversable(..) , sequence - , sequenceThrow -- * Dealing with HKD fields , ValidDecoration , pure + -- * Error decorations + , Checked + , sequenceChecked ) where import Prelude hiding (Traversable(..), pure) @@ -78,12 +80,6 @@ sequence :: => t (DecoratedWith m) -> m (t Undecorated) sequence = fmap undecorate . traverse (fmap Identity) -sequenceThrow :: - (MonadError e m, Traversable t) - => t (DecoratedWith (Either e)) - -> m (t Undecorated) -sequenceThrow = either throwError return . sequence - {------------------------------------------------------------------------------- Dealing with HKD fields -------------------------------------------------------------------------------} @@ -119,3 +115,15 @@ pure _ = case validDecoration :: IsValidDecoration Applicative f of ValidDecoratedWith -> Prelude.pure ValidUndecorated -> id + +{------------------------------------------------------------------------------- + Error decorations +-------------------------------------------------------------------------------} + +type Checked e = DecoratedWith (Either e) + +sequenceChecked :: + (MonadError e m, Traversable t) + => t (Checked e) -> m (t Undecorated) +sequenceChecked = either throwError return . sequence + From 8b393fc6c5ec03bd78dc6b2646b3257a139077e2 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 25 Jul 2024 13:08:26 +0200 Subject: [PATCH 09/10] Uniform treatment of header validation --- grapesy.cabal | 3 +- src/Network/GRPC/Client/Session.hs | 91 ++++++++----------------- src/Network/GRPC/Common/Headers.hs | 104 +++++++++++++++++++++++++++++ src/Network/GRPC/Server/Call.hs | 101 +++++++++------------------- util/Network/GRPC/Util/HKD.hs | 16 ++++- 5 files changed, 179 insertions(+), 136 deletions(-) create mode 100644 src/Network/GRPC/Common/Headers.hs diff --git a/grapesy.cabal b/grapesy.cabal index 2ba2b4bf..cc5c873b 100644 --- a/grapesy.cabal +++ b/grapesy.cabal @@ -106,12 +106,13 @@ library Network.GRPC.Common Network.GRPC.Common.Binary Network.GRPC.Common.Compression + Network.GRPC.Common.Headers + Network.GRPC.Common.HTTP2Settings Network.GRPC.Common.JSON Network.GRPC.Common.NextElem Network.GRPC.Common.Protobuf Network.GRPC.Common.StreamElem Network.GRPC.Common.StreamType - Network.GRPC.Common.HTTP2Settings Network.GRPC.Server Network.GRPC.Server.Binary Network.GRPC.Server.Protobuf diff --git a/src/Network/GRPC/Client/Session.hs b/src/Network/GRPC/Client/Session.hs index fda1eae8..cacebb2d 100644 --- a/src/Network/GRPC/Client/Session.hs +++ b/src/Network/GRPC/Client/Session.hs @@ -19,10 +19,10 @@ import Network.GRPC.Client.Connection (Connection, ConnParams(..)) import Network.GRPC.Client.Connection qualified as Connection import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr +import Network.GRPC.Common.Headers import Network.GRPC.Spec import Network.GRPC.Spec.Serialization import Network.GRPC.Util.Session -import Network.GRPC.Util.HKD qualified as HKD {------------------------------------------------------------------------------- Definition @@ -85,20 +85,37 @@ instance SupportsClientRpc rpc => IsSession (ClientSession rpc) where buildMsg _ = buildInput (Proxy @rpc) . outCompression instance SupportsClientRpc rpc => InitiateSession (ClientSession rpc) where - parseResponse session (ResponseInfo status headers body) = + parseResponse (ClientSession conn) (ResponseInfo status headers body) = case classifyServerResponse (Proxy @rpc) status headers body of Left trailersOnly -> -- We classify the response as Trailers-Only if the grpc-status header -- is present, or when the HTTP status is anything other than 200 OK -- (which we treat, as per the spec, as an implicit grpc-status). -- The 'CallClosedWithoutTrailers' case is therefore not relevant. - return $ FlowStartNoMessages trailersOnly - Right responseHeaders' -> do - cIn <- processResponseHeaders session responseHeaders' - return $ FlowStartRegular $ InboundHeaders { - inbHeaders = responseHeaders' - , inbCompression = cIn - } + case verifyAllIf connVerifyHeaders trailersOnly of + Left err -> throwIO $ CallSetupInvalidResponseHeaders err + Right _hdrs -> return $ FlowStartNoMessages trailersOnly + Right responseHeaders -> do + case verifyAllIf connVerifyHeaders responseHeaders of + Left err -> throwIO $ CallSetupInvalidResponseHeaders err + Right hdrs -> do + cIn <- getCompression $ requiredResponseCompression hdrs + return $ FlowStartRegular $ InboundHeaders { + inbHeaders = responseHeaders + , inbCompression = cIn + } + where + ConnParams{ + connCompression + , connVerifyHeaders + } = Connection.connParams conn + + getCompression :: Maybe CompressionId -> IO Compression + getCompression Nothing = return noCompression + getCompression (Just cid) = + case Compr.getSupported connCompression cid of + Just compr -> return compr + Nothing -> throwIO $ CallSetupUnsupportedCompression cid buildRequestInfo _ start = RequestInfo { requestMethod = rawMethod resourceHeaders @@ -118,62 +135,6 @@ instance SupportsClientRpc rpc => InitiateSession (ClientSession rpc) where instance NoTrailers (ClientSession rpc) where noTrailers _ = NoMetadata -{------------------------------------------------------------------------------- - Process response headers --------------------------------------------------------------------------------} - --- | Required response headers --- --- If any of these headers are missing or invalid, we throw an exception, --- independent of 'connVerifyHeaders' -data RequiredHeaders = RequiredHeaders { - requiredCompression :: Maybe CompressionId - } - --- | Validate /all/ headers, and then extract the required -validateAll :: ResponseHeaders' -> Either InvalidHeaders RequiredHeaders -validateAll = fmap go . HKD.sequence - where - go :: ResponseHeaders -> RequiredHeaders - go responseHeaders = RequiredHeaders { - requiredCompression = responseCompression responseHeaders - } - --- | Validate only the required headers -validateRequired :: ResponseHeaders' -> Either InvalidHeaders RequiredHeaders -validateRequired responseHeaders' = - RequiredHeaders - <$> responseCompression responseHeaders' - --- | Process response headers --- --- This is the client equivalent of --- 'Network.GRPC.Server.RequestHandler.processRequestHeaders'. -processResponseHeaders :: - ClientSession rpc - -> ResponseHeaders' - -> IO Compression -processResponseHeaders (ClientSession conn) responseHeaders' = do - Connection.updateConnectionMeta conn responseHeaders' - required <- either invalid return $ - if connVerifyHeaders connParams - then validateAll responseHeaders' - else validateRequired responseHeaders' - getCompression (requiredCompression required) - where - connParams :: ConnParams - connParams = Connection.connParams conn - - getCompression :: Maybe CompressionId -> IO Compression - getCompression Nothing = return noCompression - getCompression (Just cid) = - case Compr.getSupported (connCompression connParams) cid of - Just compr -> return compr - Nothing -> throwIO $ CallSetupUnsupportedCompression cid - - invalid :: forall x. InvalidHeaders -> IO x - invalid = throwIO . CallSetupInvalidResponseHeaders - {------------------------------------------------------------------------------- Exceptions -------------------------------------------------------------------------------} diff --git a/src/Network/GRPC/Common/Headers.hs b/src/Network/GRPC/Common/Headers.hs new file mode 100644 index 00000000..31b11f12 --- /dev/null +++ b/src/Network/GRPC/Common/Headers.hs @@ -0,0 +1,104 @@ +-- | Utilities for working with headers +module Network.GRPC.Common.Headers ( + HasRequiredHeaders(..) + , RequiredHeaders(..) + , verifyRequired + , verifyAll + , verifyAllIf + ) where + +import Data.Functor.Identity +import Data.Kind +import Data.Void + +import Network.GRPC.Spec +import Network.GRPC.Util.HKD (Undecorated, Checked) +import Network.GRPC.Util.HKD qualified as HKD + +{------------------------------------------------------------------------------- + Definition +-------------------------------------------------------------------------------} + +-- | Required headers +-- +-- Required headers are headers that @grapesy@ needs to know in order to +-- function. For example, we /need/ to know which compression algorithm the peer +-- is using for their messages to us. +class HKD.Traversable h => HasRequiredHeaders h where + data RequiredHeaders h :: Type + requiredHeaders :: h (Checked e) -> Either e (RequiredHeaders h) + +-- | Like 'requiredHeaders', but for already verified headers +requiredHeadersVerified :: + HasRequiredHeaders h + => h Undecorated -> RequiredHeaders h +requiredHeadersVerified = + either absurd id . requiredHeaders . HKD.map noError . HKD.decorate + where + noError :: Identity a -> Either Void a + noError = Right . runIdentity + +-- | Validate only the required headers +-- +-- By default, we only check those headers @grapesy@ needs to function. +verifyRequired :: + HasRequiredHeaders h + => h (Checked e) -> Either e (RequiredHeaders h) +verifyRequired = requiredHeaders + +-- | Validate /all/ headers +-- +-- Validate all headers; we do this only if +-- 'Network.GRPC.Client.connVerifyHeaders' (on the client) or +-- 'Network.GRPC.Server.serverVerifyHeaders' (on the server) is enabled. +verifyAll :: forall h e. + HasRequiredHeaders h + => h (Checked e) -> Either e (h Undecorated, RequiredHeaders h) +verifyAll = fmap aux . HKD.sequence + where + aux :: h Undecorated -> (h Undecorated, RequiredHeaders h) + aux verifyd = (verifyd, requiredHeadersVerified verifyd) + +-- | Convenience wrapper, conditionally verifying all headers +verifyAllIf :: + HasRequiredHeaders h + => Bool -> h (Checked e) -> Either e (RequiredHeaders h) +verifyAllIf False = verifyRequired +verifyAllIf True = fmap snd . verifyAll + +{------------------------------------------------------------------------------- + Request +-------------------------------------------------------------------------------} + +instance HasRequiredHeaders RequestHeaders_ where + data RequiredHeaders RequestHeaders_ = RequiredRequestHeaders { + requiredRequestCompression :: Maybe CompressionId + , requiredRequestTimeout :: Maybe Timeout + } + + requiredHeaders requestHeaders = + RequiredRequestHeaders + <$> requestCompression requestHeaders + <*> requestTimeout requestHeaders + +{------------------------------------------------------------------------------- + Response +-------------------------------------------------------------------------------} + +instance HasRequiredHeaders ResponseHeaders_ where + data RequiredHeaders ResponseHeaders_ = RequiredResponseHeaders { + requiredResponseCompression :: Maybe CompressionId + } + + requiredHeaders responseHeaders = + RequiredResponseHeaders + <$> responseCompression responseHeaders + +{------------------------------------------------------------------------------- + Trailers-Only +-------------------------------------------------------------------------------} + +instance HasRequiredHeaders TrailersOnly_ where + data RequiredHeaders TrailersOnly_ = NoRequiredTrailers + requiredHeaders _ = pure NoRequiredTrailers + diff --git a/src/Network/GRPC/Server/Call.hs b/src/Network/GRPC/Server/Call.hs index a99975c0..936db62f 100644 --- a/src/Network/GRPC/Server/Call.hs +++ b/src/Network/GRPC/Server/Call.hs @@ -55,12 +55,12 @@ import Network.HTTP2.Server qualified as HTTP2 import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr +import Network.GRPC.Common.Headers import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Server.Context import Network.GRPC.Server.Session import Network.GRPC.Spec import Network.GRPC.Spec.Serialization -import Network.GRPC.Util.HKD qualified as HKD import Network.GRPC.Util.HTTP2 (fromHeaderTable) import Network.GRPC.Util.Session qualified as Session import Network.GRPC.Util.Session.Server qualified as Server @@ -190,15 +190,22 @@ determineInbound :: forall rpc. -> HTTP2.Request -> IO (Headers (ServerInbound rpc), Maybe Timeout) determineInbound session req = do - (cIn, timeout) <- processRequestHeaders session requestHeaders' - return ( - InboundHeaders { - inbHeaders = requestHeaders' - , inbCompression = cIn - } - , timeout - ) + case verifyAllIf serverVerifyHeaders requestHeaders' of + Left err -> throwIO $ CallSetupInvalidRequestHeaders err + Right hdrs -> do + cIn <- getInboundCompression session (requiredRequestCompression hdrs) + return ( + InboundHeaders { + inbHeaders = requestHeaders' + , inbCompression = cIn + } + , requiredRequestTimeout hdrs + ) where + ServerSession{serverSessionContext} = session + ServerContext{serverParams} = serverSessionContext + ServerParams{serverVerifyHeaders} = serverParams + requestHeaders' :: RequestHeaders' requestHeaders' = parseRequestHeaders' (Proxy @rpc) $ fromHeaderTable $ HTTP2.requestHeaders req @@ -267,6 +274,22 @@ startOutbound serverParams metadataVar kickoffVar cOut = do , responseBody = Nothing } +-- | Determine compression used by the peer for messages to us +getInboundCompression :: + ServerSession rpc + -> Maybe CompressionId + -> IO Compression +getInboundCompression session = \case + Nothing -> return noCompression + Just cid -> + case Compr.getSupported serverCompression cid of + Just compr -> return compr + Nothing -> throwIO $ CallSetupUnsupportedCompression cid + where + ServerSession{serverSessionContext} = session + ServerContext{serverParams} = serverSessionContext + ServerParams{serverCompression} = serverParams + -- | Determine compression to be used for messages to the peer -- -- In the case that we fail to parse the @grpc-accept-encoding@ header, we @@ -293,66 +316,6 @@ serverExceptionToClientError params err mMsg <- serverExceptionToClient params err return $ simpleProperTrailers (GrpcError GrpcUnknown) mMsg mempty -{------------------------------------------------------------------------------- - Process request headers --------------------------------------------------------------------------------} - --- | Required request headers --- --- If any of these headers are missing or invalid, we throw an exception, --- independent of 'serverVerifyHeaders' -data RequiredHeaders = RequiredHeaders { - requiredCompression :: Maybe CompressionId - , requiredTimeout :: Maybe Timeout - } - --- | Validate /all/ headers, and then extract the required -validateAll :: - RequestHeaders' - -> Either InvalidHeaders RequiredHeaders -validateAll = fmap go . HKD.sequence - where - go :: RequestHeaders -> RequiredHeaders - go requestHeaders = RequiredHeaders { - requiredCompression = requestCompression requestHeaders - , requiredTimeout = requestTimeout requestHeaders - } - --- | Validate only the required headers -validateRequired :: - RequestHeaders' - -> Either InvalidHeaders RequiredHeaders -validateRequired requestHeaders' = - RequiredHeaders - <$> requestCompression requestHeaders' - <*> requestTimeout requestHeaders' - -processRequestHeaders :: - ServerSession rpc - -> RequestHeaders' - -> IO (Compression, Maybe Timeout) -processRequestHeaders session requestHeaders' = do - required <- either invalid return $ - if serverVerifyHeaders serverParams - then validateAll requestHeaders' - else validateRequired requestHeaders' - cIn <- getCompression (requiredCompression required) - return (cIn, requiredTimeout required) - where - ServerSession{serverSessionContext} = session - ServerContext{serverParams} = serverSessionContext - - -- this replaces getInboundCompression - getCompression :: Maybe CompressionId -> IO Compression - getCompression Nothing = return noCompression - getCompression (Just cid) = - case Compr.getSupported (serverCompression serverParams) cid of - Just compr -> return compr - Nothing -> throwIO $ CallSetupUnsupportedCompression cid - - invalid :: forall x. InvalidHeaders -> IO x - invalid = throwIO . CallSetupInvalidRequestHeaders - {------------------------------------------------------------------------------- Open (ongoing) call -------------------------------------------------------------------------------} diff --git a/util/Network/GRPC/Util/HKD.hs b/util/Network/GRPC/Util/HKD.hs index bbfadfe9..66519be0 100644 --- a/util/Network/GRPC/Util/HKD.hs +++ b/util/Network/GRPC/Util/HKD.hs @@ -13,6 +13,7 @@ module Network.GRPC.Util.HKD ( , Coerce(..) , Traversable(..) , sequence + , map -- * Dealing with HKD fields , ValidDecoration , pure @@ -21,7 +22,7 @@ module Network.GRPC.Util.HKD ( , sequenceChecked ) where -import Prelude hiding (Traversable(..), pure) +import Prelude hiding (Traversable(..), pure, map) import Prelude qualified import Control.Monad.Except (MonadError, throwError) @@ -68,6 +69,12 @@ class Coerce t where undecorate :: t (DecoratedWith Identity) -> t Undecorated undecorate = unsafeCoerce + -- | Introduce trivial decoration + -- + -- See 'undecorate' for discussion of the validity of the default definitino. + decorate :: t Undecorated -> t (DecoratedWith Identity) + decorate = unsafeCoerce + class Coerce t => Traversable t where traverse :: Applicative m @@ -80,6 +87,13 @@ sequence :: => t (DecoratedWith m) -> m (t Undecorated) sequence = fmap undecorate . traverse (fmap Identity) +map :: + Traversable t + => (forall a. f a -> g a) + -> t (DecoratedWith f) + -> t (DecoratedWith g) +map f = runIdentity . traverse (Identity . f) + {------------------------------------------------------------------------------- Dealing with HKD fields -------------------------------------------------------------------------------} From 9381c6c2863660b696ef750dfb04c9c7282b8648 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 25 Jul 2024 13:31:04 +0200 Subject: [PATCH 10/10] Fix build for ghc 8.10 --- src/Network/GRPC/Spec/Headers/Invalid.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Network/GRPC/Spec/Headers/Invalid.hs b/src/Network/GRPC/Spec/Headers/Invalid.hs index d7760212..3524a204 100644 --- a/src/Network/GRPC/Spec/Headers/Invalid.hs +++ b/src/Network/GRPC/Spec/Headers/Invalid.hs @@ -22,9 +22,9 @@ import Data.ByteString.Builder qualified as Builder import Data.ByteString.Builder qualified as ByteString (Builder) import Data.ByteString.UTF8 qualified as BS.UTF8 import Data.CaseInsensitive qualified as CI +import Data.Foldable (asum) import Data.Maybe (fromMaybe) import Network.HTTP.Types qualified as HTTP -import Control.Applicative {------------------------------------------------------------------------------- Definition