Skip to content

Commit

Permalink
Refactor query frontend to return prometheus error response (#5811)
Browse files Browse the repository at this point in the history
* refactor query frontend to return prometheus error response

Signed-off-by: Ben Ye <[email protected]>

* stash

Signed-off-by: Ben Ye <[email protected]>

* add test

Signed-off-by: Ben Ye <[email protected]>

* fix lint

Signed-off-by: Ben Ye <[email protected]>

---------

Signed-off-by: Ben Ye <[email protected]>
  • Loading branch information
yeya24 authored Apr 22, 2024
1 parent e39eace commit 6bca2d5
Show file tree
Hide file tree
Showing 16 changed files with 505 additions and 130 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
* [ENHANCEMENT] Ruler: Improve GetRules response time by refactoring mutexes and introducing a temporary rules cache in `ruler/manager.go`. #5805
* [ENHANCEMENT] Querier: Add context error check when merging slices from ingesters for GetLabel operations. #5837
* [ENHANCEMENT] Ring: Add experimental `-ingester.tokens-generator-strategy=minimize-spread` flag to enable the new minimize spread token generator strategy. #5855
* [ENHANCEMENT] Query Frontend: Ensure error response returned by Query Frontend follows Prometheus API error response format. #5811
* [BUGFIX] Distributor: Do not use label with empty values for sharding #5717
* [BUGFIX] Query Frontend: queries with negative offset should check whether it is cacheable or not. #5719
* [BUGFIX] Redis Cache: pass `cache_size` config correctly. #5734
Expand Down
29 changes: 11 additions & 18 deletions pkg/frontend/transport/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/httpgrpc/server"
"google.golang.org/grpc/status"

querier_stats "github.com/cortexproject/cortex/pkg/querier/stats"
"github.com/cortexproject/cortex/pkg/querier/tripperware"
"github.com/cortexproject/cortex/pkg/tenant"
"github.com/cortexproject/cortex/pkg/util"
util_api "github.com/cortexproject/cortex/pkg/util/api"
util_log "github.com/cortexproject/cortex/pkg/util/log"
)

Expand Down Expand Up @@ -239,8 +239,9 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeServiceTimingHeader(queryResponseTime, hs, stats)
}

logger := util_log.WithContext(r.Context(), f.log)
if err != nil {
writeError(w, err, hs)
writeError(logger, w, err, hs)
return
}

Expand All @@ -252,7 +253,7 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// log copy response body error so that we will know even though success response code returned
bytesCopied, err := io.Copy(w, resp.Body)
if err != nil && !errors.Is(err, syscall.EPIPE) {
level.Error(util_log.WithContext(r.Context(), f.log)).Log("msg", "write response body error", "bytesCopied", bytesCopied, "err", err)
level.Error(logger).Log("msg", "write response body error", "bytesCopied", bytesCopied, "err", err)
}
}

Expand Down Expand Up @@ -441,7 +442,7 @@ func formatQueryString(queryString url.Values) (fields []interface{}) {
return fields
}

func writeError(w http.ResponseWriter, err error, additionalHeaders http.Header) {
func writeError(logger log.Logger, w http.ResponseWriter, err error, additionalHeaders http.Header) {
switch err {
case context.Canceled:
err = errCanceled
Expand All @@ -453,21 +454,13 @@ func writeError(w http.ResponseWriter, err error, additionalHeaders http.Header)
}
}

resp, ok := httpgrpc.HTTPResponseFromError(err)
if ok {
for k, values := range additionalHeaders {
resp.Headers = append(resp.Headers, &httpgrpc.Header{Key: k, Values: values})
}
_ = server.WriteResponse(w, resp)
} else {
headers := w.Header()
for k, values := range additionalHeaders {
for _, value := range values {
headers.Set(k, value)
}
headers := w.Header()
for k, values := range additionalHeaders {
for _, value := range values {
headers.Set(k, value)
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
util_api.RespondFromGRPCError(logger, w, err)
}

func writeServiceTimingHeader(queryResponseTime time.Duration, headers http.Header, stats *querier_stats.QueryStats) {
Expand All @@ -488,7 +481,7 @@ func statsValue(name string, d time.Duration) string {
func getStatusCodeFromError(err error) int {
switch err {
case context.Canceled:
return StatusClientClosedRequest
return util_api.StatusClientClosedRequest
case context.DeadlineExceeded:
return http.StatusGatewayTimeout
default:
Expand Down
127 changes: 118 additions & 9 deletions pkg/frontend/transport/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package transport
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -13,14 +14,18 @@ import (

"github.com/go-kit/log"
"github.com/pkg/errors"
v1 "github.com/prometheus/client_golang/api/prometheus/v1"
"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/user"
"google.golang.org/grpc/codes"

querier_stats "github.com/cortexproject/cortex/pkg/querier/stats"
util_api "github.com/cortexproject/cortex/pkg/util/api"
util_log "github.com/cortexproject/cortex/pkg/util/log"
)

type roundTripperFunc func(*http.Request) (*http.Response, error)
Expand All @@ -34,19 +39,111 @@ func TestWriteError(t *testing.T) {
status int
err error
additionalHeaders http.Header
expectedErrResp util_api.Response
}{
{http.StatusInternalServerError, errors.New("unknown"), http.Header{"User-Agent": []string{"Golang"}}},
{http.StatusInternalServerError, errors.New("unknown"), nil},
{http.StatusGatewayTimeout, context.DeadlineExceeded, nil},
{StatusClientClosedRequest, context.Canceled, nil},
{StatusClientClosedRequest, context.Canceled, http.Header{"User-Agent": []string{"Golang"}}},
{StatusClientClosedRequest, context.Canceled, http.Header{"User-Agent": []string{"Golang"}, "Content-Type": []string{"application/json"}}},
{http.StatusBadRequest, httpgrpc.Errorf(http.StatusBadRequest, ""), http.Header{}},
{http.StatusRequestEntityTooLarge, errors.New("http: request body too large"), http.Header{}},
{
http.StatusInternalServerError,
errors.New("unknown"),
http.Header{"User-Agent": []string{"Golang"}},
util_api.Response{
Status: "error",
ErrorType: v1.ErrServer,
Error: "unknown",
},
},
{
http.StatusInternalServerError,
errors.New("unknown"),
nil,
util_api.Response{
Status: "error",
ErrorType: v1.ErrServer,
Error: "unknown",
},
},
{
http.StatusGatewayTimeout,
context.DeadlineExceeded,
nil,
util_api.Response{
Status: "error",
ErrorType: v1.ErrTimeout,
Error: "",
},
},
{
StatusClientClosedRequest,
context.Canceled,
nil,
util_api.Response{
Status: "error",
ErrorType: v1.ErrCanceled,
Error: "",
},
},
{
StatusClientClosedRequest,
context.Canceled,
http.Header{"User-Agent": []string{"Golang"}},
util_api.Response{
Status: "error",
ErrorType: v1.ErrCanceled,
Error: "",
},
},
{
StatusClientClosedRequest,
context.Canceled,
http.Header{"User-Agent": []string{"Golang"}, "Content-Type": []string{"application/json"}},
util_api.Response{
Status: "error",
ErrorType: v1.ErrCanceled,
Error: "",
},
},
{http.StatusBadRequest,
httpgrpc.Errorf(http.StatusBadRequest, ""),
http.Header{},
util_api.Response{
Status: "error",
ErrorType: v1.ErrBadData,
Error: "",
},
},
{
http.StatusRequestEntityTooLarge,
errors.New("http: request body too large"),
http.Header{},
util_api.Response{
Status: "error",
ErrorType: v1.ErrBadData,
Error: "http: request body too large",
},
},
{
http.StatusUnprocessableEntity,
httpgrpc.Errorf(http.StatusUnprocessableEntity, "limit hit"),
http.Header{},
util_api.Response{
Status: "error",
ErrorType: v1.ErrExec,
Error: "limit hit",
},
},
{
http.StatusUnprocessableEntity,
httpgrpc.Errorf(int(codes.PermissionDenied), "permission denied"),
http.Header{},
util_api.Response{
Status: "error",
ErrorType: v1.ErrBadData,
Error: "permission denied",
},
},
} {
t.Run(test.err.Error(), func(t *testing.T) {
w := httptest.NewRecorder()
writeError(w, test.err, test.additionalHeaders)
writeError(util_log.Logger, w, test.err, test.additionalHeaders)
require.Equal(t, test.status, w.Result().StatusCode)
expectedAdditionalHeaders := test.additionalHeaders
if expectedAdditionalHeaders != nil {
Expand All @@ -56,6 +153,18 @@ func TestWriteError(t *testing.T) {
}
}
}
data, err := io.ReadAll(w.Result().Body)
require.NoError(t, err)
var res util_api.Response
err = json.Unmarshal(data, &res)
require.NoError(t, err)
resp, ok := httpgrpc.HTTPResponseFromError(test.err)
if ok {
require.Equal(t, string(resp.Body), res.Error)
} else {
require.Equal(t, test.err.Error(), res.Error)

}
})
}
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/querier/tripperware/instantquery/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ func (l limitsMiddleware) Do(ctx context.Context, r tripperware.Request) (trippe
if maxQueryLength := validation.SmallestPositiveNonZeroDurationPerTenant(tenantIDs, l.MaxQueryLength); maxQueryLength > 0 {
expr, err := parser.ParseExpr(r.GetQuery())
if err != nil {
// Let Querier propagates the parsing error.
return l.next.Do(ctx, r)
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}

// Enforce query length across all selectors in the query.
Expand Down
7 changes: 7 additions & 0 deletions pkg/querier/tripperware/instantquery/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package instantquery

import (
"context"
"net/http"
"testing"
"time"

"github.com/prometheus/prometheus/promql/parser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/user"

"github.com/cortexproject/cortex/pkg/querier/tripperware"
Expand All @@ -20,6 +23,9 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {
thirtyDays = 30 * 24 * time.Hour
)

wrongQuery := `up[`
_, parserErr := parser.ParseExpr(wrongQuery)

tests := map[string]struct {
maxQueryLength time.Duration
query string
Expand All @@ -31,6 +37,7 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {
"even though failed to parse expression, should return no error since request will pass to next middleware": {
query: `up[`,
maxQueryLength: thirtyDays,
expectedErr: httpgrpc.Errorf(http.StatusBadRequest, parserErr.Error()).Error(),
},
"should succeed on a query not exceeding time range": {
query: `up`,
Expand Down
3 changes: 1 addition & 2 deletions pkg/querier/tripperware/queryrange/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ func (l limitsMiddleware) Do(ctx context.Context, r tripperware.Request) (trippe

expr, err := parser.ParseExpr(r.GetQuery())
if err != nil {
// Let Querier propagates the parsing error.
return l.next.Do(ctx, r)
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}

// Enforce query length across all selectors in the query.
Expand Down
7 changes: 7 additions & 0 deletions pkg/querier/tripperware/queryrange/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package queryrange

import (
"context"
"net/http"
"testing"
"time"

"github.com/prometheus/prometheus/promql/parser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/user"

"github.com/cortexproject/cortex/pkg/querier/tripperware"
Expand Down Expand Up @@ -115,6 +118,9 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {

now := time.Now()

wrongQuery := `up[`
_, parserErr := parser.ParseExpr(wrongQuery)

tests := map[string]struct {
maxQueryLength time.Duration
query string
Expand All @@ -132,6 +138,7 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {
reqStartTime: now.Add(-time.Hour),
reqEndTime: now,
maxQueryLength: thirtyDays,
expectedErr: httpgrpc.Errorf(http.StatusBadRequest, parserErr.Error()).Error(),
},
"should succeed on a query on short time range, ending now": {
maxQueryLength: thirtyDays,
Expand Down
7 changes: 1 addition & 6 deletions pkg/querier/tripperware/queryrange/split_by_interval.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,7 @@ func (s splitByInterval) Do(ctx context.Context, r tripperware.Request) (tripper
// to line up the boundaries with step.
reqs, err := splitQuery(r, s.interval(r))
if err != nil {
// If the query itself is bad, we don't return error but send the query
// to querier to return the expected error message. This is not very efficient
// but should be okay for now.
// TODO(yeya24): query frontend can reuse the Prometheus API handler and return
// expected error message locally without passing it to the querier through network.
return s.next.Do(ctx, r)
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}
s.splitByCounter.Add(float64(len(reqs)))

Expand Down
5 changes: 2 additions & 3 deletions pkg/querier/tripperware/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func NewQueryTripperware(
tenantIDs, err := tenant.TenantIDs(r.Context())
// This should never happen anyways because we have auth middleware before this.
if err != nil {
return nil, err
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}
now := time.Now()
userStr := tenant.JoinTenantIDs(tenantIDs)
Expand All @@ -161,8 +161,7 @@ func NewQueryTripperware(

expr, err := parser.ParseExpr(query)
if err != nil {
// If query is invalid, no need to go through tripperwares for further splitting.
return next.RoundTrip(r)
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}

reqStats := stats.FromContext(r.Context())
Expand Down
3 changes: 2 additions & 1 deletion pkg/querier/tripperware/shard_by.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (s shardBy) Do(ctx context.Context, r Request) (Response, error) {
analysis, err := s.analyzer.Analyze(r.GetQuery())
if err != nil {
level.Warn(logger).Log("msg", "error analyzing query", "q", r.GetQuery(), "err", err)
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}

stats.AddExtraFields(
Expand All @@ -63,7 +64,7 @@ func (s shardBy) Do(ctx context.Context, r Request) (Response, error) {
"shard_by.sharding_labels", analysis.ShardingLabels(),
)

if err != nil || !analysis.IsShardable() {
if !analysis.IsShardable() {
return s.next.Do(ctx, r)
}

Expand Down
Loading

0 comments on commit 6bca2d5

Please sign in to comment.