Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/middleware tests #282

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions approver.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func (as Approvers[R]) Append(more ...Approver[R]) Approvers[R] {
return append(as, more...)
}

// AppendFunc is a closure variant of Append that makes working with
// approvers that are functions a little easier.
func (as Approvers[R]) AppendFunc(more ...ApproverFunc[R]) Approvers[R] {
for _, m := range more {
as = append(as, m)
}

return as
}

// Approve requires all approvers in this sequence to allow access. This
// method supplies a logical AND.
//
Expand Down
77 changes: 54 additions & 23 deletions approver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,63 @@ func (suite *ApproversTestSuite) TestAuthorize() {
},
}

for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
as Approvers[string]
)
suite.Run("Append", func() {
for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
as Approvers[string]
)

for _, err := range testCase.results {
err := err
as = as.Append(
ApproverFunc[string](func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
}),
for _, err := range testCase.results {
err := err
as = as.Append(
ApproverFunc[string](func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
}),
)
}

suite.Equal(
testCase.expectedErr,
as.Approve(testCtx, placeholderResource, testToken),
)
})
}
})

suite.Run("AppendFunc", func() {
for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
as Approvers[string]
)
}

suite.Equal(
testCase.expectedErr,
as.Approve(testCtx, placeholderResource, testToken),
)
})
}
for _, err := range testCase.results {
err := err
as = as.AppendFunc(
func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
},
)
}

suite.Equal(
testCase.expectedErr,
as.Approve(testCtx, placeholderResource, testToken),
)
})
}
})
}

func (suite *ApproversTestSuite) TestAny() {
Expand Down
11 changes: 11 additions & 0 deletions authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ func WithApprovers[R any](more ...Approver[R]) AuthorizerOption[R] {
)
}

// WithApproverFuncs is a closure variant of WithApprovers that eases the
// syntactical pain of dealing with approvers that are functions.
func WithApproverFuncs[R any](more ...ApproverFunc[R]) AuthorizerOption[R] {
return authorizerOptionFunc[R](
func(a *Authorizer[R]) error {
a.approvers = a.approvers.AppendFunc(more...)
return nil
},
)
}

// NewAuthorizer constructs an Authorizer workflow using the supplied options.
//
// If no options are supplied, the returned Authorizer will authorize all tokens
Expand Down
5 changes: 5 additions & 0 deletions authorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {

approver1 = new(mockApprover[string])
approver2 = new(mockApprover[string])
approver3 = new(mockApprover[string])

listener1 = new(mockAuthorizeListener[string])
listener2 = new(mockAuthorizeListener[string])

a = suite.newAuthorizer(
WithApprovers(approver1, approver2),
WithApproverFuncs(approver3.Approve),
WithAuthorizeListeners(listener1),
WithAuthorizeListenerFuncs(listener2.OnEvent),
)
Expand All @@ -76,6 +78,8 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
Return(nil).Once()
approver2.ExpectApprove(expectedCtx, expectedResource, expectedToken).
Return(nil).Once()
approver3.ExpectApprove(expectedCtx, expectedResource, expectedToken).
Return(nil).Once()

listener1.ExpectOnEvent(AuthorizeEvent[string]{
Resource: expectedResource,
Expand All @@ -96,6 +100,7 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
listener2.AssertExpectations(suite.T())
approver1.AssertExpectations(suite.T())
approver2.AssertExpectations(suite.T())
approver3.AssertExpectations(suite.T())
}

func (suite *AuthorizerTestSuite) TestFullFirstApproverFail() {
Expand Down
14 changes: 14 additions & 0 deletions basculehttp/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ type authorizationParserOptionFunc func(*AuthorizationParser) error

func (apof authorizationParserOptionFunc) apply(ap *AuthorizationParser) error { return apof(ap) }

// WithAuthorizationHeader changes the name of the header holding the token. By default,
// the header used is DefaultAuthorizationHeader.
func WithAuthorizationHeader(header string) AuthorizationParserOption {
return authorizationParserOptionFunc(func(ap *AuthorizationParser) error {
ap.header = header
return nil
})
}

// WithScheme registers a string-based token parser that handles a
// specific authorization scheme. Invocations to this option are cumulative
// and will overwrite any existing registration.
func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) AuthorizationParserOption {
return authorizationParserOptionFunc(func(ap *AuthorizationParser) error {
// we want case-insensitive matches, so lowercase everything
Expand All @@ -71,11 +76,20 @@ func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) Authorization
})
}

// WithBasic is a shorthand for WithScheme that registers basic token parsing using
// the default scheme.
func WithBasic() AuthorizationParserOption {
return WithScheme(SchemeBasic, BasicTokenParser{})
}

// AuthorizationParsers is a bascule.TokenParser that handles the Authorization header.
type AuthorizationParser struct {
header string
parsers map[Scheme]bascule.TokenParser[string]
}

// NewAuthorizationParser constructs an Authorization parser from a set
// of configuration options.
func NewAuthorizationParser(opts ...AuthorizationParserOption) (*AuthorizationParser, error) {
ap := &AuthorizationParser{
parsers: make(map[Scheme]bascule.TokenParser[string]),
Expand Down
3 changes: 0 additions & 3 deletions basculehttp/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ func (bt basicToken) Password() string {

// BasicTokenParser is a string-based bascule.TokenParser that produces
// BasicToken instances from strings.
//
// An instance of this parser may be passed to WithScheme in order to
// configure an AuthorizationParser.
type BasicTokenParser struct{}

// Parse assumes that value is of the format required by https://datatracker.ietf.org/doc/html/rfc7617.
Expand Down
6 changes: 3 additions & 3 deletions basculehttp/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
)

const (
// WWWAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges
// WWWAuthenticateHeader is the HTTP header used for StatusUnauthorized challenges
// when encountered by the Middleware.
//
// This value is used by default when no header is supplied to Challenges.WriteHeader.
WWWAuthenticateHeaderName = "WWW-Authenticate"
WWWAuthenticateHeader = "WWW-Authenticate"
)

var (
Expand Down Expand Up @@ -216,7 +216,7 @@ func (chs Challenges) Append(ch ...Challenge) Challenges {
// halted and that error is returned.
func (chs Challenges) WriteHeader(name string, h http.Header) error {
if len(name) == 0 {
name = WWWAuthenticateHeaderName
name = WWWAuthenticateHeader
}

var o strings.Builder
Expand Down
2 changes: 1 addition & 1 deletion basculehttp/challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (suite *ChallengeTestSuite) testChallengesValid() {
suite.Run("DefaultHeader", func() {
header := make(http.Header)
suite.NoError(testCase.challenges.WriteHeader("", header))
suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeaderName))
suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeader))
})

suite.Run("CustomHeader", func() {
Expand Down
50 changes: 12 additions & 38 deletions basculehttp/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package basculehttp

import (
"encoding"
"encoding/json"
"errors"
"net/http"

Expand Down Expand Up @@ -39,14 +37,16 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
var sc statusCoder

switch {
// check if it's a status coder first, so that we can
// override status codes for built-in errors.
case errors.As(err, &sc):
return sc.StatusCode()

case errors.Is(err, bascule.ErrMissingCredentials):
return http.StatusUnauthorized

case errors.Is(err, bascule.ErrInvalidCredentials):
return http.StatusBadRequest

case errors.As(err, &sc):
return sc.StatusCode()
}

return 0
Expand All @@ -56,40 +56,10 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
// be used in an HTTP response body.
type ErrorMarshaler func(request *http.Request, err error) (contentType string, content []byte, marshalErr error)

// DefaultErrorMarshaler examines the error for several standard marshalers. The supported marshalers
// together with the returned content types are as follows, in order:
//
// - json.Marshaler "application/json"
// - encoding.TextMarshaler "text/plain; charset=utf-8"
// - encoding.BinaryMarshaler "application/octet-stream"
//
// If the error or any of its wrapped errors does not implement a supported marshaler interface,
// the error's Error() text is used with a content type of "text/plain; charset=utf-8".
// DefaultErrorMarshaler returns a plaintext representation of the error.
func DefaultErrorMarshaler(_ *http.Request, err error) (contentType string, content []byte, marshalErr error) {
// walk the wrapped errors manually, since that's way more efficient
// that walking the error tree once for each desired type
for wrapped := err; wrapped != nil && len(content) == 0 && marshalErr == nil; wrapped = errors.Unwrap(wrapped) {
switch m := wrapped.(type) { //nolint: errorlint
case json.Marshaler:
contentType = "application/json"
content, marshalErr = m.MarshalJSON()

case encoding.TextMarshaler:
contentType = "text/plain; charset=utf-8"
content, marshalErr = m.MarshalText()

case encoding.BinaryMarshaler:
contentType = "application/octet-stream"
content, marshalErr = m.MarshalBinary()
}
}

if len(content) == 0 && marshalErr == nil {
// fallback
contentType = "text/plain; charset=utf-8"
content = []byte(err.Error())
}

contentType = "text/plain; charset=utf-8"
content = []byte(err.Error())
return
}

Expand All @@ -98,6 +68,10 @@ type statusCodeError struct {
statusCode int
}

func (err *statusCodeError) Unwrap() error {
return err.error
}

func (err *statusCodeError) StatusCode() int {
return err.statusCode
}
Expand Down
Loading
Loading