From 534f927217b3d0a9f5b30d9fd14ec636966edb22 Mon Sep 17 00:00:00 2001 From: Joe Elliott Date: Thu, 7 Nov 2024 16:03:22 -0500 Subject: [PATCH] header strip ware Signed-off-by: Joe Elliott --- modules/frontend/config.go | 4 +- modules/frontend/frontend.go | 8 +++ .../async_strip_headers_middleware.go | 42 ++++++++++++++ .../async_strip_headers_middleware_test.go | 55 +++++++++++++++++++ modules/frontend/pipeline/pipeline.go | 10 ++-- .../frontend/pipeline/sync_handler_retry.go | 2 +- modules/frontend/tag_sharder.go | 2 +- modules/frontend/traceid_sharder.go | 2 +- modules/frontend/v1/request_batch_test.go | 6 +- 9 files changed, 118 insertions(+), 13 deletions(-) create mode 100644 modules/frontend/pipeline/async_strip_headers_middleware.go create mode 100644 modules/frontend/pipeline/async_strip_headers_middleware_test.go diff --git a/modules/frontend/config.go b/modules/frontend/config.go index 0ba194df4b2..1619042b362 100644 --- a/modules/frontend/config.go +++ b/modules/frontend/config.go @@ -32,8 +32,8 @@ type Config struct { // A list of regexes for black listing requests, these will apply for every request regardless the endpoint URLDenyList []string `yaml:"url_deny_list,omitempty"` - RequestWithWeights bool `yaml:"request_with_weights,omitempty"` - RetryWithWeights bool `yaml:"retry_with_weights,omitempty"` + // A list of headers allowed through the HTTP pipeline. Everything else will be stripped. + AllowedHeaders []string `yaml:"-"` } type SearchConfig struct { diff --git a/modules/frontend/frontend.go b/modules/frontend/frontend.go index 9b6a472f6b7..1ffe0ad714c 100644 --- a/modules/frontend/frontend.go +++ b/modules/frontend/frontend.go @@ -96,9 +96,11 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound) urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList) queryValidatorWare := pipeline.NewQueryValidatorWare() + headerStripWare := pipeline.NewStripHeadersWare(cfg.AllowedHeaders) tracePipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, pipeline.NewWeightRequestWare(pipeline.TraceByID, cfg.Weights), multiTenantMiddleware(cfg, logger), @@ -109,6 +111,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t searchPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, queryValidatorWare, pipeline.NewWeightRequestWare(pipeline.TraceQLSearch, cfg.Weights), @@ -120,6 +123,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t searchTagsPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights), multiTenantMiddleware(cfg, logger), @@ -130,6 +134,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t searchTagValuesPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights), multiTenantMiddleware(cfg, logger), @@ -141,6 +146,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t // metrics summary metricsPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, queryValidatorWare, pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights), @@ -152,6 +158,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t // traceql metrics queryRangePipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, queryValidatorWare, pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights), @@ -163,6 +170,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t queryInstantPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ + headerStripWare, urlDenyListWare, queryValidatorWare, pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights), diff --git a/modules/frontend/pipeline/async_strip_headers_middleware.go b/modules/frontend/pipeline/async_strip_headers_middleware.go new file mode 100644 index 00000000000..bf066c561d7 --- /dev/null +++ b/modules/frontend/pipeline/async_strip_headers_middleware.go @@ -0,0 +1,42 @@ +package pipeline + +import ( + "github.com/grafana/tempo/modules/frontend/combiner" +) + +type stripHeadersWare struct { + allowed map[string]struct{} + next AsyncRoundTripper[combiner.PipelineResponse] +} + +func NewStripHeadersWare(allowList []string) AsyncMiddleware[combiner.PipelineResponse] { + // build allowed map + allowed := make(map[string]struct{}, len(allowList)) + for _, header := range allowList { + allowed[header] = struct{}{} + } + + return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] { + return &stripHeadersWare{ + next: next, + allowed: allowed, + } + }) +} + +func (c stripHeadersWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) { + httpReq := req.HTTPRequest() + + if len(c.allowed) == 0 { + clear(httpReq.Header) + } else { + // clear out headers not in allow list + for header := range httpReq.Header { + if _, ok := c.allowed[header]; !ok { + delete(httpReq.Header, header) + } + } + } + + return c.next.RoundTrip(req.CloneFromHTTPRequest(httpReq)) +} diff --git a/modules/frontend/pipeline/async_strip_headers_middleware_test.go b/modules/frontend/pipeline/async_strip_headers_middleware_test.go new file mode 100644 index 00000000000..cac02d02a4f --- /dev/null +++ b/modules/frontend/pipeline/async_strip_headers_middleware_test.go @@ -0,0 +1,55 @@ +package pipeline + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/grafana/tempo/modules/frontend/combiner" + "github.com/stretchr/testify/require" +) + +func TestStripHeaders(t *testing.T) { + tcs := []struct { + name string + allow []string + headers map[string][]string + expected http.Header + }{ + { + name: "empty allow list", + allow: []string{}, + headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}}, + expected: map[string][]string{}, + }, + { + name: "allow list with one header", + allow: []string{"header1"}, + headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}}, + expected: map[string][]string{"header1": {"value1"}}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + next := AsyncRoundTripperFunc[combiner.PipelineResponse](func(req Request) (Responses[combiner.PipelineResponse], error) { + actualHeaders := req.HTTPRequest().Header + require.Equal(t, tc.expected, actualHeaders) + + return NewHTTPToAsyncResponse(&http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte{})), + }), nil + }) + + stripHeaders := NewStripHeadersWare(tc.allow).Wrap(next) + + req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) + req.Header = tc.headers + + _, err := stripHeaders.RoundTrip(NewHTTPRequest(req)) + require.NoError(t, err) + }) + } +} diff --git a/modules/frontend/pipeline/pipeline.go b/modules/frontend/pipeline/pipeline.go index 0dc9bae839e..4a344f464ae 100644 --- a/modules/frontend/pipeline/pipeline.go +++ b/modules/frontend/pipeline/pipeline.go @@ -12,13 +12,13 @@ var tracer = otel.Tracer("modules/frontend/pipeline") type Request interface { HTTPRequest() *http.Request - Context() context.Context - - WithContext(context.Context) CloneFromHTTPRequest(request *http.Request) Request - Weight() int + SetContext(context.Context) + Context() context.Context + SetWeight(int) + Weight() int SetCacheKey(string) CacheKey() string @@ -51,7 +51,7 @@ func (r HTTPRequest) Context() context.Context { return r.req.Context() } -func (r *HTTPRequest) WithContext(ctx context.Context) { +func (r *HTTPRequest) SetContext(ctx context.Context) { r.req = r.req.WithContext(ctx) } diff --git a/modules/frontend/pipeline/sync_handler_retry.go b/modules/frontend/pipeline/sync_handler_retry.go index b8ad87fed5a..0fe7484c650 100644 --- a/modules/frontend/pipeline/sync_handler_retry.go +++ b/modules/frontend/pipeline/sync_handler_retry.go @@ -51,7 +51,7 @@ func (r retryWare) RoundTrip(req Request) (*http.Response, error) { defer span.End() // context propagation - req.WithContext(ctx) + req.SetContext(ctx) tries := 0 defer func() { r.retriesCount.Observe(float64(tries)) }() diff --git a/modules/frontend/tag_sharder.go b/modules/frontend/tag_sharder.go index 55f1f1e7942..6e0a2892d3b 100644 --- a/modules/frontend/tag_sharder.go +++ b/modules/frontend/tag_sharder.go @@ -203,7 +203,7 @@ func (s searchTagSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline. } ctx, span := tracer.Start(ctx, "frontend.ShardSearchTags") defer span.End() - pipelineRequest.WithContext(ctx) + pipelineRequest.SetContext(ctx) // calculate and enforce max search duration maxDuration := s.maxDuration(tenantID) diff --git a/modules/frontend/traceid_sharder.go b/modules/frontend/traceid_sharder.go index cb17c09c12a..08339f40442 100644 --- a/modules/frontend/traceid_sharder.go +++ b/modules/frontend/traceid_sharder.go @@ -41,7 +41,7 @@ func newAsyncTraceIDSharder(cfg *TraceByIDConfig, logger log.Logger) pipeline.As func (s asyncTraceSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.Responses[combiner.PipelineResponse], error) { ctx, span := tracer.Start(pipelineRequest.Context(), "frontend.ShardQuery") defer span.End() - pipelineRequest.WithContext(ctx) + pipelineRequest.SetContext(ctx) reqs, err := s.buildShardedRequests(pipelineRequest) if err != nil { diff --git a/modules/frontend/v1/request_batch_test.go b/modules/frontend/v1/request_batch_test.go index 931abae26c3..7b55377648d 100644 --- a/modules/frontend/v1/request_batch_test.go +++ b/modules/frontend/v1/request_batch_test.go @@ -52,7 +52,7 @@ func TestRequestBatchContextError(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com", nil) prequest := pipeline.NewHTTPRequest(req) - prequest.WithContext(ctx) + prequest.SetContext(ctx) for i := 0; i < totalRequests-1; i++ { _ = rb.add(&request{request: prequest}) @@ -61,7 +61,7 @@ func TestRequestBatchContextError(t *testing.T) { // add a cancel context cancelCtx, cancel := context.WithCancel(ctx) prequest = pipeline.NewHTTPRequest(req) - prequest.WithContext(cancelCtx) + prequest.SetContext(cancelCtx) _ = rb.add(&request{request: prequest}) @@ -83,7 +83,7 @@ func TestDoneChanCloses(_ *testing.T) { req := httptest.NewRequest("GET", "http://example.com", nil) prequest := pipeline.NewHTTPRequest(req) - prequest.WithContext(cancelCtx) + prequest.SetContext(cancelCtx) for i := 0; i < totalRequests-1; i++ { _ = rb.add(&request{request: prequest})