diff --git a/acquire/auth.go b/acquire/auth.go index 6c74902..ada2c5b 100644 --- a/acquire/auth.go +++ b/acquire/auth.go @@ -1,4 +1,4 @@ -// package acquire is used for getting Auths to pass in http requests. +// Package acquire is used for getting Auths to pass in http requests. package acquire import ( @@ -8,31 +8,63 @@ import ( "github.com/pkg/errors" ) +//ErrEmptyCredentials is returned whenever an Acquirer is attempted to +//be built with empty credentials +//Use DefaultAcquirer for such no-op use case +var ErrEmptyCredentials = errors.New("Empty credentials are not valid") + // Acquirer gets an Authorization value that can be added to an http request. // The format of the string returned should be the key, a space, and then the -// auth string. +// auth string: '[AuthType] [AuthValue]' type Acquirer interface { Acquire() (string, error) } -// DefaultAcquirer returns nothing. This would not be a valid Authorization. +// DefaultAcquirer is a no-op Acquirer. type DefaultAcquirer struct{} +//Acquire returns the zero values of the return types func (d *DefaultAcquirer) Acquire() (string, error) { return "", nil } -// AddAuth adds an auth value to the Authorization header of an http request. +//AddAuth adds an auth value to the Authorization header of an http request. func AddAuth(r *http.Request, acquirer Acquirer) error { if r == nil { return errors.New("can't add authorization to nil request") } + + if acquirer == nil { + return errors.New("acquirer is undefined") + } + auth, err := acquirer.Acquire() + if err != nil { return emperror.Wrap(err, "failed to acquire auth for request") } + if auth != "" { r.Header.Set("Authorization", auth) } + return nil } + +//FixedValueAcquirer implements Acquirer with a constant authorization value +type FixedValueAcquirer struct { + authValue string +} + +func (f *FixedValueAcquirer) Acquire() (string, error) { + return f.authValue, nil +} + +// NewFixedAuthAcquirer returns a FixedValueAcquirer with the given authValue +func NewFixedAuthAcquirer(authValue string) (*FixedValueAcquirer, error) { + if authValue != "" { + return &FixedValueAcquirer{ + authValue: authValue}, nil + } + return nil, ErrEmptyCredentials +} diff --git a/acquire/auth_test.go b/acquire/auth_test.go new file mode 100644 index 0000000..fe020c3 --- /dev/null +++ b/acquire/auth_test.go @@ -0,0 +1,92 @@ +package acquire + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestAddAuth(t *testing.T) { + fixedAcquirer, _ := NewFixedAuthAcquirer("Basic abc==") + tests := []struct { + name string + request *http.Request + acquirer Acquirer + shouldError bool + authValue string + }{ + { + name: "RequestIsNil", + acquirer: &DefaultAcquirer{}, + shouldError: true, + }, + { + name: "AcquirerIsNil", + request: httptest.NewRequest(http.MethodGet, "/", nil), + shouldError: true, + }, + { + name: "AcquirerFails", + acquirer: &failingAcquirer{}, + shouldError: true, + }, + { + name: "HappyPath", + request: httptest.NewRequest(http.MethodGet, "/", nil), + acquirer: fixedAcquirer, + shouldError: false, + authValue: "Basic abc==", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert := assert.New(t) + + if test.shouldError { + assert.NotNil(AddAuth(test.request, test.acquirer)) + } else { + assert.Nil(AddAuth(test.request, test.acquirer)) + assert.Equal(test.authValue, test.request.Header.Get("Authorization")) + } + }) + } +} + +func TestFixedAuthAcquirer(t *testing.T) { + t.Run("HappyPath", func(t *testing.T) { + assert := assert.New(t) + + acquirer, err := NewFixedAuthAcquirer("Basic xyz==") + assert.NotNil(acquirer) + assert.Nil(err) + + authValue, _ := acquirer.Acquire() + assert.Equal("Basic xyz==", authValue) + }) + + t.Run("EmptyCredentials", func(t *testing.T) { + assert := assert.New(t) + + acquirer, err := NewFixedAuthAcquirer("") + assert.Equal(ErrEmptyCredentials, err) + assert.Nil(acquirer) + }) +} + +func TestDefaultAcquirer(t *testing.T) { + assert := assert.New(t) + acquirer := &DefaultAcquirer{} + authValue, err := acquirer.Acquire() + assert.Empty(authValue) + assert.Empty(err) +} + +type failingAcquirer struct{} + +func (f *failingAcquirer) Acquire() (string, error) { + return "", errors.New("always fails") +} diff --git a/acquire/basic.go b/acquire/basic.go deleted file mode 100644 index 72790ef..0000000 --- a/acquire/basic.go +++ /dev/null @@ -1,31 +0,0 @@ -package acquire - -import ( - "encoding/base64" - "errors" -) - -var ( - errMissingCredentials = errors.New("no credentials found") -) - -// BasicAcquirer saves a basic auth upon creation and returns it whenever -// Acquire is called. -type BasicAcquirer struct { - encodedCredentials string -} - -func (b *BasicAcquirer) Acquire() (string, error) { - if b.encodedCredentials == "" { - return "", errMissingCredentials - } - return "Basic " + b.encodedCredentials, nil -} - -func NewBasicAcquirer(credentials string) *BasicAcquirer { - return &BasicAcquirer{credentials} -} - -func NewBasicAcquirerPlainText(username, password string) *BasicAcquirer { - return &BasicAcquirer{base64.StdEncoding.EncodeToString([]byte(username + ":" + password))} -} diff --git a/acquire/basic_test.go b/acquire/basic_test.go deleted file mode 100644 index dd8d84c..0000000 --- a/acquire/basic_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package acquire - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestBasicAcquirerSuccess(t *testing.T) { - assert := assert.New(t) - credentials := "test credentials" - expctedCredentials := "Basic test credentials" - acquirer := NewBasicAcquirer(credentials) - returnedCredentials, err := acquirer.Acquire() - assert.Nil(err) - assert.Equal(expctedCredentials, returnedCredentials) -} - -func TestBasicAcquirer(t *testing.T) { - assert := assert.New(t) - credentials := "Z29waGVyOmhlbGxv" - plainAcquirer := NewBasicAcquirerPlainText("gopher", "hello") - acquirer := NewBasicAcquirer(credentials) - returnedCredentials, err := acquirer.Acquire() - assert.Nil(err) - returnedCredentialsPlain, err := plainAcquirer.Acquire() - assert.Equal(returnedCredentialsPlain, returnedCredentials) -} - -func TestBasicAcquirerFailure(t *testing.T) { - assert := assert.New(t) - credentials := "" - acquirer := NewBasicAcquirer(credentials) - returnedCredentials, err := acquirer.Acquire() - assert.Equal(errMissingCredentials, err) - assert.Equal(credentials, returnedCredentials) -} diff --git a/acquire/bearer.go b/acquire/bearer.go new file mode 100644 index 0000000..bd6f5a4 --- /dev/null +++ b/acquire/bearer.go @@ -0,0 +1,127 @@ +package acquire + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "time" + + "github.com/goph/emperror" +) + +//TokenParser defines the function signature of a bearer token extractor from a payload +type TokenParser func([]byte) (string, error) + +//ParseExpiration defines the function signature of a bearer token expiration date extractor +type ParseExpiration func([]byte) (time.Time, error) + +//DefaultTokenParser extracts a bearer token as defined by a SimpleBearer in a payload +func DefaultTokenParser(data []byte) (string, error) { + var bearer SimpleBearer + + if errUnmarshal := json.Unmarshal(data, &bearer); errUnmarshal != nil { + return "", emperror.Wrap(errUnmarshal, "unable to parse bearer token") + } + return bearer.Token, nil +} + +//DefaultExpirationParser extracts a bearer token expiration date as defined by a SimpleBearer in a payload +func DefaultExpirationParser(data []byte) (time.Time, error) { + var bearer SimpleBearer + + if errUnmarshal := json.Unmarshal(data, &bearer); errUnmarshal != nil { + return time.Time{}, emperror.Wrap(errUnmarshal, "unable to parse bearer token expiration") + } + return time.Now().Add(time.Duration(bearer.ExpiresInSeconds) * time.Second), nil +} + +//RemoteBearerTokenAcquirerOptions provides configuration for the RemoteBearerTokenAcquirer +type RemoteBearerTokenAcquirerOptions struct { + AuthURL string `json:"authURL"` + Timeout time.Duration `json:"timeout"` + Buffer time.Duration `json:"buffer"` + RequestHeaders map[string]string `json:"requestHeaders"` + + GetToken TokenParser + GetExpiration ParseExpiration +} + +//RemoteBearerTokenAcquirer implements Acquirer and fetches the tokens from a remote location with caching strategy +type RemoteBearerTokenAcquirer struct { + options RemoteBearerTokenAcquirerOptions + authValue string + authValueExpiration time.Time + httpClient *http.Client + nonExpiringSpecialCase time.Time +} + +//SimpleBearer defines the field name mappings used by the default bearer token and expiration parsers +type SimpleBearer struct { + ExpiresInSeconds float64 `json:"expires_in"` + Token string `json:"serviceAccessToken"` +} + +// NewRemoteBearerTokenAcquirer returns a RemoteBearerTokenAcquirer configured with the given options +func NewRemoteBearerTokenAcquirer(options RemoteBearerTokenAcquirerOptions) (*RemoteBearerTokenAcquirer, error) { + if options.GetToken == nil { + options.GetToken = DefaultTokenParser + } + + if options.GetExpiration == nil { + options.GetExpiration = DefaultExpirationParser + } + + //TODO: we should inject timeout and buffer defaults values as well + + return &RemoteBearerTokenAcquirer{ + options: options, + authValueExpiration: time.Now(), + httpClient: &http.Client{ + Timeout: options.Timeout, + }, + nonExpiringSpecialCase: time.Unix(0, 0), + }, nil +} + +func (acquirer *RemoteBearerTokenAcquirer) Acquire() (string, error) { + if time.Now().Add(acquirer.options.Buffer).Before(acquirer.authValueExpiration) { + return acquirer.authValue, nil + } + + req, err := http.NewRequest("GET", acquirer.options.AuthURL, nil) + if err != nil { + return "", emperror.Wrap(err, "failed to create new request for Bearer") + } + + for key, value := range acquirer.options.RequestHeaders { + req.Header.Set(key, value) + } + + resp, errHTTP := acquirer.httpClient.Do(req) + if errHTTP != nil { + return "", emperror.Wrapf(errHTTP, "error making request to '%v' to acquire bearer token", acquirer.options.AuthURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("received non 200 code acquiring Bearer: code %v", resp.Status) + } + + respBody, errRead := ioutil.ReadAll(resp.Body) + if errRead != nil { + return "", emperror.Wrap(errRead, "error reading HTTP response body") + } + + token, err := acquirer.options.GetToken(respBody) + if err != nil { + return "", emperror.Wrap(err, "error parsing bearer token from http response body") + } + expiration, err := acquirer.options.GetExpiration(respBody) + if err != nil { + return "", emperror.Wrap(err, "error parsing bearer token expiration from http response body") + } + + acquirer.authValue, acquirer.authValueExpiration = "Bearer "+token, expiration + return acquirer.authValue, nil +} diff --git a/acquire/jwt_test.go b/acquire/bearer_test.go similarity index 66% rename from acquire/jwt_test.go rename to acquire/bearer_test.go index 0407cc1..efba09d 100644 --- a/acquire/jwt_test.go +++ b/acquire/bearer_test.go @@ -3,6 +3,7 @@ package acquire import ( "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -11,11 +12,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAuthAcquireSuccess(t *testing.T) { - goodAuth := JWTBasic{ - Token: "test token", +func TestRemoteBearerTokenAcquirer(t *testing.T) { + goodAuth := SimpleBearer{ + Token: "test-token", } - goodToken := "Bearer test token" + goodToken := "Bearer test-token" tests := []struct { description string @@ -32,18 +33,18 @@ func TestAuthAcquireSuccess(t *testing.T) { expectedErr: nil, }, { - description: "HTTP Make Request Error", + description: "HTTP Do Error", authToken: goodAuth, expectedToken: "", - authURL: "/\b", - expectedErr: errors.New("failed to create new request for JWT"), + authURL: "/", + expectedErr: errors.New("error making request to '/' to acquire bearer"), }, { - description: "HTTP Do Error", + description: "HTTP Make Request Error", authToken: goodAuth, expectedToken: "", - authURL: "/", - expectedErr: errors.New("error acquiring JWT token"), + authURL: "/\b", + expectedErr: errors.New("failed to create new request"), }, { description: "HTTP Unauthorized Error", @@ -56,7 +57,7 @@ func TestAuthAcquireSuccess(t *testing.T) { description: "Unmarshal Error", authToken: []byte("{token:5555}"), expectedToken: "", - expectedErr: errors.New("unable to read json"), + expectedErr: errors.New("unable to parse bearer token"), }, } @@ -66,6 +67,11 @@ func TestAuthAcquireSuccess(t *testing.T) { // Start a local HTTP server server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + // Test optional headers + assert.Equal("v0", req.Header.Get("k0")) + assert.Equal("v1", req.Header.Get("k1")) + // Send response to be tested if tc.returnUnauthorized { rw.WriteHeader(http.StatusUnauthorized) @@ -84,10 +90,14 @@ func TestAuthAcquireSuccess(t *testing.T) { } // Use Client & URL from our local test server - auth := NewJWTAcquirer(JWTAcquirerOptions{ - AuthURL: url, - Timeout: time.Duration(5) * time.Second, + auth, errConstructor := NewRemoteBearerTokenAcquirer(RemoteBearerTokenAcquirerOptions{ + AuthURL: url, + Timeout: 5 * time.Second, + RequestHeaders: map[string]string{"k0": "v0", "k1": "v1"}, }) + + assert.Nil(errConstructor) + token, err := auth.Acquire() if tc.expectedErr == nil || err == nil { @@ -100,14 +110,14 @@ func TestAuthAcquireSuccess(t *testing.T) { } } -func TestAuthCaching(t *testing.T) { +func TestRemoteBearerTokenAcquirerCaching(t *testing.T) { assert := assert.New(t) count := 0 server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - auth := JWTBasic{ - Token: "gopher+" + string(count), - Expiration: 1, + auth := SimpleBearer{ + Token: fmt.Sprintf("gopher%v", count), + ExpiresInSeconds: 3600, //1 hour } count++ @@ -118,18 +128,22 @@ func TestAuthCaching(t *testing.T) { defer server.Close() // Use Client & URL from our local test server - auth := NewJWTAcquirer(JWTAcquirerOptions{ + auth, errConstructor := NewRemoteBearerTokenAcquirer(RemoteBearerTokenAcquirerOptions{ AuthURL: server.URL, Timeout: time.Duration(5) * time.Second, Buffer: time.Microsecond, }) + assert.Nil(errConstructor) token, err := auth.Acquire() assert.Nil(err) - tokenA, err := auth.Acquire() - assert.Nil(err) - assert.Equal(token, tokenA) - time.Sleep(time.Second) - tokenA, err = auth.Acquire() + + cachedToken, err := auth.Acquire() assert.Nil(err) - assert.NotEqual(token, tokenA) + assert.Equal(token, cachedToken) + assert.Equal(1, count) +} + +type customBearer struct { + Token string `json:"token"` + ExpiresOnUnixSeconds int64 `json:"expires_on"` } diff --git a/acquire/jwt.go b/acquire/jwt.go deleted file mode 100644 index dd32316..0000000 --- a/acquire/jwt.go +++ /dev/null @@ -1,120 +0,0 @@ -package acquire - -import ( - "bytes" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "time" - - "github.com/goph/emperror" -) - -// TODO: add a different basic JWT Acquirer that simply stores the JWT and -// rename this stuff. - -type ParseToken func([]byte) (string, error) - -func DefaultTokenParser(data []byte) (string, error) { - var jwt JWTBasic - - if errUnmarshal := json.Unmarshal(data, &jwt); errUnmarshal != nil { - return "", emperror.Wrap(errUnmarshal, "unable to read json") - } - return jwt.Token, nil -} - -type ParseExpiration func([]byte) (time.Time, error) - -func DefaultExpirationParser(data []byte) (time.Time, error) { - var jwt JWTBasic - - if errUnmarshal := json.Unmarshal(data, &jwt); errUnmarshal != nil { - return time.Time{}, emperror.Wrap(errUnmarshal, "unable to read json") - } - return time.Now().Add(time.Duration(jwt.Expiration) * time.Second), nil -} - -type JWTAcquirerOptions struct { - AuthURL string `json:"authURL"` - Timeout time.Duration `json:"timeout"` - Buffer time.Duration `json:"buffer"` - RequestHeaders map[string]string `json:"requestHeaders"` - - GetToken ParseToken - GetExpiration ParseExpiration -} - -type JWTAcquirer struct { - options JWTAcquirerOptions - - cachedAuth string - expires time.Time -} - -type JWTBasic struct { - Expiration float64 `json:"expires_in"` - Token string `json:"serviceAccessToken"` -} - -func NewJWTAcquirer(options JWTAcquirerOptions) JWTAcquirer { - if options.GetToken == nil { - options.GetToken = DefaultTokenParser - } - if options.GetExpiration == nil { - options.GetExpiration = DefaultExpirationParser - } - - return JWTAcquirer{ - options: options, - expires: time.Now(), - } -} - -func (acquire *JWTAcquirer) Acquire() (string, error) { - if time.Now().Add(acquire.options.Buffer).Before(acquire.expires) || acquire.expires == time.Unix(0, 0) { - return acquire.cachedAuth, nil - } - - jsonStr := []byte(`{}`) - httpclient := &http.Client{ - Timeout: acquire.options.Timeout, - } - req, err := http.NewRequest("GET", acquire.options.AuthURL, bytes.NewBuffer(jsonStr)) - if err != nil { - return "", emperror.Wrap(err, "failed to create new request for JWT") - } - - for key, value := range acquire.options.RequestHeaders { - req.Header.Set(key, value) - } - - resp, errHTTP := httpclient.Do(req) - if errHTTP != nil { - return "", fmt.Errorf("error acquiring JWT token: [%s]", errHTTP.Error()) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("received non 200 code acquiring JWT: code %v", resp.Status) - } - - respBody, errRead := ioutil.ReadAll(resp.Body) - if errRead != nil { - return "", fmt.Errorf("error reading JWT token: [%s]", errRead.Error()) - } - - auth, err := acquire.options.GetToken(respBody) - if err != nil { - return "", fmt.Errorf("error parsing JWT token: [%s]", err.Error()) - } - expires, err := acquire.options.GetExpiration(respBody) - if err != nil { - return "", fmt.Errorf("error parsing JWT token: [%s]", err.Error()) - } - - acquire.cachedAuth = fmt.Sprintf("Bearer %s", auth) - acquire.expires = expires - return acquire.cachedAuth, nil -} diff --git a/examples/acquirer/acquirer.go b/examples/acquirer/acquirer.go index 36f6358..9c79859 100644 --- a/examples/acquirer/acquirer.go +++ b/examples/acquirer/acquirer.go @@ -1,6 +1,7 @@ package main import ( + "encoding/base64" "fmt" "io/ioutil" "net/http" @@ -12,7 +13,12 @@ import ( func main() { // set up acquirer and add the auth to the request - acquirer := acquire.NewBasicAcquirerPlainText("testuser", "testpass") + acquirer, err := acquire.NewFixedAuthAcquirer("Basic " + base64.StdEncoding.EncodeToString([]byte("testuser:testpass"))) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to create basic auth plain text acquirer: %v\n", err.Error()) + os.Exit(1) + } + request, err := http.NewRequest(http.MethodGet, "http://localhost:6000/test", nil) if err != nil { fmt.Fprintf(os.Stderr, "failed to create request: %v\n", err.Error())