Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query-Frontend: Add middleware to drop headers #4298

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
* [ENHANCEMENT] Collection of query-frontend changes to reduce allocs. [#4242]https://github.com/grafana/tempo/pull/4242 (@joe-elliott)
* [ENHANCEMENT] Added `insecure-skip-verify` option in tempo-cli to skip SSL certificate validation when connecting to the S3 backend. [#44236](https://github.com/grafana/tempo/pull/4259) (@faridtmammadov)
* [ENHANCEMENT] Add `invalid_utf8` to reasons spanmetrics will discard spans. [#4293](https://github.com/grafana/tempo/pull/4293) (@zalegrala)
* [ENHANCEMENT] Reduce frontend and querier allocations by dropping HTTP headers early in the pipeline. [#4298](https://github.com/grafana/tempo/pull/4298) (@joe-elliott)
* [BUGFIX] Replace hedged requests roundtrips total with a counter. [#4063](https://github.com/grafana/tempo/pull/4063) [#4078](https://github.com/grafana/tempo/pull/4078) (@galalen)
* [BUGFIX] Metrics generators: Correctly drop from the ring before stopping ingestion to reduce drops during a rollout. [#4101](https://github.com/grafana/tempo/pull/4101) (@joe-elliott)
* [BUGFIX] Correctly handle 400 Bad Request and 404 Not Found in gRPC streaming [#4144](https://github.com/grafana/tempo/pull/4144) (@mapno)
Expand Down
4 changes: 2 additions & 2 deletions modules/frontend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type Config struct {
// A list of regexes for black listing requests, these will apply for every request regardless the endpoint
URLDenyList []string `yaml:"url_deny_list,omitempty"`

RequestWithWeights bool `yaml:"request_with_weights,omitempty"`
RetryWithWeights bool `yaml:"retry_with_weights,omitempty"`
// A list of headers allowed through the HTTP pipeline. Everything else will be stripped.
AllowedHeaders []string `yaml:"-"`
}

type SearchConfig struct {
Expand Down
7 changes: 7 additions & 0 deletions modules/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound)
urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList)
queryValidatorWare := pipeline.NewQueryValidatorWare()
headerStripWare := pipeline.NewStripHeadersWare(cfg.AllowedHeaders)

tracePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.TraceByID, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -109,6 +111,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLSearch, cfg.Weights),
Expand All @@ -120,6 +123,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchTagsPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -130,6 +134,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchTagValuesPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -152,6 +157,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
// traceql metrics
queryRangePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights),
Expand All @@ -163,6 +169,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

queryInstantPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights),
Expand Down
46 changes: 46 additions & 0 deletions modules/frontend/pipeline/async_strip_headers_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package pipeline

import (
"github.com/grafana/tempo/modules/frontend/combiner"
)

type stripHeadersWare struct {
allowed map[string]struct{}
next AsyncRoundTripper[combiner.PipelineResponse]
}

// NewStripHeadersWare creates a middleware that strips headers not in the allow list. This exists to reduce allocations further
// down the pipeline. All request headers should be handled at the Combiner/Collector levels. Once the request is in the pipeline
// nothing else needs HTTP headers. Stripping them out reduces allocations for copying, marshalling and unmashalling them to sometimes
// 100s of thousands of subrequests.
func NewStripHeadersWare(allowList []string) AsyncMiddleware[combiner.PipelineResponse] {
// build allowed map
allowed := make(map[string]struct{}, len(allowList))
for _, header := range allowList {
allowed[header] = struct{}{}
}

return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] {
return &stripHeadersWare{
next: next,
allowed: allowed,
}
})
}

func (c stripHeadersWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) {
httpReq := req.HTTPRequest()

if len(c.allowed) == 0 {
clear(httpReq.Header)
} else {
// clear out headers not in allow list
for header := range httpReq.Header {
if _, ok := c.allowed[header]; !ok {
delete(httpReq.Header, header)
}
}
}

return c.next.RoundTrip(req.CloneFromHTTPRequest(httpReq))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package pipeline

import (
"bytes"
"io"
"net/http"
"testing"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/stretchr/testify/require"
)

func TestStripHeaders(t *testing.T) {
tcs := []struct {
name string
allow []string
headers map[string][]string
expected http.Header
}{
{
name: "empty allow list",
allow: []string{},
headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}},
expected: map[string][]string{},
},
{
name: "allow list with one header",
allow: []string{"header1"},
headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}},
expected: map[string][]string{"header1": {"value1"}},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
next := AsyncRoundTripperFunc[combiner.PipelineResponse](func(req Request) (Responses[combiner.PipelineResponse], error) {
actualHeaders := req.HTTPRequest().Header
require.Equal(t, tc.expected, actualHeaders)

return NewHTTPToAsyncResponse(&http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte{})),
}), nil
})

stripHeaders := NewStripHeadersWare(tc.allow).Wrap(next)

req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
req.Header = tc.headers

_, err := stripHeaders.RoundTrip(NewHTTPRequest(req))
require.NoError(t, err)
})
}
}
10 changes: 5 additions & 5 deletions modules/frontend/pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ var tracer = otel.Tracer("modules/frontend/pipeline")

type Request interface {
HTTPRequest() *http.Request
Context() context.Context

WithContext(context.Context)
CloneFromHTTPRequest(request *http.Request) Request

Weight() int
SetContext(context.Context)
Context() context.Context

SetWeight(int)
Weight() int

SetCacheKey(string)
CacheKey() string
Expand Down Expand Up @@ -51,7 +51,7 @@ func (r HTTPRequest) Context() context.Context {
return r.req.Context()
}

func (r *HTTPRequest) WithContext(ctx context.Context) {
func (r *HTTPRequest) SetContext(ctx context.Context) {
r.req = r.req.WithContext(ctx)
}

Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/pipeline/sync_handler_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (r retryWare) RoundTrip(req Request) (*http.Response, error) {
defer span.End()

// context propagation
req.WithContext(ctx)
req.SetContext(ctx)

tries := 0
defer func() { r.retriesCount.Observe(float64(tries)) }()
Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/tag_sharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (s searchTagSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.
}
ctx, span := tracer.Start(ctx, "frontend.ShardSearchTags")
defer span.End()
pipelineRequest.WithContext(ctx)
pipelineRequest.SetContext(ctx)

// calculate and enforce max search duration
maxDuration := s.maxDuration(tenantID)
Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/traceid_sharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func newAsyncTraceIDSharder(cfg *TraceByIDConfig, logger log.Logger) pipeline.As
func (s asyncTraceSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.Responses[combiner.PipelineResponse], error) {
ctx, span := tracer.Start(pipelineRequest.Context(), "frontend.ShardQuery")
defer span.End()
pipelineRequest.WithContext(ctx)
pipelineRequest.SetContext(ctx)

reqs, err := s.buildShardedRequests(pipelineRequest)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions modules/frontend/v1/request_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestRequestBatchContextError(t *testing.T) {

req := httptest.NewRequest("GET", "http://example.com", nil)
prequest := pipeline.NewHTTPRequest(req)
prequest.WithContext(ctx)
prequest.SetContext(ctx)

for i := 0; i < totalRequests-1; i++ {
_ = rb.add(&request{request: prequest})
Expand All @@ -61,7 +61,7 @@ func TestRequestBatchContextError(t *testing.T) {
// add a cancel context
cancelCtx, cancel := context.WithCancel(ctx)
prequest = pipeline.NewHTTPRequest(req)
prequest.WithContext(cancelCtx)
prequest.SetContext(cancelCtx)

_ = rb.add(&request{request: prequest})

Expand All @@ -83,7 +83,7 @@ func TestDoneChanCloses(_ *testing.T) {

req := httptest.NewRequest("GET", "http://example.com", nil)
prequest := pipeline.NewHTTPRequest(req)
prequest.WithContext(cancelCtx)
prequest.SetContext(cancelCtx)

for i := 0; i < totalRequests-1; i++ {
_ = rb.add(&request{request: prequest})
Expand Down