From 57b0cc79b5d05c8aa655ecfef0c5f62eb7aa84db Mon Sep 17 00:00:00 2001 From: Aldo Fuster Turpin Date: Wed, 17 Apr 2024 15:35:49 +0000 Subject: [PATCH] add unit tests for MiddlewareLoggingPostMux function (to handle CorrelationData) --- frontend/middleware_logging.go | 71 ++++--- frontend/middleware_logging_test.go | 273 +++++++++++++++++++++++++++ internal/api/arm/correlation_test.go | 52 +++++ 3 files changed, 372 insertions(+), 24 deletions(-) create mode 100644 frontend/middleware_logging_test.go create mode 100644 internal/api/arm/correlation_test.go diff --git a/frontend/middleware_logging.go b/frontend/middleware_logging.go index b353fa8b0..951041e7a 100644 --- a/frontend/middleware_logging.go +++ b/frontend/middleware_logging.go @@ -76,52 +76,75 @@ func MiddlewareLogging(w http.ResponseWriter, r *http.Request, next http.Handler } func MiddlewareLoggingPostMux(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - var pathValue string - ctx := r.Context() correlationData := arm.NewCorrelationData(r) ctx = ContextWithCorrelationData(ctx, correlationData) + setHeaders(w, r, correlationData) + + attrs := getLogAttrs(correlationData, r) + + logger, err := LoggerFromContext(ctx) + if err != nil { + DefaultLogger().Error(err.Error()) + arm.WriteInternalServerError(w) + return + } + + handler := logger.Handler() + loggerWithAttrs := slog.New(handler.WithAttrs(attrs)) + ctx = ContextWithLogger(ctx, loggerWithAttrs) + + reqWithContext := r.WithContext(ctx) + + next(w, reqWithContext) +} + +// setHeaders writes the appropriate headers in the response writer +// based on the request and the correlation data. +func setHeaders(w http.ResponseWriter, r *http.Request, correlationData *arm.CorrelationData) { + if correlationData == nil { + return + } + w.Header().Set(arm.HeaderNameRequestID, correlationData.RequestID.String()) - if strings.EqualFold(r.Header.Get(arm.HeaderNameReturnClientRequestID), "true") { + returnClientRequestId := r.Header.Get(arm.HeaderNameReturnClientRequestID) + if strings.EqualFold(returnClientRequestId, "true") { w.Header().Set(arm.HeaderNameClientRequestID, correlationData.ClientRequestID) } +} +// getLogAttrs returns the appropiate Logging Attributes based on correlationData and a request. +func getLogAttrs(correlationData *arm.CorrelationData, r *http.Request) []slog.Attr { attrs := []slog.Attr{ slog.String("request_id", correlationData.RequestID.String()), slog.String("client_request_id", correlationData.ClientRequestID), slog.String("correlation_request_id", correlationData.CorrelationRequestID), } - if pathValue = r.PathValue(PathSegmentSubscriptionID); pathValue != "" { - attrs = append(attrs, slog.String("subscription_id", pathValue)) + subscriptionID := r.PathValue(PathSegmentSubscriptionID) + if subscriptionID != "" { + attrs = append(attrs, slog.String("subscription_id", subscriptionID)) } - if pathValue = r.PathValue(PathSegmentResourceGroupName); pathValue != "" { - attrs = append(attrs, slog.String("resource_group", pathValue)) + resourceGroup := r.PathValue(PathSegmentResourceGroupName) + if resourceGroup != "" { + attrs = append(attrs, slog.String("resource_group", resourceGroup)) } - if pathValue = r.PathValue(PathSegmentResourceName); pathValue != "" { - attrs = append(attrs, slog.String("resource_name", pathValue)) - resource_id := fmt.Sprintf("/subscriptions/%s/resourcegroups/%s/providers/%s/%s", - r.PathValue(PathSegmentSubscriptionID), - r.PathValue(PathSegmentResourceGroupName), - api.ResourceType, - pathValue) - attrs = append(attrs, slog.String("resource_id", resource_id)) + resourceName := r.PathValue(PathSegmentResourceName) + if resourceName != "" { + attrs = append(attrs, slog.String("resource_name", resourceName)) } - logger, err := LoggerFromContext(ctx) - if err != nil { - DefaultLogger().Error(err.Error()) - arm.WriteInternalServerError(w) - return + wholePath := subscriptionID != "" && resourceGroup != "" && resourceName != "" + if wholePath { + format := "/subscriptions/%s/resourcegroups/%s/providers/%s/%s" + resource_id := fmt.Sprintf(format, subscriptionID, resourceGroup, api.ResourceType, resourceName) + attrs = append(attrs, slog.String("resource_id", resource_id)) } - handler := logger.Handler() - ctx = ContextWithLogger(ctx, slog.New(handler.WithAttrs(attrs))) - - next(w, r.WithContext(ctx)) + return attrs } diff --git a/frontend/middleware_logging_test.go b/frontend/middleware_logging_test.go new file mode 100644 index 000000000..6b5c7ad36 --- /dev/null +++ b/frontend/middleware_logging_test.go @@ -0,0 +1,273 @@ +package main + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + + "github.com/Azure/ARO-HCP/internal/api" + "github.com/Azure/ARO-HCP/internal/api/arm" + "github.com/google/uuid" +) + +const ( + client_request_id = "random_client_request_id" + correlation_request_id string = "random_correlation_request_id" +) + +func TestMiddlewareLoggingPostMux(t *testing.T) { + type testCase struct { + name string + header http.Header + } + + tt := testCase{ + name: "is able to process and forward the values from request's header to context", + header: http.Header{ + arm.HeaderNameClientRequestID: []string{client_request_id}, + arm.HeaderNameCorrelationRequestID: []string{correlation_request_id}, + arm.HeaderNameRequestID: []string{uuid.NewString()}, + }, + } + + t.Run(tt.name, func(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + + request.Header = tt.header + + // we assume the request carries a logger, we set it explicitly to not fail + ctx := ContextWithLogger(request.Context(), DefaultLogger()) + request = request.WithContext(ctx) + + next := func(w http.ResponseWriter, r *http.Request) { + request = r // capture modified request + w.WriteHeader(http.StatusOK) + } + + writer := httptest.NewRecorder() + MiddlewareLoggingPostMux(writer, request, next) + + result, err := CorrelationDataFromContext(request.Context()) + if err != nil { + t.Fatal(err) + } + + if result.ClientRequestID != client_request_id { + t.Fatalf("ClientRequestID from header was not propperly propagated to requestcontext, expected %v, but got %v", + client_request_id, + result.ClientRequestID) + } + }) + +} + +// ReqPathModifier is an alias to a function that receives a request +// and it should modify its Path value as needed, for testing purposes. +type ReqPathModifier func(req *http.Request) + +// noModifyReqfunc is a function that receives a request and does not modify it. +func noModifyReqfunc(req *http.Request) { + // empty on purpose +} + +func Test_getLogAttrs(t *testing.T) { + var expectedRequestID = uuid.New() + + fakeSubscriptionId := "the_subscription_id" + fakeResourceGroupName := "the_resource_group_name" + fakeResourceName := "the_resource_name" + + sampleCorrelationData := &arm.CorrelationData{ + RequestID: expectedRequestID, + ClientRequestID: client_request_id, + CorrelationRequestID: correlation_request_id, + RequestTime: time.Now(), + } + + commonAttrs := []slog.Attr{ + slog.String("request_id", expectedRequestID.String()), + slog.String("client_request_id", client_request_id), + slog.String("correlation_request_id", correlation_request_id), + } + + type testCase struct { + name string + correlationData *arm.CorrelationData + req *http.Request + want []slog.Attr + setReqPathValue ReqPathModifier + } + + tests := []testCase{ + { + name: "handles the common logging attributes", + correlationData: sampleCorrelationData, + req: &http.Request{}, + want: commonAttrs, + setReqPathValue: noModifyReqfunc, + }, + { + name: "handles the common attributes and the attributes for the subscription_id segment path", + correlationData: sampleCorrelationData, + req: &http.Request{}, + want: append(commonAttrs, slog.String("subscription_id", fakeSubscriptionId)), + setReqPathValue: func(req *http.Request) { + req.SetPathValue(PathSegmentSubscriptionID, fakeSubscriptionId) + }, + }, + { + name: "handles the common attributes and the attributes for the resourcegroupname path", + correlationData: sampleCorrelationData, + req: &http.Request{}, + want: append(commonAttrs, slog.String("resource_group", fakeResourceGroupName)), + setReqPathValue: func(req *http.Request) { + req.SetPathValue(PathSegmentResourceGroupName, fakeResourceGroupName) + }, + }, + { + name: "handles the common attributes and the attributes for the resourcegroupname path", + correlationData: sampleCorrelationData, + req: &http.Request{}, + want: append(commonAttrs, slog.String("resource_group", fakeResourceGroupName)), + setReqPathValue: func(req *http.Request) { + req.SetPathValue(PathSegmentResourceGroupName, fakeResourceGroupName) + }, + }, + { + name: "handles the common attributes and the attributes for the resourcename path, and produces the correct resourceID attribute", + correlationData: sampleCorrelationData, + req: &http.Request{}, + want: append( + commonAttrs, + slog.String("subscription_id", fakeSubscriptionId), + slog.String("resource_group", fakeResourceGroupName), + slog.String("resource_name", fakeResourceName), + slog.String( + "resource_id", + fmt.Sprintf( + "/subscriptions/%s/resourcegroups/%s/providers/%s/%s", + fakeSubscriptionId, + fakeResourceGroupName, + api.ResourceType, + fakeResourceName)), + ), + setReqPathValue: func(req *http.Request) { + // assuming the PathSegmentResourceName is present in the Path + req.SetPathValue(PathSegmentResourceName, fakeResourceName) + + // assuming the PathSegmentSubscriptionID is present in the Path + req.SetPathValue(PathSegmentSubscriptionID, fakeSubscriptionId) + + // assuming the PathSegmentResourceGroupName is present in the Path + req.SetPathValue(PathSegmentResourceGroupName, fakeResourceGroupName) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setReqPathValue(tt.req) + got := getLogAttrs(tt.correlationData, tt.req) + if !reflect.DeepEqual(tt.want, got) { + t.Errorf("want %v, but got %v", tt.want, got) + } + }) + } +} + +func Test_setHeaders(t *testing.T) { + var expectedRequestId = uuid.New() + const expectedClientRequestId = "the_client_request_id" + + type testCase struct { + name string + w http.ResponseWriter + r *http.Request + correlationData *arm.CorrelationData + expectedHeaders http.Header + } + + tests := []testCase{ + { + name: "should set the requestId header to the value of correlation data", + w: &httptest.ResponseRecorder{}, + r: &http.Request{}, + correlationData: &arm.CorrelationData{RequestID: expectedRequestId}, + expectedHeaders: http.Header{ + arm.HeaderNameRequestID: []string{expectedRequestId.String()}, + }, + }, + { + name: "should set the clientRequestId header to the value of correlation data when the 'should return client request id' header is true", + w: &httptest.ResponseRecorder{}, + r: &http.Request{ + Header: http.Header{ + arm.HeaderNameReturnClientRequestID: []string{"true"}, + }, + }, + correlationData: &arm.CorrelationData{ + RequestID: expectedRequestId, + ClientRequestID: expectedClientRequestId, + }, + expectedHeaders: http.Header{ + arm.HeaderNameRequestID: []string{expectedRequestId.String()}, + arm.HeaderNameClientRequestID: []string{expectedClientRequestId}, + }, + }, + { + name: "should not set the clientRequestId header to the value of correlation data when the 'should return client request id' header is false", + w: &httptest.ResponseRecorder{}, + r: &http.Request{ + Header: http.Header{ + arm.HeaderNameReturnClientRequestID: []string{"false"}, + }, + }, + correlationData: &arm.CorrelationData{ + RequestID: expectedRequestId, + ClientRequestID: expectedClientRequestId, + }, + expectedHeaders: http.Header{ + arm.HeaderNameRequestID: []string{expectedRequestId.String()}, + }, + }, + { + name: "should not set the clientRequestId header to the value from correlation data when header is empty", + w: &httptest.ResponseRecorder{}, + r: &http.Request{}, + correlationData: &arm.CorrelationData{ + RequestID: expectedRequestId, + ClientRequestID: expectedClientRequestId, + }, + expectedHeaders: http.Header{ + arm.HeaderNameRequestID: []string{expectedRequestId.String()}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setHeaders(tt.w, tt.r, tt.correlationData) + assertAllHeadersAreWritten(t, tt.expectedHeaders, tt.w) + }) + } +} + +// assertAllHeadersAreWritten asserts that all the headers h are written in w +func assertAllHeadersAreWritten(t *testing.T, h http.Header, w http.ResponseWriter) { + for expectedKey, expectedValues := range h { + valueInHeader := w.Header().Get(expectedKey) + if valueInHeader == "" { + t.Fatalf("header with key %v is not present in response writer\n", expectedKey) + } + + if valueInHeader != expectedValues[0] { + t.Fatalf("header with key %v and value %v is different than expected value %v in response writer\n", expectedKey, valueInHeader, expectedValues[0]) + } + } +} diff --git a/internal/api/arm/correlation_test.go b/internal/api/arm/correlation_test.go new file mode 100644 index 000000000..79b19afea --- /dev/null +++ b/internal/api/arm/correlation_test.go @@ -0,0 +1,52 @@ +package arm + +import ( + "net/http" + "testing" + + "github.com/google/uuid" +) + +func TestNewCorrelationData(t *testing.T) { + const ( + client_request_id = "random_client_request_id" + correlation_request_id string = "random_correlation_request_id" + ) + + tests := []struct { + name string + request *http.Request + want *CorrelationData + }{ + { + name: "NewCorrelationData returns the appropriate correlation data from request", + request: &http.Request{ + Header: http.Header{ + HeaderNameClientRequestID: []string{client_request_id}, + HeaderNameCorrelationRequestID: []string{correlation_request_id}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + correlationData := NewCorrelationData(tt.request) + + if correlationData.RequestID == uuid.Nil { + t.Fatalf("correlationData.RequestID is nil") + } + + if correlationData.ClientRequestID != client_request_id { + t.Errorf("got %v, but want %v", correlationData.ClientRequestID, client_request_id) + } + + if correlationData.CorrelationRequestID != correlation_request_id { + t.Errorf("got %v, but want %v", correlationData.CorrelationRequestID, correlation_request_id) + } + + if correlationData.RequestTime.IsZero() { + t.Fatalf("correlationData.RequestTime was not initialized") + } + }) + } +}