diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index 994d09b1..9d5eddcd 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -30,16 +30,13 @@ type JanusGraphEdgeWriter struct { gremlin types.EdgeTraversal // Gremlin traversal generator function drc *gremlingo.DriverRemoteConnection // Gremlin driver remote connection traversalSource *gremlingo.GraphTraversalSource // Transacted graph traversal source - inserts []any // Object data to be inserted in the graph - mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes - batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags writerTimeout time.Duration // Timeout for the writer maxRetry int // Maximum number of retries for failed writes + mb *microBatcher // Micro batcher to batch writes } // NewJanusGraphAsyncEdgeWriter creates a new bulk edge writer instance. @@ -47,8 +44,9 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo e edge.Builder, opts ...WriterOption, ) (*JanusGraphEdgeWriter, error) { options := &writerOptions{ - WriterTimeout: defaultWriterTimeout, - MaxRetry: defaultMaxRetry, + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + WriterWorkerCount: defaultWriterWorkerCount, } for _, opt := range opts { opt(options) @@ -59,101 +57,91 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo builder: builder, gremlin: e.Traversal(), drc: drc, - inserts: make([]any, 0, e.BatchSize()), traversalSource: gremlingo.Traversal_().WithRemote(drc), - batchSize: e.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan batchItem, e.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(e.Label()), tag.Builder(builder)), writerTimeout: options.WriterTimeout, maxRetry: options.MaxRetry, } - jw.startBackgroundWriter(ctx) - - return &jw, nil -} - -// startBackgroundWriter starts a background go routine -func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { - go func() { - for { - select { - case batch, ok := <-jgv.consumerChan: - // If the channel is closed, return. - if !ok { - log.Trace(ctx).Info("Closed background janusgraph worker on channel close") - - return - } - - // If the batch is empty, return. - if len(batch.data) == 0 { - log.Trace(ctx).Warn("Empty edge batch received in background janusgraph worker, skipping") - - return - } - - _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, batch.data) - if err != nil { - var e *batchWriterError - if errors.As(err, &e) { - // If the error is retryable, retry the write operation with a smaller batch. - if e.retryable && batch.retryCount < jgv.maxRetry { - jgv.retrySplitAndRequeue(ctx, &batch, e) - - continue - } - - log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) - } - - log.Trace(ctx).Errorf("write data in background batch writer: %v", err) - } - - _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) - case <-ctx.Done(): - log.Trace(ctx).Info("Closed background janusgraph worker on context cancel") - - return + // Create a new micro batcher to batch the inserts with split and retry logic. + jw.mb = newMicroBatcher(log.Trace(ctx), e.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Try to write the batch to the graph DB. + if err := jw.batchWrite(ctx, a); err != nil { + var bwe *batchWriterError + if errors.As(err, &bwe) && bwe.retryable { + // If the write operation failed and is retryable, split the batch and retry. + return jw.splitAndRetry(ctx, 0, a) } + + return err } - }() + + return nil + }) + jw.mb.Start(ctx) + + return &jw, nil } // retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. -func (jgv *JanusGraphEdgeWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *batchWriterError) { +func (jgv *JanusGraphEdgeWriter) splitAndRetry(ctx context.Context, retryCount int, payload []any) error { _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + // If we have reached the maximum number of retries, return an error. + if retryCount >= jgv.maxRetry { + return fmt.Errorf("max retry count reached: %d", retryCount) + } + // Compute the new batch size. - newBatchSize := len(batch.data) / 2 - batch.retryCount++ + newBatchSize := len(payload) / 2 + + log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d)", len(payload), newBatchSize, retryCount) - log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + var leftErr, rightErr error - // Split the batch into smaller chunks and requeue them. - if len(batch.data[:newBatchSize]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[:newBatchSize], - retryCount: batch.retryCount, + // Split the batch into smaller chunks and retry them. + if len(payload[:newBatchSize]) > 0 { + if leftErr = jgv.batchWrite(ctx, payload[:newBatchSize]); leftErr == nil { + var bwe *batchWriterError + if errors.As(leftErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[:newBatchSize]) + } } } - if len(batch.data[newBatchSize:]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[newBatchSize:], - retryCount: batch.retryCount, + + // Process the right side of the batch. + if len(payload[newBatchSize:]) > 0 { + if rightErr = jgv.batchWrite(ctx, payload[newBatchSize:]); rightErr != nil { + var bwe *batchWriterError + if errors.As(rightErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[newBatchSize:]) + } } } + + // Return the first error encountered. + switch { + case leftErr != nil && rightErr != nil: + return fmt.Errorf("left: %w, right: %w", leftErr, rightErr) + case leftErr != nil: + return leftErr + case rightErr != nil: + return rightErr + } + + return nil } // batchWrite will write a batch of entries into the graph DB and block until the write completes. -// Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) error { span, ctx := span.SpanRunFromContext(ctx, span.JanusGraphBatchWrite) span.SetTag(tag.LabelTag, jgv.builder) var err error defer func() { span.Finish(tracer.WithError(err)) }() + + // Increment the writingInFlight wait group to track the number of writes in progress. + jgv.writingInFlight.Add(1) defer jgv.writingInFlight.Done() datalen := len(data) @@ -185,8 +173,6 @@ func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) err } func (jgv *JanusGraphEdgeWriter) Close(ctx context.Context) error { - close(jgv.consumerChan) - return nil } @@ -198,29 +184,17 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { var err error defer func() { span.Finish(tracer.WithError(err)) }() - jgv.mu.Lock() - defer jgv.mu.Unlock() - if jgv.traversalSource == nil { return errors.New("janusGraph traversalSource is not initialized") } - if len(jgv.inserts) != 0 { - _ = statsd.Incr(ctx, metric.FlushWriterCall, jgv.tags, 1) - - jgv.writingInFlight.Add(1) - err = jgv.batchWrite(ctx, jgv.inserts) - if err != nil { - log.Trace(ctx).Errorf("batch write %s: %+v", jgv.builder, err) - jgv.writingInFlight.Wait() - - return err - } - - log.Trace(ctx).Debugf("Done flushing %s writes. clearing the queue", jgv.builder) - jgv.inserts = nil + // Flush the micro batcher. + err = jgv.mb.Flush(ctx) + if err != nil { + return fmt.Errorf("micro batcher flush: %w", err) } + // Wait for all writes to complete. jgv.writingInFlight.Wait() log.Trace(ctx).Debugf("Edge writer %d %s queued", jgv.qcounter, jgv.builder) @@ -230,26 +204,5 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphEdgeWriter) Queue(ctx context.Context, v any) error { - jgv.mu.Lock() - defer jgv.mu.Unlock() - - atomic.AddInt32(&jgv.qcounter, 1) - jgv.inserts = append(jgv.inserts, v) - - if len(jgv.inserts) > jgv.batchSize { - copied := make([]any, len(jgv.inserts)) - copy(copied, jgv.inserts) - - jgv.writingInFlight.Add(1) - jgv.consumerChan <- batchItem{ - data: copied, - retryCount: 0, - } - _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) - - // cleanup the ops array after we have copied it to the channel - jgv.inserts = nil - } - - return nil + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index c176c0bc..ce44eeb2 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -28,24 +28,14 @@ type JanusGraphVertexWriter struct { gremlin types.VertexTraversal // Gremlin traversal generator function drc *gremlin.DriverRemoteConnection // Gremlin driver remote connection traversalSource *gremlin.GraphTraversalSource // Transacted graph traversal source - inserts []any // Object data to be inserted in the graph - mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes - batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags cache cache.AsyncWriter // Cache writer to cache store id -> vertex id mappings writerTimeout time.Duration // Timeout for the writer maxRetry int // Maximum number of retries for failed writes -} - -// batchItem is a single item in the batch writer queue that contains the data -// to be written and the number of retries. -type batchItem struct { - data []any - retryCount int + mb *microBatcher // Micro batcher to batch writes } // NewJanusGraphAsyncVertexWriter creates a new bulk vertex writer instance. @@ -53,8 +43,9 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo v vertex.Builder, c cache.CacheProvider, opts ...WriterOption, ) (*JanusGraphVertexWriter, error) { options := &writerOptions{ - WriterTimeout: defaultWriterTimeout, - MaxRetry: defaultMaxRetry, + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + WriterWorkerCount: defaultWriterWorkerCount, } for _, opt := range opts { opt(options) @@ -69,93 +60,32 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo builder: v.Label(), gremlin: v.Traversal(), drc: drc, - inserts: make([]any, 0, v.BatchSize()), traversalSource: gremlin.Traversal_().WithRemote(drc), - batchSize: v.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan batchItem, v.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(v.Label()), tag.Builder(v.Label())), cache: cw, writerTimeout: options.WriterTimeout, maxRetry: options.MaxRetry, } - jw.startBackgroundWriter(ctx) - - return &jw, nil -} - -// startBackgroundWriter starts a background go routine -func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { - go func() { - for { - select { - case batch, ok := <-jgv.consumerChan: - // If the channel is closed, return. - if !ok { - log.Trace(ctx).Info("Closed background janusgraph worker on channel close") - - return - } - - // If the batch is empty, return. - if len(batch.data) == 0 { - log.Trace(ctx).Warn("Empty vertex batch received in background janusgraph worker, skipping") - - return - } - - _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, batch.data) - if err != nil { - var e *batchWriterError - if errors.As(err, &e) { - // If the error is retryable, retry the write operation with a smaller batch. - if e.retryable && batch.retryCount < jgv.maxRetry { - jgv.retrySplitAndRequeue(ctx, &batch, e) - - continue - } - - log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) - } - - log.Trace(ctx).Errorf("Write data in background batch writer, data will be lost: %v", err) - } - - _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) - case <-ctx.Done(): - log.Trace(ctx).Info("Closed background janusgraph worker on context cancel") - - return + // Create a new micro batcher to batch the inserts with split and retry logic. + jw.mb = newMicroBatcher(log.Trace(ctx), v.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Try to write the batch to the graph DB. + if err := jw.batchWrite(ctx, a); err != nil { + var bwe *batchWriterError + if errors.As(err, &bwe) && bwe.retryable { + // If the write operation failed and is retryable, split the batch and retry. + return jw.splitAndRetry(ctx, 0, a) } - } - }() -} - -// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. -func (jgv *JanusGraphVertexWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *batchWriterError) { - _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) - // Compute the new batch size. - newBatchSize := len(batch.data) / 2 - batch.retryCount++ + return err + } - log.Trace(ctx).Warnf("Retrying write operation with smaller vertex batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + return nil + }) + jw.mb.Start(ctx) - // Split the batch into smaller chunks and requeue them. - if len(batch.data[:newBatchSize]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[:newBatchSize], - retryCount: batch.retryCount, - } - } - if len(batch.data[newBatchSize:]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[newBatchSize:], - retryCount: batch.retryCount, - } - } + return &jw, nil } func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremlin.Result) error { @@ -182,12 +112,16 @@ func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremli } // batchWrite will write a batch of entries into the graph DB and block until the write completes. -// Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) error { + _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) + span, ctx := span.SpanRunFromContext(ctx, span.JanusGraphBatchWrite) span.SetTag(tag.LabelTag, jgv.builder) var err error defer func() { span.Finish(tracer.WithError(err)) }() + + // Increment the writingInFlight wait group to track the number of writes in progress. + jgv.writingInFlight.Add(1) defer jgv.writingInFlight.Done() datalen := len(data) @@ -246,10 +180,63 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e return nil } +// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. +func (jgv *JanusGraphVertexWriter) splitAndRetry(ctx context.Context, retryCount int, payload []any) error { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + + // If we have reached the maximum number of retries, return an error. + if retryCount >= jgv.maxRetry { + return fmt.Errorf("max retry count reached: %d", retryCount) + } + + // Compute the new batch size. + newBatchSize := len(payload) / 2 + + log.Trace(ctx).Warnf("Retrying write operation with smaller vertex batch (n:%d -> %d, r:%d)", len(payload), newBatchSize, retryCount) + + var leftErr, rightErr error + + // Split the batch into smaller chunks and retry them. + if len(payload[:newBatchSize]) > 0 { + if leftErr = jgv.batchWrite(ctx, payload[:newBatchSize]); leftErr == nil { + var bwe *batchWriterError + if errors.As(leftErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[:newBatchSize]) + } + } + } + + // Process the right side of the batch. + if len(payload[newBatchSize:]) > 0 { + if rightErr = jgv.batchWrite(ctx, payload[newBatchSize:]); rightErr != nil { + var bwe *batchWriterError + if errors.As(rightErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[newBatchSize:]) + } + } + } + + // Return the first error encountered. + switch { + case leftErr != nil && rightErr != nil: + return fmt.Errorf("left: %w, right: %w", leftErr, rightErr) + case leftErr != nil: + return leftErr + case rightErr != nil: + return rightErr + } + + return nil +} + func (jgv *JanusGraphVertexWriter) Close(ctx context.Context) error { - close(jgv.consumerChan) + if jgv.cache != nil { + if err := jgv.cache.Close(ctx); err != nil { + return fmt.Errorf("closing cache: %w", err) + } + } - return jgv.cache.Close(ctx) + return nil } // Flush triggers writes of any remaining items in the queue. @@ -260,29 +247,17 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { var err error defer func() { span.Finish(tracer.WithError(err)) }() - jgv.mu.Lock() - defer jgv.mu.Unlock() - if jgv.traversalSource == nil { return errors.New("janusGraph traversalSource is not initialized") } - if len(jgv.inserts) != 0 { - _ = statsd.Incr(ctx, metric.FlushWriterCall, jgv.tags, 1) - - jgv.writingInFlight.Add(1) - err = jgv.batchWrite(ctx, jgv.inserts) - if err != nil { - log.Trace(ctx).Errorf("batch write %s: %+v", jgv.builder, err) - jgv.writingInFlight.Wait() - - return err - } - - log.Trace(ctx).Debugf("Done flushing %s writes. clearing the queue", jgv.builder) - jgv.inserts = nil + // Flush the micro batcher. + err = jgv.mb.Flush(ctx) + if err != nil { + return fmt.Errorf("micro batcher flush: %w", err) } + // Wait for all writes to complete. jgv.writingInFlight.Wait() err = jgv.cache.Flush(ctx) @@ -297,26 +272,5 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { - jgv.mu.Lock() - defer jgv.mu.Unlock() - - atomic.AddInt32(&jgv.qcounter, 1) - jgv.inserts = append(jgv.inserts, v) - - if len(jgv.inserts) > jgv.batchSize { - copied := make([]any, len(jgv.inserts)) - copy(copied, jgv.inserts) - - jgv.writingInFlight.Add(1) - jgv.consumerChan <- batchItem{ - data: copied, - retryCount: 0, - } - _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) - - // cleanup the ops array after we have copied it to the channel - jgv.inserts = nil - } - - return nil + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/microbatcher.go b/pkg/kubehound/storage/graphdb/microbatcher.go new file mode 100644 index 00000000..889622ac --- /dev/null +++ b/pkg/kubehound/storage/graphdb/microbatcher.go @@ -0,0 +1,183 @@ +package graphdb + +import ( + "context" + "errors" + "sync" + "sync/atomic" + + "github.com/DataDog/KubeHound/pkg/telemetry/log" +) + +// batchItem is a single item in the batch writer queue that contains the data +// to be written and the number of retries. +type batchItem struct { + data []any + retryCount int +} + +// microBatcher is a utility to batch items and flush them when the batch is full. +type microBatcher struct { + // batchSize is the maximum number of items to batch. + batchSize int + // items is the current item accumulator for the batch. This is reset after + // the batch is flushed. + items []any + // flush is the function to call to flush the batch. + flushFunc func(context.Context, []any) error + // itemChan is the channel to receive items to batch. + itemChan chan any + // batchChan is the channel to send batches to. + batchChan chan batchItem + // workerCount is the number of workers to process the batch. + workerCount int + // workerGroup is the worker group to wait for the workers to finish. + workerGroup *sync.WaitGroup + // shuttingDown is a flag to indicate if the batcher is shutting down. + shuttingDown atomic.Bool + // logger is the logger to use for logging. + logger log.LoggerI +} + +// NewMicroBatcher creates a new micro batcher. +func newMicroBatcher(logger log.LoggerI, batchSize int, workerCount int, flushFunc func(context.Context, []any) error) *microBatcher { + return µBatcher{ + logger: logger, + batchSize: batchSize, + items: make([]any, 0, batchSize), + flushFunc: flushFunc, + itemChan: make(chan any, batchSize), + batchChan: make(chan batchItem, batchSize), + workerCount: workerCount, + workerGroup: nil, // Set in Start. + } +} + +// Flush flushes the current batch and waits for the batch writer to finish. +func (mb *microBatcher) Flush(_ context.Context) error { + // Set the shutting down flag to true. + if !mb.shuttingDown.CompareAndSwap(false, true) { + return errors.New("batcher is already shutting down") + } + + // Closing the item channel to signal the accumulator to stop and flush the batch. + close(mb.itemChan) + + // Wait for the workers to finish. + if mb.workerGroup != nil { + mb.workerGroup.Wait() + } + + return nil +} + +// Enqueue adds an item to the batch processor. +func (mb *microBatcher) Enqueue(ctx context.Context, item any) error { + // If the batcher is shutting down, return an error immediately. + if mb.shuttingDown.Load() { + return errors.New("batcher is shutting down") + } + + select { + case <-ctx.Done(): + // If the context is cancelled, return. + return ctx.Err() + case mb.itemChan <- item: + } + + return nil +} + +// Start starts the batch processor. +func (mb *microBatcher) Start(ctx context.Context) { + if mb.workerGroup != nil { + // If the worker group is already set, return. + return + } + + var wg sync.WaitGroup + + // Start the workers. + for i := 0; i < mb.workerCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := mb.worker(ctx, mb.batchChan); err != nil { + mb.logger.Errorf("worker: %v", err) + } + }() + } + + // Start the item accumulator. + wg.Add(1) + go func() { + defer wg.Done() + if err := mb.runItemBatcher(ctx); err != nil { + mb.logger.Errorf("run item batcher: %v", err) + } + + // Close the batch channel to signal the workers to stop. + close(mb.batchChan) + }() + + // Set the worker group to wait for the workers to finish. + mb.workerGroup = &wg +} + +// startItemBatcher starts the item accumulator to batch items. +func (mb *microBatcher) runItemBatcher(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case item, ok := <-mb.itemChan: + if !ok { + // If the item channel is closed, send the current batch and return. + mb.batchChan <- batchItem{ + data: mb.items, + retryCount: 0, + } + + // End the accumulator. + return nil + } + + // Add the item to the batch. + mb.items = append(mb.items, item) + + // If the batch is full, send it. + if len(mb.items) == mb.batchSize { + // Send the batch to the processor. + mb.batchChan <- batchItem{ + data: mb.items, + retryCount: 0, + } + + // Reset the batch. + mb.items = mb.items[len(mb.items):] + } + } + } +} + +// startWorkers starts the workers to process the batches. +func (mb *microBatcher) worker(ctx context.Context, batchQueue <-chan batchItem) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case batch, ok := <-batchQueue: + if !ok { + return nil + } + + // Send the batch to the processor. + if len(batch.data) > 0 && mb.flushFunc != nil { + if err := mb.flushFunc(ctx, batch.data); err != nil { + mb.logger.Errorf("flush data in background batch writer: %v", err) + } + } + } + } +} diff --git a/pkg/kubehound/storage/graphdb/microbatcher_test.go b/pkg/kubehound/storage/graphdb/microbatcher_test.go new file mode 100644 index 00000000..b455c1b1 --- /dev/null +++ b/pkg/kubehound/storage/graphdb/microbatcher_test.go @@ -0,0 +1,63 @@ +package graphdb + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/DataDog/KubeHound/pkg/telemetry/log" + "github.com/stretchr/testify/assert" +) + +func microBatcherTestInstance(t *testing.T) (*microBatcher, *atomic.Int32) { + t.Helper() + + var ( + writerFuncCalledCount atomic.Int32 + ) + + underTest := newMicroBatcher(log.DefaultLogger(), 5, 1, + func(_ context.Context, _ []any) error { + writerFuncCalledCount.Add(1) + + return nil + }) + + return underTest, &writerFuncCalledCount +} + +func TestMicroBatcher_AfterBatchSize(t *testing.T) { + t.Parallel() + + underTest, writerFuncCalledCount := microBatcherTestInstance(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + underTest.Start(ctx) + + for i := 0; i < 10; i++ { + assert.NoError(t, underTest.Enqueue(ctx, i)) + } + + assert.NoError(t, underTest.Flush(ctx)) + + assert.Equal(t, int32(2), writerFuncCalledCount.Load()) +} + +func TestMicroBatcher_AfterFlush(t *testing.T) { + t.Parallel() + + underTest, writerFuncCalledCount := microBatcherTestInstance(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + underTest.Start(ctx) + + for i := 0; i < 11; i++ { + assert.NoError(t, underTest.Enqueue(ctx, i)) + } + + assert.NoError(t, underTest.Flush(ctx)) + + assert.Equal(t, int32(3), writerFuncCalledCount.Load()) +} diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index 120800bf..fbbe9b98 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -13,14 +13,16 @@ import ( ) const ( - defaultWriterTimeout = 60 * time.Second - defaultMaxRetry = 3 + defaultWriterTimeout = 60 * time.Second + defaultMaxRetry = 3 + defaultWriterWorkerCount = 1 ) type writerOptions struct { - Tags []string - WriterTimeout time.Duration - MaxRetry int + Tags []string + WriterWorkerCount int + WriterTimeout time.Duration + MaxRetry int } type WriterOption func(*writerOptions) @@ -45,6 +47,13 @@ func WithWriterMaxRetry(maxRetry int) WriterOption { } } +// WithWriterWorkerCount sets the number of workers to process the batch. +func WithWriterWorkerCount(workerCount int) WriterOption { + return func(wo *writerOptions) { + wo.WriterWorkerCount = workerCount + } +} + // Provider defines the interface for implementations of the graphdb provider for storage of the calculated K8s attack graph. // //go:generate mockery --name Provider --output mocks --case underscore --filename graph_provider.go --with-expecter