Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RE-5863] Resolve potential race condition in log fields #185

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions utils/log/loggers/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package loggers

import (
"context"
"sync"
)

type logsContext string
Expand All @@ -12,6 +13,10 @@ var (

//LogFields contains all fields that have to be added to logs
type LogFields map[string]interface{}
type protectedLogFields struct {
content LogFields
mtx sync.RWMutex
achichen marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @ankurs suggested, sync.Map is indeed a better option.

According to https://pkg.go.dev/sync#Map, it optimizes for our use case:

when the entry for a given key is only ever written once but read many times, as in caches that only grow

}

// Add or modify log fields
func (o LogFields) Add(key string, value interface{}) {
Expand All @@ -28,26 +33,39 @@ func (o LogFields) Del(key string) {
//AddToLogContext adds log fields to context.
// Any info added here will be added to all logs using this context
func AddToLogContext(ctx context.Context, key string, value interface{}) context.Context {
data := FromContext(ctx)
data := fromContext(ctx)
//Initialize if key doesn't exist
if data == nil {
ctx = context.WithValue(ctx, contextKey, make(LogFields))
data = FromContext(ctx)
}
m := ctx.Value(contextKey)
if data, ok := m.(LogFields); ok {
data.Add(key, value)
data = &protectedLogFields{content: make(LogFields)}
ctx = context.WithValue(ctx, contextKey, data)
}
data.mtx.Lock()
defer data.mtx.Unlock()
data.content.Add(key, value)
return ctx
}

//FromContext fetchs log fields from provided context
func FromContext(ctx context.Context) LogFields {
if plf := fromContext(ctx); plf != nil {
plf.mtx.RLock()
defer plf.mtx.RUnlock()
content := make(LogFields)
for k, v := range plf.content {
content[k] = v
}
return content
}
return nil
}

func fromContext(ctx context.Context) *protectedLogFields {
if ctx == nil {
return nil
}
if h := ctx.Value(contextKey); h != nil {
if logData, ok := h.(LogFields); ok {
return logData
if plf, ok := h.(*protectedLogFields); ok {
return plf
}
}
return nil
Expand Down
27 changes: 27 additions & 0 deletions utils/log/loggers/test/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//go test -v -bench=. -run=none .
package loggers_test

import (
"context"
"fmt"
"testing"

s "github.com/carousell/Orion/utils/log/loggers"
)

func BenchmarkFromContext(b *testing.B) {
ctx := context.Background()
for i := 0; i < 10000; i++ {
s.AddToLogContext(ctx, fmt.Sprintf("key%d", i), "good value")
}
for i := 0; i < b.N; i++ {
s.FromContext(ctx)
}
}

func BenchmarkFromAddToLogContext(b *testing.B) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name should be BenchmarkAddToLogContext

ctx := context.Background()
for i := 0; i < b.N; i++ {
s.AddToLogContext(ctx, fmt.Sprintf("key%d", i), "good value")
}
}
19 changes: 19 additions & 0 deletions utils/log/loggers/test/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package loggers_test

import (
"context"
"fmt"

s "github.com/carousell/Orion/utils/log/loggers"
)

func ExampleFromContext() {
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "indespensable", "amazing data")
ctx = s.AddToLogContext(ctx, "preciousData", "valuable key")
lf := s.FromContext(ctx)
fmt.Println(lf)

// Output:
// map[indespensable:amazing data preciousData:valuable key]
}
101 changes: 101 additions & 0 deletions utils/log/loggers/test/parallelism_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//go test -race
package loggers_test

import (
"context"
"fmt"
"math/rand"
"sync"
"testing"
"time"

s "github.com/carousell/Orion/utils/log/loggers"
"github.com/stretchr/testify/assert"
)

const readWorkerCount = 50
const writeWorkerCount = 50

func readWorker(idx int, ctx context.Context) {
s.FromContext(ctx)
// simulate reading task
time.Sleep(time.Millisecond * 250)
}

func writeWorker(idx int, ctx context.Context) context.Context {
key := fmt.Sprintf("key%d", idx)
val := fmt.Sprintf("val%d", rand.Intn(10000))
ctx = s.AddToLogContext(ctx, key, val)
time.Sleep(time.Millisecond * 250)
return ctx
}

func TestParallelRead(t *testing.T) {
// LogContext init, non-paralel
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "k1", "v1")
ctx = s.AddToLogContext(ctx, "k2", "v2")

var wg sync.WaitGroup
for i := 1; i <= readWorkerCount; i++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
readWorker(j, ctx)
}(i)
}
wg.Wait()
}

func TestParallelWrite(t *testing.T) {
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "test-key", "test-value")

var wg sync.WaitGroup
for i := 1; i <= writeWorkerCount; i++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
writeWorker(j, ctx)
}(i)
}
wg.Wait()

lf := s.FromContext(ctx)

assert.Contains(t, lf, "test-key")
for i := 1; i <= writeWorkerCount; i++ {
key := fmt.Sprintf("key%d", i)
assert.Contains(t, lf, key)
}
}

func TestParallelReadAndWrite(t *testing.T) {
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "test-key", "test-value")

var wg sync.WaitGroup
for i := 1; i <= readWorkerCount; i++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
readWorker(j, ctx)
}(i)
}
for i := 1; i <= writeWorkerCount; i++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
writeWorker(j, ctx)
}(i)
}
wg.Wait()

lf := s.FromContext(ctx)

assert.Contains(t, lf, "test-key")
for i := 1; i <= writeWorkerCount; i++ {
key := fmt.Sprintf("key%d", i)
assert.Contains(t, lf, key)
}
}