Skip to content

Commit

Permalink
Feature/auth refactor (#37)
Browse files Browse the repository at this point in the history
* refactor auth acquirers

* Small refactor of acquire submodule

* simplify acquirers + tests

* update acquirer example

* bring back un-expiring token logic with tests

* tiny name improvement/consistency

* Small refactor from code reviews

* wrap errors and fix tests
  • Loading branch information
joe94 authored Aug 8, 2019
1 parent 57fa1c4 commit fb9769c
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 217 deletions.
40 changes: 36 additions & 4 deletions acquire/auth.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// package acquire is used for getting Auths to pass in http requests.
// Package acquire is used for getting Auths to pass in http requests.
package acquire

import (
Expand All @@ -8,31 +8,63 @@ import (
"github.com/pkg/errors"
)

//ErrEmptyCredentials is returned whenever an Acquirer is attempted to
//be built with empty credentials
//Use DefaultAcquirer for such no-op use case
var ErrEmptyCredentials = errors.New("Empty credentials are not valid")

// Acquirer gets an Authorization value that can be added to an http request.
// The format of the string returned should be the key, a space, and then the
// auth string.
// auth string: '[AuthType] [AuthValue]'
type Acquirer interface {
Acquire() (string, error)
}

// DefaultAcquirer returns nothing. This would not be a valid Authorization.
// DefaultAcquirer is a no-op Acquirer.
type DefaultAcquirer struct{}

//Acquire returns the zero values of the return types
func (d *DefaultAcquirer) Acquire() (string, error) {
return "", nil
}

// AddAuth adds an auth value to the Authorization header of an http request.
//AddAuth adds an auth value to the Authorization header of an http request.
func AddAuth(r *http.Request, acquirer Acquirer) error {
if r == nil {
return errors.New("can't add authorization to nil request")
}

if acquirer == nil {
return errors.New("acquirer is undefined")
}

auth, err := acquirer.Acquire()

if err != nil {
return emperror.Wrap(err, "failed to acquire auth for request")
}

if auth != "" {
r.Header.Set("Authorization", auth)
}

return nil
}

//FixedValueAcquirer implements Acquirer with a constant authorization value
type FixedValueAcquirer struct {
authValue string
}

func (f *FixedValueAcquirer) Acquire() (string, error) {
return f.authValue, nil
}

// NewFixedAuthAcquirer returns a FixedValueAcquirer with the given authValue
func NewFixedAuthAcquirer(authValue string) (*FixedValueAcquirer, error) {
if authValue != "" {
return &FixedValueAcquirer{
authValue: authValue}, nil
}
return nil, ErrEmptyCredentials
}
92 changes: 92 additions & 0 deletions acquire/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package acquire

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

func TestAddAuth(t *testing.T) {
fixedAcquirer, _ := NewFixedAuthAcquirer("Basic abc==")
tests := []struct {
name string
request *http.Request
acquirer Acquirer
shouldError bool
authValue string
}{
{
name: "RequestIsNil",
acquirer: &DefaultAcquirer{},
shouldError: true,
},
{
name: "AcquirerIsNil",
request: httptest.NewRequest(http.MethodGet, "/", nil),
shouldError: true,
},
{
name: "AcquirerFails",
acquirer: &failingAcquirer{},
shouldError: true,
},
{
name: "HappyPath",
request: httptest.NewRequest(http.MethodGet, "/", nil),
acquirer: fixedAcquirer,
shouldError: false,
authValue: "Basic abc==",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert := assert.New(t)

if test.shouldError {
assert.NotNil(AddAuth(test.request, test.acquirer))
} else {
assert.Nil(AddAuth(test.request, test.acquirer))
assert.Equal(test.authValue, test.request.Header.Get("Authorization"))
}
})
}
}

func TestFixedAuthAcquirer(t *testing.T) {
t.Run("HappyPath", func(t *testing.T) {
assert := assert.New(t)

acquirer, err := NewFixedAuthAcquirer("Basic xyz==")
assert.NotNil(acquirer)
assert.Nil(err)

authValue, _ := acquirer.Acquire()
assert.Equal("Basic xyz==", authValue)
})

t.Run("EmptyCredentials", func(t *testing.T) {
assert := assert.New(t)

acquirer, err := NewFixedAuthAcquirer("")
assert.Equal(ErrEmptyCredentials, err)
assert.Nil(acquirer)
})
}

func TestDefaultAcquirer(t *testing.T) {
assert := assert.New(t)
acquirer := &DefaultAcquirer{}
authValue, err := acquirer.Acquire()
assert.Empty(authValue)
assert.Empty(err)
}

type failingAcquirer struct{}

func (f *failingAcquirer) Acquire() (string, error) {
return "", errors.New("always fails")
}
31 changes: 0 additions & 31 deletions acquire/basic.go

This file was deleted.

36 changes: 0 additions & 36 deletions acquire/basic_test.go

This file was deleted.

127 changes: 127 additions & 0 deletions acquire/bearer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package acquire

import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"time"

"github.com/goph/emperror"
)

//TokenParser defines the function signature of a bearer token extractor from a payload
type TokenParser func([]byte) (string, error)

//ParseExpiration defines the function signature of a bearer token expiration date extractor
type ParseExpiration func([]byte) (time.Time, error)

//DefaultTokenParser extracts a bearer token as defined by a SimpleBearer in a payload
func DefaultTokenParser(data []byte) (string, error) {
var bearer SimpleBearer

if errUnmarshal := json.Unmarshal(data, &bearer); errUnmarshal != nil {
return "", emperror.Wrap(errUnmarshal, "unable to parse bearer token")
}
return bearer.Token, nil
}

//DefaultExpirationParser extracts a bearer token expiration date as defined by a SimpleBearer in a payload
func DefaultExpirationParser(data []byte) (time.Time, error) {
var bearer SimpleBearer

if errUnmarshal := json.Unmarshal(data, &bearer); errUnmarshal != nil {
return time.Time{}, emperror.Wrap(errUnmarshal, "unable to parse bearer token expiration")
}
return time.Now().Add(time.Duration(bearer.ExpiresInSeconds) * time.Second), nil
}

//RemoteBearerTokenAcquirerOptions provides configuration for the RemoteBearerTokenAcquirer
type RemoteBearerTokenAcquirerOptions struct {
AuthURL string `json:"authURL"`
Timeout time.Duration `json:"timeout"`
Buffer time.Duration `json:"buffer"`
RequestHeaders map[string]string `json:"requestHeaders"`

GetToken TokenParser
GetExpiration ParseExpiration
}

//RemoteBearerTokenAcquirer implements Acquirer and fetches the tokens from a remote location with caching strategy
type RemoteBearerTokenAcquirer struct {
options RemoteBearerTokenAcquirerOptions
authValue string
authValueExpiration time.Time
httpClient *http.Client
nonExpiringSpecialCase time.Time
}

//SimpleBearer defines the field name mappings used by the default bearer token and expiration parsers
type SimpleBearer struct {
ExpiresInSeconds float64 `json:"expires_in"`
Token string `json:"serviceAccessToken"`
}

// NewRemoteBearerTokenAcquirer returns a RemoteBearerTokenAcquirer configured with the given options
func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*RemoteBearerTokenAcquirer, error) {
if options.GetToken == nil {
options.GetToken = DefaultTokenParser
}

if options.GetExpiration == nil {
options.GetExpiration = DefaultExpirationParser
}

//TODO: we should inject timeout and buffer defaults values as well

return &RemoteBearerTokenAcquirer{
options: options,
authValueExpiration: time.Now(),
httpClient: &http.Client{
Timeout: options.Timeout,
},
nonExpiringSpecialCase: time.Unix(0, 0),
}, nil
}

func (acquirer *RemoteBearerTokenAcquirer) Acquire() (string, error) {
if time.Now().Add(acquirer.options.Buffer).Before(acquirer.authValueExpiration) {
return acquirer.authValue, nil
}

req, err := http.NewRequest("GET", acquirer.options.AuthURL, nil)
if err != nil {
return "", emperror.Wrap(err, "failed to create new request for Bearer")
}

for key, value := range acquirer.options.RequestHeaders {
req.Header.Set(key, value)
}

resp, errHTTP := acquirer.httpClient.Do(req)
if errHTTP != nil {
return "", emperror.Wrapf(errHTTP, "error making request to '%v' to acquire bearer token", acquirer.options.AuthURL)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("received non 200 code acquiring Bearer: code %v", resp.Status)
}

respBody, errRead := ioutil.ReadAll(resp.Body)
if errRead != nil {
return "", emperror.Wrap(errRead, "error reading HTTP response body")
}

token, err := acquirer.options.GetToken(respBody)
if err != nil {
return "", emperror.Wrap(err, "error parsing bearer token from http response body")
}
expiration, err := acquirer.options.GetExpiration(respBody)
if err != nil {
return "", emperror.Wrap(err, "error parsing bearer token expiration from http response body")
}

acquirer.authValue, acquirer.authValueExpiration = "Bearer "+token, expiration
return acquirer.authValue, nil
}
Loading

0 comments on commit fb9769c

Please sign in to comment.