Skip to content

Commit

Permalink
Adding middleware handling API (#1130)
Browse files Browse the repository at this point in the history
This PR adds a new API call for registering middlewares into the HTTP
FrontendConnector.
  • Loading branch information
pdelewski authored Dec 20, 2024
1 parent 7b4de20 commit c1a9b05
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 16 deletions.
38 changes: 38 additions & 0 deletions quesma/frontend_connectors/basic_http_frontend_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type BasicHTTPFrontendConnector struct {
phoneHomeClient diag.PhoneHomeClient
debugInfoCollector diag.DebugInfoCollector
logger quesma_api.QuesmaLogger
middlewares []http.Handler
}

func (h *BasicHTTPFrontendConnector) GetChildComponents() []interface{} {
Expand Down Expand Up @@ -66,6 +67,7 @@ func NewBasicHTTPFrontendConnector(endpoint string, config *config.QuesmaConfigu
responseMutator: func(w http.ResponseWriter) http.ResponseWriter {
return w
},
middlewares: make([]http.Handler, 0),
}
}

Expand All @@ -81,7 +83,39 @@ func (h *BasicHTTPFrontendConnector) GetRouter() quesma_api.Router {
return h.router
}

type ResponseWriterWithStatusCode struct {
http.ResponseWriter
statusCode int
}

func (w *ResponseWriterWithStatusCode) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}

func (h *BasicHTTPFrontendConnector) ServeHTTP(w http.ResponseWriter, req *http.Request) {
index := 0
var runMiddleware func()

runMiddleware = func() {
if index < len(h.middlewares) {
middleware := h.middlewares[index]
index++
responseWriter := &ResponseWriterWithStatusCode{w, 0}
middleware.ServeHTTP(responseWriter, req) // Automatically proceeds to the next middleware
// Only if the middleware did not set a status code, we proceed to the next middleware
if responseWriter.statusCode == 0 {
runMiddleware()
}

} else {
h.finalHandler(w, req)
}
}
runMiddleware()
}

func (h *BasicHTTPFrontendConnector) finalHandler(w http.ResponseWriter, req *http.Request) {
reqBody, err := PeekBodyV2(req)
if err != nil {
http.Error(w, "Error reading request body", http.StatusInternalServerError)
Expand Down Expand Up @@ -144,3 +178,7 @@ func ReadRequestBody(request *http.Request) ([]byte, error) {
func (h *BasicHTTPFrontendConnector) GetRouterInstance() *RouterV2 {
return h.routerInstance
}

func (h *BasicHTTPFrontendConnector) AddMiddleware(middleware http.Handler) {
h.middlewares = append(h.middlewares, middleware)
}
83 changes: 79 additions & 4 deletions quesma/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ func fallbackScenario() quesma_api.QuesmaBuilder {
var ingestPipeline quesma_api.PipelineBuilder = quesma_api.NewPipeline()
ingestPipeline.AddFrontendConnector(ingestFrontendConnector)
quesmaBuilder.AddPipeline(ingestPipeline)
quesma, _ := quesmaBuilder.Build()
quesma.Start()
return quesma

return quesmaBuilder
}

func Test_fallbackScenario(t *testing.T) {
q1 := fallbackScenario()
qBuilder := fallbackScenario()
q1, _ := qBuilder.Build()
q1.Start()
stop := make(chan os.Signal, 1)
emitRequests(stop)
Expand All @@ -159,3 +159,78 @@ func Test_scenario1(t *testing.T) {
<-stop
q1.Stop(context.Background())
}

var middlewareCallCount int32 = 0

type Middleware struct {
emitError bool
}

func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&middlewareCallCount, 1)
if m.emitError {
http.Error(w, "middleware", http.StatusInternalServerError)
}
}

type Middleware2 struct {
}

func (m *Middleware2) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&middlewareCallCount, 1)
w.WriteHeader(200)
}

func createMiddleWareScenario(emitError bool, cfg *config.QuesmaConfiguration) quesma_api.QuesmaBuilder {
var quesmaBuilder quesma_api.QuesmaBuilder = quesma_api.NewQuesma(quesma_api.EmptyDependencies())

frontendConnector := frontend_connectors.NewBasicHTTPFrontendConnector(":8888", cfg)
HTTPRouter := quesma_api.NewPathRouter()
var fallback quesma_api.HTTPFrontendHandler = fallback
HTTPRouter.AddFallbackHandler(fallback)
frontendConnector.AddRouter(HTTPRouter)
frontendConnector.AddMiddleware(&Middleware{emitError: emitError})
frontendConnector.AddMiddleware(&Middleware2{})

var pipeline quesma_api.PipelineBuilder = quesma_api.NewPipeline()
pipeline.AddFrontendConnector(frontendConnector)
var ingestProcessor quesma_api.Processor = NewIngestProcessor()
pipeline.AddProcessor(ingestProcessor)
quesmaBuilder.AddPipeline(pipeline)
return quesmaBuilder
}

func Test_middleware(t *testing.T) {

cfg := &config.QuesmaConfiguration{
DisableAuth: true,
Elasticsearch: config.ElasticsearchConfiguration{
Url: &config.Url{Host: "localhost:9200", Scheme: "http"},
User: "",
Password: "",
},
}
{
quesmaBuilder := createMiddleWareScenario(true, cfg)
quesmaBuilder.Build()
quesmaBuilder.Start()
stop := make(chan os.Signal, 1)
emitRequests(stop)
<-stop
quesmaBuilder.Stop(context.Background())
atomic.LoadInt32(&middlewareCallCount)
assert.Equal(t, int32(4), middlewareCallCount)
}
atomic.StoreInt32(&middlewareCallCount, 0)
{
quesmaBuilder := createMiddleWareScenario(false, cfg)
quesmaBuilder.Build()
quesmaBuilder.Start()
stop := make(chan os.Signal, 1)
emitRequests(stop)
<-stop
quesmaBuilder.Stop(context.Background())
atomic.LoadInt32(&middlewareCallCount)
assert.Equal(t, int32(8), middlewareCallCount)
}
}
19 changes: 15 additions & 4 deletions quesma/quesma/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ type authMiddleware struct {
authHeaderCache sync.Map
cacheWipeInterval time.Duration
esClient elasticsearch.SimpleClient
v2 bool
}

func NewAuthMiddleware(next http.Handler, esConf config.ElasticsearchConfiguration) http.Handler {
esClient := elasticsearch.NewSimpleClient(&esConf)
middleware := &authMiddleware{nextHttpHandler: next, esClient: *esClient, cacheWipeInterval: cacheWipeInterval}
middleware := &authMiddleware{nextHttpHandler: next, esClient: *esClient, cacheWipeInterval: cacheWipeInterval, v2: false}
go middleware.startCacheWipeScheduler()
return middleware
}
Expand All @@ -49,7 +50,9 @@ func (a *authMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if _, ok := a.authHeaderCache.Load(auth); ok {
logger.Debug().Msgf("[AUTH] [%s] called by [%s] - credentials loaded from cache", r.URL, userName)
a.nextHttpHandler.ServeHTTP(w, r)
if !a.v2 {
a.nextHttpHandler.ServeHTTP(w, r)
}
return
}

Expand All @@ -61,8 +64,9 @@ func (a *authMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

a.nextHttpHandler.ServeHTTP(w, r)
if !a.v2 {
a.nextHttpHandler.ServeHTTP(w, r)
}
}

func (a *authMiddleware) startCacheWipeScheduler() {
Expand All @@ -86,3 +90,10 @@ func (a *authMiddleware) wipeCache() {
return true
})
}

func NewAuthMiddlewareV2(esConf config.ElasticsearchConfiguration) http.Handler {
esClient := elasticsearch.NewSimpleClient(&esConf)
middleware := &authMiddleware{esClient: *esClient, cacheWipeInterval: cacheWipeInterval, v2: true}
go middleware.startCacheWipeScheduler()
return middleware
}
18 changes: 10 additions & 8 deletions quesma/quesma/dual_write_proxy_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ const concurrentClientsLimitV2 = 100 // FIXME this should be configurable

type simultaneousClientsLimiterV2 struct {
counter atomic.Int64
handler http.Handler
limit int64
}

func newSimultaneousClientsLimiterV2(handler http.Handler, limit int64) *simultaneousClientsLimiterV2 {
func newSimultaneousClientsLimiterV2(limit int64) *simultaneousClientsLimiterV2 {
return &simultaneousClientsLimiterV2{
handler: handler,
limit: limit,
limit: limit,
}
}

Expand All @@ -49,7 +47,6 @@ func (c *simultaneousClientsLimiterV2) ServeHTTP(w http.ResponseWriter, r *http.

c.counter.Add(1)
defer c.counter.Add(-1)
c.handler.ServeHTTP(w, r)
}

type dualWriteHttpProxyV2 struct {
Expand Down Expand Up @@ -101,12 +98,17 @@ func newDualWriteProxyV2(dependencies quesma_api.Dependencies, schemaLoader clic
if err != nil {
logger.Fatal().Msgf("Error building Quesma: %v", err)
}

var limitedHandler http.Handler
if config.DisableAuth {
limitedHandler = newSimultaneousClientsLimiterV2(elasticHttpIngestFrontendConnector, concurrentClientsLimitV2)
elasticHttpIngestFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2))
elasticHttpQueryFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2))
limitedHandler = elasticHttpIngestFrontendConnector
} else {
limitedHandler = newSimultaneousClientsLimiterV2(NewAuthMiddleware(elasticHttpIngestFrontendConnector, config.Elasticsearch), concurrentClientsLimitV2)
elasticHttpQueryFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2))
elasticHttpQueryFrontendConnector.AddMiddleware(NewAuthMiddlewareV2(config.Elasticsearch))
elasticHttpIngestFrontendConnector.AddMiddleware(newSimultaneousClientsLimiterV2(concurrentClientsLimitV2))
elasticHttpIngestFrontendConnector.AddMiddleware(NewAuthMiddlewareV2(config.Elasticsearch))
limitedHandler = elasticHttpIngestFrontendConnector
}

return &dualWriteHttpProxyV2{
Expand Down
6 changes: 6 additions & 0 deletions quesma/v2/core/quesma_apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package quesma_api
import (
"context"
"net"
"net/http"
)

type InstanceNamer interface {
Expand All @@ -31,8 +32,13 @@ type FrontendConnector interface {

type HTTPFrontendConnector interface {
FrontendConnector
// AddRouter adds a router to the HTTPFrontendConnector
AddRouter(router Router)
GetRouter() Router
// AddMiddleware adds a middleware to the HTTPFrontendConnector.
// The middleware chain is executed in the order it is added
// and before the router is executed.
AddMiddleware(middleware http.Handler)
}

type TCPFrontendConnector interface {
Expand Down

0 comments on commit c1a9b05

Please sign in to comment.