diff --git a/verifier.go b/verifier.go index 1ba6447..8a502fc 100644 --- a/verifier.go +++ b/verifier.go @@ -28,9 +28,10 @@ func defaultOptions(issuer string) *Options { return opts } -// Option for the OktaMetadataProvider +// Option for the Verifier type Option func(*Options) +// Verifier is the implementation of the Okta JWT verification logic. type Verifier struct { keyfuncProvider keyfunc.Provider issuer string @@ -49,25 +50,68 @@ func NewVerifier(issuer string, clientId string, options ...Option) *Verifier { // VerifyIdToken verifies an Okta ID token. func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*jwt.Token, error) { - jwt, err := v.parseToken(ctx, idToken) + token, err := v.parseToken(ctx, idToken) if err != nil { return nil, fmt.Errorf("verifying id token: %w", err) } + if err = v.validateCommonClaims(ctx, token); err != nil { + return nil, fmt.Errorf("validating claims: %w", err) + } + + claims := token.Claims.(jwt.MapClaims) + + _, exists := claims["nonce"] + if !exists { + return nil, fmt.Errorf("verifying token nonce: no nonce found") + } + + return token, nil +} + +// VerifyAccessToken verifies an Okta access token. +func (v *Verifier) VerifyAccessToken(ctx context.Context, accessToken string) (*jwt.Token, error) { + jwt, err := v.parseToken(ctx, accessToken) + if err != nil { + return nil, fmt.Errorf("verifying access token: %w", err) + } + + if err = v.validateCommonClaims(ctx, jwt); err != nil { + return nil, fmt.Errorf("validating claims: %w", err) + } + + return jwt, nil +} + +func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Token, error) { + keyfunc, err := v.keyfuncProvider.GetKeyfunc(ctx) + if err != nil { + return nil, fmt.Errorf("getting key function: %w", err) + } + + token, err := jwt.Parse(tokenString, keyfunc) + if err != nil { + return nil, fmt.Errorf("parsing token: %w", err) + } + + return token, err +} + +func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) error { claims := jwt.Claims jwtIssuer, err := claims.GetIssuer() if err != nil { - return nil, fmt.Errorf("verifying id token issuer: %w", err) + return fmt.Errorf("verifying token issuer: %w", err) } if jwtIssuer != v.issuer { - return nil, fmt.Errorf("verifying id token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer) + return fmt.Errorf("verifying token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer) } jwtAuds, err := claims.GetAudience() if err != nil { - return nil, fmt.Errorf("veriying id token audience: %w", err) + return fmt.Errorf("veriying token audience: %w", err) } matchFound := false @@ -79,43 +123,26 @@ func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*jwt.Toke } if !matchFound { - return nil, fmt.Errorf("verifying id token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId) + return fmt.Errorf("verifying token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId) } jwtIat, err := claims.GetIssuedAt() if err != nil { - return nil, fmt.Errorf("verifying id token issued time: %w", err) + return fmt.Errorf("verifying id token issued time: %w", err) } if jwtIat == nil { - return nil, fmt.Errorf("verifying id token issued time: no issued time found") + return fmt.Errorf("verifying token issued time: no issued time found") } jwtExp, err := claims.GetExpirationTime() if err != nil { - return nil, fmt.Errorf("verifying id token expriation time: %w", err) + return fmt.Errorf("verifying token expriation time: %w", err) } if jwtExp == nil { - return nil, fmt.Errorf("verifying id token expiration time: no expiration time found") + return fmt.Errorf("verifying token expiration time: no expiration time found") } - // FIXME: add support for nonce - - return jwt, nil -} - -func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Token, error) { - - keyfunc, err := v.keyfuncProvider.GetKeyfunc(ctx) - if err != nil { - return nil, fmt.Errorf("getting key function: %w", err) - } - - token, err := jwt.Parse(tokenString, keyfunc) - if err != nil { - return nil, fmt.Errorf("parsing token: %w", err) - } - - return token, err + return nil } diff --git a/verifier_test.go b/verifier_test.go index 318bf64..bbbdd10 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -13,7 +13,7 @@ import ( "time" ) -func TestVerifier(t *testing.T) { +func TestVerifierVerifyIdToken(t *testing.T) { issuer := "https://test.okta.com" clientId := "test" @@ -35,6 +35,120 @@ func TestVerifier(t *testing.T) { v := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp)) t.Run("verify valid id token", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.NoError(t, err) + }) + + t.Run("verify id token missing issuer", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'") + }) + + t.Run("verify id token missing audience", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'") + }) + + t.Run("verify id token missing issued time", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "exp": time.Now().Add(24 * time.Hour).Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token issued time: no issued time found") + }) + + t.Run("verify id token missing expiration", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token expiration time: no expiration time found") + }) + + t.Run("verify id token expired", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Unix(), + "nonce": 456, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + _, err = v.VerifyIdToken(ctx, idToken) + require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired") + }) +} + +func TestVerifierVerifyAccessToken(t *testing.T) { + issuer := "https://test.okta.com" + clientId := "test" + + // Generate RSA key. + pk, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + ctx := context.Background() + + uri, _ := oktatest.ServeJwks(t, ctx, pk) + + mp := &oktatest.StaticMetadataProvider{ + Md: metadata.Metadata{ + JwksUri: uri, + }, + } + + kp := okta.NewKeyfuncProvider(mp) + v := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp)) + + t.Run("verify valid access token", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "iss": issuer, "aud": clientId, @@ -45,11 +159,11 @@ func TestVerifier(t *testing.T) { idToken, err := token.SignedString(pk) require.NoError(t, err) - _, err = v.VerifyIdToken(ctx, idToken) + _, err = v.VerifyAccessToken(ctx, idToken) require.NoError(t, err) }) - t.Run("verify id token missing issuer", func(t *testing.T) { + t.Run("verify access token missing issuer", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "aud": clientId, "iat": time.Now().Unix(), @@ -59,11 +173,11 @@ func TestVerifier(t *testing.T) { idToken, err := token.SignedString(pk) require.NoError(t, err) - _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying id token issuer: issuer '' in token does not match 'https://test.okta.com'") + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'") }) - t.Run("verify id token missing audience", func(t *testing.T) { + t.Run("verify access token missing audience", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "iss": issuer, "iat": time.Now().Unix(), @@ -73,11 +187,11 @@ func TestVerifier(t *testing.T) { idToken, err := token.SignedString(pk) require.NoError(t, err) - _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying id token audience: audience '[]' in token does not match 'test'") + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'") }) - t.Run("verify id token missing issued time", func(t *testing.T) { + t.Run("verify access token missing issued time", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "iss": issuer, "aud": clientId, @@ -87,11 +201,11 @@ func TestVerifier(t *testing.T) { idToken, err := token.SignedString(pk) require.NoError(t, err) - _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying id token issued time: no issued time found") + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token issued time: no issued time found") }) - t.Run("verify id token missing expiration", func(t *testing.T) { + t.Run("verify access token missing expiration", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "iss": issuer, "aud": clientId, @@ -101,11 +215,11 @@ func TestVerifier(t *testing.T) { idToken, err := token.SignedString(pk) require.NoError(t, err) - _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying id token expiration time: no expiration time found") + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying token expiration time: no expiration time found") }) - t.Run("verify id token expired", func(t *testing.T) { + t.Run("verify access token expired", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "iss": issuer, "aud": clientId, @@ -116,7 +230,7 @@ func TestVerifier(t *testing.T) { idToken, err := token.SignedString(pk) require.NoError(t, err) - _, err = v.VerifyIdToken(ctx, idToken) - require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired") + _, err = v.VerifyAccessToken(ctx, idToken) + require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired") }) }