From 95753d3c2f21fa0a32aa63d1ab68666eba5fa5e0 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Thu, 28 Nov 2024 14:55:28 +0100 Subject: [PATCH] feat(graphdb): add deadline/retry/split behavioural patterns to batch writer. --- .../graphdb/janusgraph_vertex_writer.go | 150 +++++++++++++++--- pkg/kubehound/storage/graphdb/provider.go | 19 ++- pkg/telemetry/metric/metrics.go | 1 + 3 files changed, 146 insertions(+), 24 deletions(-) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 43396746..85c2e72e 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/graph/vertex" @@ -29,20 +30,51 @@ type JanusGraphVertexWriter struct { traversalSource *gremlin.GraphTraversalSource // Transacted graph traversal source inserts []any // Object data to be inserted in the graph mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan []any // Channel consuming inserts for async writing + consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags cache cache.AsyncWriter // Cache writer to cache store id -> vertex id mappings + writerTimeout time.Duration // Timeout for the writer + maxRetry int // Maximum number of retries for failed writes +} + +// batchItem is a single item in the batch writer queue that contains the data +// to be written and the number of retries. +type batchItem struct { + data []any + retryCount int +} + +// errBatchWriter is an error type that wraps an error and indicates whether the +// error is retryable. +type errBatchWriter struct { + err error + retryable bool +} + +func (e errBatchWriter) Error() string { + if e.err == nil { + return fmt.Sprintf("batch writer error (retriable:%v)", e.retryable) + } + + return fmt.Sprintf("batch writer error (retriable:%v): %v", e.retryable, e.err.Error()) +} + +func (e errBatchWriter) Unwrap() error { + return e.err } // NewJanusGraphAsyncVertexWriter creates a new bulk vertex writer instance. func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemoteConnection, - v vertex.Builder, c cache.CacheProvider, opts ...WriterOption) (*JanusGraphVertexWriter, error) { - - options := &writerOptions{} + v vertex.Builder, c cache.CacheProvider, opts ...WriterOption, +) (*JanusGraphVertexWriter, error) { + options := &writerOptions{ + WriterTimeout: 60 * time.Second, + MaxRetry: 3, + } for _, opt := range opts { opt(options) } @@ -60,9 +92,11 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo traversalSource: gremlin.Traversal_().WithRemote(drc), batchSize: v.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan []any, v.BatchSize()*channelSizeBatchFactor), + consumerChan: make(chan batchItem, v.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(v.Label()), tag.Builder(v.Label())), cache: cw, + writerTimeout: options.WriterTimeout, + maxRetry: options.MaxRetry, } jw.startBackgroundWriter(ctx) @@ -75,16 +109,52 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { go func() { for { select { - case data := <-jgv.consumerChan: - // closing the channel shoud stop the go routine - if data == nil { + case batch, ok := <-jgv.consumerChan: + // If the channel is closed, return. + if !ok { + log.Trace(ctx).Info("Closed background janusgraph worker on channel close") + return + } + + // If the batch is empty, return. + if len(batch.data) == 0 { + log.Trace(ctx).Warn("Empty batch received in background janusgraph worker, skipping") return } _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, data) + err := jgv.batchWrite(ctx, batch.data) if err != nil { - log.Trace(ctx).Errorf("Write data in background batch writer: %v", err) + var e *errBatchWriter + if errors.As(err, &e) && e.retryable { + // If the context deadline is exceeded, retry the write operation with a smaller batch. + if batch.retryCount < jgv.maxRetry { + // Compute the new batch size. + newBatchSize := len(batch.data) / 2 + batch.retryCount++ + + log.Trace(ctx).Warnf("Retrying write operation with smaller batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + + // Split the batch into smaller chunks and requeue them. + if len(batch.data[:newBatchSize]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[:newBatchSize], + retryCount: batch.retryCount, + } + } + if len(batch.data[newBatchSize:]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[newBatchSize:], + retryCount: batch.retryCount, + } + } + continue + } + + log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) + } + + log.Trace(ctx).Errorf("Write data in background batch writer, data will be lost: %v", err) } _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) @@ -134,19 +204,50 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e log.Trace(ctx).Debugf("Batch write JanusGraphVertexWriter with %d elements", datalen) atomic.AddInt32(&jgv.wcounter, int32(datalen)) //nolint:gosec // disable G115 - op := jgv.gremlin(jgv.traversalSource, data) - raw, err := op.Project("id", "storeID"). - By(gremlin.T.Id). - By("storeID"). - ToList() - if err != nil { - return fmt.Errorf("%s vertex insert: %w", jgv.builder, err) - } + // Create a channel to signal the completion of the write operation. + errChan := make(chan error, 1) + + // We need to ensure that the write operation is completed within a certain + // time frame to avoid blocking the writer indefinitely if the backend + // is unresponsive. + go func() { + // Create a new gremlin operation to insert the data into the graph. + op := jgv.gremlin(jgv.traversalSource, data) + raw, err := op.Project("id", "storeID"). + By(gremlin.T.Id). + By("storeID"). + ToList() + if err != nil { + errChan <- fmt.Errorf("%s vertex insert: %w", jgv.builder, err) + return + } + + // Gremlin will return a list of maps containing and vertex id and store + // id values for each vertex inserted. + // We need to parse each map entry and add to our cache. + if err = jgv.cacheIds(ctx, raw); err != nil { + errChan <- fmt.Errorf("cache ids: %w", err) + return + } + + errChan <- nil + }() - // Gremlin will return a list of maps containing and vertex id and store id values for each vertex inserted. - // We need to parse each map entry and add to our cache. - if err = jgv.cacheIds(ctx, raw); err != nil { - return err + // Wait for the write operation to complete or timeout. + select { + case <-ctx.Done(): + // If the context is cancelled, return the error. + return ctx.Err() + case <-time.After(jgv.writerTimeout): + // If the write operation takes too long, return an error. + return &errBatchWriter{ + err: errors.New("write operation timed out"), + retryable: true, + } + case err = <-errChan: + if err != nil { + return fmt.Errorf("janusgraph batch write: %w", err) + } } return nil @@ -214,7 +315,10 @@ func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { copy(copied, jgv.inserts) jgv.writingInFlight.Add(1) - jgv.consumerChan <- copied + jgv.consumerChan <- batchItem{ + data: copied, + retryCount: 0, + } _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) // cleanup the ops array after we have copied it to the channel diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index bd82a7e1..03c80536 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -2,6 +2,7 @@ package graphdb import ( "context" + "time" "github.com/DataDog/KubeHound/pkg/config" "github.com/DataDog/KubeHound/pkg/kubehound/graph/edge" @@ -12,7 +13,9 @@ import ( ) type writerOptions struct { - Tags []string + Tags []string + WriterTimeout time.Duration + MaxRetry int } type WriterOption func(*writerOptions) @@ -23,6 +26,20 @@ func WithTags(tags []string) WriterOption { } } +// WithWriterTimeout sets the timeout for the writer to complete the write operation. +func WithWriterTimeout(timeout time.Duration) WriterOption { + return func(wo *writerOptions) { + wo.WriterTimeout = timeout + } +} + +// WithWriterMaxRetry sets the maximum number of retries for failed writes. +func WithWriterMaxRetry(maxRetry int) WriterOption { + return func(wo *writerOptions) { + wo.MaxRetry = maxRetry + } +} + // Provider defines the interface for implementations of the graphdb provider for storage of the calculated K8s attack graph. // //go:generate mockery --name Provider --output mocks --case underscore --filename graph_provider.go --with-expecter diff --git a/pkg/telemetry/metric/metrics.go b/pkg/telemetry/metric/metrics.go index afb92770..1a7e97b5 100644 --- a/pkg/telemetry/metric/metrics.go +++ b/pkg/telemetry/metric/metrics.go @@ -28,6 +28,7 @@ var ( QueueSize = "kubehound.storage.queue.size" BackgroundWriterCall = "kubehound.storage.writer.background" FlushWriterCall = "kubehound.storage.writer.flush" + RetryWriterCall = "kubehound.storage.writer.retry" ) // Cache metrics