From 1a9fc3b6f4bd52bfc084ff24e79bba07e247e6a8 Mon Sep 17 00:00:00 2001 From: Antoine Toulme Date: Sat, 2 Nov 2024 18:24:15 -0700 Subject: [PATCH] [receiver/splunkhec] fix memory leak (#36146) #### Description Fix memory leak by changing how we run obsreports for metrics and logs. #### Link to tracking issue Fixes https://github.com/open-telemetry/opentelemetry-collector-contrib/issues/35294 --- .chloggen/receiver_ops2.yaml | 27 ++++++++ receiver/splunkhecreceiver/receiver.go | 90 ++++++++++---------------- 2 files changed, 62 insertions(+), 55 deletions(-) create mode 100644 .chloggen/receiver_ops2.yaml diff --git a/.chloggen/receiver_ops2.yaml b/.chloggen/receiver_ops2.yaml new file mode 100644 index 000000000000..e228db09d2b6 --- /dev/null +++ b/.chloggen/receiver_ops2.yaml @@ -0,0 +1,27 @@ +# Use this changelog template to create an entry for release notes. + +# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix' +change_type: bug_fix + +# The name of the component, or a single word describing the area of concern, (e.g. filelogreceiver) +component: splunkhecreceiver + +# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`). +note: Avoid a memory leak by changing how we record obsreports for logs and metrics. + +# Mandatory: One or more tracking issues related to the change. You can use the PR number here if no issue exists. +issues: [35294] + +# (Optional) One or more lines of additional information to render under the primary note. +# These lines will be padded with 2 spaces and then inserted directly into the document. +# Use pipe (|) for multiline entries. +subtext: + +# If your change doesn't affect end users or the exported elements of any package, +# you should instead start your pull request title with [chore] or use the "Skip Changelog" label. +# Optional: The change log or logs in which this entry should be included. +# e.g. '[user]' or '[user, api]' +# Include 'user' if the change is relevant to end users. +# Include 'api' if there is a change to a library API. +# Default: '[user]' +change_logs: [] \ No newline at end of file diff --git a/receiver/splunkhecreceiver/receiver.go b/receiver/splunkhecreceiver/receiver.go index 1d091bf3c55a..26043b1dfc7e 100644 --- a/receiver/splunkhecreceiver/receiver.go +++ b/receiver/splunkhecreceiver/receiver.go @@ -212,15 +212,14 @@ func (r *splunkReceiver) processSuccessResponse(resp http.ResponseWriter, bodyCo } func (r *splunkReceiver) handleAck(resp http.ResponseWriter, req *http.Request) { - ctx := req.Context() if req.Method != http.MethodPost { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, 0, errInvalidMethod) + r.failRequest(resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, errInvalidMethod) return } // shouldn't run into this case since we only enable this handler IF ackExt exists. But we have this check just in case if r.ackExt == nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, 0, errExtensionMissing) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, errExtensionMissing) return } @@ -228,11 +227,11 @@ func (r *splunkReceiver) handleAck(resp http.ResponseWriter, req *http.Request) var extracted bool if channelID, extracted = r.extractChannel(req); extracted { if channelErr := r.validateChannelHeader(channelID); channelErr != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr) + r.failRequest(resp, http.StatusBadRequest, []byte(channelErr.Error()), channelErr) return } } else { - r.failRequest(ctx, resp, http.StatusBadRequest, requiredDataChannelHeader, 0, nil) + r.failRequest(resp, http.StatusBadRequest, requiredDataChannelHeader, nil) return } @@ -241,19 +240,19 @@ func (r *splunkReceiver) handleAck(resp http.ResponseWriter, req *http.Request) err := dec.Decode(&ackRequest) if err != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, 0, err) + r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, err) return } if len(ackRequest.Acks) == 0 { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, 0, errors.New("request body must include at least one ackID to be queried")) + r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, errors.New("request body must include at least one ackID to be queried")) return } queriedAcks := r.ackExt.QueryAcks(channelID, ackRequest.Acks) ackString, _ := json.Marshal(queriedAcks) if err := r.processSuccessResponse(resp, []byte(fmt.Sprintf(ackResponse, ackString))); err != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, 0, err) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, err) } } @@ -262,13 +261,13 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques ctx = r.obsrecv.StartLogsOp(ctx) if req.Method != http.MethodPost { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, 0, errInvalidMethod) + r.failRequest(resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, errInvalidMethod) return } encoding := req.Header.Get(httpContentEncodingHeader) if encoding != "" && encoding != gzipEncoding { - r.failRequest(ctx, resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, 0, errInvalidEncoding) + r.failRequest(resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, errInvalidEncoding) return } @@ -276,14 +275,14 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques var extracted bool if channelID, extracted = r.extractChannel(req); extracted { if channelErr := r.validateChannelHeader(channelID); channelErr != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr) + r.failRequest(resp, http.StatusBadRequest, []byte(channelErr.Error()), channelErr) return } } if req.ContentLength == 0 { r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), 0, nil) - r.failRequest(ctx, resp, http.StatusBadRequest, noDataRespBody, 0, nil) + r.failRequest(resp, http.StatusBadRequest, noDataRespBody, nil) return } @@ -293,7 +292,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques err := reader.Reset(bodyReader) if err != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, errGzipReaderRespBody, 0, err) + r.failRequest(resp, http.StatusBadRequest, errGzipReaderRespBody, err) _, _ = io.ReadAll(req.Body) _ = req.Body.Close() return @@ -311,7 +310,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques err = errors.New("time cannot be less than 0") } if err != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, 0, err) + r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, err) return } timestamp = pcommon.NewTimestampFromTime(time.Unix(t, 0)) @@ -319,7 +318,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques ld, slLen, err := splunkHecRawToLogData(bodyReader, query, resourceCustomizer, r.config, timestamp) if err != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, slLen, err) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, err) return } consumerErr := r.logsConsumer.ConsumeLogs(ctx, ld) @@ -327,7 +326,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques _ = bodyReader.Close() if consumerErr != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, slLen, consumerErr) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, consumerErr) } else { var ackErr error if len(channelID) > 0 && r.ackExt != nil { @@ -336,7 +335,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques ackErr = r.processSuccessResponse(resp, okRespBody) } if ackErr != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, ld.LogRecordCount(), err) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, err) } else { r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), slLen, nil) } @@ -377,28 +376,22 @@ func (r *splunkReceiver) validateChannelHeader(channelID string) error { func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) { ctx := req.Context() - if r.logsConsumer != nil { - ctx = r.obsrecv.StartLogsOp(ctx) - } - if r.metricsConsumer != nil { - ctx = r.obsrecv.StartMetricsOp(ctx) - } if req.Method != http.MethodPost { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, 0, errInvalidMethod) + r.failRequest(resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, errInvalidMethod) return } encoding := req.Header.Get(httpContentEncodingHeader) if encoding != "" && encoding != gzipEncoding { - r.failRequest(ctx, resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, 0, errInvalidEncoding) + r.failRequest(resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, errInvalidEncoding) return } channelID, extracted := r.extractChannel(req) if extracted { if channelErr := r.validateChannelHeader(channelID); channelErr != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr) + r.failRequest(resp, http.StatusBadRequest, []byte(channelErr.Error()), channelErr) return } } @@ -408,7 +401,7 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) reader := r.gzipReaderPool.Get().(*gzip.Reader) err := reader.Reset(bodyReader) if err != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, errGzipReaderRespBody, 0, err) + r.failRequest(resp, http.StatusBadRequest, errGzipReaderRespBody, err) return } bodyReader = reader @@ -416,7 +409,7 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) } if req.ContentLength == 0 { - r.failRequest(ctx, resp, http.StatusBadRequest, noDataRespBody, 0, nil) + r.failRequest(resp, http.StatusBadRequest, noDataRespBody, nil) return } @@ -429,35 +422,35 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) var msg splunk.Event err := dec.Decode(&msg) if err != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, len(events)+len(metricEvents), err) + r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, err) return } if msg.Event == nil { - r.failRequest(ctx, resp, http.StatusBadRequest, eventRequiredRespBody, len(events)+len(metricEvents), nil) + r.failRequest(resp, http.StatusBadRequest, eventRequiredRespBody, nil) return } if msg.Event == "" { - r.failRequest(ctx, resp, http.StatusBadRequest, eventBlankRespBody, len(events)+len(metricEvents), nil) + r.failRequest(resp, http.StatusBadRequest, eventBlankRespBody, nil) return } for _, v := range msg.Fields { if !isFlatJSONField(v) { - r.failRequest(ctx, resp, http.StatusBadRequest, []byte(fmt.Sprintf(responseErrHandlingIndexedFields, len(events)+len(metricEvents))), len(events)+len(metricEvents), nil) + r.failRequest(resp, http.StatusBadRequest, []byte(fmt.Sprintf(responseErrHandlingIndexedFields, len(events)+len(metricEvents))), nil) return } } if msg.IsMetric() { if r.metricsConsumer == nil { - r.failRequest(ctx, resp, http.StatusBadRequest, errUnsupportedMetricEvent, len(metricEvents), err) + r.failRequest(resp, http.StatusBadRequest, errUnsupportedMetricEvent, err) return } metricEvents = append(metricEvents, &msg) } else { if r.logsConsumer == nil { - r.failRequest(ctx, resp, http.StatusBadRequest, errUnsupportedLogEvent, len(events), err) + r.failRequest(resp, http.StatusBadRequest, errUnsupportedLogEvent, err) return } events = append(events, &msg) @@ -468,21 +461,24 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) if r.logsConsumer != nil && len(events) > 0 { ld, err := splunkHecToLogData(r.settings.Logger, events, resourceCustomizer, r.config) if err != nil { - r.failRequest(ctx, resp, http.StatusBadRequest, errUnmarshalBodyRespBody, len(events), err) + r.failRequest(resp, http.StatusBadRequest, errUnmarshalBodyRespBody, err) return } + ctx = r.obsrecv.StartLogsOp(ctx) decodeErr := r.logsConsumer.ConsumeLogs(ctx, ld) + r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), len(events), nil) if decodeErr != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, len(events), decodeErr) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, decodeErr) return } } if r.metricsConsumer != nil && len(metricEvents) > 0 { md, _ := splunkHecToMetricsData(r.settings.Logger, metricEvents, resourceCustomizer, r.config) - + ctx = r.obsrecv.StartMetricsOp(ctx) decodeErr := r.metricsConsumer.ConsumeMetrics(ctx, md) + r.obsrecv.EndMetricsOp(ctx, metadata.Type.String(), len(metricEvents), nil) if decodeErr != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, len(metricEvents), decodeErr) + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, decodeErr) return } } @@ -494,14 +490,7 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) ackErr = r.processSuccessResponse(resp, okRespBody) } if ackErr != nil { - r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, len(events)+len(metricEvents), ackErr) - } else { - if r.logsConsumer != nil { - r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), len(events), nil) - } - if r.metricsConsumer != nil { - r.obsrecv.EndMetricsOp(ctx, metadata.Type.String(), len(metricEvents), nil) - } + r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, ackErr) } } @@ -519,11 +508,9 @@ func (r *splunkReceiver) createResourceCustomizer(req *http.Request) func(resour } func (r *splunkReceiver) failRequest( - ctx context.Context, resp http.ResponseWriter, httpStatusCode int, jsonResponse []byte, - numRecordsReceived int, err error, ) { resp.WriteHeader(httpStatusCode) @@ -536,13 +523,6 @@ func (r *splunkReceiver) failRequest( } } - if r.logsConsumer != nil { - r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), numRecordsReceived, err) - } - if r.metricsConsumer != nil { - r.obsrecv.EndMetricsOp(ctx, metadata.Type.String(), numRecordsReceived, err) - } - if r.settings.Logger.Core().Enabled(zap.DebugLevel) { msg := string(jsonResponse) r.settings.Logger.Debug(