Skip to content

Commit

Permalink
Merge pull request #282 from xmidt-org/feature/middleware-tests
Browse files Browse the repository at this point in the history
Feature/middleware tests
  • Loading branch information
johnabass authored Aug 19, 2024
2 parents dcb46b0 + 336432b commit 5927844
Show file tree
Hide file tree
Showing 13 changed files with 740 additions and 90 deletions.
10 changes: 10 additions & 0 deletions approver.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func (as Approvers[R]) Append(more ...Approver[R]) Approvers[R] {
return append(as, more...)
}

// AppendFunc is a closure variant of Append that makes working with
// approvers that are functions a little easier.
func (as Approvers[R]) AppendFunc(more ...ApproverFunc[R]) Approvers[R] {
for _, m := range more {
as = append(as, m)
}

return as
}

// Approve requires all approvers in this sequence to allow access. This
// method supplies a logical AND.
//
Expand Down
77 changes: 54 additions & 23 deletions approver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,63 @@ func (suite *ApproversTestSuite) TestAuthorize() {
},
}

for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
as Approvers[string]
)
suite.Run("Append", func() {
for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
as Approvers[string]
)

for _, err := range testCase.results {
err := err
as = as.Append(
ApproverFunc[string](func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
}),
for _, err := range testCase.results {
err := err
as = as.Append(
ApproverFunc[string](func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
}),
)
}

suite.Equal(
testCase.expectedErr,
as.Approve(testCtx, placeholderResource, testToken),
)
})
}
})

suite.Run("AppendFunc", func() {
for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
as Approvers[string]
)
}

suite.Equal(
testCase.expectedErr,
as.Approve(testCtx, placeholderResource, testToken),
)
})
}
for _, err := range testCase.results {
err := err
as = as.AppendFunc(
func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
},
)
}

suite.Equal(
testCase.expectedErr,
as.Approve(testCtx, placeholderResource, testToken),
)
})
}
})
}

func (suite *ApproversTestSuite) TestAny() {
Expand Down
11 changes: 11 additions & 0 deletions authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ func WithApprovers[R any](more ...Approver[R]) AuthorizerOption[R] {
)
}

// WithApproverFuncs is a closure variant of WithApprovers that eases the
// syntactical pain of dealing with approvers that are functions.
func WithApproverFuncs[R any](more ...ApproverFunc[R]) AuthorizerOption[R] {
return authorizerOptionFunc[R](
func(a *Authorizer[R]) error {
a.approvers = a.approvers.AppendFunc(more...)
return nil
},
)
}

// NewAuthorizer constructs an Authorizer workflow using the supplied options.
//
// If no options are supplied, the returned Authorizer will authorize all tokens
Expand Down
5 changes: 5 additions & 0 deletions authorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {

approver1 = new(mockApprover[string])
approver2 = new(mockApprover[string])
approver3 = new(mockApprover[string])

listener1 = new(mockAuthorizeListener[string])
listener2 = new(mockAuthorizeListener[string])

a = suite.newAuthorizer(
WithApprovers(approver1, approver2),
WithApproverFuncs(approver3.Approve),
WithAuthorizeListeners(listener1),
WithAuthorizeListenerFuncs(listener2.OnEvent),
)
Expand All @@ -76,6 +78,8 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
Return(nil).Once()
approver2.ExpectApprove(expectedCtx, expectedResource, expectedToken).
Return(nil).Once()
approver3.ExpectApprove(expectedCtx, expectedResource, expectedToken).
Return(nil).Once()

listener1.ExpectOnEvent(AuthorizeEvent[string]{
Resource: expectedResource,
Expand All @@ -96,6 +100,7 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
listener2.AssertExpectations(suite.T())
approver1.AssertExpectations(suite.T())
approver2.AssertExpectations(suite.T())
approver3.AssertExpectations(suite.T())
}

func (suite *AuthorizerTestSuite) TestFullFirstApproverFail() {
Expand Down
14 changes: 14 additions & 0 deletions basculehttp/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ type authorizationParserOptionFunc func(*AuthorizationParser) error

func (apof authorizationParserOptionFunc) apply(ap *AuthorizationParser) error { return apof(ap) }

// WithAuthorizationHeader changes the name of the header holding the token. By default,
// the header used is DefaultAuthorizationHeader.
func WithAuthorizationHeader(header string) AuthorizationParserOption {
return authorizationParserOptionFunc(func(ap *AuthorizationParser) error {
ap.header = header
return nil
})
}

// WithScheme registers a string-based token parser that handles a
// specific authorization scheme. Invocations to this option are cumulative
// and will overwrite any existing registration.
func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) AuthorizationParserOption {
return authorizationParserOptionFunc(func(ap *AuthorizationParser) error {
// we want case-insensitive matches, so lowercase everything
Expand All @@ -71,11 +76,20 @@ func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) Authorization
})
}

// WithBasic is a shorthand for WithScheme that registers basic token parsing using
// the default scheme.
func WithBasic() AuthorizationParserOption {
return WithScheme(SchemeBasic, BasicTokenParser{})
}

// AuthorizationParsers is a bascule.TokenParser that handles the Authorization header.
type AuthorizationParser struct {
header string
parsers map[Scheme]bascule.TokenParser[string]
}

// NewAuthorizationParser constructs an Authorization parser from a set
// of configuration options.
func NewAuthorizationParser(opts ...AuthorizationParserOption) (*AuthorizationParser, error) {
ap := &AuthorizationParser{
parsers: make(map[Scheme]bascule.TokenParser[string]),
Expand Down
3 changes: 0 additions & 3 deletions basculehttp/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ func (bt basicToken) Password() string {

// BasicTokenParser is a string-based bascule.TokenParser that produces
// BasicToken instances from strings.
//
// An instance of this parser may be passed to WithScheme in order to
// configure an AuthorizationParser.
type BasicTokenParser struct{}

// Parse assumes that value is of the format required by https://datatracker.ietf.org/doc/html/rfc7617.
Expand Down
6 changes: 3 additions & 3 deletions basculehttp/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
)

const (
// WWWAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges
// WWWAuthenticateHeader is the HTTP header used for StatusUnauthorized challenges
// when encountered by the Middleware.
//
// This value is used by default when no header is supplied to Challenges.WriteHeader.
WWWAuthenticateHeaderName = "WWW-Authenticate"
WWWAuthenticateHeader = "WWW-Authenticate"
)

var (
Expand Down Expand Up @@ -216,7 +216,7 @@ func (chs Challenges) Append(ch ...Challenge) Challenges {
// halted and that error is returned.
func (chs Challenges) WriteHeader(name string, h http.Header) error {
if len(name) == 0 {
name = WWWAuthenticateHeaderName
name = WWWAuthenticateHeader
}

var o strings.Builder
Expand Down
2 changes: 1 addition & 1 deletion basculehttp/challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (suite *ChallengeTestSuite) testChallengesValid() {
suite.Run("DefaultHeader", func() {
header := make(http.Header)
suite.NoError(testCase.challenges.WriteHeader("", header))
suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeaderName))
suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeader))
})

suite.Run("CustomHeader", func() {
Expand Down
50 changes: 12 additions & 38 deletions basculehttp/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package basculehttp

import (
"encoding"
"encoding/json"
"errors"
"net/http"

Expand Down Expand Up @@ -39,14 +37,16 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
var sc statusCoder

switch {
// check if it's a status coder first, so that we can
// override status codes for built-in errors.
case errors.As(err, &sc):
return sc.StatusCode()

case errors.Is(err, bascule.ErrMissingCredentials):
return http.StatusUnauthorized

case errors.Is(err, bascule.ErrInvalidCredentials):
return http.StatusBadRequest

case errors.As(err, &sc):
return sc.StatusCode()
}

return 0
Expand All @@ -56,40 +56,10 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
// be used in an HTTP response body.
type ErrorMarshaler func(request *http.Request, err error) (contentType string, content []byte, marshalErr error)

// DefaultErrorMarshaler examines the error for several standard marshalers. The supported marshalers
// together with the returned content types are as follows, in order:
//
// - json.Marshaler "application/json"
// - encoding.TextMarshaler "text/plain; charset=utf-8"
// - encoding.BinaryMarshaler "application/octet-stream"
//
// If the error or any of its wrapped errors does not implement a supported marshaler interface,
// the error's Error() text is used with a content type of "text/plain; charset=utf-8".
// DefaultErrorMarshaler returns a plaintext representation of the error.
func DefaultErrorMarshaler(_ *http.Request, err error) (contentType string, content []byte, marshalErr error) {
// walk the wrapped errors manually, since that's way more efficient
// that walking the error tree once for each desired type
for wrapped := err; wrapped != nil && len(content) == 0 && marshalErr == nil; wrapped = errors.Unwrap(wrapped) {
switch m := wrapped.(type) { //nolint: errorlint
case json.Marshaler:
contentType = "application/json"
content, marshalErr = m.MarshalJSON()

case encoding.TextMarshaler:
contentType = "text/plain; charset=utf-8"
content, marshalErr = m.MarshalText()

case encoding.BinaryMarshaler:
contentType = "application/octet-stream"
content, marshalErr = m.MarshalBinary()
}
}

if len(content) == 0 && marshalErr == nil {
// fallback
contentType = "text/plain; charset=utf-8"
content = []byte(err.Error())
}

contentType = "text/plain; charset=utf-8"
content = []byte(err.Error())
return
}

Expand All @@ -98,6 +68,10 @@ type statusCodeError struct {
statusCode int
}

func (err *statusCodeError) Unwrap() error {
return err.error
}

func (err *statusCodeError) StatusCode() int {
return err.statusCode
}
Expand Down
Loading

0 comments on commit 5927844

Please sign in to comment.