From 1016ff75b5bf9d2a25183d4468b8bbfcd9e0ef6a Mon Sep 17 00:00:00 2001 From: kristinaspring Date: Mon, 24 Feb 2020 16:26:40 -0800 Subject: [PATCH] Hotfix/datarace (#55) * small comment changes * fixed data race and added a test that catches it * updated changelog * updated test * added whitespace --- CHANGELOG.md | 6 +++++- acquire/auth.go | 10 +++++----- acquire/bearer.go | 27 ++++++++++++++++++--------- acquire/bearer_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 562426c..c3e02c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +## [v0.8.1] +- fixed data race in RemoteBearerTokenAcquirer [#55](https://github.com/xmidt-org/bascule/pull/55) + ## [v0.8.0] - Add support for key paths in token attribute getters [#52](https://github.com/xmidt-org/bascule/pull/52) @@ -71,7 +74,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Added constructor, enforcer, and listener alice decorators - Basic code and structure established -[Unreleased]: https://github.com/xmidt-org/bascule/compare/v0.8.0...HEAD +[Unreleased]: https://github.com/xmidt-org/bascule/compare/v0.8.1...HEAD +[v0.8.1]: https://github.com/xmidt-org/bascule/compare/v0.8.0...v0.8.1 [v0.8.0]: https://github.com/xmidt-org/bascule/compare/v0.7.0...v0.8.0 [v0.7.0]: https://github.com/xmidt-org/bascule/compare/v0.6.0...v0.7.0 [v0.6.0]: https://github.com/xmidt-org/bascule/compare/v0.5.0...v0.6.0 diff --git a/acquire/auth.go b/acquire/auth.go index ada2c5b..34e0749 100644 --- a/acquire/auth.go +++ b/acquire/auth.go @@ -9,8 +9,8 @@ import ( ) //ErrEmptyCredentials is returned whenever an Acquirer is attempted to -//be built with empty credentials -//Use DefaultAcquirer for such no-op use case +//be built with empty credentials. +//Use DefaultAcquirer for such no-op use case. var ErrEmptyCredentials = errors.New("Empty credentials are not valid") // Acquirer gets an Authorization value that can be added to an http request. @@ -23,7 +23,7 @@ type Acquirer interface { // DefaultAcquirer is a no-op Acquirer. type DefaultAcquirer struct{} -//Acquire returns the zero values of the return types +//Acquire returns the zero values of the return types. func (d *DefaultAcquirer) Acquire() (string, error) { return "", nil } @@ -51,7 +51,7 @@ func AddAuth(r *http.Request, acquirer Acquirer) error { return nil } -//FixedValueAcquirer implements Acquirer with a constant authorization value +//FixedValueAcquirer implements Acquirer with a constant authorization value. type FixedValueAcquirer struct { authValue string } @@ -60,7 +60,7 @@ func (f *FixedValueAcquirer) Acquire() (string, error) { return f.authValue, nil } -// NewFixedAuthAcquirer returns a FixedValueAcquirer with the given authValue +// NewFixedAuthAcquirer returns a FixedValueAcquirer with the given authValue. func NewFixedAuthAcquirer(authValue string) (*FixedValueAcquirer, error) { if authValue != "" { return &FixedValueAcquirer{ diff --git a/acquire/bearer.go b/acquire/bearer.go index bd6f5a4..b0cae2b 100644 --- a/acquire/bearer.go +++ b/acquire/bearer.go @@ -5,18 +5,19 @@ import ( "fmt" "io/ioutil" "net/http" + "sync" "time" "github.com/goph/emperror" ) -//TokenParser defines the function signature of a bearer token extractor from a payload +// TokenParser defines the function signature of a bearer token extractor from a payload. type TokenParser func([]byte) (string, error) -//ParseExpiration defines the function signature of a bearer token expiration date extractor +// ParseExpiration defines the function signature of a bearer token expiration date extractor. type ParseExpiration func([]byte) (time.Time, error) -//DefaultTokenParser extracts a bearer token as defined by a SimpleBearer in a payload +// DefaultTokenParser extracts a bearer token as defined by a SimpleBearer in a payload. func DefaultTokenParser(data []byte) (string, error) { var bearer SimpleBearer @@ -26,7 +27,7 @@ func DefaultTokenParser(data []byte) (string, error) { return bearer.Token, nil } -//DefaultExpirationParser extracts a bearer token expiration date as defined by a SimpleBearer in a payload +// DefaultExpirationParser extracts a bearer token expiration date as defined by a SimpleBearer in a payload. func DefaultExpirationParser(data []byte) (time.Time, error) { var bearer SimpleBearer @@ -36,7 +37,7 @@ func DefaultExpirationParser(data []byte) (time.Time, error) { return time.Now().Add(time.Duration(bearer.ExpiresInSeconds) * time.Second), nil } -//RemoteBearerTokenAcquirerOptions provides configuration for the RemoteBearerTokenAcquirer +// RemoteBearerTokenAcquirerOptions provides configuration for the RemoteBearerTokenAcquirer. type RemoteBearerTokenAcquirerOptions struct { AuthURL string `json:"authURL"` Timeout time.Duration `json:"timeout"` @@ -47,22 +48,23 @@ type RemoteBearerTokenAcquirerOptions struct { GetExpiration ParseExpiration } -//RemoteBearerTokenAcquirer implements Acquirer and fetches the tokens from a remote location with caching strategy +// RemoteBearerTokenAcquirer implements Acquirer and fetches the tokens from a remote location with caching strategy. type RemoteBearerTokenAcquirer struct { options RemoteBearerTokenAcquirerOptions authValue string authValueExpiration time.Time httpClient *http.Client nonExpiringSpecialCase time.Time + lock sync.RWMutex } -//SimpleBearer defines the field name mappings used by the default bearer token and expiration parsers +// SimpleBearer defines the field name mappings used by the default bearer token and expiration parsers. type SimpleBearer struct { ExpiresInSeconds float64 `json:"expires_in"` Token string `json:"serviceAccessToken"` } -// NewRemoteBearerTokenAcquirer returns a RemoteBearerTokenAcquirer configured with the given options +// NewRemoteBearerTokenAcquirer returns a RemoteBearerTokenAcquirer configured with the given options. func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*RemoteBearerTokenAcquirer, error) { if options.GetToken == nil { options.GetToken = DefaultTokenParser @@ -72,7 +74,7 @@ func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*Re options.GetExpiration = DefaultExpirationParser } - //TODO: we should inject timeout and buffer defaults values as well + // TODO: we should inject timeout and buffer defaults values as well. return &RemoteBearerTokenAcquirer{ options: options, @@ -84,10 +86,17 @@ func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*Re }, nil } +// Acquire provides the cached token or, if it's near its expiry time, contacts +// the server for a new token to cache. func (acquirer *RemoteBearerTokenAcquirer) Acquire() (string, error) { + acquirer.lock.RLock() if time.Now().Add(acquirer.options.Buffer).Before(acquirer.authValueExpiration) { + defer acquirer.lock.RUnlock() return acquirer.authValue, nil } + acquirer.lock.RUnlock() + acquirer.lock.Lock() + defer acquirer.lock.Unlock() req, err := http.NewRequest("GET", acquirer.options.AuthURL, nil) if err != nil { diff --git a/acquire/bearer_test.go b/acquire/bearer_test.go index efba09d..7025267 100644 --- a/acquire/bearer_test.go +++ b/acquire/bearer_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -143,6 +144,47 @@ func TestRemoteBearerTokenAcquirerCaching(t *testing.T) { assert.Equal(1, count) } +func TestRemoteBearerTokenAcquirerExiting(t *testing.T) { + assert := assert.New(t) + + count := 0 + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + auth := SimpleBearer{ + Token: fmt.Sprintf("gopher%v", count), + ExpiresInSeconds: 1, //1 second + } + count++ + + marshaledAuth, err := json.Marshal(&auth) + assert.Nil(err) + rw.Write(marshaledAuth) + })) + defer server.Close() + + // Use Client & URL from our local test server + auth, errConstructor := NewRemoteBearerTokenAcquirer(RemoteBearerTokenAcquirerOptions{ + AuthURL: server.URL, + Timeout: time.Duration(5) * time.Second, + Buffer: time.Second, + }) + assert.Nil(errConstructor) + token, err := auth.Acquire() + assert.Nil(err) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + _, err := auth.Acquire() + assert.Nil(err) + wg.Done() + }() + } + wg.Wait() + cachedToken, err := auth.Acquire() + assert.Nil(err) + assert.NotEqual(token, cachedToken) +} + type customBearer struct { Token string `json:"token"` ExpiresOnUnixSeconds int64 `json:"expires_on"`