From 39def7048f31f2527349bbfd357807a93d7e625d Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Fri, 14 Jun 2024 14:49:44 -0500 Subject: [PATCH] fix error marshalling issues (#235) Errors from AuthRelationshipResponse were being marshalled but due to the error interface could not be unmarshalled. This changes the type to be a new Errors type which can handle encoding and decoding errors without resulting in unmarshalling errors. The errors returned implement the error interface however they cannot be directly compared due to the errors being dynamically generated. Signed-off-by: Mike Mason --- events/message.go | 66 +++++++++++++++++++++- events/nats_test.go | 130 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 1 deletion(-) diff --git a/events/message.go b/events/message.go index db547aa..7a1fc73 100644 --- a/events/message.go +++ b/events/message.go @@ -18,7 +18,9 @@ package events import ( "context" "encoding/json" + "errors" "fmt" + "strings" "time" "go.opentelemetry.io/otel" @@ -303,11 +305,73 @@ func (r AuthRelationshipRelation) Validate() error { return err } +// Errors contains one or more errors and handles marshalling the errors. +// See [Errors.MarshalJSON] and [Errors.UnmarshalJSON] for details on how marshalling is done. +type Errors []error + +// MarshalJSON encodes a string of arrays with each errors Error string. +// Entries which are nil are skipped. +// If no non nil errors are provided, null is returned. +func (e Errors) MarshalJSON() ([]byte, error) { + errs := make([]string, 0, len(e)) + + for _, err := range e { + if err != nil { + errs = append(errs, err.Error()) + } + } + + if len(errs) == 0 { + return []byte("null"), nil + } + + return json.Marshal(errs) +} + +// UnmarshalJSON converts a list of string errors into new errors. +// All errors unmarshalled are new errors and cannot be compared directly to another error. +// Errors should be checked using string comparison. +func (e *Errors) UnmarshalJSON(b []byte) error { + var errs []string + + if err := json.Unmarshal(b, &errs); err != nil { + return err + } + + if len(errs) == 0 { + *e = nil + + return nil + } + + *e = make(Errors, len(errs)) + + for i, err := range errs { + (*e)[i] = errors.New(err) //nolint:goerr113 // errors are dynamically returned + } + + return nil +} + +// Error returns each error on a new line. +// Nil error are not included. +func (e Errors) Error() string { + errs := make([]string, 0, len(e)) + + for _, err := range e { + if err != nil { + errs = append(errs, err.Error()) + } + } + + return strings.Join(errs, "\n") +} + // AuthRelationshipResponse contains the data structure expected to be received from an AuthRelationshipRequest // message to write or delete an auth relationship from PermissionsAPI type AuthRelationshipResponse struct { // Errors contains any errors, if empty the request was successful - Errors []error `json:"errors"` + Errors Errors `json:"errors"` // TraceContext is a map of values used for OpenTelemetry context propagation. TraceContext map[string]string `json:"traceContext"` // TraceID is the ID of the trace for this event diff --git a/events/nats_test.go b/events/nats_test.go index 13334a7..412e1c7 100644 --- a/events/nats_test.go +++ b/events/nats_test.go @@ -2,7 +2,9 @@ package events_test import ( "context" + "encoding/json" "errors" + "os" "testing" "time" @@ -193,6 +195,134 @@ func TestNATSRequestReply(t *testing.T) { close(respGot) } +func TestNATSRequestReplyMarshalling(t *testing.T) { + testCases := []struct { + name string + input events.AuthRelationshipResponse + expectDecoded map[string]any + expectResponse events.AuthRelationshipResponse + }{ + { + "no error", + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + }, + map[string]any{ + "errors": nil, + "spanID": "", + "traceContext": map[string]any{}, + "traceID": "some-id", + }, + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + }, + }, + { + "with errors", + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + Errors: []error{ + os.ErrInvalid, + }, + }, + map[string]any{ + "errors": []any{ + os.ErrInvalid.Error(), + }, + "spanID": "", + "traceContext": map[string]any{}, + "traceID": "some-id", + }, + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + Errors: []error{ + errors.New(os.ErrInvalid.Error()), //nolint:goerr113 // ensure equals same error with text + }, + }, + }, + { + "nil errors skipped", + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + Errors: []error{ + os.ErrInvalid, + nil, + os.ErrExist, + nil, + }, + }, + map[string]any{ + "errors": []any{ + os.ErrInvalid.Error(), + os.ErrExist.Error(), + }, + "spanID": "", + "traceContext": map[string]any{}, + "traceID": "some-id", + }, + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + Errors: []error{ + errors.New(os.ErrInvalid.Error()), //nolint:goerr113 // ensure equals same error with text + errors.New(os.ErrExist.Error()), //nolint:goerr113 // ensure equals same error with text + }, + }, + }, + { + "all nil errors skipped", + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + Errors: []error{ + nil, + nil, + }, + }, + map[string]any{ + "errors": nil, + "spanID": "", + "traceContext": map[string]any{}, + "traceID": "some-id", + }, + events.AuthRelationshipResponse{ + TraceID: "some-id", + TraceContext: map[string]string{}, + Errors: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + encoded, err := json.Marshal(tc.input) + require.NoError(t, err, "unexpected error marshalling input") + + decoded := map[string]any{} + + err = json.Unmarshal(encoded, &decoded) + require.NoError(t, err, "unexpected error unmarshalling encoded input into map") + + assert.Equal(t, tc.expectDecoded, decoded, "unexpected encoded response") + + var response events.AuthRelationshipResponse + + err = json.Unmarshal(encoded, &response) + require.NoError(t, err, "unexpected error unmarshalling encoded input into response") + + assert.Equal(t, tc.expectResponse, response, "unexpected response") + + assert.Equal(t, len(tc.expectResponse.Errors), len(response.Errors), "unexpected response error count") + }) + } +} func getSingleMessage[T any](messages <-chan events.Message[T], timeout time.Duration) (events.Message[T], error) { select {