diff --git a/utils/log/loggers/fields.go b/utils/log/loggers/fields.go index 41105977..25d214dd 100644 --- a/utils/log/loggers/fields.go +++ b/utils/log/loggers/fields.go @@ -2,6 +2,7 @@ package loggers import ( "context" + "sync" ) type logsContext string @@ -12,6 +13,10 @@ var ( //LogFields contains all fields that have to be added to logs type LogFields map[string]interface{} +type protectedLogFields struct { + content LogFields + mtx sync.RWMutex +} // Add or modify log fields func (o LogFields) Add(key string, value interface{}) { @@ -28,26 +33,39 @@ func (o LogFields) Del(key string) { //AddToLogContext adds log fields to context. // Any info added here will be added to all logs using this context func AddToLogContext(ctx context.Context, key string, value interface{}) context.Context { - data := FromContext(ctx) + data := fromContext(ctx) + //Initialize if key doesn't exist if data == nil { - ctx = context.WithValue(ctx, contextKey, make(LogFields)) - data = FromContext(ctx) - } - m := ctx.Value(contextKey) - if data, ok := m.(LogFields); ok { - data.Add(key, value) + data = &protectedLogFields{content: make(LogFields)} + ctx = context.WithValue(ctx, contextKey, data) } + data.mtx.Lock() + defer data.mtx.Unlock() + data.content.Add(key, value) return ctx } //FromContext fetchs log fields from provided context func FromContext(ctx context.Context) LogFields { + if plf := fromContext(ctx); plf != nil { + plf.mtx.RLock() + defer plf.mtx.RUnlock() + content := make(LogFields) + for k, v := range plf.content { + content[k] = v + } + return content + } + return nil +} + +func fromContext(ctx context.Context) *protectedLogFields { if ctx == nil { return nil } if h := ctx.Value(contextKey); h != nil { - if logData, ok := h.(LogFields); ok { - return logData + if plf, ok := h.(*protectedLogFields); ok { + return plf } } return nil diff --git a/utils/log/loggers/test/benchmark_test.go b/utils/log/loggers/test/benchmark_test.go new file mode 100644 index 00000000..935ce3dc --- /dev/null +++ b/utils/log/loggers/test/benchmark_test.go @@ -0,0 +1,27 @@ +//go test -v -bench=. -run=none . +package loggers_test + +import ( + "context" + "fmt" + "testing" + + s "github.com/carousell/Orion/utils/log/loggers" +) + +func BenchmarkFromContext(b *testing.B) { + ctx := context.Background() + for i := 0; i < 10000; i++ { + s.AddToLogContext(ctx, fmt.Sprintf("key%d", i), "good value") + } + for i := 0; i < b.N; i++ { + s.FromContext(ctx) + } +} + +func BenchmarkFromAddToLogContext(b *testing.B) { + ctx := context.Background() + for i := 0; i < b.N; i++ { + s.AddToLogContext(ctx, fmt.Sprintf("key%d", i), "good value") + } +} diff --git a/utils/log/loggers/test/example_test.go b/utils/log/loggers/test/example_test.go new file mode 100644 index 00000000..82ea3999 --- /dev/null +++ b/utils/log/loggers/test/example_test.go @@ -0,0 +1,19 @@ +package loggers_test + +import ( + "context" + "fmt" + + s "github.com/carousell/Orion/utils/log/loggers" +) + +func ExampleFromContext() { + ctx := context.Background() + ctx = s.AddToLogContext(ctx, "indespensable", "amazing data") + ctx = s.AddToLogContext(ctx, "preciousData", "valuable key") + lf := s.FromContext(ctx) + fmt.Println(lf) + + // Output: + // map[indespensable:amazing data preciousData:valuable key] +} diff --git a/utils/log/loggers/test/parallelism_test.go b/utils/log/loggers/test/parallelism_test.go new file mode 100644 index 00000000..4f5a8a20 --- /dev/null +++ b/utils/log/loggers/test/parallelism_test.go @@ -0,0 +1,101 @@ +//go test -race +package loggers_test + +import ( + "context" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + s "github.com/carousell/Orion/utils/log/loggers" + "github.com/stretchr/testify/assert" +) + +const readWorkerCount = 50 +const writeWorkerCount = 50 + +func readWorker(idx int, ctx context.Context) { + s.FromContext(ctx) + // simulate reading task + time.Sleep(time.Millisecond * 250) +} + +func writeWorker(idx int, ctx context.Context) context.Context { + key := fmt.Sprintf("key%d", idx) + val := fmt.Sprintf("val%d", rand.Intn(10000)) + ctx = s.AddToLogContext(ctx, key, val) + time.Sleep(time.Millisecond * 250) + return ctx +} + +func TestParallelRead(t *testing.T) { + // LogContext init, non-paralel + ctx := context.Background() + ctx = s.AddToLogContext(ctx, "k1", "v1") + ctx = s.AddToLogContext(ctx, "k2", "v2") + + var wg sync.WaitGroup + for i := 1; i <= readWorkerCount; i++ { + wg.Add(1) + go func(j int) { + defer wg.Done() + readWorker(j, ctx) + }(i) + } + wg.Wait() +} + +func TestParallelWrite(t *testing.T) { + ctx := context.Background() + ctx = s.AddToLogContext(ctx, "test-key", "test-value") + + var wg sync.WaitGroup + for i := 1; i <= writeWorkerCount; i++ { + wg.Add(1) + go func(j int) { + defer wg.Done() + writeWorker(j, ctx) + }(i) + } + wg.Wait() + + lf := s.FromContext(ctx) + + assert.Contains(t, lf, "test-key") + for i := 1; i <= writeWorkerCount; i++ { + key := fmt.Sprintf("key%d", i) + assert.Contains(t, lf, key) + } +} + +func TestParallelReadAndWrite(t *testing.T) { + ctx := context.Background() + ctx = s.AddToLogContext(ctx, "test-key", "test-value") + + var wg sync.WaitGroup + for i := 1; i <= readWorkerCount; i++ { + wg.Add(1) + go func(j int) { + defer wg.Done() + readWorker(j, ctx) + }(i) + } + for i := 1; i <= writeWorkerCount; i++ { + wg.Add(1) + go func(j int) { + defer wg.Done() + writeWorker(j, ctx) + }(i) + } + wg.Wait() + + lf := s.FromContext(ctx) + + assert.Contains(t, lf, "test-key") + for i := 1; i <= writeWorkerCount; i++ { + key := fmt.Sprintf("key%d", i) + assert.Contains(t, lf, key) + } +}