Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for leeway #7

Merged
merged 1 commit into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions metadata/okta/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"time"
)

const (
DefaultCacheTtl = 5 * time.Minute
)

// Options are configurable options for the MetadataProvider.
type Options struct {
httpClient *http.Client
Expand Down Expand Up @@ -42,7 +46,7 @@ func defaultOptions() *Options {
opts := &Options{}
WithHttpClient(http.DefaultClient)(opts)
withClock(clock.New())(opts)
WithCacheTtl(5 * time.Minute)(opts)
WithCacheTtl(DefaultCacheTtl)(opts)
return opts
}

Expand Down Expand Up @@ -78,7 +82,12 @@ func NewMetadataProvider(issuer string, options ...Option) *MetadataProvider {
}

metadataUrl := fmt.Sprintf("%s%s", issuer, "/.well-known/openid-configuration")
return &MetadataProvider{metadataUrl: metadataUrl, httpClient: opts.httpClient, clock: opts.clock, cacheTtl: opts.cacheTtl}
return &MetadataProvider{
metadataUrl: metadataUrl,
httpClient: opts.httpClient,
clock: opts.clock,
cacheTtl: opts.cacheTtl,
}
}

// GetMetadata gets metadata for the specified Okta issuer.
Expand Down
63 changes: 26 additions & 37 deletions verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ import (
"github.com/sovietaced/okta-jwt-verifier/keyfunc"
"github.com/sovietaced/okta-jwt-verifier/keyfunc/okta"
oktametadata "github.com/sovietaced/okta-jwt-verifier/metadata/okta"
"time"
)

const (
DefaultLeeway = 0 // Default leeway that is configured for JWT validation
)

// Options are configurable options for the Verifier.
type Options struct {
keyfuncProvider keyfunc.Provider
leeway time.Duration
}

// WithKeyfuncProvider allows for a configurable keyfunc.Provider, which may be useful if you want to customize
Expand All @@ -22,9 +28,17 @@ func WithKeyfuncProvider(keyfuncProvider keyfunc.Provider) Option {
}
}

// WithLeeway adds leeway to all time related validations.
func WithLeeway(leeway time.Duration) Option {
return func(mo *Options) {
mo.leeway = leeway
}
}

func defaultOptions(issuer string) *Options {
opts := &Options{}
WithKeyfuncProvider(okta.NewKeyfuncProvider(oktametadata.NewMetadataProvider(issuer)))(opts)
WithLeeway(DefaultLeeway)(opts)
return opts
}

Expand All @@ -44,6 +58,7 @@ func newJwtFromToken(token *jwt.Token) *Jwt {

// Verifier is the implementation of the Okta JWT verification logic.
type Verifier struct {
parser *jwt.Parser
keyfuncProvider keyfunc.Provider
issuer string
clientId string
Expand All @@ -56,7 +71,15 @@ func NewVerifier(issuer string, clientId string, options ...Option) *Verifier {
option(opts)
}

return &Verifier{issuer: issuer, clientId: clientId, keyfuncProvider: opts.keyfuncProvider}
// Configure JWT parser
parser := jwt.NewParser(
jwt.WithLeeway(opts.leeway),
jwt.WithIssuer(issuer),
jwt.WithAudience(clientId),
jwt.WithExpirationRequired(),
)

return &Verifier{issuer: issuer, clientId: clientId, keyfuncProvider: opts.keyfuncProvider, parser: parser}
}

// VerifyIdToken verifies an Okta ID token.
Expand Down Expand Up @@ -100,43 +123,18 @@ func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Tok
return nil, fmt.Errorf("getting key function: %w", err)
}

token, err := jwt.Parse(tokenString, keyfunc)
token, err := v.parser.Parse(tokenString, keyfunc)
if err != nil {
return nil, fmt.Errorf("parsing token: %w", err)
}

return token, err
}

// validateCommonClaims validates claims that aren't validated natively by jwt.Parser
func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) error {
claims := jwt.Claims

jwtIssuer, err := claims.GetIssuer()
if err != nil {
return fmt.Errorf("verifying token issuer: %w", err)
}

if jwtIssuer != v.issuer {
return fmt.Errorf("verifying token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer)
}

jwtAuds, err := claims.GetAudience()
if err != nil {
return fmt.Errorf("veriying token audience: %w", err)
}

matchFound := false
for _, jwtAud := range jwtAuds {
if jwtAud == v.clientId {
matchFound = true
break
}
}

if !matchFound {
return fmt.Errorf("verifying token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId)
}

jwtIat, err := claims.GetIssuedAt()
if err != nil {
return fmt.Errorf("verifying id token issued time: %w", err)
Expand All @@ -146,14 +144,5 @@ func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) err
return fmt.Errorf("verifying token issued time: no issued time found")
}

jwtExp, err := claims.GetExpirationTime()
if err != nil {
return fmt.Errorf("verifying token expriation time: %w", err)
}

if jwtExp == nil {
return fmt.Errorf("verifying token expiration time: no expiration time found")
}

return nil
}
142 changes: 136 additions & 6 deletions verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,23 @@ func TestVerifierVerifyIdToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'")
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: iss claim is required")
})

t.Run("verify id token wrong issuer", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "wrong",
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token has invalid issuer")
})

t.Run("verify id token missing audience", func(t *testing.T) {
Expand All @@ -77,7 +93,23 @@ func TestVerifierVerifyIdToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'")
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: aud claim is required")
})

t.Run("verify id token wrong audience", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": "wrong",
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token has invalid audience")
})

t.Run("verify id token missing issued time", func(t *testing.T) {
Expand Down Expand Up @@ -107,7 +139,7 @@ func TestVerifierVerifyIdToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token expiration time: no expiration time found")
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is missing required claim: exp claim is required")
})

t.Run("verify id token expired", func(t *testing.T) {
Expand All @@ -125,6 +157,39 @@ func TestVerifierVerifyIdToken(t *testing.T) {
_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired")
})

t.Run("verify id token expiration with leeway", func(t *testing.T) {

lv := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp), WithLeeway(time.Minute))

token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-30 * time.Second).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyIdToken(ctx, idToken)
require.NoError(t, err)

token = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-2 * time.Minute).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err = token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired")
})
}

func TestVerifierVerifyAccessToken(t *testing.T) {
Expand Down Expand Up @@ -201,7 +266,23 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'")
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: iss claim is required")
})

t.Run("verify access token wrong issuer", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "wrong",
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token has invalid issuer")
})

t.Run("verify access token missing audience", func(t *testing.T) {
Expand All @@ -215,7 +296,23 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'")
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: aud claim is required")
})

t.Run("verify access token wrong audience", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": "wrong",
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token has invalid audience")
})

t.Run("verify access token missing issued time", func(t *testing.T) {
Expand Down Expand Up @@ -243,7 +340,7 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
require.NoError(t, err)

_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token expiration time: no expiration time found")
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is missing required claim: exp claim is required")
})

t.Run("verify access token expired", func(t *testing.T) {
Expand All @@ -260,4 +357,37 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired")
})

t.Run("verify access token expiration with leeway", func(t *testing.T) {

lv := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp), WithLeeway(time.Minute))

token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-30 * time.Second).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyAccessToken(ctx, idToken)
require.NoError(t, err)

token = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(-2 * time.Minute).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err = token.SignedString(pk)
require.NoError(t, err)

_, err = lv.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired")
})
}
Loading