Skip to content

Commit

Permalink
Query Frontend Refactor: Fix goroutine leaks (#3495)
Browse files Browse the repository at this point in the history
* drain channels

Signed-off-by: Joe Elliott <[email protected]>

* remove duplicate test

Signed-off-by: Joe Elliott <[email protected]>

* lint

Signed-off-by: Joe Elliott <[email protected]>

* fix

Signed-off-by: Joe Elliott <[email protected]>

* remove handlers test

Signed-off-by: Joe Elliott <[email protected]>

* simplify + tests pass

Signed-off-by: Joe Elliott <[email protected]>

* lint

Signed-off-by: Joe Elliott <[email protected]>

* drain async for proper cancellation

Signed-off-by: Joe Elliott <[email protected]>

* Add context to send

Signed-off-by: Joe Elliott <[email protected]>

---------

Signed-off-by: Joe Elliott <[email protected]>
  • Loading branch information
joe-elliott authored Mar 20, 2024
1 parent bc2e006 commit 85cb5de
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 23 deletions.
2 changes: 1 addition & 1 deletion modules/frontend/pipeline/async_handler_multitenant.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (t *tenantRoundTripper) RoundTrip(req *http.Request) (Responses[*http.Respo
// join tenants for logger because list value type is unsupported.
_ = level.Debug(t.logger).Log("msg", "handling multi-tenant query", "tenants", strings.Join(tenants, ","))

return NewAsyncSharderFunc(0, len(tenants), func(tenantIdx int) *http.Request {
return NewAsyncSharderFunc(req.Context(), 0, len(tenants), func(tenantIdx int) *http.Request {
if tenantIdx >= len(tenants) {
return nil
}
Expand Down
16 changes: 8 additions & 8 deletions modules/frontend/pipeline/async_sharding.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pipeline

import (
"context"
"net/http"
"sync"

Expand All @@ -15,7 +16,7 @@ type waitGroup interface {

// NewAsyncSharderFunc creates a new AsyncResponse that shards requests to the next AsyncRoundTripper[*http.Response]. It creates one
// goroutine per concurrent request.
func NewAsyncSharderFunc(concurrentReqs, totalReqs int, reqFn func(i int) *http.Request, next AsyncRoundTripper[*http.Response]) Responses[*http.Response] {
func NewAsyncSharderFunc(ctx context.Context, concurrentReqs, totalReqs int, reqFn func(i int) *http.Request, next AsyncRoundTripper[*http.Response]) Responses[*http.Response] {
var wg waitGroup
if concurrentReqs <= 0 {
wg = &sync.WaitGroup{}
Expand All @@ -26,7 +27,7 @@ func NewAsyncSharderFunc(concurrentReqs, totalReqs int, reqFn func(i int) *http.
asyncResp := newAsyncResponse()

go func() {
defer asyncResp.done()
defer asyncResp.SendComplete()

for i := 0; i < totalReqs; i++ {
req := reqFn(i)
Expand All @@ -50,7 +51,7 @@ func NewAsyncSharderFunc(concurrentReqs, totalReqs int, reqFn func(i int) *http.
return
}

asyncResp.Send(resp)
asyncResp.Send(ctx, resp)
}(req)
}

Expand All @@ -61,7 +62,7 @@ func NewAsyncSharderFunc(concurrentReqs, totalReqs int, reqFn func(i int) *http.
}

// NewAsyncSharderChan creates a new AsyncResponse that shards requests to the next AsyncRoundTripper[*http.Response] using a limited number of goroutines.
func NewAsyncSharderChan(concurrentReqs int, reqs <-chan *http.Request, resps Responses[*http.Response], next AsyncRoundTripper[*http.Response]) Responses[*http.Response] {
func NewAsyncSharderChan(ctx context.Context, concurrentReqs int, reqs <-chan *http.Request, resps Responses[*http.Response], next AsyncRoundTripper[*http.Response]) Responses[*http.Response] {
if concurrentReqs == 0 {
panic("NewAsyncSharderChan: concurrentReqs must be greater than 0")
}
Expand All @@ -85,20 +86,19 @@ func NewAsyncSharderChan(concurrentReqs int, reqs <-chan *http.Request, resps Re
continue
}

asyncResp.Send(resp)
asyncResp.Send(ctx, resp)
}
}()
}

go func() {
// send any responses back the caller would like to send
if resps != nil {
asyncResp.Send(resps)
asyncResp.Send(ctx, resps)
}

// and wait for all the workers to finish
wg.Wait()
asyncResp.done()
asyncResp.SendComplete()
}()

return asyncResp
Expand Down
6 changes: 3 additions & 3 deletions modules/frontend/pipeline/async_sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestAsyncSharders(t *testing.T) {
{
name: "AsyncSharder",
responseFn: func(next AsyncRoundTripper[*http.Response]) *asyncResponse {
return NewAsyncSharderFunc(10, expectedRequestCount, func(i int) *http.Request {
return NewAsyncSharderFunc(context.Background(), 10, expectedRequestCount, func(i int) *http.Request {
if i >= expectedRequestCount {
return nil
}
Expand All @@ -32,7 +32,7 @@ func TestAsyncSharders(t *testing.T) {
{
name: "AsyncSharder - no limit",
responseFn: func(next AsyncRoundTripper[*http.Response]) *asyncResponse {
return NewAsyncSharderFunc(0, expectedRequestCount, func(i int) *http.Request {
return NewAsyncSharderFunc(context.Background(), 0, expectedRequestCount, func(i int) *http.Request {
if i >= expectedRequestCount {
return nil
}
Expand All @@ -51,7 +51,7 @@ func TestAsyncSharders(t *testing.T) {
close(reqChan)
}()

return NewAsyncSharderChan(10, reqChan, nil, next).(*asyncResponse)
return NewAsyncSharderChan(context.Background(), 10, reqChan, nil, next).(*asyncResponse)
},
},
}
Expand Down
1 change: 1 addition & 0 deletions modules/frontend/pipeline/collector_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func (c GRPCCollector[T]) RoundTrip(req *http.Request) error {
if err != nil {
return grpcError(err)
}

return nil
}

Expand Down
1 change: 1 addition & 0 deletions modules/frontend/pipeline/collector_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ func (r httpCollector) RoundTrip(req *http.Request) (*http.Response, error) {
}

resp, err := r.combiner.HTTPFinal()

return resp, err
}
11 changes: 8 additions & 3 deletions modules/frontend/pipeline/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
)

type Responses[T any] interface {
// Next returns the next response or an error if one is available. It always prefers an error over a response.
Next(context.Context) (T, bool, error) // bool = done
}

Expand Down Expand Up @@ -72,8 +73,11 @@ func newAsyncResponse() *asyncResponse {
}
}

func (a *asyncResponse) Send(r Responses[*http.Response]) {
a.respChan <- r
func (a *asyncResponse) Send(ctx context.Context, r Responses[*http.Response]) {
select {
case <-ctx.Done():
case a.respChan <- r:
}
}

// SendError sends an error to the asyncResponse. This will cause the asyncResponse to return the error on the next call to Next.
Expand All @@ -87,7 +91,8 @@ func (a *asyncResponse) SendError(err error) {
}
}

func (a *asyncResponse) done() {
// SendComplete indicates the sender is done. We close the channel to give a clear signal to the consumer
func (a *asyncResponse) SendComplete() {
close(a.respChan)
}

Expand Down
Loading

0 comments on commit 85cb5de

Please sign in to comment.