From f037fdaee5697aba108b12510eff6ef1f856f06e Mon Sep 17 00:00:00 2001 From: karupanerura Date: Tue, 28 Feb 2023 14:43:27 +0900 Subject: [PATCH] fix: race-condition problem on testhelper (#5) (#8) --- internal/config/config.go | 23 ++++++++++++++++++++ internal/logger/logger.go | 14 ++++++++++-- internal/logger/logger_test.go | 5 +++-- logz_test.go | 2 +- middleware/nethttp_test.go | 17 ++++++++------- testhelper/testhelper.go | 39 ++++++++++++++++++++++------------ 6 files changed, 73 insertions(+), 27 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index d75e966..81b3d8a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "io" "os" @@ -28,3 +29,25 @@ func init() { ApplicationLogOut = os.Stdout AccessLogOut = os.Stderr } + +type ContextConfig struct { + // ApplicationLogOut is io.Writer object for application log + ApplicationLogOut io.Writer + // AccessLogOut is io.Writer object for access log + AccessLogOut io.Writer +} + +type contextKey struct{} + +var contextConfigKey = &contextKey{} + +// GetContextConfig sets the ContextConfig instance to context +func SetContextConfig(ctx context.Context, cs *ContextConfig) context.Context { + return context.WithValue(ctx, contextConfigKey, cs) +} + +// GetContextConfig gets the ContextSeverity instance from context +func GetContextConfig(ctx context.Context) *ContextConfig { + v, _ := ctx.Value(contextConfigKey).(*ContextConfig) + return v +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 80b03f1..c753acb 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -53,7 +53,12 @@ func WriteApplicationLog(ctx context.Context, s severity.Severity, format string TraceSampled: sc.TraceSampled, } - if err := json.NewEncoder(config.ApplicationLogOut).Encode(ety); err != nil { + w := config.ApplicationLogOut + if cc := config.GetContextConfig(ctx); cc != nil && cc.ApplicationLogOut != nil { + w = cc.ApplicationLogOut + } + + if err := json.NewEncoder(w).Encode(ety); err != nil { fmt.Printf("failed to write log: %v", err) } } @@ -86,7 +91,12 @@ func WriteAccessLog(ctx context.Context, req types.HTTPRequest) { HTTPRequest: req, } - if err := json.NewEncoder(config.AccessLogOut).Encode(ety); err != nil { + w := config.AccessLogOut + if cc := config.GetContextConfig(ctx); cc != nil && cc.AccessLogOut != nil { + w = cc.AccessLogOut + } + + if err := json.NewEncoder(w).Encode(ety); err != nil { fmt.Printf("failed to write log: %v", err) } } diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index daa464e..e0d75ec 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -48,7 +48,7 @@ func TestLoggerWriteApplicationLog(t *testing.T) { }() t.Run("Tests WriteApplicationLog function", func(t *testing.T) { - got := testhelper.ExtractApplicationLogOut(t, func() { + got := testhelper.ExtractApplicationLogOut(t, ctx, func(ctx context.Context) { // tests the function logger.WriteApplicationLog(ctx, severity.Info, "writes %s log", "info") }) @@ -99,7 +99,7 @@ func TestLoggerWriteAccessLog(t *testing.T) { t.Run("Tests WriteAccessLog function", func(t *testing.T) { - got := testhelper.ExtractAccessLogOut(t, func() { + got := testhelper.ExtractAccessLogOut(t, ctx, func(ctx context.Context) { // Tests the function httpReq := httptest.NewRequest(http.MethodGet, "/test1", nil) req := types.MakeHTTPRequest(*httpReq, 200, 333, time.Duration(100)) @@ -112,4 +112,5 @@ func TestLoggerWriteAccessLog(t *testing.T) { t.Errorf("failed log info test: %v", diff) } }) + } diff --git a/logz_test.go b/logz_test.go index 577d4e2..8ab51fe 100644 --- a/logz_test.go +++ b/logz_test.go @@ -130,7 +130,7 @@ func TestLogzWriteLog(t *testing.T) { sc := spancontext.Extract(ctx) - got := testhelper.ExtractApplicationLogOut(t, func() { + got := testhelper.ExtractApplicationLogOut(t, ctx, func(ctx context.Context) { logz.Infof(ctx, "writes %s log", "info") }) diff --git a/middleware/nethttp_test.go b/middleware/nethttp_test.go index 53dce52..3b844d0 100644 --- a/middleware/nethttp_test.go +++ b/middleware/nethttp_test.go @@ -1,6 +1,7 @@ package middleware_test import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -65,13 +66,13 @@ func TestNetHTTPMaxSeverity(t *testing.T) { mid := middleware.NetHTTP("test/component")(mux) rec := httptest.NewRecorder() - got := testhelper.ExtractAccessLogOut(t, func() { + got := testhelper.ExtractAccessLogOut(t, context.Background(), func(ctx context.Context) { req1 := httptest.NewRequest(http.MethodGet, "/test1", nil) - mid.ServeHTTP(rec, req1) + mid.ServeHTTP(rec, req1.WithContext(ctx)) req2 := httptest.NewRequest(http.MethodGet, "/test2", nil) - mid.ServeHTTP(rec, req2) + mid.ServeHTTP(rec, req2.WithContext(ctx)) req3 := httptest.NewRequest(http.MethodGet, "/test3", nil) - mid.ServeHTTP(rec, req3) + mid.ServeHTTP(rec, req3.WithContext(ctx)) }) if !strings.Contains(got, `"severity":"ERROR"`) { @@ -107,13 +108,13 @@ func TestNetHTTPMaxSeverityNoLog(t *testing.T) { mid := middleware.NetHTTP("test/component")(mux) rec := httptest.NewRecorder() - got := testhelper.ExtractAccessLogOut(t, func() { + got := testhelper.ExtractAccessLogOut(t, context.Background(), func(ctx context.Context) { req1 := httptest.NewRequest(http.MethodGet, "/test1", nil) - mid.ServeHTTP(rec, req1) + mid.ServeHTTP(rec, req1.WithContext(ctx)) req2 := httptest.NewRequest(http.MethodGet, "/test2", nil) - mid.ServeHTTP(rec, req2) + mid.ServeHTTP(rec, req2.WithContext(ctx)) req3 := httptest.NewRequest(http.MethodGet, "/test3", nil) - mid.ServeHTTP(rec, req3) + mid.ServeHTTP(rec, req3.WithContext(ctx)) }) if !strings.Contains(got, `"severity":"ERROR"`) { diff --git a/testhelper/testhelper.go b/testhelper/testhelper.go index dccaca1..4aaa9c3 100644 --- a/testhelper/testhelper.go +++ b/testhelper/testhelper.go @@ -2,7 +2,8 @@ package testhelper import ( "bytes" - "os" + "context" + "io" "strings" "testing" @@ -10,31 +11,41 @@ import ( ) // ExtractStdout extracts string from stdout -func ExtractApplicationLogOut(t *testing.T, fnc func()) string { +func ExtractApplicationLogOut(t *testing.T, ctx context.Context, fnc func(ctx context.Context)) string { t.Helper() var buf bytes.Buffer - config.ApplicationLogOut = &buf - defer func() { - config.ApplicationLogOut = os.Stdout - }() - - fnc() + if cc := config.GetContextConfig(ctx); cc != nil { + cc.ApplicationLogOut = &buf + } else { + ctx = config.SetContextConfig(ctx, &config.ContextConfig{ApplicationLogOut: &buf}) + } + fnc(ctx) return strings.TrimRight(buf.String(), "\n") } // ExtractStdout extracts string from stderr -func ExtractAccessLogOut(t *testing.T, fnc func()) string { +func ExtractAccessLogOut(t *testing.T, ctx context.Context, fnc func(ctx context.Context)) string { t.Helper() var buf bytes.Buffer - config.AccessLogOut = &buf - defer func() { - config.AccessLogOut = os.Stdout - }() + if cc := config.GetContextConfig(ctx); cc != nil { + cc.AccessLogOut = &buf + } else { + ctx = config.SetContextConfig(ctx, &config.ContextConfig{AccessLogOut: &buf}) + } - fnc() + fnc(ctx) return strings.TrimRight(buf.String(), "\n") } + +// OverrideLogOutContext override log I/O in the context +func OverrideLogOutContext(t *testing.T, ctx context.Context, appLogOut, accessLogOut io.Writer) context.Context { + t.Helper() + return config.SetContextConfig(ctx, &config.ContextConfig{ + ApplicationLogOut: appLogOut, + AccessLogOut: accessLogOut, + }) +}