Skip to content

Commit

Permalink
Merge pull request #155 from tzvatot/retry-access-token
Browse files Browse the repository at this point in the history
Support retry when fetching access token
  • Loading branch information
tzvatot authored Aug 10, 2020
2 parents 5716b8f + 0c5385a commit 323cd42
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 22 deletions.
12 changes: 6 additions & 6 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,27 @@ var _ = Describe("Connection", func() {
Expect(err).ToNot(HaveOccurred())
defer connection.Close()
_, _, err = connection.Tokens()
Expect(transport.called).To(BeTrue())
Expect(transport.called > 2).To(BeTrue()) // it means the retry was called
Expect(err).To(HaveOccurred())
})
})

type TestTransport struct {
called bool
called int
}

func (t *TestTransport) RoundTrip(request *http.Request) (response *http.Response, err error) {
t.called = true
t.called++
header := http.Header{}
header.Add("Content-type", "application/json")
response = &http.Response{
StatusCode: 401,
StatusCode: http.StatusInternalServerError,
Header: header,
Body: gbytes.NewBuffer(),
Body: gbytes.BufferWithBytes([]byte("{}")),
}
return response, nil
}

func NewTestTransport() *TestTransport {
return &TestTransport{called: false}
return &TestTransport{called: 0}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/openshift-online/ocm-sdk-go
go 1.12

require (
github.com/cenkalti/backoff/v4 v4.0.0
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/ghodss/yaml v1.0.0
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/cenkalti/backoff/v4 v4.0.0 h1:6VeaLF9aI+MAUQ95106HwWzYZgJJpZ4stumjj6RFYAU=
github.com/cenkalti/backoff/v4 v4.0.0/go.mod h1:eEew/i+1Q6OrCDZh3WiXYv3+nJwBASZ8Bog/87DQnVg=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand Down
35 changes: 30 additions & 5 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
"strings"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/dgrijalva/jwt-go"

"github.com/openshift-online/ocm-sdk-go/internal"
)

Expand All @@ -49,7 +49,30 @@ func (c *Connection) Tokens() (access, refresh string, err error) {
// TokensContext returns the access and refresh tokens that is currently in use by the connection.
// If it is necessary to request a new token because it wasn't requested yet, or because it is
// expired, this method will do it and will return an error if it fails.
// The function will retry the operation in an exponential-backoff method.
func (c *Connection) TokensContext(ctx context.Context) (access, refresh string, err error) {
operation := func() error {
c.logger.Debug(ctx, "trying to get tokens")
var code int
code, access, refresh, err = c.tokensContext(ctx)
if err != nil {
if code >= http.StatusInternalServerError {
c.logger.Error(ctx, "failed to get tokens, got http code %d, will attempt to retry. err: %v", code, err)
return err
}
c.logger.Debug(ctx, "failed to get tokens, got http code %d, will not attempt to retry. err: %v", code, err)
return nil
}
return nil
}
backoffMethod := backoff.NewExponentialBackOff()
backoffMethod.MaxElapsedTime = time.Second * 15
// nolint
backoff.Retry(operation, backoffMethod)
return access, refresh, err
}

func (c *Connection) tokensContext(ctx context.Context) (code int, access, refresh string, err error) {
// We need to make sure that this method isn't execute concurrently, as we will be updating
// multiple attributes of the connection:
c.tokenMutex.Lock()
Expand Down Expand Up @@ -88,7 +111,7 @@ func (c *Connection) TokensContext(ctx context.Context) (access, refresh string,
// At this point we know that the access token is unavailable, expired or about to expire.
// So we need to check if we can use the refresh token to request a new one.
if c.refreshToken != nil && (!refreshExpires || refreshLeft >= 1*time.Minute) {
_, _, err = c.sendRefreshTokenForm(ctx)
code, _, err = c.sendRefreshTokenForm(ctx)
if err != nil {
return
}
Expand All @@ -100,7 +123,7 @@ func (c *Connection) TokensContext(ctx context.Context) (access, refresh string,
// expire. So we need to check if we have other credentials that can be used to request a
// new token, and use them.
if c.haveCredentials() {
_, _, err = c.sendRequestTokenForm(ctx)
code, _, err = c.sendRequestTokenForm(ctx)
if err != nil {
return
}
Expand All @@ -118,7 +141,7 @@ func (c *Connection) TokensContext(ctx context.Context) (access, refresh string,
"obtain a new token, so will try to use it anyhow",
refreshLeft,
)
_, _, err = c.sendRefreshTokenForm(ctx)
code, _, err = c.sendRefreshTokenForm(ctx)
if err != nil {
return
}
Expand Down Expand Up @@ -282,6 +305,8 @@ func (c *Connection) sendTokenFormTimed(ctx context.Context, form url.Values) (c
}
defer response.Body.Close()

code = response.StatusCode

// Check that the response content type is JSON:
err = c.checkContentType(response)
if err != nil {
Expand Down Expand Up @@ -311,7 +336,7 @@ func (c *Connection) sendTokenFormTimed(ctx context.Context, form url.Values) (c
return
}
if response.StatusCode != http.StatusOK {
err = fmt.Errorf("token response status is: %s", response.Status)
err = fmt.Errorf("token response status code is '%d'", response.StatusCode)
return
}
if result.TokenType != nil && *result.TokenType != "bearer" {
Expand Down
117 changes: 106 additions & 11 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,19 @@ var _ = Describe("Tokens", func() {
refreshToken := DefaultToken("Refresh", 10*time.Hour)

// Configure the server:
oidServer.AppendHandlers(
ghttp.RespondWith(
http.StatusServiceUnavailable,
`Service unavailable`,
http.Header{
"Content-Type": []string{
"text/plain",
for i := 0; i < 100; i++ { // there are going to be several retries
oidServer.AppendHandlers(
ghttp.RespondWith(
http.StatusServiceUnavailable,
`Service unavailable`,
http.Header{
"Content-Type": []string{
"text/plain",
},
},
},
),
)
),
)
}

// Create the connection:
connection, err := NewConnectionBuilder().
Expand Down Expand Up @@ -287,7 +289,7 @@ var _ = Describe("Tokens", func() {
// Configure the server:
oidServer.AppendHandlers(
ghttp.RespondWith(
http.StatusServiceUnavailable,
http.StatusBadRequest,
content,
http.Header{
"Content-Type": []string{
Expand Down Expand Up @@ -913,6 +915,99 @@ var _ = Describe("Tokens", func() {
Expect(err).ToNot(HaveOccurred())
})
})

Describe("Test retry for getting access token", func() {
It("Return access token after a few retries", func() {
// Generate tokens:
refreshToken := DefaultToken("Refresh", 10*time.Hour)
accessToken := DefaultToken("Bearer", 5*time.Minute)

oidServer.AppendHandlers(
ghttp.RespondWith(
http.StatusInternalServerError,
`Internal Server Error`,
http.Header{
"Content-Type": []string{
"text/plain",
},
},
),
ghttp.RespondWith(
http.StatusBadGateway,
`Bad Gateway`,
http.Header{
"Content-Type": []string{
"text/plain",
},
},
),
ghttp.CombineHandlers(
VerifyRefreshGrant(refreshToken),
RespondWithTokens(accessToken, refreshToken),
),
)

// Create the connection:
connection, err := NewConnectionBuilder().
Logger(logger).
TokenURL(oidServer.URL()).
URL(apiServer.URL()).
Tokens(refreshToken).
Build()
Expect(err).ToNot(HaveOccurred())
defer connection.Close()

// Get the tokens:
returnedAccess, returnedRefresh, err := connection.Tokens()
Expect(err).ToNot(HaveOccurred())
Expect(returnedAccess).ToNot(BeEmpty())
Expect(returnedRefresh).ToNot(BeEmpty())
})
It("Test no retry when status is not http 5xx", func() {
// Generate tokens:
refreshToken := DefaultToken("Refresh", 10*time.Hour)
accessToken := DefaultToken("Bearer", 5*time.Minute)

oidServer.AppendHandlers(
ghttp.RespondWith(
http.StatusInternalServerError,
`Internal Server Error`,
http.Header{
"Content-Type": []string{
"text/plain",
},
},
),
ghttp.RespondWith(
http.StatusForbidden,
`{}`,
http.Header{
"Content-Type": []string{
"application/json",
},
},
),
ghttp.CombineHandlers(
VerifyRefreshGrant(refreshToken),
RespondWithTokens(accessToken, refreshToken),
),
)

// Create the connection:
connection, err := NewConnectionBuilder().
Logger(logger).
TokenURL(oidServer.URL()).
URL(apiServer.URL()).
Tokens(refreshToken).
Build()
Expect(err).ToNot(HaveOccurred())
defer connection.Close()

// Get the tokens:
_, _, err = connection.Tokens()
Expect(err).To(HaveOccurred())
})
})
})

func VerifyPasswordGrant(user, password string) http.HandlerFunc {
Expand Down

0 comments on commit 323cd42

Please sign in to comment.