From 167fb94852b460d02188e72ff943f14c77c14aa4 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Sat, 6 Jan 2024 23:06:19 -0800 Subject: [PATCH] Add support for leeway --- metadata/okta/okta.go | 13 +++- verifier.go | 63 ++++++++----------- verifier_test.go | 142 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 173 insertions(+), 45 deletions(-) diff --git a/metadata/okta/okta.go b/metadata/okta/okta.go index 28d2a67..28106d7 100644 --- a/metadata/okta/okta.go +++ b/metadata/okta/okta.go @@ -11,6 +11,10 @@ import ( "time" ) +const ( + DefaultCacheTtl = 5 * time.Minute +) + // Options are configurable options for the MetadataProvider. type Options struct { httpClient *http.Client @@ -42,7 +46,7 @@ func defaultOptions() *Options { opts := &Options{} WithHttpClient(http.DefaultClient)(opts) withClock(clock.New())(opts) - WithCacheTtl(5 * time.Minute)(opts) + WithCacheTtl(DefaultCacheTtl)(opts) return opts } @@ -78,7 +82,12 @@ func NewMetadataProvider(issuer string, options ...Option) *MetadataProvider { } metadataUrl := fmt.Sprintf("%s%s", issuer, "/.well-known/openid-configuration") - return &MetadataProvider{metadataUrl: metadataUrl, httpClient: opts.httpClient, clock: opts.clock, cacheTtl: opts.cacheTtl} + return &MetadataProvider{ + metadataUrl: metadataUrl, + httpClient: opts.httpClient, + clock: opts.clock, + cacheTtl: opts.cacheTtl, + } } // GetMetadata gets metadata for the specified Okta issuer. diff --git a/verifier.go b/verifier.go index e428975..46528da 100644 --- a/verifier.go +++ b/verifier.go @@ -7,11 +7,17 @@ import ( "github.com/sovietaced/okta-jwt-verifier/keyfunc" "github.com/sovietaced/okta-jwt-verifier/keyfunc/okta" oktametadata "github.com/sovietaced/okta-jwt-verifier/metadata/okta" + "time" +) + +const ( + DefaultLeeway = 0 // Default leeway that is configured for JWT validation ) // Options are configurable options for the Verifier. type Options struct { keyfuncProvider keyfunc.Provider + leeway time.Duration } // WithKeyfuncProvider allows for a configurable keyfunc.Provider, which may be useful if you want to customize @@ -22,9 +28,17 @@ func WithKeyfuncProvider(keyfuncProvider keyfunc.Provider) Option { } } +// WithLeeway adds leeway to all time related validations. +func WithLeeway(leeway time.Duration) Option { + return func(mo *Options) { + mo.leeway = leeway + } +} + func defaultOptions(issuer string) *Options { opts := &Options{} WithKeyfuncProvider(okta.NewKeyfuncProvider(oktametadata.NewMetadataProvider(issuer)))(opts) + WithLeeway(DefaultLeeway)(opts) return opts } @@ -44,6 +58,7 @@ func newJwtFromToken(token *jwt.Token) *Jwt { // Verifier is the implementation of the Okta JWT verification logic. type Verifier struct { + parser *jwt.Parser keyfuncProvider keyfunc.Provider issuer string clientId string @@ -56,7 +71,15 @@ func NewVerifier(issuer string, clientId string, options ...Option) *Verifier { option(opts) } - return &Verifier{issuer: issuer, clientId: clientId, keyfuncProvider: opts.keyfuncProvider} + // Configure JWT parser + parser := jwt.NewParser( + jwt.WithLeeway(opts.leeway), + jwt.WithIssuer(issuer), + jwt.WithAudience(clientId), + jwt.WithExpirationRequired(), + ) + + return &Verifier{issuer: issuer, clientId: clientId, keyfuncProvider: opts.keyfuncProvider, parser: parser} } // VerifyIdToken verifies an Okta ID token. @@ -100,7 +123,7 @@ func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Tok return nil, fmt.Errorf("getting key function: %w", err) } - token, err := jwt.Parse(tokenString, keyfunc) + token, err := v.parser.Parse(tokenString, keyfunc) if err != nil { return nil, fmt.Errorf("parsing token: %w", err) } @@ -108,35 +131,10 @@ func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Tok return token, err } +// validateCommonClaims validates claims that aren't validated natively by jwt.Parser func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) error { claims := jwt.Claims - jwtIssuer, err := claims.GetIssuer() - if err != nil { - return fmt.Errorf("verifying token issuer: %w", err) - } - - if jwtIssuer != v.issuer { - return fmt.Errorf("verifying token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer) - } - - jwtAuds, err := claims.GetAudience() - if err != nil { - return fmt.Errorf("veriying token audience: %w", err) - } - - matchFound := false - for _, jwtAud := range jwtAuds { - if jwtAud == v.clientId { - matchFound = true - break - } - } - - if !matchFound { - return fmt.Errorf("verifying token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId) - } - jwtIat, err := claims.GetIssuedAt() if err != nil { return fmt.Errorf("verifying id token issued time: %w", err) @@ -146,14 +144,5 @@ func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) err return fmt.Errorf("verifying token issued time: no issued time found") } - jwtExp, err := claims.GetExpirationTime() - if err != nil { - return fmt.Errorf("verifying token expriation time: %w", err) - } - - if jwtExp == nil { - return fmt.Errorf("verifying token expiration time: no expiration time found") - } - return nil } diff --git a/verifier_test.go b/verifier_test.go index c57be79..12c78f9 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -62,7 +62,23 @@ func TestVerifierVerifyIdToken(t *testing.T) { require.NoError(t, err) _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'") + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: iss claim is required") + }) + + t.Run("verify id token wrong issuer", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": "wrong", + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token has invalid issuer") }) t.Run("verify id token missing audience", func(t *testing.T) { @@ -77,7 +93,23 @@ func TestVerifierVerifyIdToken(t *testing.T) { require.NoError(t, err) _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'") + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: aud claim is required") + }) + + t.Run("verify id token wrong audience", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": "wrong", + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token has invalid audience") }) t.Run("verify id token missing issued time", func(t *testing.T) { @@ -107,7 +139,7 @@ func TestVerifierVerifyIdToken(t *testing.T) { require.NoError(t, err) _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying token expiration time: no expiration time found") + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: exp claim is required") }) t.Run("verify id token expired", func(t *testing.T) { @@ -125,6 +157,39 @@ func TestVerifierVerifyIdToken(t *testing.T) { _, err = v.VerifyIdToken(ctx, idToken) require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired") }) + + t.Run("verify id token expiration with leeway", func(t *testing.T) { + + lv := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp), WithLeeway(time.Minute)) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(-30 * time.Second).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = lv.VerifyIdToken(ctx, idToken) + require.NoError(t, err) + + token = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(-2 * time.Minute).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err = token.SignedString(pk) + require.NoError(t, err) + + _, err = lv.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired") + }) } func TestVerifierVerifyAccessToken(t *testing.T) { @@ -201,7 +266,23 @@ func TestVerifierVerifyAccessToken(t *testing.T) { require.NoError(t, err) _, err = v.VerifyAccessToken(ctx, idToken) - require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'") + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: iss claim is required") + }) + + t.Run("verify access token wrong issuer", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": "wrong", + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token has invalid issuer") }) t.Run("verify access token missing audience", func(t *testing.T) { @@ -215,7 +296,23 @@ func TestVerifierVerifyAccessToken(t *testing.T) { require.NoError(t, err) _, err = v.VerifyAccessToken(ctx, idToken) - require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'") + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: aud claim is required") + }) + + t.Run("verify access token wrong audience", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": "wrong", + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token has invalid audience") }) t.Run("verify access token missing issued time", func(t *testing.T) { @@ -243,7 +340,7 @@ func TestVerifierVerifyAccessToken(t *testing.T) { require.NoError(t, err) _, err = v.VerifyAccessToken(ctx, idToken) - require.ErrorContains(t, err, "verifying token expiration time: no expiration time found") + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: exp claim is required") }) t.Run("verify access token expired", func(t *testing.T) { @@ -260,4 +357,37 @@ func TestVerifierVerifyAccessToken(t *testing.T) { _, err = v.VerifyAccessToken(ctx, idToken) require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired") }) + + t.Run("verify access token expiration with leeway", func(t *testing.T) { + + lv := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp), WithLeeway(time.Minute)) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(-30 * time.Second).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = lv.VerifyAccessToken(ctx, idToken) + require.NoError(t, err) + + token = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(-2 * time.Minute).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err = token.SignedString(pk) + require.NoError(t, err) + + _, err = lv.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired") + }) }