Skip to content

Commit

Permalink
Add support for leeway
Browse files Browse the repository at this point in the history
  • Loading branch information
Sovietaced committed Jan 7, 2024
1 parent 1f8f5e1 commit 167fb94
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 45 deletions.
13 changes: 11 additions & 2 deletions metadata/okta/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"time"
)

const (
DefaultCacheTtl = 5 * time.Minute
)

// Options are configurable options for the MetadataProvider.
type Options struct {
httpClient *http.Client
Expand Down Expand Up @@ -42,7 +46,7 @@ func defaultOptions() *Options {
opts := &Options{}
WithHttpClient(http.DefaultClient)(opts)
withClock(clock.New())(opts)
WithCacheTtl(5 * time.Minute)(opts)
WithCacheTtl(DefaultCacheTtl)(opts)
return opts
}

Expand Down Expand Up @@ -78,7 +82,12 @@ func NewMetadataProvider(issuer string, options ...Option) *MetadataProvider {
}

metadataUrl := fmt.Sprintf("%s%s", issuer, "/.well-known/openid-configuration")
return &MetadataProvider{metadataUrl: metadataUrl, httpClient: opts.httpClient, clock: opts.clock, cacheTtl: opts.cacheTtl}
return &MetadataProvider{
metadataUrl: metadataUrl,
httpClient: opts.httpClient,
clock: opts.clock,
cacheTtl: opts.cacheTtl,
}
}

// GetMetadata gets metadata for the specified Okta issuer.
Expand Down
63 changes: 26 additions & 37 deletions verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ import (
"github.com/sovietaced/okta-jwt-verifier/keyfunc"
"github.com/sovietaced/okta-jwt-verifier/keyfunc/okta"
oktametadata "github.com/sovietaced/okta-jwt-verifier/metadata/okta"
"time"
)

const (
DefaultLeeway = 0 // Default leeway that is configured for JWT validation
)

// Options are configurable options for the Verifier.
type Options struct {
keyfuncProvider keyfunc.Provider
leeway time.Duration
}

// WithKeyfuncProvider allows for a configurable keyfunc.Provider, which may be useful if you want to customize
Expand All @@ -22,9 +28,17 @@ func WithKeyfuncProvider(keyfuncProvider keyfunc.Provider) Option {
}
}

// WithLeeway adds leeway to all time related validations.
func WithLeeway(leeway time.Duration) Option {
return func(mo *Options) {
mo.leeway = leeway
}
}

func defaultOptions(issuer string) *Options {
opts := &Options{}
WithKeyfuncProvider(okta.NewKeyfuncProvider(oktametadata.NewMetadataProvider(issuer)))(opts)
WithLeeway(DefaultLeeway)(opts)
return opts
}

Expand All @@ -44,6 +58,7 @@ func newJwtFromToken(token *jwt.Token) *Jwt {

// Verifier is the implementation of the Okta JWT verification logic.
type Verifier struct {
parser *jwt.Parser
keyfuncProvider keyfunc.Provider
issuer string
clientId string
Expand All @@ -56,7 +71,15 @@ func NewVerifier(issuer string, clientId string, options ...Option) *Verifier {
option(opts)
}

return &Verifier{issuer: issuer, clientId: clientId, keyfuncProvider: opts.keyfuncProvider}
// Configure JWT parser
parser := jwt.NewParser(
jwt.WithLeeway(opts.leeway),
jwt.WithIssuer(issuer),
jwt.WithAudience(clientId),
jwt.WithExpirationRequired(),
)

return &Verifier{issuer: issuer, clientId: clientId, keyfuncProvider: opts.keyfuncProvider, parser: parser}
}

// VerifyIdToken verifies an Okta ID token.
Expand Down Expand Up @@ -100,43 +123,18 @@ func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Tok
return nil, fmt.Errorf("getting key function: %w", err)
}

token, err := jwt.Parse(tokenString, keyfunc)
token, err := v.parser.Parse(tokenString, keyfunc)
if err != nil {
return nil, fmt.Errorf("parsing token: %w", err)
}

return token, err
}

// validateCommonClaims validates claims that aren't validated natively by jwt.Parser
func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) error {
claims := jwt.Claims

jwtIssuer, err := claims.GetIssuer()
if err != nil {
return fmt.Errorf("verifying token issuer: %w", err)
}

if jwtIssuer != v.issuer {
return fmt.Errorf("verifying token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer)
}

jwtAuds, err := claims.GetAudience()
if err != nil {
return fmt.Errorf("veriying token audience: %w", err)
}

matchFound := false
for _, jwtAud := range jwtAuds {
if jwtAud == v.clientId {
matchFound = true
break
}
}

if !matchFound {
return fmt.Errorf("verifying token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId)
}

jwtIat, err := claims.GetIssuedAt()
if err != nil {
return fmt.Errorf("verifying id token issued time: %w", err)
Expand All @@ -146,14 +144,5 @@ func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) err
return fmt.Errorf("verifying token issued time: no issued time found")
}

jwtExp, err := claims.GetExpirationTime()
if err != nil {
return fmt.Errorf("verifying token expriation time: %w", err)
}

if jwtExp == nil {
return fmt.Errorf("verifying token expiration time: no expiration time found")
}

return nil
}
142 changes: 136 additions & 6 deletions verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,23 @@ func TestVerifierVerifyIdToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'")
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: iss claim is required")
})

t.Run("verify id token wrong issuer", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "wrong",
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token has invalid issuer")
})

t.Run("verify id token missing audience", func(t *testing.T) {
Expand All @@ -77,7 +93,23 @@ func TestVerifierVerifyIdToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'")
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: aud claim is required")
})

t.Run("verify id token wrong audience", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": "wrong",
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token has invalid audience")
})

t.Run("verify id token missing issued time", func(t *testing.T) {
Expand Down Expand Up @@ -107,7 +139,7 @@ func TestVerifierVerifyIdToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token expiration time: no expiration time found")
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: exp claim is required")
})

t.Run("verify id token expired", func(t *testing.T) {
Expand All @@ -125,6 +157,39 @@ func TestVerifierVerifyIdToken(t *testing.T) {
_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired")
})

t.Run("verify id token expiration with leeway", func(t *testing.T) {

lv := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp), WithLeeway(time.Minute))

token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-30 * time.Second).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyIdToken(ctx, idToken)
require.NoError(t, err)

token = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-2 * time.Minute).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err = token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired")
})
}

func TestVerifierVerifyAccessToken(t *testing.T) {
Expand Down Expand Up @@ -201,7 +266,23 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'")
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: iss claim is required")
})

t.Run("verify access token wrong issuer", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "wrong",
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token has invalid issuer")
})

t.Run("verify access token missing audience", func(t *testing.T) {
Expand All @@ -215,7 +296,23 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'")
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: aud claim is required")
})

t.Run("verify access token wrong audience", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": "wrong",
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token has invalid audience")
})

t.Run("verify access token missing issued time", func(t *testing.T) {
Expand Down Expand Up @@ -243,7 +340,7 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token expiration time: no expiration time found")
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: exp claim is required")
})

t.Run("verify access token expired", func(t *testing.T) {
Expand All @@ -260,4 +357,37 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired")
})

t.Run("verify access token expiration with leeway", func(t *testing.T) {

lv := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp), WithLeeway(time.Minute))

token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-30 * time.Second).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyAccessToken(ctx, idToken)
require.NoError(t, err)

token = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-2 * time.Minute).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err = token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired")
})
}

0 comments on commit 167fb94

Please sign in to comment.