Skip to content

Commit

Permalink
[receiver/splunkhec] fix memory leak (#36146)
Browse files Browse the repository at this point in the history
<!--Ex. Fixing a bug - Describe the bug and how this fixes the issue.
Ex. Adding a feature - Explain what this achieves.-->
#### Description
Fix memory leak by changing how we run obsreports for metrics and logs.

<!-- Issue number (e.g. #1234) or full URL to issue, if applicable. -->
#### Link to tracking issue
Fixes
#35294
  • Loading branch information
atoulme authored Nov 3, 2024
1 parent 806a4bd commit 1a9fc3b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 55 deletions.
27 changes: 27 additions & 0 deletions .chloggen/receiver_ops2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Use this changelog template to create an entry for release notes.

# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix'
change_type: bug_fix

# The name of the component, or a single word describing the area of concern, (e.g. filelogreceiver)
component: splunkhecreceiver

# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`).
note: Avoid a memory leak by changing how we record obsreports for logs and metrics.

# Mandatory: One or more tracking issues related to the change. You can use the PR number here if no issue exists.
issues: [35294]

# (Optional) One or more lines of additional information to render under the primary note.
# These lines will be padded with 2 spaces and then inserted directly into the document.
# Use pipe (|) for multiline entries.
subtext:

# If your change doesn't affect end users or the exported elements of any package,
# you should instead start your pull request title with [chore] or use the "Skip Changelog" label.
# Optional: The change log or logs in which this entry should be included.
# e.g. '[user]' or '[user, api]'
# Include 'user' if the change is relevant to end users.
# Include 'api' if there is a change to a library API.
# Default: '[user]'
change_logs: []
90 changes: 35 additions & 55 deletions receiver/splunkhecreceiver/receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,26 @@ func (r *splunkReceiver) processSuccessResponse(resp http.ResponseWriter, bodyCo
}

func (r *splunkReceiver) handleAck(resp http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if req.Method != http.MethodPost {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, 0, errInvalidMethod)
r.failRequest(resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, errInvalidMethod)
return
}

// shouldn't run into this case since we only enable this handler IF ackExt exists. But we have this check just in case
if r.ackExt == nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, 0, errExtensionMissing)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, errExtensionMissing)
return
}

var channelID string
var extracted bool
if channelID, extracted = r.extractChannel(req); extracted {
if channelErr := r.validateChannelHeader(channelID); channelErr != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr)
r.failRequest(resp, http.StatusBadRequest, []byte(channelErr.Error()), channelErr)
return
}
} else {
r.failRequest(ctx, resp, http.StatusBadRequest, requiredDataChannelHeader, 0, nil)
r.failRequest(resp, http.StatusBadRequest, requiredDataChannelHeader, nil)
return
}

Expand All @@ -241,19 +240,19 @@ func (r *splunkReceiver) handleAck(resp http.ResponseWriter, req *http.Request)

err := dec.Decode(&ackRequest)
if err != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, 0, err)
r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, err)
return
}

if len(ackRequest.Acks) == 0 {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, 0, errors.New("request body must include at least one ackID to be queried"))
r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, errors.New("request body must include at least one ackID to be queried"))
return
}

queriedAcks := r.ackExt.QueryAcks(channelID, ackRequest.Acks)
ackString, _ := json.Marshal(queriedAcks)
if err := r.processSuccessResponse(resp, []byte(fmt.Sprintf(ackResponse, ackString))); err != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, 0, err)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, err)
}
}

Expand All @@ -262,28 +261,28 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques
ctx = r.obsrecv.StartLogsOp(ctx)

if req.Method != http.MethodPost {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, 0, errInvalidMethod)
r.failRequest(resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, errInvalidMethod)
return
}

encoding := req.Header.Get(httpContentEncodingHeader)
if encoding != "" && encoding != gzipEncoding {
r.failRequest(ctx, resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, 0, errInvalidEncoding)
r.failRequest(resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, errInvalidEncoding)
return
}

var channelID string
var extracted bool
if channelID, extracted = r.extractChannel(req); extracted {
if channelErr := r.validateChannelHeader(channelID); channelErr != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr)
r.failRequest(resp, http.StatusBadRequest, []byte(channelErr.Error()), channelErr)
return
}
}

if req.ContentLength == 0 {
r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), 0, nil)
r.failRequest(ctx, resp, http.StatusBadRequest, noDataRespBody, 0, nil)
r.failRequest(resp, http.StatusBadRequest, noDataRespBody, nil)
return
}

Expand All @@ -293,7 +292,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques
err := reader.Reset(bodyReader)

if err != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, errGzipReaderRespBody, 0, err)
r.failRequest(resp, http.StatusBadRequest, errGzipReaderRespBody, err)
_, _ = io.ReadAll(req.Body)
_ = req.Body.Close()
return
Expand All @@ -311,23 +310,23 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques
err = errors.New("time cannot be less than 0")
}
if err != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, 0, err)
r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, err)
return
}
timestamp = pcommon.NewTimestampFromTime(time.Unix(t, 0))
}

ld, slLen, err := splunkHecRawToLogData(bodyReader, query, resourceCustomizer, r.config, timestamp)
if err != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, slLen, err)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, err)
return
}
consumerErr := r.logsConsumer.ConsumeLogs(ctx, ld)

_ = bodyReader.Close()

if consumerErr != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, slLen, consumerErr)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, consumerErr)
} else {
var ackErr error
if len(channelID) > 0 && r.ackExt != nil {
Expand All @@ -336,7 +335,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques
ackErr = r.processSuccessResponse(resp, okRespBody)
}
if ackErr != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, ld.LogRecordCount(), err)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, err)
} else {
r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), slLen, nil)
}
Expand Down Expand Up @@ -377,28 +376,22 @@ func (r *splunkReceiver) validateChannelHeader(channelID string) error {

func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if r.logsConsumer != nil {
ctx = r.obsrecv.StartLogsOp(ctx)
}
if r.metricsConsumer != nil {
ctx = r.obsrecv.StartMetricsOp(ctx)
}

if req.Method != http.MethodPost {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, 0, errInvalidMethod)
r.failRequest(resp, http.StatusBadRequest, invalidMethodRespBodyPostOnly, errInvalidMethod)
return
}

encoding := req.Header.Get(httpContentEncodingHeader)
if encoding != "" && encoding != gzipEncoding {
r.failRequest(ctx, resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, 0, errInvalidEncoding)
r.failRequest(resp, http.StatusUnsupportedMediaType, invalidEncodingRespBody, errInvalidEncoding)
return
}

channelID, extracted := r.extractChannel(req)
if extracted {
if channelErr := r.validateChannelHeader(channelID); channelErr != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr)
r.failRequest(resp, http.StatusBadRequest, []byte(channelErr.Error()), channelErr)
return
}
}
Expand All @@ -408,15 +401,15 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request)
reader := r.gzipReaderPool.Get().(*gzip.Reader)
err := reader.Reset(bodyReader)
if err != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, errGzipReaderRespBody, 0, err)
r.failRequest(resp, http.StatusBadRequest, errGzipReaderRespBody, err)
return
}
bodyReader = reader
defer r.gzipReaderPool.Put(reader)
}

if req.ContentLength == 0 {
r.failRequest(ctx, resp, http.StatusBadRequest, noDataRespBody, 0, nil)
r.failRequest(resp, http.StatusBadRequest, noDataRespBody, nil)
return
}

Expand All @@ -429,35 +422,35 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request)
var msg splunk.Event
err := dec.Decode(&msg)
if err != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, invalidFormatRespBody, len(events)+len(metricEvents), err)
r.failRequest(resp, http.StatusBadRequest, invalidFormatRespBody, err)
return
}

if msg.Event == nil {
r.failRequest(ctx, resp, http.StatusBadRequest, eventRequiredRespBody, len(events)+len(metricEvents), nil)
r.failRequest(resp, http.StatusBadRequest, eventRequiredRespBody, nil)
return
}

if msg.Event == "" {
r.failRequest(ctx, resp, http.StatusBadRequest, eventBlankRespBody, len(events)+len(metricEvents), nil)
r.failRequest(resp, http.StatusBadRequest, eventBlankRespBody, nil)
return
}

for _, v := range msg.Fields {
if !isFlatJSONField(v) {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(fmt.Sprintf(responseErrHandlingIndexedFields, len(events)+len(metricEvents))), len(events)+len(metricEvents), nil)
r.failRequest(resp, http.StatusBadRequest, []byte(fmt.Sprintf(responseErrHandlingIndexedFields, len(events)+len(metricEvents))), nil)
return
}
}
if msg.IsMetric() {
if r.metricsConsumer == nil {
r.failRequest(ctx, resp, http.StatusBadRequest, errUnsupportedMetricEvent, len(metricEvents), err)
r.failRequest(resp, http.StatusBadRequest, errUnsupportedMetricEvent, err)
return
}
metricEvents = append(metricEvents, &msg)
} else {
if r.logsConsumer == nil {
r.failRequest(ctx, resp, http.StatusBadRequest, errUnsupportedLogEvent, len(events), err)
r.failRequest(resp, http.StatusBadRequest, errUnsupportedLogEvent, err)
return
}
events = append(events, &msg)
Expand All @@ -468,21 +461,24 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request)
if r.logsConsumer != nil && len(events) > 0 {
ld, err := splunkHecToLogData(r.settings.Logger, events, resourceCustomizer, r.config)
if err != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, errUnmarshalBodyRespBody, len(events), err)
r.failRequest(resp, http.StatusBadRequest, errUnmarshalBodyRespBody, err)
return
}
ctx = r.obsrecv.StartLogsOp(ctx)
decodeErr := r.logsConsumer.ConsumeLogs(ctx, ld)
r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), len(events), nil)
if decodeErr != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, len(events), decodeErr)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, decodeErr)
return
}
}
if r.metricsConsumer != nil && len(metricEvents) > 0 {
md, _ := splunkHecToMetricsData(r.settings.Logger, metricEvents, resourceCustomizer, r.config)

ctx = r.obsrecv.StartMetricsOp(ctx)
decodeErr := r.metricsConsumer.ConsumeMetrics(ctx, md)
r.obsrecv.EndMetricsOp(ctx, metadata.Type.String(), len(metricEvents), nil)
if decodeErr != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, len(metricEvents), decodeErr)
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, decodeErr)
return
}
}
Expand All @@ -494,14 +490,7 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request)
ackErr = r.processSuccessResponse(resp, okRespBody)
}
if ackErr != nil {
r.failRequest(ctx, resp, http.StatusInternalServerError, errInternalServerError, len(events)+len(metricEvents), ackErr)
} else {
if r.logsConsumer != nil {
r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), len(events), nil)
}
if r.metricsConsumer != nil {
r.obsrecv.EndMetricsOp(ctx, metadata.Type.String(), len(metricEvents), nil)
}
r.failRequest(resp, http.StatusInternalServerError, errInternalServerError, ackErr)
}
}

Expand All @@ -519,11 +508,9 @@ func (r *splunkReceiver) createResourceCustomizer(req *http.Request) func(resour
}

func (r *splunkReceiver) failRequest(
ctx context.Context,
resp http.ResponseWriter,
httpStatusCode int,
jsonResponse []byte,
numRecordsReceived int,
err error,
) {
resp.WriteHeader(httpStatusCode)
Expand All @@ -536,13 +523,6 @@ func (r *splunkReceiver) failRequest(
}
}

if r.logsConsumer != nil {
r.obsrecv.EndLogsOp(ctx, metadata.Type.String(), numRecordsReceived, err)
}
if r.metricsConsumer != nil {
r.obsrecv.EndMetricsOp(ctx, metadata.Type.String(), numRecordsReceived, err)
}

if r.settings.Logger.Core().Enabled(zap.DebugLevel) {
msg := string(jsonResponse)
r.settings.Logger.Debug(
Expand Down

0 comments on commit 1a9fc3b

Please sign in to comment.