Skip to content

Commit

Permalink
fix error marshalling issues (#235)
Browse files Browse the repository at this point in the history
Errors from AuthRelationshipResponse were being marshalled but due to
the error interface could not be unmarshalled. This changes the type to
be a new Errors type which can handle encoding and decoding errors
without resulting in unmarshalling errors.

The errors returned implement the error interface however they cannot be
directly compared due to the errors being dynamically generated.

Signed-off-by: Mike Mason <[email protected]>
  • Loading branch information
mikemrm authored Jun 14, 2024
1 parent e749c7e commit 39def70
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 1 deletion.
66 changes: 65 additions & 1 deletion events/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package events
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"go.opentelemetry.io/otel"
Expand Down Expand Up @@ -303,11 +305,73 @@ func (r AuthRelationshipRelation) Validate() error {
return err
}

// Errors contains one or more errors and handles marshalling the errors.
// See [Errors.MarshalJSON] and [Errors.UnmarshalJSON] for details on how marshalling is done.
type Errors []error

// MarshalJSON encodes a string of arrays with each errors Error string.
// Entries which are nil are skipped.
// If no non nil errors are provided, null is returned.
func (e Errors) MarshalJSON() ([]byte, error) {
errs := make([]string, 0, len(e))

for _, err := range e {
if err != nil {
errs = append(errs, err.Error())
}
}

if len(errs) == 0 {
return []byte("null"), nil
}

return json.Marshal(errs)
}

// UnmarshalJSON converts a list of string errors into new errors.
// All errors unmarshalled are new errors and cannot be compared directly to another error.
// Errors should be checked using string comparison.
func (e *Errors) UnmarshalJSON(b []byte) error {
var errs []string

if err := json.Unmarshal(b, &errs); err != nil {
return err
}

if len(errs) == 0 {
*e = nil

return nil
}

*e = make(Errors, len(errs))

for i, err := range errs {
(*e)[i] = errors.New(err) //nolint:goerr113 // errors are dynamically returned
}

return nil
}

// Error returns each error on a new line.
// Nil error are not included.
func (e Errors) Error() string {
errs := make([]string, 0, len(e))

for _, err := range e {
if err != nil {
errs = append(errs, err.Error())
}
}

return strings.Join(errs, "\n")
}

// AuthRelationshipResponse contains the data structure expected to be received from an AuthRelationshipRequest
// message to write or delete an auth relationship from PermissionsAPI
type AuthRelationshipResponse struct {
// Errors contains any errors, if empty the request was successful
Errors []error `json:"errors"`
Errors Errors `json:"errors"`
// TraceContext is a map of values used for OpenTelemetry context propagation.
TraceContext map[string]string `json:"traceContext"`
// TraceID is the ID of the trace for this event
Expand Down
130 changes: 130 additions & 0 deletions events/nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package events_test

import (
"context"
"encoding/json"
"errors"
"os"
"testing"
"time"

Expand Down Expand Up @@ -193,6 +195,134 @@ func TestNATSRequestReply(t *testing.T) {

close(respGot)
}
func TestNATSRequestReplyMarshalling(t *testing.T) {
testCases := []struct {
name string
input events.AuthRelationshipResponse
expectDecoded map[string]any
expectResponse events.AuthRelationshipResponse
}{
{
"no error",
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
},
map[string]any{
"errors": nil,
"spanID": "",
"traceContext": map[string]any{},
"traceID": "some-id",
},
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
},
},
{
"with errors",
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
Errors: []error{
os.ErrInvalid,
},
},
map[string]any{
"errors": []any{
os.ErrInvalid.Error(),
},
"spanID": "",
"traceContext": map[string]any{},
"traceID": "some-id",
},
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
Errors: []error{
errors.New(os.ErrInvalid.Error()), //nolint:goerr113 // ensure equals same error with text
},
},
},
{
"nil errors skipped",
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
Errors: []error{
os.ErrInvalid,
nil,
os.ErrExist,
nil,
},
},
map[string]any{
"errors": []any{
os.ErrInvalid.Error(),
os.ErrExist.Error(),
},
"spanID": "",
"traceContext": map[string]any{},
"traceID": "some-id",
},
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
Errors: []error{
errors.New(os.ErrInvalid.Error()), //nolint:goerr113 // ensure equals same error with text
errors.New(os.ErrExist.Error()), //nolint:goerr113 // ensure equals same error with text
},
},
},
{
"all nil errors skipped",
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
Errors: []error{
nil,
nil,
},
},
map[string]any{
"errors": nil,
"spanID": "",
"traceContext": map[string]any{},
"traceID": "some-id",
},
events.AuthRelationshipResponse{
TraceID: "some-id",
TraceContext: map[string]string{},
Errors: nil,
},
},
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
encoded, err := json.Marshal(tc.input)
require.NoError(t, err, "unexpected error marshalling input")

decoded := map[string]any{}

err = json.Unmarshal(encoded, &decoded)
require.NoError(t, err, "unexpected error unmarshalling encoded input into map")

assert.Equal(t, tc.expectDecoded, decoded, "unexpected encoded response")

var response events.AuthRelationshipResponse

err = json.Unmarshal(encoded, &response)
require.NoError(t, err, "unexpected error unmarshalling encoded input into response")

assert.Equal(t, tc.expectResponse, response, "unexpected response")

assert.Equal(t, len(tc.expectResponse.Errors), len(response.Errors), "unexpected response error count")
})
}
}

func getSingleMessage[T any](messages <-chan events.Message[T], timeout time.Duration) (events.Message[T], error) {
select {
Expand Down

0 comments on commit 39def70

Please sign in to comment.