From 54f6f82e44d31ff37deea781cede21dadb5fc768 Mon Sep 17 00:00:00 2001 From: johnabass Date: Tue, 9 Jul 2024 14:46:00 -0700 Subject: [PATCH 1/4] genericize the source object where tokens are taken from --- basculehttp/accessor.go | 42 -------------------------- basculehttp/credentials.go | 60 ++++++++++++++++++++++++++++++++++++-- basculehttp/middleware.go | 27 +++++------------ credentials.go | 17 ++++++----- token.go | 29 +++++++++--------- 5 files changed, 87 insertions(+), 88 deletions(-) diff --git a/basculehttp/accessor.go b/basculehttp/accessor.go index c036e7c..3b1fa10 100644 --- a/basculehttp/accessor.go +++ b/basculehttp/accessor.go @@ -5,50 +5,8 @@ package basculehttp import ( "net/http" - "strings" ) -const ( - // DefaultAuthorizationHeader is the name of the header used by default to obtain - // the raw credentials. - DefaultAuthorizationHeader = "Authorization" -) - -// DuplicateHeaderError indicates that an HTTP header had more than one value -// when only one value was expected. -type DuplicateHeaderError struct { - // Header is the name of the duplicate header. - Header string -} - -func (err *DuplicateHeaderError) Error() string { - var o strings.Builder - o.WriteString(`Duplicate header: "`) - o.WriteString(err.Header) - o.WriteString(`"`) - return o.String() -} - -// MissingHeaderError indicates that an expected HTTP header is missing. -type MissingHeaderError struct { - // Header is the name of the missing header. - Header string -} - -func (err *MissingHeaderError) Error() string { - var o strings.Builder - o.WriteString(`Missing header: "`) - o.WriteString(err.Header) - o.WriteString(`"`) - return o.String() -} - -// StatusCode returns http.StatusUnauthorized, as the request carries -// no authorization in it. -func (err *MissingHeaderError) StatusCode() int { - return http.StatusUnauthorized -} - // Accessor is the strategy for obtaining credentials from an HTTP request. type Accessor interface { // GetCredentials returns the raw credentials from a request. diff --git a/basculehttp/credentials.go b/basculehttp/credentials.go index e654eee..c9661cd 100644 --- a/basculehttp/credentials.go +++ b/basculehttp/credentials.go @@ -4,11 +4,54 @@ package basculehttp import ( + "context" + "net/http" "strings" "github.com/xmidt-org/bascule/v1" ) +const ( + // DefaultAuthorizationHeader is the name of the header used by default to obtain + // the raw credentials. + DefaultAuthorizationHeader = "Authorization" +) + +// DuplicateHeaderError indicates that an HTTP header had more than one value +// when only one value was expected. +type DuplicateHeaderError struct { + // Header is the name of the duplicate header. + Header string +} + +func (err *DuplicateHeaderError) Error() string { + var o strings.Builder + o.WriteString(`Duplicate header: "`) + o.WriteString(err.Header) + o.WriteString(`"`) + return o.String() +} + +// MissingHeaderError indicates that an expected HTTP header is missing. +type MissingHeaderError struct { + // Header is the name of the missing header. + Header string +} + +func (err *MissingHeaderError) Error() string { + var o strings.Builder + o.WriteString(`Missing header: "`) + o.WriteString(err.Header) + o.WriteString(`"`) + return o.String() +} + +// StatusCode returns http.StatusUnauthorized, as the request carries +// no authorization in it. +func (err *MissingHeaderError) StatusCode() int { + return http.StatusUnauthorized +} + // fastIsSpace tests an ASCII byte to see if it's whitespace. // HTTP headers are restricted to US-ASCII, so we don't need // the full unicode stack. @@ -16,8 +59,19 @@ func fastIsSpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' || b == '\v' || b == '\f' } -var defaultCredentialsParser bascule.CredentialsParser = bascule.CredentialsParserFunc( - func(raw string) (c bascule.Credentials, err error) { +// DefaultCredentialsParser is the default algorithm used to produce HTTP credentials +// from a request. +type DefaultCredentialsParser struct { + // HeaderName is the name of the authorization header. If unset, + // DefaultAuthorizationHeader is used. + HeaderName string +} + +func (dcp DefaultCredentialsParser) Parse(_ context.Context, source *http.Request) (bascule.Credentials, error) { +} + +var defaultCredentialsParser CredentialsParser = bascule.CredentialsParserFunc[*http.Request]( + func(ctx context.Context, source *http.Request) (c bascule.Credentials, err error) { // format is // the code is strict: it requires no leading or trailing space // and exactly one (1) space as a separator. @@ -40,6 +94,6 @@ var defaultCredentialsParser bascule.CredentialsParser = bascule.CredentialsPars // DefaultCredentialsParser returns the default strategy for parsing credentials. This // builtin strategy is very strict on whitespace. The format must correspond exactly // to the format specified in https://www.rfc-editor.org/rfc/rfc7235. -func DefaultCredentialsParser() bascule.CredentialsParser { +func DefaultCredentialsParser() bascule.CredentialsParser[*http.Request] { return defaultCredentialsParser } diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index de53faf..e5121c4 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -12,6 +12,10 @@ import ( "go.uber.org/multierr" ) +type CredentialsParser bascule.CredentialsParser[*http.Request] +type TokenParser bascule.TokenParser[*http.Request] +type TokenParsers bascule.TokenParsers[*http.Request] + // MiddlewareOption is a functional option for tailoring a Middleware. type MiddlewareOption interface { apply(*Middleware) error @@ -23,23 +27,9 @@ func (mof middlewareOptionFunc) apply(m *Middleware) error { return mof(m) } -// WithAccessor configures a credentials Accessor for this Middleware. If not supplied -// or if the supplied Accessor is nil, DefaultAccessor() is used. -func WithAccessor(a Accessor) MiddlewareOption { - return middlewareOptionFunc(func(m *Middleware) error { - if a != nil { - m.accessor = a - } else { - m.accessor = DefaultAccessor() - } - - return nil - }) -} - // WithCredentialsParser configures a credentials parser for this Middleware. If not supplied // or if the supplied CredentialsParser is nil, DefaultCredentialsParser() is used. -func WithCredentialsParser(cp bascule.CredentialsParser) MiddlewareOption { +func WithCredentialsParser(cp CredentialsParser) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { if cp != nil { m.credentialsParser = cp @@ -55,7 +45,7 @@ func WithCredentialsParser(cp bascule.CredentialsParser) MiddlewareOption { // already been registered, the given parser will replace that registration. // // The parser cannot be nil. -func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser) MiddlewareOption { +func WithTokenParser(scheme bascule.Scheme, tp TokenParser) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { m.tokenParsers.Register(scheme, tp) return nil @@ -126,9 +116,8 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { // Middleware is an immutable configuration that can decorate multiple handlers. type Middleware struct { - accessor Accessor - credentialsParser bascule.CredentialsParser - tokenParsers bascule.TokenParsers + credentialsParser CredentialsParser + tokenParsers TokenParsers authentication bascule.Validators authorization bascule.Authorizers[*http.Request] challenges Challenges diff --git a/credentials.go b/credentials.go index c7f689f..eeda081 100644 --- a/credentials.go +++ b/credentials.go @@ -3,6 +3,8 @@ package bascule +import "context" + // Scheme represents how a security token should be parsed. For HTTP, examples // of a scheme are "Bearer" and "Basic". type Scheme string @@ -16,16 +18,15 @@ type Credentials struct { Value string } -// CredentialsParser produces Credentials from their serialized form. -type CredentialsParser interface { - // Parse parses the raw, marshaled version of credentials and - // returns the Credentials object. - Parse(raw string) (Credentials, error) +// CredentialsParser produces Credentials from a data source. +type CredentialsParser[S any] interface { + // Parse extracts Credentials from a Source data object. + Parse(ctx context.Context, source S) (Credentials, error) } // CredentialsParserFunc is a function type that implements CredentialsParser. -type CredentialsParserFunc func(string) (Credentials, error) +type CredentialsParserFunc[S any] func(context.Context, S) (Credentials, error) -func (cpf CredentialsParserFunc) Parse(raw string) (Credentials, error) { - return cpf(raw) +func (cpf CredentialsParserFunc[S]) Parse(ctx context.Context, source S) (Credentials, error) { + return cpf(ctx, source) } diff --git a/token.go b/token.go index 9a2c2c6..884eaf0 100644 --- a/token.go +++ b/token.go @@ -15,31 +15,28 @@ type Token interface { Principal() string } -// TokenParser produces tokens from credentials. -type TokenParser interface { - // Parse turns a Credentials into a token. This method may validate parts - // of the credential's value, but should not perform any authentication itself. - // - // Some token parsers may interact with external systems, such as databases. The supplied - // context should be passed to any calls that might need to honor cancelation semantics. - Parse(context.Context, Credentials) (Token, error) +// TokenParser produces tokens from credentials. The original source S of the credentials +// are made available to the parser. +type TokenParser[S any] interface { + // Parse extracts a Token from a set of credentials. + Parse(ctx context.Context, source S, c Credentials) (Token, error) } // TokenParserFunc is a closure type that implements TokenParser. -type TokenParserFunc func(context.Context, Credentials) (Token, error) +type TokenParserFunc[S any] func(context.Context, S, Credentials) (Token, error) -func (tpf TokenParserFunc) Parse(ctx context.Context, c Credentials) (Token, error) { - return tpf(ctx, c) +func (tpf TokenParserFunc[S]) Parse(ctx context.Context, source S, c Credentials) (Token, error) { + return tpf(ctx, source, c) } // TokenParsers is a registry of parsers based on credential schemes. // The zero value of this type is valid and ready to use. -type TokenParsers map[Scheme]TokenParser +type TokenParsers[S any] map[Scheme]TokenParser[S] // Register adds or replaces the parser associated with the given scheme. -func (tp *TokenParsers) Register(scheme Scheme, p TokenParser) { +func (tp *TokenParsers[S]) Register(scheme Scheme, p TokenParser[S]) { if *tp == nil { - *tp = make(TokenParsers) + *tp = make(TokenParsers[S]) } (*tp)[scheme] = p @@ -47,9 +44,9 @@ func (tp *TokenParsers) Register(scheme Scheme, p TokenParser) { // Parse chooses a TokenParser based on the Scheme and invokes that // parser. If the credential scheme is unsupported, an error is returned. -func (tp TokenParsers) Parse(ctx context.Context, c Credentials) (t Token, err error) { +func (tp TokenParsers[S]) Parse(ctx context.Context, source S, c Credentials) (t Token, err error) { if p, ok := tp[c.Scheme]; ok { - t, err = p.Parse(ctx, c) + t, err = p.Parse(ctx, source, c) } else { err = &UnsupportedSchemeError{ Scheme: c.Scheme, From 35462933a339d8c98bf42f5be54c8a61beff5d6f Mon Sep 17 00:00:00 2001 From: johnabass Date: Tue, 9 Jul 2024 15:05:58 -0700 Subject: [PATCH 2/4] allow a generic source type; simplify the credentials parsing --- basculehttp/basic.go | 3 +- basculehttp/credentials.go | 70 ++++++++++++++++++++------------- basculehttp/credentials_test.go | 17 ++++++-- basculehttp/middleware.go | 26 ++++-------- basculehttp/token.go | 10 +++-- basculejwt/token.go | 9 +++-- credentials_test.go | 5 ++- token_test.go | 20 +++++----- 8 files changed, 90 insertions(+), 70 deletions(-) diff --git a/basculehttp/basic.go b/basculehttp/basic.go index 9dfbe15..43e69b2 100644 --- a/basculehttp/basic.go +++ b/basculehttp/basic.go @@ -6,6 +6,7 @@ package basculehttp import ( "context" "encoding/base64" + "net/http" "strings" "github.com/xmidt-org/bascule/v1" @@ -35,7 +36,7 @@ func (err *InvalidBasicAuthError) Error() string { type basicTokenParser struct{} -func (btp basicTokenParser) Parse(_ context.Context, c bascule.Credentials) (t bascule.Token, err error) { +func (btp basicTokenParser) Parse(_ context.Context, _ *http.Request, c bascule.Credentials) (t bascule.Token, err error) { var decoded []byte decoded, err = base64.StdEncoding.DecodeString(c.Value) if err != nil { diff --git a/basculehttp/credentials.go b/basculehttp/credentials.go index c9661cd..88a8f6d 100644 --- a/basculehttp/credentials.go +++ b/basculehttp/credentials.go @@ -60,40 +60,54 @@ func fastIsSpace(b byte) bool { } // DefaultCredentialsParser is the default algorithm used to produce HTTP credentials -// from a request. +// from a source request. type DefaultCredentialsParser struct { - // HeaderName is the name of the authorization header. If unset, + // Header is the name of the authorization header. If unset, // DefaultAuthorizationHeader is used. - HeaderName string -} + Header string -func (dcp DefaultCredentialsParser) Parse(_ context.Context, source *http.Request) (bascule.Credentials, error) { + // ErrorOnDuplicate controls whether an error is returned if more + // than one Header is found in the request. By default, this is false. + ErrorOnDuplicate bool } -var defaultCredentialsParser CredentialsParser = bascule.CredentialsParserFunc[*http.Request]( - func(ctx context.Context, source *http.Request) (c bascule.Credentials, err error) { - // format is - // the code is strict: it requires no leading or trailing space - // and exactly one (1) space as a separator. - scheme, value, found := strings.Cut(raw, " ") - if found && len(scheme) > 0 && !fastIsSpace(value[0]) && !fastIsSpace(value[len(value)-1]) { - c = bascule.Credentials{ - Scheme: bascule.Scheme(scheme), - Value: value, - } - } else { - err = &bascule.BadCredentialsError{ - Raw: raw, - } +func (dcp DefaultCredentialsParser) Parse(_ context.Context, source *http.Request) (c bascule.Credentials, err error) { + header := dcp.Header + if len(header) == 0 { + header = DefaultAuthorizationHeader + } + + var raw string + values := source.Header.Values(header) + switch { + case len(values) == 0: + err = &MissingHeaderError{ + Header: header, } - return - }, -) + case len(values) == 1 || !dcp.ErrorOnDuplicate: + raw = values[0] + + default: + err = &DuplicateHeaderError{ + Header: header, + } + } + + // format is + // the code is strict: it requires no leading or trailing space + // and exactly one (1) space as a separator. + scheme, credValue, found := strings.Cut(raw, " ") + if found && len(scheme) > 0 && !fastIsSpace(credValue[0]) && !fastIsSpace(credValue[len(credValue)-1]) { + c = bascule.Credentials{ + Scheme: bascule.Scheme(scheme), + Value: credValue, + } + } else { + err = &bascule.BadCredentialsError{ + Raw: raw, + } + } -// DefaultCredentialsParser returns the default strategy for parsing credentials. This -// builtin strategy is very strict on whitespace. The format must correspond exactly -// to the format specified in https://www.rfc-editor.org/rfc/rfc7235. -func DefaultCredentialsParser() bascule.CredentialsParser[*http.Request] { - return defaultCredentialsParser + return } diff --git a/basculehttp/credentials_test.go b/basculehttp/credentials_test.go index 7b2e82c..dabfde1 100644 --- a/basculehttp/credentials_test.go +++ b/basculehttp/credentials_test.go @@ -4,6 +4,9 @@ package basculehttp import ( + "context" + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/suite" @@ -14,6 +17,12 @@ type CredentialsTestSuite struct { suite.Suite } +func (suite *CredentialsTestSuite) newDefaultSource(value string) *http.Request { + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(DefaultAuthorizationHeader, value) + return r +} + func (suite *CredentialsTestSuite) testDefaultCredentialsParserSuccess() { const ( expectedScheme bascule.Scheme = "Test" @@ -26,10 +35,10 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserSuccess() { for _, testCase := range testCases { suite.Run(testCase, func() { - dp := DefaultCredentialsParser() + dp := DefaultCredentialsParser{} suite.Require().NotNil(dp) - creds, err := dp.Parse(testCase) + creds, err := dp.Parse(context.Background(), suite.newDefaultSource(testCase)) suite.Require().NoError(err) suite.Equal( bascule.Credentials{ @@ -55,10 +64,10 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserFailure() { for _, testCase := range testCases { suite.Run(testCase, func() { - dp := DefaultCredentialsParser() + dp := DefaultCredentialsParser{} suite.Require().NotNil(dp) - creds, err := dp.Parse(testCase) + creds, err := dp.Parse(context.Background(), suite.newDefaultSource(testCase)) suite.Require().Error(err) suite.Equal(bascule.Credentials{}, creds) diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index e5121c4..6354dea 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -12,10 +12,6 @@ import ( "go.uber.org/multierr" ) -type CredentialsParser bascule.CredentialsParser[*http.Request] -type TokenParser bascule.TokenParser[*http.Request] -type TokenParsers bascule.TokenParsers[*http.Request] - // MiddlewareOption is a functional option for tailoring a Middleware. type MiddlewareOption interface { apply(*Middleware) error @@ -29,12 +25,12 @@ func (mof middlewareOptionFunc) apply(m *Middleware) error { // WithCredentialsParser configures a credentials parser for this Middleware. If not supplied // or if the supplied CredentialsParser is nil, DefaultCredentialsParser() is used. -func WithCredentialsParser(cp CredentialsParser) MiddlewareOption { +func WithCredentialsParser(cp bascule.CredentialsParser[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { if cp != nil { m.credentialsParser = cp } else { - m.credentialsParser = DefaultCredentialsParser() + m.credentialsParser = DefaultCredentialsParser{} } return nil @@ -45,7 +41,7 @@ func WithCredentialsParser(cp CredentialsParser) MiddlewareOption { // already been registered, the given parser will replace that registration. // // The parser cannot be nil. -func WithTokenParser(scheme bascule.Scheme, tp TokenParser) MiddlewareOption { +func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { m.tokenParsers.Register(scheme, tp) return nil @@ -116,8 +112,8 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { // Middleware is an immutable configuration that can decorate multiple handlers. type Middleware struct { - credentialsParser CredentialsParser - tokenParsers TokenParsers + credentialsParser bascule.CredentialsParser[*http.Request] + tokenParsers bascule.TokenParsers[*http.Request] authentication bascule.Validators authorization bascule.Authorizers[*http.Request] challenges Challenges @@ -130,8 +126,7 @@ type Middleware struct { // No options will result in a Middleware with default behavior. func NewMiddleware(opts ...MiddlewareOption) (m *Middleware, err error) { m = &Middleware{ - accessor: DefaultAccessor(), - credentialsParser: DefaultCredentialsParser(), + credentialsParser: DefaultCredentialsParser{}, tokenParsers: DefaultTokenParsers(), errorStatusCoder: DefaultErrorStatusCoder, errorMarshaler: DefaultErrorMarshaler, @@ -191,14 +186,9 @@ func (m *Middleware) writeError(response http.ResponseWriter, request *http.Requ } func (m *Middleware) getCredentialsAndToken(ctx context.Context, request *http.Request) (c bascule.Credentials, t bascule.Token, err error) { - var raw string - raw, err = m.accessor.GetCredentials(request) - if err == nil { - c, err = m.credentialsParser.Parse(raw) - } - + c, err = m.credentialsParser.Parse(request.Context(), request) if err == nil { - t, err = m.tokenParsers.Parse(ctx, c) + t, err = m.tokenParsers.Parse(ctx, request, c) } return diff --git a/basculehttp/token.go b/basculehttp/token.go index df708d1..52f2dc0 100644 --- a/basculehttp/token.go +++ b/basculehttp/token.go @@ -3,7 +3,11 @@ package basculehttp -import "github.com/xmidt-org/bascule/v1" +import ( + "net/http" + + "github.com/xmidt-org/bascule/v1" +) // Token is bascule's default HTTP token. type Token struct { @@ -17,8 +21,8 @@ func (t *Token) Principal() string { // DefaultTokenParsers returns the default suite of parsers supported by // bascule. This method returns a distinct instance each time it is called, // thus allowing calling code to tailor it independently of other calls. -func DefaultTokenParsers() bascule.TokenParsers { - return bascule.TokenParsers{ +func DefaultTokenParsers() bascule.TokenParsers[*http.Request] { + return bascule.TokenParsers[*http.Request]{ BasicScheme: basicTokenParser{}, } } diff --git a/basculejwt/token.go b/basculejwt/token.go index 30ae464..b53bec9 100644 --- a/basculejwt/token.go +++ b/basculejwt/token.go @@ -58,15 +58,16 @@ func (t *token) Principal() string { } // tokenParser is the canonical parser for bascule that deals with JWTs. -type tokenParser struct { +// This parser does not use the source. +type tokenParser[S any] struct { options []jwt.ParseOption } // NewTokenParser constructs a parser using the supplied set of parse options. // The returned parser will produce tokens that implement the Token interface // in this package. -func NewTokenParser(options ...jwt.ParseOption) (bascule.TokenParser, error) { - return &tokenParser{ +func NewTokenParser[S any](options ...jwt.ParseOption) (bascule.TokenParser[S], error) { + return &tokenParser[S]{ options: append( make([]jwt.ParseOption, 0, len(options)), options..., @@ -74,7 +75,7 @@ func NewTokenParser(options ...jwt.ParseOption) (bascule.TokenParser, error) { }, nil } -func (tp *tokenParser) Parse(_ context.Context, c bascule.Credentials) (bascule.Token, error) { +func (tp *tokenParser[S]) Parse(_ context.Context, _ S, c bascule.Credentials) (bascule.Token, error) { jwtToken, err := jwt.ParseString(c.Value, tp.options...) if err != nil { return nil, err diff --git a/credentials_test.go b/credentials_test.go index 5aa25bd..20956ca 100644 --- a/credentials_test.go +++ b/credentials_test.go @@ -4,6 +4,7 @@ package bascule import ( + "context" "errors" "testing" @@ -17,7 +18,7 @@ type CredentialsTestSuite struct { func (suite *CredentialsTestSuite) TestCredentialsParserFunc() { const expectedRaw = "expected raw credentials" expectedErr := errors.New("expected error") - var c CredentialsParser = CredentialsParserFunc(func(raw string) (Credentials, error) { + var c CredentialsParser[string] = CredentialsParserFunc[string](func(_ context.Context, raw string) (Credentials, error) { suite.Equal(expectedRaw, raw) return Credentials{ Scheme: Scheme("test"), @@ -25,7 +26,7 @@ func (suite *CredentialsTestSuite) TestCredentialsParserFunc() { }, expectedErr }) - creds, err := c.Parse(expectedRaw) + creds, err := c.Parse(context.Background(), expectedRaw) suite.Equal( Credentials{ Scheme: Scheme("test"), diff --git a/token_test.go b/token_test.go index 94b916f..052475e 100644 --- a/token_test.go +++ b/token_test.go @@ -23,27 +23,27 @@ func (suite *TokenParsersSuite) assertUnsupportedScheme(scheme Scheme, err error } func (suite *TokenParsersSuite) testParseEmpty() { - var tp TokenParsers + var tp TokenParsers[string] // legal, but will always fail - token, err := tp.Parse(context.Background(), suite.testCredentials()) + token, err := tp.Parse(context.Background(), "doesnotmatter", suite.testCredentials()) suite.Nil(token) suite.assertUnsupportedScheme(testScheme, err) } func (suite *TokenParsersSuite) testParseUnsupported() { - var tp TokenParsers + var tp TokenParsers[string] tp.Register( Scheme("Supported"), - TokenParserFunc( - func(context.Context, Credentials) (Token, error) { + TokenParserFunc[string]( + func(context.Context, string, Credentials) (Token, error) { suite.Fail("TokenParser should not have been called") return nil, nil }, ), ) - token, err := tp.Parse(context.Background(), suite.testCredentials()) + token, err := tp.Parse(context.Background(), "doesnotmatter", suite.testCredentials()) suite.Nil(token) suite.assertUnsupportedScheme(testScheme, err) } @@ -56,11 +56,11 @@ func (suite *TokenParsersSuite) testParseSupported() { testCredentials = suite.testCredentials() ) - var tp TokenParsers + var tp TokenParsers[string] tp.Register( testCredentials.Scheme, - TokenParserFunc( - func(ctx context.Context, c Credentials) (Token, error) { + TokenParserFunc[string]( + func(ctx context.Context, _ string, c Credentials) (Token, error) { suite.Equal(testCtx, ctx) suite.Equal(testCredentials, c) return suite.testToken(), expectedErr @@ -68,7 +68,7 @@ func (suite *TokenParsersSuite) testParseSupported() { ), ) - token, err := tp.Parse(testCtx, testCredentials) + token, err := tp.Parse(testCtx, "doesnotmatter", testCredentials) suite.Equal(suite.testToken(), token) suite.Same(expectedErr, err) } From f2b8c69e17f1908d41788a00afdaacb48e8770cb Mon Sep 17 00:00:00 2001 From: johnabass Date: Tue, 9 Jul 2024 15:47:22 -0700 Subject: [PATCH 3/4] removed accessor; patched tests; added simple middleware example --- basculehttp/accessor.go | 61 ------------------------- basculehttp/credentials.go | 26 ++++++----- basculehttp/credentials_test.go | 31 ++++++++++--- basculehttp/middleware.go | 4 +- basculehttp/middleware_examples_test.go | 42 +++++++++++++++++ 5 files changed, 83 insertions(+), 81 deletions(-) delete mode 100644 basculehttp/accessor.go create mode 100644 basculehttp/middleware_examples_test.go diff --git a/basculehttp/accessor.go b/basculehttp/accessor.go deleted file mode 100644 index 3b1fa10..0000000 --- a/basculehttp/accessor.go +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package basculehttp - -import ( - "net/http" -) - -// Accessor is the strategy for obtaining credentials from an HTTP request. -type Accessor interface { - // GetCredentials returns the raw credentials from a request. - GetCredentials(*http.Request) (string, error) -} - -var defaultAccessor Accessor = HeaderAccessor{} - -// DefaultAccessor returns the builtin default strategy for obtaining raw credentials -// from an HTTP request. The returned Accessor simply retrieves the Authorization header -// value if it exists. -func DefaultAccessor() Accessor { return defaultAccessor } - -// HeaderAccessor obtains the raw credentials from a specific header in -// an HTTP request. -type HeaderAccessor struct { - // Header is the name of the HTTP header to use. If not supplied, - // DefaultAuthorizationHeader is used. - // - // If no authorization header can be found in an HTTP request, - // MissingHeaderError is returned. - Header string - - // ErrorOnDuplicate controls whether an error is returned if more - // than one Header is found in the request. By default, this is false. - ErrorOnDuplicate bool -} - -func (ha HeaderAccessor) GetCredentials(r *http.Request) (raw string, err error) { - h := ha.Header - if len(h) == 0 { - h = DefaultAuthorizationHeader - } - - values := r.Header.Values(h) - switch { - case len(values) == 0: - err = &MissingHeaderError{ - Header: h, - } - - case len(values) == 1 || !ha.ErrorOnDuplicate: - raw = values[0] - - default: - err = &DuplicateHeaderError{ - Header: h, - } - } - - return -} diff --git a/basculehttp/credentials.go b/basculehttp/credentials.go index 88a8f6d..000d6ca 100644 --- a/basculehttp/credentials.go +++ b/basculehttp/credentials.go @@ -94,18 +94,20 @@ func (dcp DefaultCredentialsParser) Parse(_ context.Context, source *http.Reques } } - // format is - // the code is strict: it requires no leading or trailing space - // and exactly one (1) space as a separator. - scheme, credValue, found := strings.Cut(raw, " ") - if found && len(scheme) > 0 && !fastIsSpace(credValue[0]) && !fastIsSpace(credValue[len(credValue)-1]) { - c = bascule.Credentials{ - Scheme: bascule.Scheme(scheme), - Value: credValue, - } - } else { - err = &bascule.BadCredentialsError{ - Raw: raw, + if err == nil { + // format is + // the code is strict: it requires no leading or trailing space + // and exactly one (1) space as a separator. + scheme, credValue, found := strings.Cut(raw, " ") + if found && len(scheme) > 0 && !fastIsSpace(credValue[0]) && !fastIsSpace(credValue[len(credValue)-1]) { + c = bascule.Credentials{ + Scheme: bascule.Scheme(scheme), + Value: credValue, + } + } else { + err = &bascule.BadCredentialsError{ + Raw: raw, + } } } diff --git a/basculehttp/credentials_test.go b/basculehttp/credentials_test.go index dabfde1..89941e0 100644 --- a/basculehttp/credentials_test.go +++ b/basculehttp/credentials_test.go @@ -35,10 +35,10 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserSuccess() { for _, testCase := range testCases { suite.Run(testCase, func() { - dp := DefaultCredentialsParser{} - suite.Require().NotNil(dp) + dcp := DefaultCredentialsParser{} + suite.Require().NotNil(dcp) - creds, err := dp.Parse(context.Background(), suite.newDefaultSource(testCase)) + creds, err := dcp.Parse(context.Background(), suite.newDefaultSource(testCase)) suite.Require().NoError(err) suite.Equal( bascule.Credentials{ @@ -64,10 +64,10 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserFailure() { for _, testCase := range testCases { suite.Run(testCase, func() { - dp := DefaultCredentialsParser{} - suite.Require().NotNil(dp) + dcp := DefaultCredentialsParser{} + suite.Require().NotNil(dcp) - creds, err := dp.Parse(context.Background(), suite.newDefaultSource(testCase)) + creds, err := dcp.Parse(context.Background(), suite.newDefaultSource(testCase)) suite.Require().Error(err) suite.Equal(bascule.Credentials{}, creds) @@ -79,9 +79,28 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserFailure() { } } +func (suite *CredentialsTestSuite) testDefaultCredentialsParserMissingHeader() { + dcp := DefaultCredentialsParser{} + suite.Require().NotNil(dcp) + + r := httptest.NewRequest("GET", "/", nil) + creds, err := dcp.Parse(context.Background(), r) + suite.Require().Error(err) + suite.Equal(bascule.Credentials{}, creds) + + type statusCoder interface { + StatusCode() int + } + + var sc statusCoder + suite.Require().ErrorAs(err, &sc) + suite.Equal(http.StatusUnauthorized, sc.StatusCode()) +} + func (suite *CredentialsTestSuite) TestDefaultCredentialsParser() { suite.Run("Success", suite.testDefaultCredentialsParserSuccess) suite.Run("Failure", suite.testDefaultCredentialsParserFailure) + suite.Run("MissingHeader", suite.testDefaultCredentialsParserMissingHeader) } func TestCredentials(t *testing.T) { diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 6354dea..0c9f4b5 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -221,7 +221,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque ctx = bascule.WithCredentials(ctx, creds) err = fd.middleware.authenticate(ctx, token) - if err == nil { + if err != nil { // at this point in the workflow, the request has valid credentials. we use // StatusForbidden as the default because any failure to authenticate isn't a // case where the caller needs to supply credentials. Rather, the supplied @@ -232,7 +232,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque ctx = bascule.WithToken(ctx, token) err = fd.middleware.authorize(ctx, token, request) - if err == nil { + if err != nil { fd.middleware.writeError(response, request, http.StatusForbidden, err) return } diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go new file mode 100644 index 0000000..7b9c54a --- /dev/null +++ b/basculehttp/middleware_examples_test.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +import ( + "fmt" + "net/http" + "net/http/httptest" +) + +// ExampleMiddleware_simple illustrates how to use a basculehttp Middleware with +// just the defaults. +func ExampleMiddleware_simple() { + m, err := NewMiddleware() // all defaults + if err != nil { + panic(err) + } + + // decorate a handler that needs authorization + h := m.ThenFunc( + func(response http.ResponseWriter, request *http.Request) { + }, + ) + + // what happens when no authorization is set? + noAuth := httptest.NewRequest("GET", "/", nil) + response := httptest.NewRecorder() + h.ServeHTTP(response, noAuth) + fmt.Println("no authorization response code:", response.Code) + + // what happens when a valid Basic token is set? + withBasic := httptest.NewRequest("GET", "/", nil) + withBasic.SetBasicAuth("joe", "password") + response = httptest.NewRecorder() + h.ServeHTTP(response, withBasic) + fmt.Println("with basic auth response code:", response.Code) + + // Output: + // no authorization response code: 401 + // with basic auth response code: 200 +} From d84b6049ab59b50778674e532ba1f15bd63baa5a Mon Sep 17 00:00:00 2001 From: johnabass Date: Tue, 9 Jul 2024 16:03:07 -0700 Subject: [PATCH 4/4] demonstrate how to access the token from inside a handler --- basculehttp/middleware_examples_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go index 7b9c54a..04da19c 100644 --- a/basculehttp/middleware_examples_test.go +++ b/basculehttp/middleware_examples_test.go @@ -7,6 +7,8 @@ import ( "fmt" "net/http" "net/http/httptest" + + "github.com/xmidt-org/bascule/v1" ) // ExampleMiddleware_simple illustrates how to use a basculehttp Middleware with @@ -20,6 +22,12 @@ func ExampleMiddleware_simple() { // decorate a handler that needs authorization h := m.ThenFunc( func(response http.ResponseWriter, request *http.Request) { + t, ok := bascule.GetTokenFrom(request) + if !ok { + panic("no token found") + } + + fmt.Println("principal:", t.Principal()) }, ) @@ -38,5 +46,6 @@ func ExampleMiddleware_simple() { // Output: // no authorization response code: 401 + // principal: joe // with basic auth response code: 200 }