diff --git a/approver.go b/approver.go index 0f96929..598f0ff 100644 --- a/approver.go +++ b/approver.go @@ -39,6 +39,16 @@ func (as Approvers[R]) Append(more ...Approver[R]) Approvers[R] { return append(as, more...) } +// AppendFunc is a closure variant of Append that makes working with +// approvers that are functions a little easier. +func (as Approvers[R]) AppendFunc(more ...ApproverFunc[R]) Approvers[R] { + for _, m := range more { + as = append(as, m) + } + + return as +} + // Approve requires all approvers in this sequence to allow access. This // method supplies a logical AND. // diff --git a/approver_test.go b/approver_test.go index 184f839..b2808ef 100644 --- a/approver_test.go +++ b/approver_test.go @@ -53,32 +53,63 @@ func (suite *ApproversTestSuite) TestAuthorize() { }, } - for _, testCase := range testCases { - suite.Run(testCase.name, func() { - var ( - testCtx = suite.testContext() - testToken = suite.testToken() - as Approvers[string] - ) + suite.Run("Append", func() { + for _, testCase := range testCases { + suite.Run(testCase.name, func() { + var ( + testCtx = suite.testContext() + testToken = suite.testToken() + as Approvers[string] + ) - for _, err := range testCase.results { - err := err - as = as.Append( - ApproverFunc[string](func(ctx context.Context, resource string, token Token) error { - suite.Same(testCtx, ctx) - suite.Equal(testToken, token) - suite.Equal(placeholderResource, resource) - return err - }), + for _, err := range testCase.results { + err := err + as = as.Append( + ApproverFunc[string](func(ctx context.Context, resource string, token Token) error { + suite.Same(testCtx, ctx) + suite.Equal(testToken, token) + suite.Equal(placeholderResource, resource) + return err + }), + ) + } + + suite.Equal( + testCase.expectedErr, + as.Approve(testCtx, placeholderResource, testToken), + ) + }) + } + }) + + suite.Run("AppendFunc", func() { + for _, testCase := range testCases { + suite.Run(testCase.name, func() { + var ( + testCtx = suite.testContext() + testToken = suite.testToken() + as Approvers[string] ) - } - suite.Equal( - testCase.expectedErr, - as.Approve(testCtx, placeholderResource, testToken), - ) - }) - } + for _, err := range testCase.results { + err := err + as = as.AppendFunc( + func(ctx context.Context, resource string, token Token) error { + suite.Same(testCtx, ctx) + suite.Equal(testToken, token) + suite.Equal(placeholderResource, resource) + return err + }, + ) + } + + suite.Equal( + testCase.expectedErr, + as.Approve(testCtx, placeholderResource, testToken), + ) + }) + } + }) } func (suite *ApproversTestSuite) TestAny() { diff --git a/authorizer.go b/authorizer.go index 977c319..6cad639 100644 --- a/authorizer.go +++ b/authorizer.go @@ -62,6 +62,17 @@ func WithApprovers[R any](more ...Approver[R]) AuthorizerOption[R] { ) } +// WithApproverFuncs is a closure variant of WithApprovers that eases the +// syntactical pain of dealing with approvers that are functions. +func WithApproverFuncs[R any](more ...ApproverFunc[R]) AuthorizerOption[R] { + return authorizerOptionFunc[R]( + func(a *Authorizer[R]) error { + a.approvers = a.approvers.AppendFunc(more...) + return nil + }, + ) +} + // NewAuthorizer constructs an Authorizer workflow using the supplied options. // // If no options are supplied, the returned Authorizer will authorize all tokens diff --git a/authorizer_test.go b/authorizer_test.go index bfa688e..6e376ad 100644 --- a/authorizer_test.go +++ b/authorizer_test.go @@ -61,12 +61,14 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() { approver1 = new(mockApprover[string]) approver2 = new(mockApprover[string]) + approver3 = new(mockApprover[string]) listener1 = new(mockAuthorizeListener[string]) listener2 = new(mockAuthorizeListener[string]) a = suite.newAuthorizer( WithApprovers(approver1, approver2), + WithApproverFuncs(approver3.Approve), WithAuthorizeListeners(listener1), WithAuthorizeListenerFuncs(listener2.OnEvent), ) @@ -76,6 +78,8 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() { Return(nil).Once() approver2.ExpectApprove(expectedCtx, expectedResource, expectedToken). Return(nil).Once() + approver3.ExpectApprove(expectedCtx, expectedResource, expectedToken). + Return(nil).Once() listener1.ExpectOnEvent(AuthorizeEvent[string]{ Resource: expectedResource, @@ -96,6 +100,7 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() { listener2.AssertExpectations(suite.T()) approver1.AssertExpectations(suite.T()) approver2.AssertExpectations(suite.T()) + approver3.AssertExpectations(suite.T()) } func (suite *AuthorizerTestSuite) TestFullFirstApproverFail() { diff --git a/basculehttp/authorization.go b/basculehttp/authorization.go index f76c789..2ab4805 100644 --- a/basculehttp/authorization.go +++ b/basculehttp/authorization.go @@ -56,6 +56,8 @@ type authorizationParserOptionFunc func(*AuthorizationParser) error func (apof authorizationParserOptionFunc) apply(ap *AuthorizationParser) error { return apof(ap) } +// WithAuthorizationHeader changes the name of the header holding the token. By default, +// the header used is DefaultAuthorizationHeader. func WithAuthorizationHeader(header string) AuthorizationParserOption { return authorizationParserOptionFunc(func(ap *AuthorizationParser) error { ap.header = header @@ -63,6 +65,9 @@ func WithAuthorizationHeader(header string) AuthorizationParserOption { }) } +// WithScheme registers a string-based token parser that handles a +// specific authorization scheme. Invocations to this option are cumulative +// and will overwrite any existing registration. func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) AuthorizationParserOption { return authorizationParserOptionFunc(func(ap *AuthorizationParser) error { // we want case-insensitive matches, so lowercase everything @@ -71,11 +76,20 @@ func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) Authorization }) } +// WithBasic is a shorthand for WithScheme that registers basic token parsing using +// the default scheme. +func WithBasic() AuthorizationParserOption { + return WithScheme(SchemeBasic, BasicTokenParser{}) +} + +// AuthorizationParsers is a bascule.TokenParser that handles the Authorization header. type AuthorizationParser struct { header string parsers map[Scheme]bascule.TokenParser[string] } +// NewAuthorizationParser constructs an Authorization parser from a set +// of configuration options. func NewAuthorizationParser(opts ...AuthorizationParserOption) (*AuthorizationParser, error) { ap := &AuthorizationParser{ parsers: make(map[Scheme]bascule.TokenParser[string]), diff --git a/basculehttp/basic.go b/basculehttp/basic.go index 9db199c..5dbfcd7 100644 --- a/basculehttp/basic.go +++ b/basculehttp/basic.go @@ -38,9 +38,6 @@ func (bt basicToken) Password() string { // BasicTokenParser is a string-based bascule.TokenParser that produces // BasicToken instances from strings. -// -// An instance of this parser may be passed to WithScheme in order to -// configure an AuthorizationParser. type BasicTokenParser struct{} // Parse assumes that value is of the format required by https://datatracker.ietf.org/doc/html/rfc7617. diff --git a/basculehttp/challenge.go b/basculehttp/challenge.go index 73a1b24..35fc336 100644 --- a/basculehttp/challenge.go +++ b/basculehttp/challenge.go @@ -10,11 +10,11 @@ import ( ) const ( - // WWWAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges + // WWWAuthenticateHeader is the HTTP header used for StatusUnauthorized challenges // when encountered by the Middleware. // // This value is used by default when no header is supplied to Challenges.WriteHeader. - WWWAuthenticateHeaderName = "WWW-Authenticate" + WWWAuthenticateHeader = "WWW-Authenticate" ) var ( @@ -216,7 +216,7 @@ func (chs Challenges) Append(ch ...Challenge) Challenges { // halted and that error is returned. func (chs Challenges) WriteHeader(name string, h http.Header) error { if len(name) == 0 { - name = WWWAuthenticateHeaderName + name = WWWAuthenticateHeader } var o strings.Builder diff --git a/basculehttp/challenge_test.go b/basculehttp/challenge_test.go index b49e494..93bb231 100644 --- a/basculehttp/challenge_test.go +++ b/basculehttp/challenge_test.go @@ -192,7 +192,7 @@ func (suite *ChallengeTestSuite) testChallengesValid() { suite.Run("DefaultHeader", func() { header := make(http.Header) suite.NoError(testCase.challenges.WriteHeader("", header)) - suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeaderName)) + suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeader)) }) suite.Run("CustomHeader", func() { diff --git a/basculehttp/error.go b/basculehttp/error.go index 2919150..d48a735 100644 --- a/basculehttp/error.go +++ b/basculehttp/error.go @@ -4,8 +4,6 @@ package basculehttp import ( - "encoding" - "encoding/json" "errors" "net/http" @@ -39,14 +37,16 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int { var sc statusCoder switch { + // check if it's a status coder first, so that we can + // override status codes for built-in errors. + case errors.As(err, &sc): + return sc.StatusCode() + case errors.Is(err, bascule.ErrMissingCredentials): return http.StatusUnauthorized case errors.Is(err, bascule.ErrInvalidCredentials): return http.StatusBadRequest - - case errors.As(err, &sc): - return sc.StatusCode() } return 0 @@ -56,40 +56,10 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int { // be used in an HTTP response body. type ErrorMarshaler func(request *http.Request, err error) (contentType string, content []byte, marshalErr error) -// DefaultErrorMarshaler examines the error for several standard marshalers. The supported marshalers -// together with the returned content types are as follows, in order: -// -// - json.Marshaler "application/json" -// - encoding.TextMarshaler "text/plain; charset=utf-8" -// - encoding.BinaryMarshaler "application/octet-stream" -// -// If the error or any of its wrapped errors does not implement a supported marshaler interface, -// the error's Error() text is used with a content type of "text/plain; charset=utf-8". +// DefaultErrorMarshaler returns a plaintext representation of the error. func DefaultErrorMarshaler(_ *http.Request, err error) (contentType string, content []byte, marshalErr error) { - // walk the wrapped errors manually, since that's way more efficient - // that walking the error tree once for each desired type - for wrapped := err; wrapped != nil && len(content) == 0 && marshalErr == nil; wrapped = errors.Unwrap(wrapped) { - switch m := wrapped.(type) { //nolint: errorlint - case json.Marshaler: - contentType = "application/json" - content, marshalErr = m.MarshalJSON() - - case encoding.TextMarshaler: - contentType = "text/plain; charset=utf-8" - content, marshalErr = m.MarshalText() - - case encoding.BinaryMarshaler: - contentType = "application/octet-stream" - content, marshalErr = m.MarshalBinary() - } - } - - if len(content) == 0 && marshalErr == nil { - // fallback - contentType = "text/plain; charset=utf-8" - content = []byte(err.Error()) - } - + contentType = "text/plain; charset=utf-8" + content = []byte(err.Error()) return } @@ -98,6 +68,10 @@ type statusCodeError struct { statusCode int } +func (err *statusCodeError) Unwrap() error { + return err.error +} + func (err *statusCodeError) StatusCode() int { return err.statusCode } diff --git a/basculehttp/error_test.go b/basculehttp/error_test.go new file mode 100644 index 0000000..c486643 --- /dev/null +++ b/basculehttp/error_test.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +import ( + "errors" + "mime" + "net/http" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/xmidt-org/bascule/v1" +) + +type ErrorTestSuite struct { + suite.Suite +} + +func (suite *ErrorTestSuite) TestDefaultErrorStatusCoder() { + suite.Run("ErrMissingCredentials", func() { + suite.Equal( + http.StatusUnauthorized, + DefaultErrorStatusCoder(nil, bascule.ErrMissingCredentials), + ) + }) + + suite.Run("ErrInvalidCredentials", func() { + suite.Equal( + http.StatusBadRequest, + DefaultErrorStatusCoder(nil, bascule.ErrInvalidCredentials), + ) + }) + + suite.Run("StatusCoder", func() { + suite.Equal( + 317, + DefaultErrorStatusCoder( + nil, + UseStatusCode(317, errors.New("unrecognized")), + ), + ) + }) + + suite.Run("OverrideStatusCode", func() { + suite.Equal( + http.StatusNotFound, + DefaultErrorStatusCoder( + nil, + UseStatusCode(http.StatusNotFound, bascule.ErrMissingCredentials), + ), + ) + }) + + suite.Run("Unrecognized", func() { + suite.Equal( + 0, + DefaultErrorStatusCoder(nil, errors.New("unrecognized error")), + ) + }) +} + +func (suite *ErrorTestSuite) TestDefaultErrorMarshaler() { + contentType, content, marshalErr := DefaultErrorMarshaler( + nil, + bascule.ErrMissingCredentials, + ) + + suite.Require().NoError(marshalErr) + suite.Equal(bascule.ErrMissingCredentials.Error(), string(content)) + + mediaType, _, err := mime.ParseMediaType(contentType) + suite.Require().NoError(err) + suite.Equal("text/plain", mediaType) +} + +func (suite *ErrorTestSuite) TestUseStatusCode() { + var ( + err = errors.New("an error") + wrapperErr = UseStatusCode(511, err) + ) + + suite.Error(wrapperErr) + suite.ErrorIs(wrapperErr, err) + + type statusCoder interface { + StatusCode() int + } + + var sc statusCoder + suite.Require().ErrorAs(wrapperErr, &sc) + suite.Equal(511, sc.StatusCode()) +} + +func TestError(t *testing.T) { + suite.Run(t, new(ErrorTestSuite)) +} diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 4529cba..df4ed74 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -4,6 +4,7 @@ package basculehttp import ( + "errors" "net/http" "strconv" @@ -11,6 +12,12 @@ import ( "go.uber.org/multierr" ) +var ( + // ErrNoAuthenticator is returned by NewMiddleware to indicate that an Authorizer + // was configured without an Authenticator. + ErrNoAuthenticator = errors.New("An Authenticator is required if an Authorizer is configured") +) + // MiddlewareOption is a functional option for tailoring a Middleware. type MiddlewareOption interface { apply(*Middleware) error @@ -23,8 +30,6 @@ func (mof middlewareOptionFunc) apply(m *Middleware) error { } // WithAuthenticator supplies the Authenticator workflow for the middleware. -// -// Note: If no authenticator is supplied, NewMiddeware returns an error. func WithAuthenticator(authenticator *bascule.Authenticator[*http.Request]) MiddlewareOption { return UseAuthenticator(authenticator, nil) } @@ -81,12 +86,7 @@ func WithChallenges(ch ...Challenge) MiddlewareOption { // option is omitted or if esc is nil, DefaultErrorStatusCoder is used. func WithErrorStatusCoder(esc ErrorStatusCoder) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - if esc != nil { - m.errorStatusCoder = esc - } else { - m.errorStatusCoder = DefaultErrorStatusCoder - } - + m.errorStatusCoder = esc return nil }) } @@ -95,17 +95,33 @@ func WithErrorStatusCoder(esc ErrorStatusCoder) MiddlewareOption { // option is omitted or if esc is nil, DefaultErrorMarshaler is used. func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - if em != nil { - m.errorMarshaler = em - } else { - m.errorMarshaler = DefaultErrorMarshaler - } - + m.errorMarshaler = em return nil }) } -// Middleware is an immutable configuration that can decorate multiple handlers. +// Middleware is an immutable HTTP workflow that can decorate multiple handlers. +// +// A Middleware can have either or both of an Authenticator, which creates +// tokens from HTTP requests, and an Authorizer, which approves access to +// the resource identified by the request. The behavior of a Middleware +// depends mostly on these two components. +// +// If both an authenticator and an authorizer are supplied, the full bascule +// workflow, including events, is implemented. +// +// If an authenticator is supplied without an authorizer, only token creation +// is implemented. Without an authorizer, it is assumed that all tokens have +// access to all requests. +// +// If no authenticator is supplied, but an authorizer IS supplied, then +// NewMiddleware returns an error. An authenticator is required in order to +// create tokens. +// +// Finally, if neither an authenticator or an authorizer is supplied, +// then this Middleware is a noop. Any attempt to decorate handlers will +// result in those handlers being returned as is. This allows a Middleware +// to be turned off via configuration. type Middleware struct { authenticator *bascule.Authenticator[*http.Request] authorizer *bascule.Authorizer[*http.Request] @@ -117,16 +133,37 @@ type Middleware struct { // NewMiddleware creates an immutable Middleware instance from a supplied set of options. // No options will result in a Middleware with default behavior. +// +// If no authenticator is configured, but an authorizer is, this function returns +// ErrNoAuthenticator. +// +// Note that if no workflow components are configured, i.e. neither an authenticator nor +// an authorizer are supplied, then the returned Middleware is a noop. func NewMiddleware(opts ...MiddlewareOption) (m *Middleware, err error) { - m = &Middleware{ - errorStatusCoder: DefaultErrorStatusCoder, - errorMarshaler: DefaultErrorMarshaler, - } - + m = new(Middleware) for _, o := range opts { err = multierr.Append(err, o.apply(m)) } + switch { + case err != nil: + m = nil + + case m.authenticator == nil && m.authorizer != nil: + err = multierr.Append(err, ErrNoAuthenticator) + m = nil + + default: + // cleanup after the options run + if m.errorStatusCoder == nil { + m.errorStatusCoder = DefaultErrorStatusCoder + } + + if m.errorMarshaler == nil { + m.errorMarshaler = DefaultErrorMarshaler + } + } + return } @@ -137,6 +174,12 @@ func (m *Middleware) Then(protected http.Handler) http.Handler { protected = http.DefaultServeMux } + // no point in decorating if there's no workflow + // this also allows a Middleware to be turned off via configuration + if m.authenticator == nil && m.authorizer == nil { + return protected + } + return &frontDoor{ Middleware: m, protected: protected, @@ -156,7 +199,7 @@ func (m *Middleware) ThenFunc(protected http.HandlerFunc) http.Handler { // The response is always a text/plain representation of the error. func (m *Middleware) writeRawError(response http.ResponseWriter, err error) { response.WriteHeader(http.StatusInternalServerError) - response.Header().Set("Content-Type", "text/plain") + response.Header().Set("Content-Type", "text/plain; charset=utf-8") errBody := []byte(err.Error()) response.Header().Set("Content-Length", strconv.Itoa(len(errBody))) @@ -210,6 +253,9 @@ type frontDoor struct { // ServeHTTP implements the bascule workflow, using the configured middleware. func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Request) { ctx := request.Context() + + // an authenticator is is required if we are decorating + // if the authenticator was nil, a frontDoor won't get created token, err := fd.authenticator.Authenticate(ctx, request) if err != nil { // by default, failing to parse a token is a malformed request diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go index 742844c..689e861 100644 --- a/basculehttp/middleware_examples_test.go +++ b/basculehttp/middleware_examples_test.go @@ -15,7 +15,7 @@ import ( // just basic auth. func ExampleMiddleware_basicauth() { tp, _ := NewAuthorizationParser( - WithScheme(SchemeBasic, BasicTokenParser{}), + WithBasic(), ) m, _ := NewMiddleware( diff --git a/basculehttp/middleware_test.go b/basculehttp/middleware_test.go new file mode 100644 index 0000000..486b103 --- /dev/null +++ b/basculehttp/middleware_test.go @@ -0,0 +1,465 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +import ( + "context" + "errors" + "mime" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/xmidt-org/bascule/v1" +) + +type MiddlewareTestSuite struct { + suite.Suite + + expectedPrincipal string + expectedPassword string + expectedToken bascule.Token +} + +func (suite *MiddlewareTestSuite) SetupSuite() { + suite.expectedPrincipal = "testPrincipal" + suite.expectedPassword = "test_password" + suite.expectedToken = basicToken{ + userName: suite.expectedPrincipal, + password: suite.expectedPassword, + } +} + +// newRequest creates a standardized test request, devoid of any authorization. +func (suite *MiddlewareTestSuite) newRequest() *http.Request { + return httptest.NewRequest("GET", "/test", nil) +} + +// newBasicAuthRequest creates a new test request configured with valid basic auth. +func (suite *MiddlewareTestSuite) newBasicAuthRequest() *http.Request { + request := suite.newRequest() + request.SetBasicAuth(suite.expectedPrincipal, suite.expectedPassword) + return request +} + +// assertBasicAuthRequest asserts that the given request matches this suite's expectations. +func (suite *MiddlewareTestSuite) assertBasicAuthRequest(request *http.Request) { + suite.Require().NotNil(request) + suite.Equal("GET", request.Method) + suite.Equal("/test", request.URL.String()) +} + +// assertBasicAuthToken asserts that the token matches this suite's expectations. +func (suite *MiddlewareTestSuite) assertBasicAuthToken(token bascule.Token) { + suite.Require().NotNil(token) + suite.Equal(suite.expectedPrincipal, token.Principal()) + suite.Require().Implements((*BasicToken)(nil), token) + suite.Equal(suite.expectedPrincipal, token.(BasicToken).UserName()) + suite.Equal(suite.expectedPassword, token.(BasicToken).Password()) +} + +// newAuthorizationParser creates an AuthorizationParser that is expected to be valid. +// Assertions as to validity are made prior to returning. +func (suite *MiddlewareTestSuite) newAuthorizationParser(opts ...AuthorizationParserOption) *AuthorizationParser { + ap, err := NewAuthorizationParser(opts...) + suite.Require().NoError(err) + suite.Require().NotNil(ap) + return ap +} + +// newAuthenticator creates a bascule.Authenticator that is expected to be valid. +// Assertions as to validity are made prior to returning. +func (suite *MiddlewareTestSuite) newAuthenticator(opts ...bascule.AuthenticatorOption[*http.Request]) *bascule.Authenticator[*http.Request] { + a, err := NewAuthenticator(opts...) + suite.Require().NoError(err) + suite.Require().NotNil(a) + return a +} + +// newAuthorizer creates a bascule.Authorizer that is expected to be valid. +// Assertions as to validity are made prior to returning. +func (suite *MiddlewareTestSuite) newAuthorizer(opts ...bascule.AuthorizerOption[*http.Request]) *bascule.Authorizer[*http.Request] { + a, err := NewAuthorizer(opts...) + suite.Require().NoError(err) + suite.Require().NotNil(a) + return a +} + +// newMiddleware creates a Middleware that is expected to be valid. +// Assertions as to validity are made prior to returning. +func (suite *MiddlewareTestSuite) newMiddleware(opts ...MiddlewareOption) *Middleware { + m, err := NewMiddleware(opts...) + suite.Require().NoError(err) + suite.Require().NotNil(m) + return m +} + +// serveHTTPFunc is a standard, non-error handler function that sets the normalResponseCode. +func (suite *MiddlewareTestSuite) serveHTTPFunc(response http.ResponseWriter, _ *http.Request) { + response.Header().Set("Content-Type", "application/octet-stream") + response.WriteHeader(299) + response.Write([]byte("normal response")) +} + +// assertNormalResponse asserts that the Middleware allowed the response from serveHTTPFunc. +func (suite *MiddlewareTestSuite) assertNormalResponse(response *httptest.ResponseRecorder) { + suite.Equal(299, response.Code) + suite.Equal("application/octet-stream", response.HeaderMap.Get("Content-Type")) + suite.Equal("normal response", response.Body.String()) +} + +// serveHTTPNoCall is a handler function that should be blocked by the Middleware. +func (suite *MiddlewareTestSuite) serveHTTPNoCall(http.ResponseWriter, *http.Request) { + suite.Fail("The handler should not have been called") +} + +func (suite *MiddlewareTestSuite) assertChallenge(c Challenge, err error) Challenge { + suite.Require().NoError(err) + return c +} + +func (suite *MiddlewareTestSuite) TestUseAuthenticatorError() { + m, err := NewMiddleware( + UseAuthenticator( + bascule.NewAuthenticator[*http.Request](), // no token parsers + ), + ) + + suite.ErrorIs(err, bascule.ErrNoTokenParsers) + suite.Nil(m) +} + +func (suite *MiddlewareTestSuite) TestUseAuthorizerError() { + expectedErr := errors.New("expected") + m, err := NewMiddleware( + UseAuthenticator( + NewAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + UseAuthorizer(nil, expectedErr), + ) + + suite.ErrorIs(err, expectedErr) + suite.Nil(m) +} + +func (suite *MiddlewareTestSuite) TestNoAuthenticatorWithAuthorizer() { + m, err := NewMiddleware( + WithAuthorizer( + suite.newAuthorizer(), + ), + ) + + suite.Nil(m) + suite.ErrorIs(err, ErrNoAuthenticator) +} + +func (suite *MiddlewareTestSuite) TestThen() { + suite.Run("NilHandler", func() { + var ( + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + ) + + h = m.Then(nil) + + response = httptest.NewRecorder() + request = suite.newBasicAuthRequest() + ) + + h.ServeHTTP(response, request) + suite.Equal(http.StatusNotFound, response.Code) // use the unconfigured http.DefaultServeMux + }) + + suite.Run("NoDecoration", func() { + var ( + m = suite.newMiddleware() + h = m.Then(http.HandlerFunc( + suite.serveHTTPFunc, + )) + + response = httptest.NewRecorder() + request = suite.newRequest() + ) + + h.ServeHTTP(response, request) + suite.assertNormalResponse(response) + }) +} + +func (suite *MiddlewareTestSuite) TestThenFunc() { + suite.Run("NilHandlerFunc", func() { + var ( + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + ) + + h = m.ThenFunc(nil) + + response = httptest.NewRecorder() + request = suite.newBasicAuthRequest() + ) + + h.ServeHTTP(response, request) + suite.Equal(http.StatusNotFound, response.Code) // use the unconfigured http.DefaultServeMux + }) +} + +func (suite *MiddlewareTestSuite) TestCustomErrorRendering() { + var ( + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + WithErrorStatusCoder( + func(request *http.Request, err error) int { + suite.Equal(request.URL.String(), "/test") + return 567 + }, + ), + WithErrorMarshaler( + func(request *http.Request, err error) (contentType string, content []byte, marshalErr error) { + contentType = "text/xml" + content = []byte("") + return + }, + ), + ) + + response = httptest.NewRecorder() + request = suite.newRequest() + + h = m.ThenFunc(suite.serveHTTPNoCall) + ) + + h.ServeHTTP(response, request) + suite.Equal(567, response.Code) + suite.Equal("text/xml", response.HeaderMap.Get("Content-Type")) + suite.Equal("", response.Body.String()) +} + +func (suite *MiddlewareTestSuite) TestMarshalError() { + var ( + marshalErr = errors.New("expected marshal error") + + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + WithErrorStatusCoder( + func(request *http.Request, err error) int { + suite.Equal(request.URL.String(), "/test") + return 567 + }, + ), + WithErrorMarshaler( + func(request *http.Request, err error) (string, []byte, error) { + return "", nil, marshalErr + }, + ), + ) + + response = httptest.NewRecorder() + request = suite.newRequest() + + h = m.ThenFunc(suite.serveHTTPNoCall) + ) + + h.ServeHTTP(response, request) + suite.Equal(http.StatusInternalServerError, response.Code) + + mediaType, _, err := mime.ParseMediaType(response.HeaderMap.Get("Content-Type")) + suite.Require().NoError(err) + suite.Equal("text/plain", mediaType) + suite.Equal(marshalErr.Error(), response.Body.String()) +} + +func (suite *MiddlewareTestSuite) testBasicAuthSuccess() { + var ( + authenticateEvent = false + authorizeEvent = false + + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + bascule.WithAuthenticateListenerFuncs( + func(e bascule.AuthenticateEvent[*http.Request]) { + suite.Equal("/test", e.Source.URL.String()) + suite.Require().NotNil(e.Token) + suite.NoError(e.Err) + authenticateEvent = true + }, + ), + ), + ), + WithAuthorizer( + suite.newAuthorizer( + bascule.WithApproverFuncs( + func(_ context.Context, request *http.Request, token bascule.Token) error { + suite.assertBasicAuthRequest(request) + suite.assertBasicAuthToken(token) + return nil + }, + ), + bascule.WithAuthorizeListenerFuncs( + func(e bascule.AuthorizeEvent[*http.Request]) { + suite.assertBasicAuthRequest(e.Resource) + suite.assertBasicAuthToken(e.Token) + authorizeEvent = true + }, + ), + ), + ), + ) + + response = httptest.NewRecorder() + request = suite.newBasicAuthRequest() + + h = m.ThenFunc(suite.serveHTTPFunc) + ) + + h.ServeHTTP(response, request) + suite.assertNormalResponse(response) + suite.True(authenticateEvent) + suite.True(authorizeEvent) +} + +func (suite *MiddlewareTestSuite) testBasicAuthChallenge() { + var ( + authenticateEvent = false + + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + bascule.WithAuthenticateListenerFuncs( + func(e bascule.AuthenticateEvent[*http.Request]) { + suite.assertBasicAuthRequest(e.Source) + suite.ErrorIs(bascule.ErrMissingCredentials, e.Err) + suite.Nil(e.Token) + authenticateEvent = true + }, + ), + ), + ), + WithChallenges( + suite.assertChallenge(NewBasicChallenge("test", true)), + ), + ) + + response = httptest.NewRecorder() + request = suite.newRequest() + + h = m.ThenFunc(suite.serveHTTPNoCall) + ) + + h.ServeHTTP(response, request) + suite.Equal(http.StatusUnauthorized, response.Code) + + suite.Equal( + `Basic realm="test" charset="UTF-8"`, + response.HeaderMap.Get(WWWAuthenticateHeader), + ) + + suite.True(authenticateEvent) +} + +func (suite *MiddlewareTestSuite) testBasicAuthInvalid() { + var ( + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + ) + + response = httptest.NewRecorder() + request = suite.newRequest() + + h = m.ThenFunc(suite.serveHTTPNoCall) + ) + + request.Header.Set("Authorization", "Basic this is most definitely not a valid basic auth string") + h.ServeHTTP(response, request) + suite.Equal(http.StatusBadRequest, response.Code) +} + +func (suite *MiddlewareTestSuite) testBasicAuthAuthorizerError() { + var ( + expectedErr = errors.New("expected error") + + m = suite.newMiddleware( + WithAuthenticator( + suite.newAuthenticator( + bascule.WithTokenParsers( + suite.newAuthorizationParser(WithBasic()), + ), + ), + ), + WithAuthorizer( + suite.newAuthorizer( + bascule.WithApproverFuncs( + func(_ context.Context, resource *http.Request, token bascule.Token) error { + suite.assertBasicAuthRequest(resource) + suite.assertBasicAuthToken(token) + return expectedErr + }, + ), + ), + ), + ) + + response = httptest.NewRecorder() + request = suite.newBasicAuthRequest() + + h = m.ThenFunc(suite.serveHTTPNoCall) + ) + + h.ServeHTTP(response, request) + suite.Equal(http.StatusForbidden, response.Code) + suite.Equal(expectedErr.Error(), response.Body.String()) +} + +func (suite *MiddlewareTestSuite) TestBasicAuth() { + suite.Run("Success", suite.testBasicAuthSuccess) + suite.Run("Challenge", suite.testBasicAuthChallenge) + suite.Run("Invalid", suite.testBasicAuthInvalid) + suite.Run("AuthorizerError", suite.testBasicAuthAuthorizerError) +} + +func TestMiddleware(t *testing.T) { + suite.Run(t, new(MiddlewareTestSuite)) +}