From 29d1134f1e1874d52ca2f5487767e4916a34f53e Mon Sep 17 00:00:00 2001 From: Jerome Amon Date: Mon, 13 Nov 2023 01:38:59 +0100 Subject: [PATCH] feat(middlewares): update timeout middleware to set the processing timeout based on some conditions like path and method --- api.middlewares.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/api.middlewares.go b/api.middlewares.go index 17dc29c..18721d8 100644 --- a/api.middlewares.go +++ b/api.middlewares.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "sync/atomic" + "time" "github.com/julienschmidt/httprouter" "go.uber.org/zap" @@ -31,8 +32,9 @@ type MiddlewareMap struct { func (api *APIHandler) StatsMiddleware(next httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { logger := api.GetLoggerFromContext(r.Context()) + conn := GetConnFromContext(r.Context()) + nw := NewCustomResponseWriter(w, conn) start := api.clock.Now() - nw := NewCustomResponseWriter(w) next(nw, r, ps) logger.Info( "stats", @@ -144,8 +146,8 @@ func (api *APIHandler) TimeoutMiddleware(next httprouter.Handle) httprouter.Hand return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { requestID := GetValueFromContext(r.Context(), ContextRequestID) logger := api.GetLoggerFromContext(r.Context()) - timeout := api.config.Server.RequestTimeout - ctx, cancel := context.WithTimeout(r.Context(), api.config.Server.RequestTimeout) + timeout := api.GetTimeout(r) + ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() r = r.WithContext(ctx) done := make(chan struct{}) @@ -176,6 +178,17 @@ func (api *APIHandler) TimeoutMiddleware(next httprouter.Handle) httprouter.Hand } } +// GetTimeout returns the processing timeout to use to update +// a given request context deadline based on path and method. +func (api *APIHandler) GetTimeout(r *http.Request) time.Duration { + switch { + case r.Method == "GET" && r.URL.Path == "/v1/books": + return api.config.Server.LongRequestProcessingTimeout + default: + return api.config.Server.RequestTimeout + } +} + // Chain wraps a given httprouter.Handle with a list of middlewares. // It does by starting from the last middleware from the list. func (m *Middlewares) Chain(h httprouter.Handle) httprouter.Handle {