diff --git a/approver.go b/approver.go
index 0f96929..598f0ff 100644
--- a/approver.go
+++ b/approver.go
@@ -39,6 +39,16 @@ func (as Approvers[R]) Append(more ...Approver[R]) Approvers[R] {
return append(as, more...)
}
+// AppendFunc is a closure variant of Append that makes working with
+// approvers that are functions a little easier.
+func (as Approvers[R]) AppendFunc(more ...ApproverFunc[R]) Approvers[R] {
+ for _, m := range more {
+ as = append(as, m)
+ }
+
+ return as
+}
+
// Approve requires all approvers in this sequence to allow access. This
// method supplies a logical AND.
//
diff --git a/approver_test.go b/approver_test.go
index 184f839..b2808ef 100644
--- a/approver_test.go
+++ b/approver_test.go
@@ -53,32 +53,63 @@ func (suite *ApproversTestSuite) TestAuthorize() {
},
}
- for _, testCase := range testCases {
- suite.Run(testCase.name, func() {
- var (
- testCtx = suite.testContext()
- testToken = suite.testToken()
- as Approvers[string]
- )
+ suite.Run("Append", func() {
+ for _, testCase := range testCases {
+ suite.Run(testCase.name, func() {
+ var (
+ testCtx = suite.testContext()
+ testToken = suite.testToken()
+ as Approvers[string]
+ )
- for _, err := range testCase.results {
- err := err
- as = as.Append(
- ApproverFunc[string](func(ctx context.Context, resource string, token Token) error {
- suite.Same(testCtx, ctx)
- suite.Equal(testToken, token)
- suite.Equal(placeholderResource, resource)
- return err
- }),
+ for _, err := range testCase.results {
+ err := err
+ as = as.Append(
+ ApproverFunc[string](func(ctx context.Context, resource string, token Token) error {
+ suite.Same(testCtx, ctx)
+ suite.Equal(testToken, token)
+ suite.Equal(placeholderResource, resource)
+ return err
+ }),
+ )
+ }
+
+ suite.Equal(
+ testCase.expectedErr,
+ as.Approve(testCtx, placeholderResource, testToken),
+ )
+ })
+ }
+ })
+
+ suite.Run("AppendFunc", func() {
+ for _, testCase := range testCases {
+ suite.Run(testCase.name, func() {
+ var (
+ testCtx = suite.testContext()
+ testToken = suite.testToken()
+ as Approvers[string]
)
- }
- suite.Equal(
- testCase.expectedErr,
- as.Approve(testCtx, placeholderResource, testToken),
- )
- })
- }
+ for _, err := range testCase.results {
+ err := err
+ as = as.AppendFunc(
+ func(ctx context.Context, resource string, token Token) error {
+ suite.Same(testCtx, ctx)
+ suite.Equal(testToken, token)
+ suite.Equal(placeholderResource, resource)
+ return err
+ },
+ )
+ }
+
+ suite.Equal(
+ testCase.expectedErr,
+ as.Approve(testCtx, placeholderResource, testToken),
+ )
+ })
+ }
+ })
}
func (suite *ApproversTestSuite) TestAny() {
diff --git a/authorizer.go b/authorizer.go
index 977c319..6cad639 100644
--- a/authorizer.go
+++ b/authorizer.go
@@ -62,6 +62,17 @@ func WithApprovers[R any](more ...Approver[R]) AuthorizerOption[R] {
)
}
+// WithApproverFuncs is a closure variant of WithApprovers that eases the
+// syntactical pain of dealing with approvers that are functions.
+func WithApproverFuncs[R any](more ...ApproverFunc[R]) AuthorizerOption[R] {
+ return authorizerOptionFunc[R](
+ func(a *Authorizer[R]) error {
+ a.approvers = a.approvers.AppendFunc(more...)
+ return nil
+ },
+ )
+}
+
// NewAuthorizer constructs an Authorizer workflow using the supplied options.
//
// If no options are supplied, the returned Authorizer will authorize all tokens
diff --git a/authorizer_test.go b/authorizer_test.go
index bfa688e..6e376ad 100644
--- a/authorizer_test.go
+++ b/authorizer_test.go
@@ -61,12 +61,14 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
approver1 = new(mockApprover[string])
approver2 = new(mockApprover[string])
+ approver3 = new(mockApprover[string])
listener1 = new(mockAuthorizeListener[string])
listener2 = new(mockAuthorizeListener[string])
a = suite.newAuthorizer(
WithApprovers(approver1, approver2),
+ WithApproverFuncs(approver3.Approve),
WithAuthorizeListeners(listener1),
WithAuthorizeListenerFuncs(listener2.OnEvent),
)
@@ -76,6 +78,8 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
Return(nil).Once()
approver2.ExpectApprove(expectedCtx, expectedResource, expectedToken).
Return(nil).Once()
+ approver3.ExpectApprove(expectedCtx, expectedResource, expectedToken).
+ Return(nil).Once()
listener1.ExpectOnEvent(AuthorizeEvent[string]{
Resource: expectedResource,
@@ -96,6 +100,7 @@ func (suite *AuthorizerTestSuite) TestFullSuccess() {
listener2.AssertExpectations(suite.T())
approver1.AssertExpectations(suite.T())
approver2.AssertExpectations(suite.T())
+ approver3.AssertExpectations(suite.T())
}
func (suite *AuthorizerTestSuite) TestFullFirstApproverFail() {
diff --git a/basculehttp/authorization.go b/basculehttp/authorization.go
index f76c789..2ab4805 100644
--- a/basculehttp/authorization.go
+++ b/basculehttp/authorization.go
@@ -56,6 +56,8 @@ type authorizationParserOptionFunc func(*AuthorizationParser) error
func (apof authorizationParserOptionFunc) apply(ap *AuthorizationParser) error { return apof(ap) }
+// WithAuthorizationHeader changes the name of the header holding the token. By default,
+// the header used is DefaultAuthorizationHeader.
func WithAuthorizationHeader(header string) AuthorizationParserOption {
return authorizationParserOptionFunc(func(ap *AuthorizationParser) error {
ap.header = header
@@ -63,6 +65,9 @@ func WithAuthorizationHeader(header string) AuthorizationParserOption {
})
}
+// WithScheme registers a string-based token parser that handles a
+// specific authorization scheme. Invocations to this option are cumulative
+// and will overwrite any existing registration.
func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) AuthorizationParserOption {
return authorizationParserOptionFunc(func(ap *AuthorizationParser) error {
// we want case-insensitive matches, so lowercase everything
@@ -71,11 +76,20 @@ func WithScheme(scheme Scheme, parser bascule.TokenParser[string]) Authorization
})
}
+// WithBasic is a shorthand for WithScheme that registers basic token parsing using
+// the default scheme.
+func WithBasic() AuthorizationParserOption {
+ return WithScheme(SchemeBasic, BasicTokenParser{})
+}
+
+// AuthorizationParsers is a bascule.TokenParser that handles the Authorization header.
type AuthorizationParser struct {
header string
parsers map[Scheme]bascule.TokenParser[string]
}
+// NewAuthorizationParser constructs an Authorization parser from a set
+// of configuration options.
func NewAuthorizationParser(opts ...AuthorizationParserOption) (*AuthorizationParser, error) {
ap := &AuthorizationParser{
parsers: make(map[Scheme]bascule.TokenParser[string]),
diff --git a/basculehttp/basic.go b/basculehttp/basic.go
index 9db199c..5dbfcd7 100644
--- a/basculehttp/basic.go
+++ b/basculehttp/basic.go
@@ -38,9 +38,6 @@ func (bt basicToken) Password() string {
// BasicTokenParser is a string-based bascule.TokenParser that produces
// BasicToken instances from strings.
-//
-// An instance of this parser may be passed to WithScheme in order to
-// configure an AuthorizationParser.
type BasicTokenParser struct{}
// Parse assumes that value is of the format required by https://datatracker.ietf.org/doc/html/rfc7617.
diff --git a/basculehttp/challenge.go b/basculehttp/challenge.go
index 73a1b24..35fc336 100644
--- a/basculehttp/challenge.go
+++ b/basculehttp/challenge.go
@@ -10,11 +10,11 @@ import (
)
const (
- // WWWAuthenticateHeaderName is the HTTP header used for StatusUnauthorized challenges
+ // WWWAuthenticateHeader is the HTTP header used for StatusUnauthorized challenges
// when encountered by the Middleware.
//
// This value is used by default when no header is supplied to Challenges.WriteHeader.
- WWWAuthenticateHeaderName = "WWW-Authenticate"
+ WWWAuthenticateHeader = "WWW-Authenticate"
)
var (
@@ -216,7 +216,7 @@ func (chs Challenges) Append(ch ...Challenge) Challenges {
// halted and that error is returned.
func (chs Challenges) WriteHeader(name string, h http.Header) error {
if len(name) == 0 {
- name = WWWAuthenticateHeaderName
+ name = WWWAuthenticateHeader
}
var o strings.Builder
diff --git a/basculehttp/challenge_test.go b/basculehttp/challenge_test.go
index b49e494..93bb231 100644
--- a/basculehttp/challenge_test.go
+++ b/basculehttp/challenge_test.go
@@ -192,7 +192,7 @@ func (suite *ChallengeTestSuite) testChallengesValid() {
suite.Run("DefaultHeader", func() {
header := make(http.Header)
suite.NoError(testCase.challenges.WriteHeader("", header))
- suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeaderName))
+ suite.ElementsMatch(testCase.expected, header.Values(WWWAuthenticateHeader))
})
suite.Run("CustomHeader", func() {
diff --git a/basculehttp/error.go b/basculehttp/error.go
index 2919150..d48a735 100644
--- a/basculehttp/error.go
+++ b/basculehttp/error.go
@@ -4,8 +4,6 @@
package basculehttp
import (
- "encoding"
- "encoding/json"
"errors"
"net/http"
@@ -39,14 +37,16 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
var sc statusCoder
switch {
+ // check if it's a status coder first, so that we can
+ // override status codes for built-in errors.
+ case errors.As(err, &sc):
+ return sc.StatusCode()
+
case errors.Is(err, bascule.ErrMissingCredentials):
return http.StatusUnauthorized
case errors.Is(err, bascule.ErrInvalidCredentials):
return http.StatusBadRequest
-
- case errors.As(err, &sc):
- return sc.StatusCode()
}
return 0
@@ -56,40 +56,10 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
// be used in an HTTP response body.
type ErrorMarshaler func(request *http.Request, err error) (contentType string, content []byte, marshalErr error)
-// DefaultErrorMarshaler examines the error for several standard marshalers. The supported marshalers
-// together with the returned content types are as follows, in order:
-//
-// - json.Marshaler "application/json"
-// - encoding.TextMarshaler "text/plain; charset=utf-8"
-// - encoding.BinaryMarshaler "application/octet-stream"
-//
-// If the error or any of its wrapped errors does not implement a supported marshaler interface,
-// the error's Error() text is used with a content type of "text/plain; charset=utf-8".
+// DefaultErrorMarshaler returns a plaintext representation of the error.
func DefaultErrorMarshaler(_ *http.Request, err error) (contentType string, content []byte, marshalErr error) {
- // walk the wrapped errors manually, since that's way more efficient
- // that walking the error tree once for each desired type
- for wrapped := err; wrapped != nil && len(content) == 0 && marshalErr == nil; wrapped = errors.Unwrap(wrapped) {
- switch m := wrapped.(type) { //nolint: errorlint
- case json.Marshaler:
- contentType = "application/json"
- content, marshalErr = m.MarshalJSON()
-
- case encoding.TextMarshaler:
- contentType = "text/plain; charset=utf-8"
- content, marshalErr = m.MarshalText()
-
- case encoding.BinaryMarshaler:
- contentType = "application/octet-stream"
- content, marshalErr = m.MarshalBinary()
- }
- }
-
- if len(content) == 0 && marshalErr == nil {
- // fallback
- contentType = "text/plain; charset=utf-8"
- content = []byte(err.Error())
- }
-
+ contentType = "text/plain; charset=utf-8"
+ content = []byte(err.Error())
return
}
@@ -98,6 +68,10 @@ type statusCodeError struct {
statusCode int
}
+func (err *statusCodeError) Unwrap() error {
+ return err.error
+}
+
func (err *statusCodeError) StatusCode() int {
return err.statusCode
}
diff --git a/basculehttp/error_test.go b/basculehttp/error_test.go
new file mode 100644
index 0000000..c486643
--- /dev/null
+++ b/basculehttp/error_test.go
@@ -0,0 +1,97 @@
+// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
+// SPDX-License-Identifier: Apache-2.0
+
+package basculehttp
+
+import (
+ "errors"
+ "mime"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/xmidt-org/bascule/v1"
+)
+
+type ErrorTestSuite struct {
+ suite.Suite
+}
+
+func (suite *ErrorTestSuite) TestDefaultErrorStatusCoder() {
+ suite.Run("ErrMissingCredentials", func() {
+ suite.Equal(
+ http.StatusUnauthorized,
+ DefaultErrorStatusCoder(nil, bascule.ErrMissingCredentials),
+ )
+ })
+
+ suite.Run("ErrInvalidCredentials", func() {
+ suite.Equal(
+ http.StatusBadRequest,
+ DefaultErrorStatusCoder(nil, bascule.ErrInvalidCredentials),
+ )
+ })
+
+ suite.Run("StatusCoder", func() {
+ suite.Equal(
+ 317,
+ DefaultErrorStatusCoder(
+ nil,
+ UseStatusCode(317, errors.New("unrecognized")),
+ ),
+ )
+ })
+
+ suite.Run("OverrideStatusCode", func() {
+ suite.Equal(
+ http.StatusNotFound,
+ DefaultErrorStatusCoder(
+ nil,
+ UseStatusCode(http.StatusNotFound, bascule.ErrMissingCredentials),
+ ),
+ )
+ })
+
+ suite.Run("Unrecognized", func() {
+ suite.Equal(
+ 0,
+ DefaultErrorStatusCoder(nil, errors.New("unrecognized error")),
+ )
+ })
+}
+
+func (suite *ErrorTestSuite) TestDefaultErrorMarshaler() {
+ contentType, content, marshalErr := DefaultErrorMarshaler(
+ nil,
+ bascule.ErrMissingCredentials,
+ )
+
+ suite.Require().NoError(marshalErr)
+ suite.Equal(bascule.ErrMissingCredentials.Error(), string(content))
+
+ mediaType, _, err := mime.ParseMediaType(contentType)
+ suite.Require().NoError(err)
+ suite.Equal("text/plain", mediaType)
+}
+
+func (suite *ErrorTestSuite) TestUseStatusCode() {
+ var (
+ err = errors.New("an error")
+ wrapperErr = UseStatusCode(511, err)
+ )
+
+ suite.Error(wrapperErr)
+ suite.ErrorIs(wrapperErr, err)
+
+ type statusCoder interface {
+ StatusCode() int
+ }
+
+ var sc statusCoder
+ suite.Require().ErrorAs(wrapperErr, &sc)
+ suite.Equal(511, sc.StatusCode())
+}
+
+func TestError(t *testing.T) {
+ suite.Run(t, new(ErrorTestSuite))
+}
diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go
index 4529cba..df4ed74 100644
--- a/basculehttp/middleware.go
+++ b/basculehttp/middleware.go
@@ -4,6 +4,7 @@
package basculehttp
import (
+ "errors"
"net/http"
"strconv"
@@ -11,6 +12,12 @@ import (
"go.uber.org/multierr"
)
+var (
+ // ErrNoAuthenticator is returned by NewMiddleware to indicate that an Authorizer
+ // was configured without an Authenticator.
+ ErrNoAuthenticator = errors.New("An Authenticator is required if an Authorizer is configured")
+)
+
// MiddlewareOption is a functional option for tailoring a Middleware.
type MiddlewareOption interface {
apply(*Middleware) error
@@ -23,8 +30,6 @@ func (mof middlewareOptionFunc) apply(m *Middleware) error {
}
// WithAuthenticator supplies the Authenticator workflow for the middleware.
-//
-// Note: If no authenticator is supplied, NewMiddeware returns an error.
func WithAuthenticator(authenticator *bascule.Authenticator[*http.Request]) MiddlewareOption {
return UseAuthenticator(authenticator, nil)
}
@@ -81,12 +86,7 @@ func WithChallenges(ch ...Challenge) MiddlewareOption {
// option is omitted or if esc is nil, DefaultErrorStatusCoder is used.
func WithErrorStatusCoder(esc ErrorStatusCoder) MiddlewareOption {
return middlewareOptionFunc(func(m *Middleware) error {
- if esc != nil {
- m.errorStatusCoder = esc
- } else {
- m.errorStatusCoder = DefaultErrorStatusCoder
- }
-
+ m.errorStatusCoder = esc
return nil
})
}
@@ -95,17 +95,33 @@ func WithErrorStatusCoder(esc ErrorStatusCoder) MiddlewareOption {
// option is omitted or if esc is nil, DefaultErrorMarshaler is used.
func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption {
return middlewareOptionFunc(func(m *Middleware) error {
- if em != nil {
- m.errorMarshaler = em
- } else {
- m.errorMarshaler = DefaultErrorMarshaler
- }
-
+ m.errorMarshaler = em
return nil
})
}
-// Middleware is an immutable configuration that can decorate multiple handlers.
+// Middleware is an immutable HTTP workflow that can decorate multiple handlers.
+//
+// A Middleware can have either or both of an Authenticator, which creates
+// tokens from HTTP requests, and an Authorizer, which approves access to
+// the resource identified by the request. The behavior of a Middleware
+// depends mostly on these two components.
+//
+// If both an authenticator and an authorizer are supplied, the full bascule
+// workflow, including events, is implemented.
+//
+// If an authenticator is supplied without an authorizer, only token creation
+// is implemented. Without an authorizer, it is assumed that all tokens have
+// access to all requests.
+//
+// If no authenticator is supplied, but an authorizer IS supplied, then
+// NewMiddleware returns an error. An authenticator is required in order to
+// create tokens.
+//
+// Finally, if neither an authenticator or an authorizer is supplied,
+// then this Middleware is a noop. Any attempt to decorate handlers will
+// result in those handlers being returned as is. This allows a Middleware
+// to be turned off via configuration.
type Middleware struct {
authenticator *bascule.Authenticator[*http.Request]
authorizer *bascule.Authorizer[*http.Request]
@@ -117,16 +133,37 @@ type Middleware struct {
// NewMiddleware creates an immutable Middleware instance from a supplied set of options.
// No options will result in a Middleware with default behavior.
+//
+// If no authenticator is configured, but an authorizer is, this function returns
+// ErrNoAuthenticator.
+//
+// Note that if no workflow components are configured, i.e. neither an authenticator nor
+// an authorizer are supplied, then the returned Middleware is a noop.
func NewMiddleware(opts ...MiddlewareOption) (m *Middleware, err error) {
- m = &Middleware{
- errorStatusCoder: DefaultErrorStatusCoder,
- errorMarshaler: DefaultErrorMarshaler,
- }
-
+ m = new(Middleware)
for _, o := range opts {
err = multierr.Append(err, o.apply(m))
}
+ switch {
+ case err != nil:
+ m = nil
+
+ case m.authenticator == nil && m.authorizer != nil:
+ err = multierr.Append(err, ErrNoAuthenticator)
+ m = nil
+
+ default:
+ // cleanup after the options run
+ if m.errorStatusCoder == nil {
+ m.errorStatusCoder = DefaultErrorStatusCoder
+ }
+
+ if m.errorMarshaler == nil {
+ m.errorMarshaler = DefaultErrorMarshaler
+ }
+ }
+
return
}
@@ -137,6 +174,12 @@ func (m *Middleware) Then(protected http.Handler) http.Handler {
protected = http.DefaultServeMux
}
+ // no point in decorating if there's no workflow
+ // this also allows a Middleware to be turned off via configuration
+ if m.authenticator == nil && m.authorizer == nil {
+ return protected
+ }
+
return &frontDoor{
Middleware: m,
protected: protected,
@@ -156,7 +199,7 @@ func (m *Middleware) ThenFunc(protected http.HandlerFunc) http.Handler {
// The response is always a text/plain representation of the error.
func (m *Middleware) writeRawError(response http.ResponseWriter, err error) {
response.WriteHeader(http.StatusInternalServerError)
- response.Header().Set("Content-Type", "text/plain")
+ response.Header().Set("Content-Type", "text/plain; charset=utf-8")
errBody := []byte(err.Error())
response.Header().Set("Content-Length", strconv.Itoa(len(errBody)))
@@ -210,6 +253,9 @@ type frontDoor struct {
// ServeHTTP implements the bascule workflow, using the configured middleware.
func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Request) {
ctx := request.Context()
+
+ // an authenticator is is required if we are decorating
+ // if the authenticator was nil, a frontDoor won't get created
token, err := fd.authenticator.Authenticate(ctx, request)
if err != nil {
// by default, failing to parse a token is a malformed request
diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go
index 742844c..689e861 100644
--- a/basculehttp/middleware_examples_test.go
+++ b/basculehttp/middleware_examples_test.go
@@ -15,7 +15,7 @@ import (
// just basic auth.
func ExampleMiddleware_basicauth() {
tp, _ := NewAuthorizationParser(
- WithScheme(SchemeBasic, BasicTokenParser{}),
+ WithBasic(),
)
m, _ := NewMiddleware(
diff --git a/basculehttp/middleware_test.go b/basculehttp/middleware_test.go
new file mode 100644
index 0000000..486b103
--- /dev/null
+++ b/basculehttp/middleware_test.go
@@ -0,0 +1,465 @@
+// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
+// SPDX-License-Identifier: Apache-2.0
+
+package basculehttp
+
+import (
+ "context"
+ "errors"
+ "mime"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/xmidt-org/bascule/v1"
+)
+
+type MiddlewareTestSuite struct {
+ suite.Suite
+
+ expectedPrincipal string
+ expectedPassword string
+ expectedToken bascule.Token
+}
+
+func (suite *MiddlewareTestSuite) SetupSuite() {
+ suite.expectedPrincipal = "testPrincipal"
+ suite.expectedPassword = "test_password"
+ suite.expectedToken = basicToken{
+ userName: suite.expectedPrincipal,
+ password: suite.expectedPassword,
+ }
+}
+
+// newRequest creates a standardized test request, devoid of any authorization.
+func (suite *MiddlewareTestSuite) newRequest() *http.Request {
+ return httptest.NewRequest("GET", "/test", nil)
+}
+
+// newBasicAuthRequest creates a new test request configured with valid basic auth.
+func (suite *MiddlewareTestSuite) newBasicAuthRequest() *http.Request {
+ request := suite.newRequest()
+ request.SetBasicAuth(suite.expectedPrincipal, suite.expectedPassword)
+ return request
+}
+
+// assertBasicAuthRequest asserts that the given request matches this suite's expectations.
+func (suite *MiddlewareTestSuite) assertBasicAuthRequest(request *http.Request) {
+ suite.Require().NotNil(request)
+ suite.Equal("GET", request.Method)
+ suite.Equal("/test", request.URL.String())
+}
+
+// assertBasicAuthToken asserts that the token matches this suite's expectations.
+func (suite *MiddlewareTestSuite) assertBasicAuthToken(token bascule.Token) {
+ suite.Require().NotNil(token)
+ suite.Equal(suite.expectedPrincipal, token.Principal())
+ suite.Require().Implements((*BasicToken)(nil), token)
+ suite.Equal(suite.expectedPrincipal, token.(BasicToken).UserName())
+ suite.Equal(suite.expectedPassword, token.(BasicToken).Password())
+}
+
+// newAuthorizationParser creates an AuthorizationParser that is expected to be valid.
+// Assertions as to validity are made prior to returning.
+func (suite *MiddlewareTestSuite) newAuthorizationParser(opts ...AuthorizationParserOption) *AuthorizationParser {
+ ap, err := NewAuthorizationParser(opts...)
+ suite.Require().NoError(err)
+ suite.Require().NotNil(ap)
+ return ap
+}
+
+// newAuthenticator creates a bascule.Authenticator that is expected to be valid.
+// Assertions as to validity are made prior to returning.
+func (suite *MiddlewareTestSuite) newAuthenticator(opts ...bascule.AuthenticatorOption[*http.Request]) *bascule.Authenticator[*http.Request] {
+ a, err := NewAuthenticator(opts...)
+ suite.Require().NoError(err)
+ suite.Require().NotNil(a)
+ return a
+}
+
+// newAuthorizer creates a bascule.Authorizer that is expected to be valid.
+// Assertions as to validity are made prior to returning.
+func (suite *MiddlewareTestSuite) newAuthorizer(opts ...bascule.AuthorizerOption[*http.Request]) *bascule.Authorizer[*http.Request] {
+ a, err := NewAuthorizer(opts...)
+ suite.Require().NoError(err)
+ suite.Require().NotNil(a)
+ return a
+}
+
+// newMiddleware creates a Middleware that is expected to be valid.
+// Assertions as to validity are made prior to returning.
+func (suite *MiddlewareTestSuite) newMiddleware(opts ...MiddlewareOption) *Middleware {
+ m, err := NewMiddleware(opts...)
+ suite.Require().NoError(err)
+ suite.Require().NotNil(m)
+ return m
+}
+
+// serveHTTPFunc is a standard, non-error handler function that sets the normalResponseCode.
+func (suite *MiddlewareTestSuite) serveHTTPFunc(response http.ResponseWriter, _ *http.Request) {
+ response.Header().Set("Content-Type", "application/octet-stream")
+ response.WriteHeader(299)
+ response.Write([]byte("normal response"))
+}
+
+// assertNormalResponse asserts that the Middleware allowed the response from serveHTTPFunc.
+func (suite *MiddlewareTestSuite) assertNormalResponse(response *httptest.ResponseRecorder) {
+ suite.Equal(299, response.Code)
+ suite.Equal("application/octet-stream", response.HeaderMap.Get("Content-Type"))
+ suite.Equal("normal response", response.Body.String())
+}
+
+// serveHTTPNoCall is a handler function that should be blocked by the Middleware.
+func (suite *MiddlewareTestSuite) serveHTTPNoCall(http.ResponseWriter, *http.Request) {
+ suite.Fail("The handler should not have been called")
+}
+
+func (suite *MiddlewareTestSuite) assertChallenge(c Challenge, err error) Challenge {
+ suite.Require().NoError(err)
+ return c
+}
+
+func (suite *MiddlewareTestSuite) TestUseAuthenticatorError() {
+ m, err := NewMiddleware(
+ UseAuthenticator(
+ bascule.NewAuthenticator[*http.Request](), // no token parsers
+ ),
+ )
+
+ suite.ErrorIs(err, bascule.ErrNoTokenParsers)
+ suite.Nil(m)
+}
+
+func (suite *MiddlewareTestSuite) TestUseAuthorizerError() {
+ expectedErr := errors.New("expected")
+ m, err := NewMiddleware(
+ UseAuthenticator(
+ NewAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ UseAuthorizer(nil, expectedErr),
+ )
+
+ suite.ErrorIs(err, expectedErr)
+ suite.Nil(m)
+}
+
+func (suite *MiddlewareTestSuite) TestNoAuthenticatorWithAuthorizer() {
+ m, err := NewMiddleware(
+ WithAuthorizer(
+ suite.newAuthorizer(),
+ ),
+ )
+
+ suite.Nil(m)
+ suite.ErrorIs(err, ErrNoAuthenticator)
+}
+
+func (suite *MiddlewareTestSuite) TestThen() {
+ suite.Run("NilHandler", func() {
+ var (
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ )
+
+ h = m.Then(nil)
+
+ response = httptest.NewRecorder()
+ request = suite.newBasicAuthRequest()
+ )
+
+ h.ServeHTTP(response, request)
+ suite.Equal(http.StatusNotFound, response.Code) // use the unconfigured http.DefaultServeMux
+ })
+
+ suite.Run("NoDecoration", func() {
+ var (
+ m = suite.newMiddleware()
+ h = m.Then(http.HandlerFunc(
+ suite.serveHTTPFunc,
+ ))
+
+ response = httptest.NewRecorder()
+ request = suite.newRequest()
+ )
+
+ h.ServeHTTP(response, request)
+ suite.assertNormalResponse(response)
+ })
+}
+
+func (suite *MiddlewareTestSuite) TestThenFunc() {
+ suite.Run("NilHandlerFunc", func() {
+ var (
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ )
+
+ h = m.ThenFunc(nil)
+
+ response = httptest.NewRecorder()
+ request = suite.newBasicAuthRequest()
+ )
+
+ h.ServeHTTP(response, request)
+ suite.Equal(http.StatusNotFound, response.Code) // use the unconfigured http.DefaultServeMux
+ })
+}
+
+func (suite *MiddlewareTestSuite) TestCustomErrorRendering() {
+ var (
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ WithErrorStatusCoder(
+ func(request *http.Request, err error) int {
+ suite.Equal(request.URL.String(), "/test")
+ return 567
+ },
+ ),
+ WithErrorMarshaler(
+ func(request *http.Request, err error) (contentType string, content []byte, marshalErr error) {
+ contentType = "text/xml"
+ content = []byte("")
+ return
+ },
+ ),
+ )
+
+ response = httptest.NewRecorder()
+ request = suite.newRequest()
+
+ h = m.ThenFunc(suite.serveHTTPNoCall)
+ )
+
+ h.ServeHTTP(response, request)
+ suite.Equal(567, response.Code)
+ suite.Equal("text/xml", response.HeaderMap.Get("Content-Type"))
+ suite.Equal("", response.Body.String())
+}
+
+func (suite *MiddlewareTestSuite) TestMarshalError() {
+ var (
+ marshalErr = errors.New("expected marshal error")
+
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ WithErrorStatusCoder(
+ func(request *http.Request, err error) int {
+ suite.Equal(request.URL.String(), "/test")
+ return 567
+ },
+ ),
+ WithErrorMarshaler(
+ func(request *http.Request, err error) (string, []byte, error) {
+ return "", nil, marshalErr
+ },
+ ),
+ )
+
+ response = httptest.NewRecorder()
+ request = suite.newRequest()
+
+ h = m.ThenFunc(suite.serveHTTPNoCall)
+ )
+
+ h.ServeHTTP(response, request)
+ suite.Equal(http.StatusInternalServerError, response.Code)
+
+ mediaType, _, err := mime.ParseMediaType(response.HeaderMap.Get("Content-Type"))
+ suite.Require().NoError(err)
+ suite.Equal("text/plain", mediaType)
+ suite.Equal(marshalErr.Error(), response.Body.String())
+}
+
+func (suite *MiddlewareTestSuite) testBasicAuthSuccess() {
+ var (
+ authenticateEvent = false
+ authorizeEvent = false
+
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ bascule.WithAuthenticateListenerFuncs(
+ func(e bascule.AuthenticateEvent[*http.Request]) {
+ suite.Equal("/test", e.Source.URL.String())
+ suite.Require().NotNil(e.Token)
+ suite.NoError(e.Err)
+ authenticateEvent = true
+ },
+ ),
+ ),
+ ),
+ WithAuthorizer(
+ suite.newAuthorizer(
+ bascule.WithApproverFuncs(
+ func(_ context.Context, request *http.Request, token bascule.Token) error {
+ suite.assertBasicAuthRequest(request)
+ suite.assertBasicAuthToken(token)
+ return nil
+ },
+ ),
+ bascule.WithAuthorizeListenerFuncs(
+ func(e bascule.AuthorizeEvent[*http.Request]) {
+ suite.assertBasicAuthRequest(e.Resource)
+ suite.assertBasicAuthToken(e.Token)
+ authorizeEvent = true
+ },
+ ),
+ ),
+ ),
+ )
+
+ response = httptest.NewRecorder()
+ request = suite.newBasicAuthRequest()
+
+ h = m.ThenFunc(suite.serveHTTPFunc)
+ )
+
+ h.ServeHTTP(response, request)
+ suite.assertNormalResponse(response)
+ suite.True(authenticateEvent)
+ suite.True(authorizeEvent)
+}
+
+func (suite *MiddlewareTestSuite) testBasicAuthChallenge() {
+ var (
+ authenticateEvent = false
+
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ bascule.WithAuthenticateListenerFuncs(
+ func(e bascule.AuthenticateEvent[*http.Request]) {
+ suite.assertBasicAuthRequest(e.Source)
+ suite.ErrorIs(bascule.ErrMissingCredentials, e.Err)
+ suite.Nil(e.Token)
+ authenticateEvent = true
+ },
+ ),
+ ),
+ ),
+ WithChallenges(
+ suite.assertChallenge(NewBasicChallenge("test", true)),
+ ),
+ )
+
+ response = httptest.NewRecorder()
+ request = suite.newRequest()
+
+ h = m.ThenFunc(suite.serveHTTPNoCall)
+ )
+
+ h.ServeHTTP(response, request)
+ suite.Equal(http.StatusUnauthorized, response.Code)
+
+ suite.Equal(
+ `Basic realm="test" charset="UTF-8"`,
+ response.HeaderMap.Get(WWWAuthenticateHeader),
+ )
+
+ suite.True(authenticateEvent)
+}
+
+func (suite *MiddlewareTestSuite) testBasicAuthInvalid() {
+ var (
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ )
+
+ response = httptest.NewRecorder()
+ request = suite.newRequest()
+
+ h = m.ThenFunc(suite.serveHTTPNoCall)
+ )
+
+ request.Header.Set("Authorization", "Basic this is most definitely not a valid basic auth string")
+ h.ServeHTTP(response, request)
+ suite.Equal(http.StatusBadRequest, response.Code)
+}
+
+func (suite *MiddlewareTestSuite) testBasicAuthAuthorizerError() {
+ var (
+ expectedErr = errors.New("expected error")
+
+ m = suite.newMiddleware(
+ WithAuthenticator(
+ suite.newAuthenticator(
+ bascule.WithTokenParsers(
+ suite.newAuthorizationParser(WithBasic()),
+ ),
+ ),
+ ),
+ WithAuthorizer(
+ suite.newAuthorizer(
+ bascule.WithApproverFuncs(
+ func(_ context.Context, resource *http.Request, token bascule.Token) error {
+ suite.assertBasicAuthRequest(resource)
+ suite.assertBasicAuthToken(token)
+ return expectedErr
+ },
+ ),
+ ),
+ ),
+ )
+
+ response = httptest.NewRecorder()
+ request = suite.newBasicAuthRequest()
+
+ h = m.ThenFunc(suite.serveHTTPNoCall)
+ )
+
+ h.ServeHTTP(response, request)
+ suite.Equal(http.StatusForbidden, response.Code)
+ suite.Equal(expectedErr.Error(), response.Body.String())
+}
+
+func (suite *MiddlewareTestSuite) TestBasicAuth() {
+ suite.Run("Success", suite.testBasicAuthSuccess)
+ suite.Run("Challenge", suite.testBasicAuthChallenge)
+ suite.Run("Invalid", suite.testBasicAuthInvalid)
+ suite.Run("AuthorizerError", suite.testBasicAuthAuthorizerError)
+}
+
+func TestMiddleware(t *testing.T) {
+ suite.Run(t, new(MiddlewareTestSuite))
+}