diff --git a/authorizer_test.go b/authorizer_test.go index 3f5c564..a878a55 100644 --- a/authorizer_test.go +++ b/authorizer_test.go @@ -66,7 +66,7 @@ func (suite *AuthorizersTestSuite) TestAuthorize() { as = as.Append( AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) - suite.Same(testToken, token) + suite.Equal(testToken, token) suite.Equal(placeholderResource, resource) return err }), @@ -126,7 +126,7 @@ func (suite *AuthorizersTestSuite) TestAny() { as = as.Append( AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error { suite.Same(testCtx, ctx) - suite.Same(testToken, token) + suite.Equal(testToken, token) suite.Equal(placeholderResource, resource) return err }), diff --git a/mocks_test.go b/mocks_test.go new file mode 100644 index 0000000..00d0e3b --- /dev/null +++ b/mocks_test.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package bascule + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type testToken string + +func (tt testToken) Principal() string { return string(tt) } + +type mockValidator[S any] struct { + mock.Mock +} + +func (m *mockValidator[S]) Validate(ctx context.Context, source S, token Token) (Token, error) { + args := m.Called(ctx, source, token) + t, _ := args.Get(0).(Token) + return t, args.Error(1) +} + +func (m *mockValidator[S]) ExpectValidate(ctx context.Context, source S, token Token) *mock.Call { + return m.On("Validate", ctx, source, token) +} + +func assertValidators[S any](t mock.TestingT, vs ...Validator[S]) (passed bool) { + for _, v := range vs { + passed = v.(*mockValidator[S]).AssertExpectations(t) && passed + } + + return +} diff --git a/testSuite_test.go b/testSuite_test.go index 554ef02..987cf63 100644 --- a/testSuite_test.go +++ b/testSuite_test.go @@ -14,14 +14,6 @@ const ( testScheme Scheme = "Test" ) -type testToken struct { - principal string -} - -func (tt *testToken) Principal() string { - return tt.principal -} - // TestSuite holds generally useful functionality for testing bascule. type TestSuite struct { suite.Suite @@ -43,9 +35,7 @@ func (suite *TestSuite) testCredentials() Credentials { } func (suite *TestSuite) testToken() Token { - return &testToken{ - principal: "test", - } + return testToken("test") } func (suite *TestSuite) contexter(ctx context.Context) Contexter { diff --git a/validator.go b/validator.go index 5aaa32b..c0bea98 100644 --- a/validator.go +++ b/validator.go @@ -84,15 +84,20 @@ type ValidatorFunc[S any] interface { // and uncurry a closure. type validatorFunc[S any] func(context.Context, S, Token) (Token, error) -func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { - return vf(ctx, source, t) +func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(ctx, source, t) + if next == nil { + next = t + } + + return } var ( - tokenReturnsError = reflect.TypeOf((func(Token) error)(nil)) - tokenReturnsTokenError = reflect.TypeOf((func(Token) (Token, error))(nil)) - contextTokenReturnsError = reflect.TypeOf((func(context.Context, Token) error)(nil)) - contextTokenReturnsTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) + tokenReturnError = reflect.TypeOf((func(Token) error)(nil)) + tokenReturnTokenAndError = reflect.TypeOf((func(Token) (Token, error))(nil)) + contextTokenReturnError = reflect.TypeOf((func(context.Context, Token) error)(nil)) + contextTokenReturnTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) ) // asValidatorSimple tries simple conversions on f. This function will not catch @@ -102,14 +107,14 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { case func(Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(t) + return t, vf(t) }, ) case func(S, Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(source, t) + return t, vf(source, t) }, ) @@ -140,14 +145,14 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { case func(context.Context, Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(ctx, t) + return t, vf(ctx, t) }, ) case func(context.Context, S, Token) error: v = validatorFunc[S]( func(ctx context.Context, source S, t Token) (Token, error) { - return nil, vf(ctx, source, t) + return t, vf(ctx, source, t) }, ) @@ -171,7 +176,8 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { } // AsValidator takes a ValidatorFunc closure and returns a Validator instance that -// executes that closure. +// executes that closure. This function can also convert custom types which can +// be converted to any of the closure signatures. func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { // first, try the simple way: if v := asValidatorSimple[S](f); v != nil { @@ -182,24 +188,24 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { // require the source type. fVal := reflect.ValueOf(f) switch { - case fVal.CanConvert(tokenReturnsError): + case fVal.CanConvert(tokenReturnError): return asValidatorSimple[S]( - fVal.Convert(tokenReturnsError).Interface().(func(Token) error), + fVal.Convert(tokenReturnError).Interface().(func(Token) error), ) - case fVal.CanConvert(tokenReturnsTokenError): + case fVal.CanConvert(tokenReturnTokenAndError): return asValidatorSimple[S]( - fVal.Convert(tokenReturnsError).Interface().(func(Token) (Token, error)), + fVal.Convert(tokenReturnTokenAndError).Interface().(func(Token) (Token, error)), ) - case fVal.CanConvert(contextTokenReturnsError): + case fVal.CanConvert(contextTokenReturnError): return asValidatorSimple[S]( - fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) error), + fVal.Convert(contextTokenReturnError).Interface().(func(context.Context, Token) error), ) - case fVal.CanConvert(contextTokenReturnsTokenError): + case fVal.CanConvert(contextTokenReturnTokenError): return asValidatorSimple[S]( - fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) (Token, error)), + fVal.Convert(contextTokenReturnTokenError).Interface().(func(context.Context, Token) (Token, error)), ) } @@ -219,6 +225,7 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { ) } else { // we know this can be converted to this final type + ft := reflect.TypeOf((func(context.Context, S, Token) (Token, error))(nil)) return asValidatorSimple[S]( fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)), ) diff --git a/validator_test.go b/validator_test.go index 94fd939..af45572 100644 --- a/validator_test.go +++ b/validator_test.go @@ -4,6 +4,9 @@ package bascule import ( + "context" + "errors" + "fmt" "testing" "github.com/stretchr/testify/suite" @@ -11,6 +14,357 @@ import ( type ValidatorsTestSuite struct { TestSuite + + expectedCtx context.Context + expectedSource int + inputToken Token + outputToken Token + expectedErr error +} + +func (suite *ValidatorsTestSuite) SetupSuite() { + type contextKey struct{} + suite.expectedCtx = context.WithValue( + context.Background(), + contextKey{}, + "value", + ) + + suite.expectedSource = 123 + suite.inputToken = testToken("input token") + suite.outputToken = testToken("output token") + suite.expectedErr = errors.New("expected validator error") +} + +// assertNoTransform verifies that the validator returns the same token as the input token. +func (suite *ValidatorsTestSuite) assertNoTransform(v Validator[int]) { + suite.Require().NotNil(v) + actualToken, actualErr := v.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, actualToken) + suite.ErrorIs(suite.expectedErr, actualErr) +} + +// assertTransform verifies a validator that returns a different token than the input token. +func (suite *ValidatorsTestSuite) assertTransform(v Validator[int]) { + suite.Require().NotNil(v) + actualToken, actualErr := v.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.outputToken, actualToken) + suite.ErrorIs(suite.expectedErr, actualErr) +} + +// validateToken is a ValidatorFunc of the signature func(Token) error +func (suite *ValidatorsTestSuite) validateToken(actualToken Token) error { + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateSourceToken is a ValidatorFunc of the signature func(, Token) error +func (suite *ValidatorsTestSuite) validateSourceToken(actualSource int, actualToken Token) error { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateContextToken is a ValidatorFunc of the signature func(context.Context, Token) error +func (suite *ValidatorsTestSuite) validateContextToken(actualCtx context.Context, actualToken Token) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// validateContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) error +func (suite *ValidatorsTestSuite) validateContextSourceToken(actualCtx context.Context, actualSource int, actualToken Token) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.expectedErr +} + +// transformToken is a ValidatorFunc of the signature func(Token) (Token, error). +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformToken(actualToken Token) (Token, error) { + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformTokenToNil is a ValidatorFunc of the signature func(Token) (Token, error). +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformTokenToNil(actualToken Token) (Token, error) { + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformSourceToken is a ValidatorFunc of the signature func(, Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformSourceToken(actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformSourceTokenToNil is a ValidatorFunc of the signature func(, Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformSourceTokenToNil(actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformContextToken is a ValidatorFunc of the signature func(context.context, Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformContextToken(actualCtx context.Context, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformContextTokenToNil is a ValidatorFunc of the signature func(context.context, Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformContextTokenToNil(actualCtx context.Context, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +// transformContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) (Token, error) +// This variant returns suite.outputToken. +func (suite *ValidatorsTestSuite) transformContextSourceToken(actualCtx context.Context, actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return suite.outputToken, suite.expectedErr +} + +// transformContextSourceToken is a ValidatorFunc of the signature func(context.Context, , Token) (Token, error) +// This variant returns a nil Token, indicating that the original token is unchanged. +func (suite *ValidatorsTestSuite) transformContextSourceTokenToNil(actualCtx context.Context, actualSource int, actualToken Token) (Token, error) { + suite.Equal(suite.expectedCtx, actualCtx) + suite.Equal(suite.expectedSource, actualSource) + suite.Equal(suite.inputToken, actualToken) + return nil, suite.expectedErr +} + +func (suite *ValidatorsTestSuite) testAsValidatorToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(Token) error + f := Custom(suite.validateToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(Token) (Token, error) + f := Custom(suite.transformToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorSourceToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateSourceToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(int, Token) error + f := Custom(suite.validateSourceToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformSourceToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformSourceTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(int, Token) (Token, error) + f := Custom(suite.transformSourceToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorContextToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateContextToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, Token) error + f := Custom(suite.validateContextToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformContextToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformContextTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, Token) (Token, error) + f := Custom(suite.transformContextToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) testAsValidatorContextSourceToken() { + suite.Run("ReturnError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.validateContextSourceToken) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int, Token) error + f := Custom(suite.validateContextSourceToken) + v := AsValidator[int](f) + suite.assertNoTransform(v) + }) + }) + + suite.Run("ReturnTokenError", func() { + suite.Run("Simple", func() { + v := AsValidator[int](suite.transformContextSourceToken) + suite.assertTransform(v) + }) + + suite.Run("NilOutputToken", func() { + v := AsValidator[int](suite.transformContextSourceTokenToNil) + suite.assertNoTransform(v) + }) + + suite.Run("CustomType", func() { + type Custom func(context.Context, int, Token) (Token, error) + f := Custom(suite.transformContextSourceToken) + v := AsValidator[int](f) + suite.assertTransform(v) + }) + }) +} + +func (suite *ValidatorsTestSuite) TestAsValidator() { + suite.Run("Token", suite.testAsValidatorToken) + suite.Run("SourceToken", suite.testAsValidatorSourceToken) + suite.Run("ContextToken", suite.testAsValidatorContextToken) + suite.Run("ContextSourceToken", suite.testAsValidatorContextSourceToken) +} + +// newValidators constructs an array of validators that can only be called once +// and which successfully validate the suite's input token. +func (suite *ValidatorsTestSuite) newValidators(count int) (vs []Validator[int]) { + vs = make([]Validator[int], 0, count) + for len(vs) < cap(vs) { + v := new(mockValidator[int]) + v.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + vs = append(vs, v) + } + + return +} + +func (suite *ValidatorsTestSuite) TestValidate() { + suite.Run("NoValidators", func() { + outputToken, err := Validate[int](suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, outputToken) + suite.NoError(err) + }) + + suite.Run("NilOutputToken", func() { + for _, count := range []int{1, 2, 5} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + vs := suite.newValidators(count) + actualToken, actualErr := Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken, vs...) + suite.Equal(suite.inputToken, actualToken) + suite.NoError(actualErr) + assertValidators(suite.T(), vs...) + }) + } + }) +} + +func (suite *ValidatorsTestSuite) TestCompositeValidators() { + suite.Run("Empty", func() { + var vs Validators[int] + outputToken, err := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, outputToken) + suite.NoError(err) + }) + + suite.Run("NotEmpty", func() { + suite.Run("len=1", func() { + v := new(mockValidator[int]) + v.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(suite.outputToken, nil).Once() + + var vs Validators[int] + vs = vs.Append(v) + actualToken, actualErr := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.outputToken, actualToken) + suite.NoError(actualErr) + assertValidators(suite.T(), v) + }) + + suite.Run("len=2", func() { + v1 := new(mockValidator[int]) + v1.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + v2 := new(mockValidator[int]) + v2.ExpectValidate(suite.expectedCtx, suite.expectedSource, suite.inputToken). + Return(nil, nil).Once() + + var vs Validators[int] + vs = vs.Append(v1, v2) + actualToken, actualErr := vs.Validate(suite.expectedCtx, suite.expectedSource, suite.inputToken) + suite.Equal(suite.inputToken, actualToken) // the token should be unchanged + suite.NoError(actualErr) + assertValidators(suite.T(), v1, v2) + }) + }) } func TestValidators(t *testing.T) {