Skip to content

Commit

Permalink
Hotfix/datarace (#55)
Browse files Browse the repository at this point in the history
* small comment changes

* fixed data race and added a test that catches it

* updated changelog

* updated test

* added whitespace
  • Loading branch information
kristinapathak authored Feb 25, 2020
1 parent 271bc0e commit 1016ff7
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 15 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

## [v0.8.1]
- fixed data race in RemoteBearerTokenAcquirer [#55](https://github.com/xmidt-org/bascule/pull/55)

## [v0.8.0]
- Add support for key paths in token attribute getters [#52](https://github.com/xmidt-org/bascule/pull/52)

Expand Down Expand Up @@ -71,7 +74,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Added constructor, enforcer, and listener alice decorators
- Basic code and structure established

[Unreleased]: https://github.com/xmidt-org/bascule/compare/v0.8.0...HEAD
[Unreleased]: https://github.com/xmidt-org/bascule/compare/v0.8.1...HEAD
[v0.8.1]: https://github.com/xmidt-org/bascule/compare/v0.8.0...v0.8.1
[v0.8.0]: https://github.com/xmidt-org/bascule/compare/v0.7.0...v0.8.0
[v0.7.0]: https://github.com/xmidt-org/bascule/compare/v0.6.0...v0.7.0
[v0.6.0]: https://github.com/xmidt-org/bascule/compare/v0.5.0...v0.6.0
Expand Down
10 changes: 5 additions & 5 deletions acquire/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

//ErrEmptyCredentials is returned whenever an Acquirer is attempted to
//be built with empty credentials
//Use DefaultAcquirer for such no-op use case
//be built with empty credentials.
//Use DefaultAcquirer for such no-op use case.
var ErrEmptyCredentials = errors.New("Empty credentials are not valid")

// Acquirer gets an Authorization value that can be added to an http request.
Expand All @@ -23,7 +23,7 @@ type Acquirer interface {
// DefaultAcquirer is a no-op Acquirer.
type DefaultAcquirer struct{}

//Acquire returns the zero values of the return types
//Acquire returns the zero values of the return types.
func (d *DefaultAcquirer) Acquire() (string, error) {
return "", nil
}
Expand Down Expand Up @@ -51,7 +51,7 @@ func AddAuth(r *http.Request, acquirer Acquirer) error {
return nil
}

//FixedValueAcquirer implements Acquirer with a constant authorization value
//FixedValueAcquirer implements Acquirer with a constant authorization value.
type FixedValueAcquirer struct {
authValue string
}
Expand All @@ -60,7 +60,7 @@ func (f *FixedValueAcquirer) Acquire() (string, error) {
return f.authValue, nil
}

// NewFixedAuthAcquirer returns a FixedValueAcquirer with the given authValue
// NewFixedAuthAcquirer returns a FixedValueAcquirer with the given authValue.
func NewFixedAuthAcquirer(authValue string) (*FixedValueAcquirer, error) {
if authValue != "" {
return &FixedValueAcquirer{
Expand Down
27 changes: 18 additions & 9 deletions acquire/bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ import (
"fmt"
"io/ioutil"
"net/http"
"sync"
"time"

"github.com/goph/emperror"
)

//TokenParser defines the function signature of a bearer token extractor from a payload
// TokenParser defines the function signature of a bearer token extractor from a payload.
type TokenParser func([]byte) (string, error)

//ParseExpiration defines the function signature of a bearer token expiration date extractor
// ParseExpiration defines the function signature of a bearer token expiration date extractor.
type ParseExpiration func([]byte) (time.Time, error)

//DefaultTokenParser extracts a bearer token as defined by a SimpleBearer in a payload
// DefaultTokenParser extracts a bearer token as defined by a SimpleBearer in a payload.
func DefaultTokenParser(data []byte) (string, error) {
var bearer SimpleBearer

Expand All @@ -26,7 +27,7 @@ func DefaultTokenParser(data []byte) (string, error) {
return bearer.Token, nil
}

//DefaultExpirationParser extracts a bearer token expiration date as defined by a SimpleBearer in a payload
// DefaultExpirationParser extracts a bearer token expiration date as defined by a SimpleBearer in a payload.
func DefaultExpirationParser(data []byte) (time.Time, error) {
var bearer SimpleBearer

Expand All @@ -36,7 +37,7 @@ func DefaultExpirationParser(data []byte) (time.Time, error) {
return time.Now().Add(time.Duration(bearer.ExpiresInSeconds) * time.Second), nil
}

//RemoteBearerTokenAcquirerOptions provides configuration for the RemoteBearerTokenAcquirer
// RemoteBearerTokenAcquirerOptions provides configuration for the RemoteBearerTokenAcquirer.
type RemoteBearerTokenAcquirerOptions struct {
AuthURL string `json:"authURL"`
Timeout time.Duration `json:"timeout"`
Expand All @@ -47,22 +48,23 @@ type RemoteBearerTokenAcquirerOptions struct {
GetExpiration ParseExpiration
}

//RemoteBearerTokenAcquirer implements Acquirer and fetches the tokens from a remote location with caching strategy
// RemoteBearerTokenAcquirer implements Acquirer and fetches the tokens from a remote location with caching strategy.
type RemoteBearerTokenAcquirer struct {
options RemoteBearerTokenAcquirerOptions
authValue string
authValueExpiration time.Time
httpClient *http.Client
nonExpiringSpecialCase time.Time
lock sync.RWMutex
}

//SimpleBearer defines the field name mappings used by the default bearer token and expiration parsers
// SimpleBearer defines the field name mappings used by the default bearer token and expiration parsers.
type SimpleBearer struct {
ExpiresInSeconds float64 `json:"expires_in"`
Token string `json:"serviceAccessToken"`
}

// NewRemoteBearerTokenAcquirer returns a RemoteBearerTokenAcquirer configured with the given options
// NewRemoteBearerTokenAcquirer returns a RemoteBearerTokenAcquirer configured with the given options.
func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*RemoteBearerTokenAcquirer, error) {
if options.GetToken == nil {
options.GetToken = DefaultTokenParser
Expand All @@ -72,7 +74,7 @@ func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*Re
options.GetExpiration = DefaultExpirationParser
}

//TODO: we should inject timeout and buffer defaults values as well
// TODO: we should inject timeout and buffer defaults values as well.

return &RemoteBearerTokenAcquirer{
options: options,
Expand All @@ -84,10 +86,17 @@ func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*Re
}, nil
}

// Acquire provides the cached token or, if it's near its expiry time, contacts
// the server for a new token to cache.
func (acquirer *RemoteBearerTokenAcquirer) Acquire() (string, error) {
acquirer.lock.RLock()
if time.Now().Add(acquirer.options.Buffer).Before(acquirer.authValueExpiration) {
defer acquirer.lock.RUnlock()
return acquirer.authValue, nil
}
acquirer.lock.RUnlock()
acquirer.lock.Lock()
defer acquirer.lock.Unlock()

req, err := http.NewRequest("GET", acquirer.options.AuthURL, nil)
if err != nil {
Expand Down
42 changes: 42 additions & 0 deletions acquire/bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -143,6 +144,47 @@ func TestRemoteBearerTokenAcquirerCaching(t *testing.T) {
assert.Equal(1, count)
}

func TestRemoteBearerTokenAcquirerExiting(t *testing.T) {
assert := assert.New(t)

count := 0
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
auth := SimpleBearer{
Token: fmt.Sprintf("gopher%v", count),
ExpiresInSeconds: 1, //1 second
}
count++

marshaledAuth, err := json.Marshal(&auth)
assert.Nil(err)
rw.Write(marshaledAuth)
}))
defer server.Close()

// Use Client & URL from our local test server
auth, errConstructor := NewRemoteBearerTokenAcquirer(RemoteBearerTokenAcquirerOptions{
AuthURL: server.URL,
Timeout: time.Duration(5) * time.Second,
Buffer: time.Second,
})
assert.Nil(errConstructor)
token, err := auth.Acquire()
assert.Nil(err)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
_, err := auth.Acquire()
assert.Nil(err)
wg.Done()
}()
}
wg.Wait()
cachedToken, err := auth.Acquire()
assert.Nil(err)
assert.NotEqual(token, cachedToken)
}

type customBearer struct {
Token string `json:"token"`
ExpiresOnUnixSeconds int64 `json:"expires_on"`
Expand Down

0 comments on commit 1016ff7

Please sign in to comment.