Skip to content

Commit

Permalink
header strip ware
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Elliott <[email protected]>
  • Loading branch information
joe-elliott committed Nov 7, 2024
1 parent b6d7289 commit 534f927
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 13 deletions.
4 changes: 2 additions & 2 deletions modules/frontend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type Config struct {
// A list of regexes for black listing requests, these will apply for every request regardless the endpoint
URLDenyList []string `yaml:"url_deny_list,omitempty"`

RequestWithWeights bool `yaml:"request_with_weights,omitempty"`
RetryWithWeights bool `yaml:"retry_with_weights,omitempty"`
// A list of headers allowed through the HTTP pipeline. Everything else will be stripped.
AllowedHeaders []string `yaml:"-"`
}

type SearchConfig struct {
Expand Down
8 changes: 8 additions & 0 deletions modules/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound)
urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList)
queryValidatorWare := pipeline.NewQueryValidatorWare()
headerStripWare := pipeline.NewStripHeadersWare(cfg.AllowedHeaders)

tracePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.TraceByID, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -109,6 +111,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLSearch, cfg.Weights),
Expand All @@ -120,6 +123,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchTagsPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -130,6 +134,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchTagValuesPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -141,6 +146,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
// metrics summary
metricsPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
Expand All @@ -152,6 +158,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
// traceql metrics
queryRangePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights),
Expand All @@ -163,6 +170,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

queryInstantPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights),
Expand Down
42 changes: 42 additions & 0 deletions modules/frontend/pipeline/async_strip_headers_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package pipeline

import (
"github.com/grafana/tempo/modules/frontend/combiner"
)

type stripHeadersWare struct {
allowed map[string]struct{}
next AsyncRoundTripper[combiner.PipelineResponse]
}

func NewStripHeadersWare(allowList []string) AsyncMiddleware[combiner.PipelineResponse] {
// build allowed map
allowed := make(map[string]struct{}, len(allowList))
for _, header := range allowList {
allowed[header] = struct{}{}
}

return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] {
return &stripHeadersWare{
next: next,
allowed: allowed,
}
})
}

func (c stripHeadersWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) {
httpReq := req.HTTPRequest()

if len(c.allowed) == 0 {
clear(httpReq.Header)
} else {
// clear out headers not in allow list
for header := range httpReq.Header {
if _, ok := c.allowed[header]; !ok {
delete(httpReq.Header, header)
}
}
}

return c.next.RoundTrip(req.CloneFromHTTPRequest(httpReq))
}
55 changes: 55 additions & 0 deletions modules/frontend/pipeline/async_strip_headers_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package pipeline

import (
"bytes"
"io"
"net/http"
"testing"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/stretchr/testify/require"
)

func TestStripHeaders(t *testing.T) {
tcs := []struct {
name string
allow []string
headers map[string][]string
expected http.Header
}{
{
name: "empty allow list",
allow: []string{},
headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}},
expected: map[string][]string{},
},
{
name: "allow list with one header",
allow: []string{"header1"},
headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}},
expected: map[string][]string{"header1": {"value1"}},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
next := AsyncRoundTripperFunc[combiner.PipelineResponse](func(req Request) (Responses[combiner.PipelineResponse], error) {
actualHeaders := req.HTTPRequest().Header
require.Equal(t, tc.expected, actualHeaders)

return NewHTTPToAsyncResponse(&http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte{})),
}), nil
})

stripHeaders := NewStripHeadersWare(tc.allow).Wrap(next)

req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
req.Header = tc.headers

_, err := stripHeaders.RoundTrip(NewHTTPRequest(req))
require.NoError(t, err)
})
}
}
10 changes: 5 additions & 5 deletions modules/frontend/pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ var tracer = otel.Tracer("modules/frontend/pipeline")

type Request interface {
HTTPRequest() *http.Request
Context() context.Context

WithContext(context.Context)
CloneFromHTTPRequest(request *http.Request) Request

Weight() int
SetContext(context.Context)
Context() context.Context

SetWeight(int)
Weight() int

SetCacheKey(string)
CacheKey() string
Expand Down Expand Up @@ -51,7 +51,7 @@ func (r HTTPRequest) Context() context.Context {
return r.req.Context()
}

func (r *HTTPRequest) WithContext(ctx context.Context) {
func (r *HTTPRequest) SetContext(ctx context.Context) {
r.req = r.req.WithContext(ctx)
}

Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/pipeline/sync_handler_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (r retryWare) RoundTrip(req Request) (*http.Response, error) {
defer span.End()

// context propagation
req.WithContext(ctx)
req.SetContext(ctx)

tries := 0
defer func() { r.retriesCount.Observe(float64(tries)) }()
Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/tag_sharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (s searchTagSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.
}
ctx, span := tracer.Start(ctx, "frontend.ShardSearchTags")
defer span.End()
pipelineRequest.WithContext(ctx)
pipelineRequest.SetContext(ctx)

// calculate and enforce max search duration
maxDuration := s.maxDuration(tenantID)
Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/traceid_sharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func newAsyncTraceIDSharder(cfg *TraceByIDConfig, logger log.Logger) pipeline.As
func (s asyncTraceSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.Responses[combiner.PipelineResponse], error) {
ctx, span := tracer.Start(pipelineRequest.Context(), "frontend.ShardQuery")
defer span.End()
pipelineRequest.WithContext(ctx)
pipelineRequest.SetContext(ctx)

reqs, err := s.buildShardedRequests(pipelineRequest)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions modules/frontend/v1/request_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestRequestBatchContextError(t *testing.T) {

req := httptest.NewRequest("GET", "http://example.com", nil)
prequest := pipeline.NewHTTPRequest(req)
prequest.WithContext(ctx)
prequest.SetContext(ctx)

for i := 0; i < totalRequests-1; i++ {
_ = rb.add(&request{request: prequest})
Expand All @@ -61,7 +61,7 @@ func TestRequestBatchContextError(t *testing.T) {
// add a cancel context
cancelCtx, cancel := context.WithCancel(ctx)
prequest = pipeline.NewHTTPRequest(req)
prequest.WithContext(cancelCtx)
prequest.SetContext(cancelCtx)

_ = rb.add(&request{request: prequest})

Expand All @@ -83,7 +83,7 @@ func TestDoneChanCloses(_ *testing.T) {

req := httptest.NewRequest("GET", "http://example.com", nil)
prequest := pipeline.NewHTTPRequest(req)
prequest.WithContext(cancelCtx)
prequest.SetContext(cancelCtx)

for i := 0; i < totalRequests-1; i++ {
_ = rb.add(&request{request: prequest})
Expand Down

0 comments on commit 534f927

Please sign in to comment.