Skip to content

Commit

Permalink
Merge pull request #266 from xmidt-org/feature/generic-source
Browse files Browse the repository at this point in the history
Feature/generic source
  • Loading branch information
johnabass authored Jul 10, 2024
2 parents 17fe87e + d84b604 commit 5d8a265
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 198 deletions.
103 changes: 0 additions & 103 deletions basculehttp/accessor.go

This file was deleted.

3 changes: 2 additions & 1 deletion basculehttp/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package basculehttp
import (
"context"
"encoding/base64"
"net/http"
"strings"

"github.com/xmidt-org/bascule/v1"
Expand Down Expand Up @@ -35,7 +36,7 @@ func (err *InvalidBasicAuthError) Error() string {

type basicTokenParser struct{}

func (btp basicTokenParser) Parse(_ context.Context, c bascule.Credentials) (t bascule.Token, err error) {
func (btp basicTokenParser) Parse(_ context.Context, _ *http.Request, c bascule.Credentials) (t bascule.Token, err error) {
var decoded []byte
decoded, err = base64.StdEncoding.DecodeString(c.Value)
if err != nil {
Expand Down
98 changes: 84 additions & 14 deletions basculehttp/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,112 @@
package basculehttp

import (
"context"
"net/http"
"strings"

"github.com/xmidt-org/bascule/v1"
)

const (
// DefaultAuthorizationHeader is the name of the header used by default to obtain
// the raw credentials.
DefaultAuthorizationHeader = "Authorization"
)

// DuplicateHeaderError indicates that an HTTP header had more than one value
// when only one value was expected.
type DuplicateHeaderError struct {
// Header is the name of the duplicate header.
Header string
}

func (err *DuplicateHeaderError) Error() string {
var o strings.Builder
o.WriteString(`Duplicate header: "`)
o.WriteString(err.Header)
o.WriteString(`"`)
return o.String()
}

// MissingHeaderError indicates that an expected HTTP header is missing.
type MissingHeaderError struct {
// Header is the name of the missing header.
Header string
}

func (err *MissingHeaderError) Error() string {
var o strings.Builder
o.WriteString(`Missing header: "`)
o.WriteString(err.Header)
o.WriteString(`"`)
return o.String()
}

// StatusCode returns http.StatusUnauthorized, as the request carries
// no authorization in it.
func (err *MissingHeaderError) StatusCode() int {
return http.StatusUnauthorized
}

// fastIsSpace tests an ASCII byte to see if it's whitespace.
// HTTP headers are restricted to US-ASCII, so we don't need
// the full unicode stack.
func fastIsSpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r' || b == '\v' || b == '\f'
}

var defaultCredentialsParser bascule.CredentialsParser = bascule.CredentialsParserFunc(
func(raw string) (c bascule.Credentials, err error) {
// DefaultCredentialsParser is the default algorithm used to produce HTTP credentials
// from a source request.
type DefaultCredentialsParser struct {
// Header is the name of the authorization header. If unset,
// DefaultAuthorizationHeader is used.
Header string

// ErrorOnDuplicate controls whether an error is returned if more
// than one Header is found in the request. By default, this is false.
ErrorOnDuplicate bool
}

func (dcp DefaultCredentialsParser) Parse(_ context.Context, source *http.Request) (c bascule.Credentials, err error) {
header := dcp.Header
if len(header) == 0 {
header = DefaultAuthorizationHeader
}

var raw string
values := source.Header.Values(header)
switch {
case len(values) == 0:
err = &MissingHeaderError{
Header: header,
}

case len(values) == 1 || !dcp.ErrorOnDuplicate:
raw = values[0]

default:
err = &DuplicateHeaderError{
Header: header,
}
}

if err == nil {
// format is <scheme><single space><credential value>
// the code is strict: it requires no leading or trailing space
// and exactly one (1) space as a separator.
scheme, value, found := strings.Cut(raw, " ")
if found && len(scheme) > 0 && !fastIsSpace(value[0]) && !fastIsSpace(value[len(value)-1]) {
scheme, credValue, found := strings.Cut(raw, " ")
if found && len(scheme) > 0 && !fastIsSpace(credValue[0]) && !fastIsSpace(credValue[len(credValue)-1]) {
c = bascule.Credentials{
Scheme: bascule.Scheme(scheme),
Value: value,
Value: credValue,
}
} else {
err = &bascule.BadCredentialsError{
Raw: raw,
}
}
}

return
},
)

// DefaultCredentialsParser returns the default strategy for parsing credentials. This
// builtin strategy is very strict on whitespace. The format must correspond exactly
// to the format specified in https://www.rfc-editor.org/rfc/rfc7235.
func DefaultCredentialsParser() bascule.CredentialsParser {
return defaultCredentialsParser
return
}
40 changes: 34 additions & 6 deletions basculehttp/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package basculehttp

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/suite"
Expand All @@ -14,6 +17,12 @@ type CredentialsTestSuite struct {
suite.Suite
}

func (suite *CredentialsTestSuite) newDefaultSource(value string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set(DefaultAuthorizationHeader, value)
return r
}

func (suite *CredentialsTestSuite) testDefaultCredentialsParserSuccess() {
const (
expectedScheme bascule.Scheme = "Test"
Expand All @@ -26,10 +35,10 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserSuccess() {

for _, testCase := range testCases {
suite.Run(testCase, func() {
dp := DefaultCredentialsParser()
suite.Require().NotNil(dp)
dcp := DefaultCredentialsParser{}
suite.Require().NotNil(dcp)

creds, err := dp.Parse(testCase)
creds, err := dcp.Parse(context.Background(), suite.newDefaultSource(testCase))
suite.Require().NoError(err)
suite.Equal(
bascule.Credentials{
Expand All @@ -55,10 +64,10 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserFailure() {

for _, testCase := range testCases {
suite.Run(testCase, func() {
dp := DefaultCredentialsParser()
suite.Require().NotNil(dp)
dcp := DefaultCredentialsParser{}
suite.Require().NotNil(dcp)

creds, err := dp.Parse(testCase)
creds, err := dcp.Parse(context.Background(), suite.newDefaultSource(testCase))
suite.Require().Error(err)
suite.Equal(bascule.Credentials{}, creds)

Expand All @@ -70,9 +79,28 @@ func (suite *CredentialsTestSuite) testDefaultCredentialsParserFailure() {
}
}

func (suite *CredentialsTestSuite) testDefaultCredentialsParserMissingHeader() {
dcp := DefaultCredentialsParser{}
suite.Require().NotNil(dcp)

r := httptest.NewRequest("GET", "/", nil)
creds, err := dcp.Parse(context.Background(), r)
suite.Require().Error(err)
suite.Equal(bascule.Credentials{}, creds)

type statusCoder interface {
StatusCode() int
}

var sc statusCoder
suite.Require().ErrorAs(err, &sc)
suite.Equal(http.StatusUnauthorized, sc.StatusCode())
}

func (suite *CredentialsTestSuite) TestDefaultCredentialsParser() {
suite.Run("Success", suite.testDefaultCredentialsParserSuccess)
suite.Run("Failure", suite.testDefaultCredentialsParserFailure)
suite.Run("MissingHeader", suite.testDefaultCredentialsParserMissingHeader)
}

func TestCredentials(t *testing.T) {
Expand Down
Loading

0 comments on commit 5d8a265

Please sign in to comment.