From c1a9b053f7523186be1eaf949188863ece30f66f Mon Sep 17 00:00:00 2001 From: Przemyslaw Delewski <102958445+pdelewski@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:02:43 +0100 Subject: [PATCH] Adding middleware handling API (#1130) This PR adds a new API call for registering middlewares into the HTTP FrontendConnector. --- .../basic_http_frontend_connector.go | 38 +++++++++ quesma/main_test.go | 83 ++++++++++++++++++- quesma/quesma/auth_middleware.go | 19 ++++- quesma/quesma/dual_write_proxy_v2.go | 18 ++-- quesma/v2/core/quesma_apis.go | 6 ++ 5 files changed, 148 insertions(+), 16 deletions(-) diff --git a/quesma/frontend_connectors/basic_http_frontend_connector.go b/quesma/frontend_connectors/basic_http_frontend_connector.go index a008eb65c..51888c92a 100644 --- a/quesma/frontend_connectors/basic_http_frontend_connector.go +++ b/quesma/frontend_connectors/basic_http_frontend_connector.go @@ -30,6 +30,7 @@ type BasicHTTPFrontendConnector struct { phoneHomeClient diag.PhoneHomeClient debugInfoCollector diag.DebugInfoCollector logger quesma_api.QuesmaLogger + middlewares []http.Handler } func (h *BasicHTTPFrontendConnector) GetChildComponents() []interface{} { @@ -66,6 +67,7 @@ func NewBasicHTTPFrontendConnector(endpoint string, config *config.QuesmaConfigu responseMutator: func(w http.ResponseWriter) http.ResponseWriter { return w }, + middlewares: make([]http.Handler, 0), } } @@ -81,7 +83,39 @@ func (h *BasicHTTPFrontendConnector) GetRouter() quesma_api.Router { return h.router } +type ResponseWriterWithStatusCode struct { + http.ResponseWriter + statusCode int +} + +func (w *ResponseWriterWithStatusCode) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + func (h *BasicHTTPFrontendConnector) ServeHTTP(w http.ResponseWriter, req *http.Request) { + index := 0 + var runMiddleware func() + + runMiddleware = func() { + if index < len(h.middlewares) { + middleware := h.middlewares[index] + index++ + responseWriter := &ResponseWriterWithStatusCode{w, 0} + middleware.ServeHTTP(responseWriter, req) // Automatically proceeds to the next middleware + // Only if the middleware did not set a status code, we proceed to the next middleware + if responseWriter.statusCode == 0 { + runMiddleware() + } + + } else { + h.finalHandler(w, req) + } + } + runMiddleware() +} + +func (h *BasicHTTPFrontendConnector) finalHandler(w http.ResponseWriter, req *http.Request) { reqBody, err := PeekBodyV2(req) if err != nil { http.Error(w, "Error reading request body", http.StatusInternalServerError) @@ -144,3 +178,7 @@ func ReadRequestBody(request *http.Request) ([]byte, error) { func (h *BasicHTTPFrontendConnector) GetRouterInstance() *RouterV2 { return h.routerInstance } + +func (h *BasicHTTPFrontendConnector) AddMiddleware(middleware http.Handler) { + h.middlewares = append(h.middlewares, middleware) +} diff --git a/quesma/main_test.go b/quesma/main_test.go index d7eff8d68..03c6e507b 100644 --- a/quesma/main_test.go +++ b/quesma/main_test.go @@ -135,13 +135,13 @@ func fallbackScenario() quesma_api.QuesmaBuilder { var ingestPipeline quesma_api.PipelineBuilder = quesma_api.NewPipeline() ingestPipeline.AddFrontendConnector(ingestFrontendConnector) quesmaBuilder.AddPipeline(ingestPipeline) - quesma, _ := quesmaBuilder.Build() - quesma.Start() - return quesma + + return quesmaBuilder } func Test_fallbackScenario(t *testing.T) { - q1 := fallbackScenario() + qBuilder := fallbackScenario() + q1, _ := qBuilder.Build() q1.Start() stop := make(chan os.Signal, 1) emitRequests(stop) @@ -159,3 +159,78 @@ func Test_scenario1(t *testing.T) { <-stop q1.Stop(context.Background()) } + +var middlewareCallCount int32 = 0 + +type Middleware struct { + emitError bool +} + +func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&middlewareCallCount, 1) + if m.emitError { + http.Error(w, "middleware", http.StatusInternalServerError) + } +} + +type Middleware2 struct { +} + +func (m *Middleware2) ServeHTTP(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&middlewareCallCount, 1) + w.WriteHeader(200) +} + +func createMiddleWareScenario(emitError bool, cfg *config.QuesmaConfiguration) quesma_api.QuesmaBuilder { + var quesmaBuilder quesma_api.QuesmaBuilder = quesma_api.NewQuesma(quesma_api.EmptyDependencies()) + + frontendConnector := frontend_connectors.NewBasicHTTPFrontendConnector(":8888", cfg) + HTTPRouter := quesma_api.NewPathRouter() + var fallback quesma_api.HTTPFrontendHandler = fallback + HTTPRouter.AddFallbackHandler(fallback) + frontendConnector.AddRouter(HTTPRouter) + frontendConnector.AddMiddleware(&Middleware{emitError: emitError}) + frontendConnector.AddMiddleware(&Middleware2{}) + + var pipeline quesma_api.PipelineBuilder = quesma_api.NewPipeline() + pipeline.AddFrontendConnector(frontendConnector) + var ingestProcessor quesma_api.Processor = NewIngestProcessor() + pipeline.AddProcessor(ingestProcessor) + quesmaBuilder.AddPipeline(pipeline) + return quesmaBuilder +} + +func Test_middleware(t *testing.T) { + + cfg := &config.QuesmaConfiguration{ + DisableAuth: true, + Elasticsearch: config.ElasticsearchConfiguration{ + Url: &config.Url{Host: "localhost:9200", Scheme: "http"}, + User: "", + Password: "", + }, + } + { + quesmaBuilder := createMiddleWareScenario(true, cfg) + quesmaBuilder.Build() + quesmaBuilder.Start() + stop := make(chan os.Signal, 1) + emitRequests(stop) + <-stop + quesmaBuilder.Stop(context.Background()) + atomic.LoadInt32(&middlewareCallCount) + assert.Equal(t, int32(4), middlewareCallCount) + } + atomic.StoreInt32(&middlewareCallCount, 0) + { + quesmaBuilder := createMiddleWareScenario(false, cfg) + quesmaBuilder.Build() + quesmaBuilder.Start() + stop := make(chan os.Signal, 1) + emitRequests(stop) + <-stop + quesmaBuilder.Stop(context.Background()) + atomic.LoadInt32(&middlewareCallCount) + assert.Equal(t, int32(8), middlewareCallCount) + } +} diff --git a/quesma/quesma/auth_middleware.go b/quesma/quesma/auth_middleware.go index 80aa96084..01636af8e 100644 --- a/quesma/quesma/auth_middleware.go +++ b/quesma/quesma/auth_middleware.go @@ -24,11 +24,12 @@ type authMiddleware struct { authHeaderCache sync.Map cacheWipeInterval time.Duration esClient elasticsearch.SimpleClient + v2 bool } func NewAuthMiddleware(next http.Handler, esConf config.ElasticsearchConfiguration) http.Handler { esClient := elasticsearch.NewSimpleClient(&esConf) - middleware := &authMiddleware{nextHttpHandler: next, esClient: *esClient, cacheWipeInterval: cacheWipeInterval} + middleware := &authMiddleware{nextHttpHandler: next, esClient: *esClient, cacheWipeInterval: cacheWipeInterval, v2: false} go middleware.startCacheWipeScheduler() return middleware } @@ -49,7 +50,9 @@ func (a *authMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if _, ok := a.authHeaderCache.Load(auth); ok { logger.Debug().Msgf("[AUTH] [%s] called by [%s] - credentials loaded from cache", r.URL, userName) - a.nextHttpHandler.ServeHTTP(w, r) + if !a.v2 { + a.nextHttpHandler.ServeHTTP(w, r) + } return } @@ -61,8 +64,9 @@ func (a *authMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - - a.nextHttpHandler.ServeHTTP(w, r) + if !a.v2 { + a.nextHttpHandler.ServeHTTP(w, r) + } } func (a *authMiddleware) startCacheWipeScheduler() { @@ -86,3 +90,10 @@ func (a *authMiddleware) wipeCache() { return true }) } + +func NewAuthMiddlewareV2(esConf config.ElasticsearchConfiguration) http.Handler { + esClient := elasticsearch.NewSimpleClient(&esConf) + middleware := &authMiddleware{esClient: *esClient, cacheWipeInterval: cacheWipeInterval, v2: true} + go middleware.startCacheWipeScheduler() + return middleware +} diff --git a/quesma/quesma/dual_write_proxy_v2.go b/quesma/quesma/dual_write_proxy_v2.go index b29c28f8e..32594025a 100644 --- a/quesma/quesma/dual_write_proxy_v2.go +++ b/quesma/quesma/dual_write_proxy_v2.go @@ -26,14 +26,12 @@ const concurrentClientsLimitV2 = 100 // FIXME this should be configurable type simultaneousClientsLimiterV2 struct { counter atomic.Int64 - handler http.Handler limit int64 } -func newSimultaneousClientsLimiterV2(handler http.Handler, limit int64) *simultaneousClientsLimiterV2 { +func newSimultaneousClientsLimiterV2(limit int64) *simultaneousClientsLimiterV2 { return &simultaneousClientsLimiterV2{ - handler: handler, - limit: limit, + limit: limit, } } @@ -49,7 +47,6 @@ func (c *simultaneousClientsLimiterV2) ServeHTTP(w http.ResponseWriter, r *http. c.counter.Add(1) defer c.counter.Add(-1) - c.handler.ServeHTTP(w, r) } type dualWriteHttpProxyV2 struct { @@ -101,12 +98,17 @@ func newDualWriteProxyV2(dependencies quesma_api.Dependencies, schemaLoader clic if err != nil { logger.Fatal().Msgf("Error building Quesma: %v", err) } - var limitedHandler http.Handler if config.DisableAuth { - limitedHandler = newSimultaneousClientsLimiterV2(elasticHttpIngestFrontendConnector, concurrentClientsLimitV2) + elasticHttpIngestFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2)) + elasticHttpQueryFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2)) + limitedHandler = elasticHttpIngestFrontendConnector } else { - limitedHandler = newSimultaneousClientsLimiterV2(NewAuthMiddleware(elasticHttpIngestFrontendConnector, config.Elasticsearch), concurrentClientsLimitV2) + elasticHttpQueryFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2)) + elasticHttpQueryFrontendConnector.AddMiddleware(NewAuthMiddlewareV2(config.Elasticsearch)) + elasticHttpIngestFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2)) + elasticHttpIngestFrontendConnector.AddMiddleware(NewAuthMiddlewareV2(config.Elasticsearch)) + limitedHandler = elasticHttpIngestFrontendConnector } return &dualWriteHttpProxyV2{ diff --git a/quesma/v2/core/quesma_apis.go b/quesma/v2/core/quesma_apis.go index 0aca32cc0..9b59af592 100644 --- a/quesma/v2/core/quesma_apis.go +++ b/quesma/v2/core/quesma_apis.go @@ -5,6 +5,7 @@ package quesma_api import ( "context" "net" + "net/http" ) type InstanceNamer interface { @@ -31,8 +32,13 @@ type FrontendConnector interface { type HTTPFrontendConnector interface { FrontendConnector + // AddRouter adds a router to the HTTPFrontendConnector AddRouter(router Router) GetRouter() Router + // AddMiddleware adds a middleware to the HTTPFrontendConnector. + // The middleware chain is executed in the order it is added + // and before the router is executed. + AddMiddleware(middleware http.Handler) } type TCPFrontendConnector interface {