Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for verifying access token #3

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ includes support for telemetry (ie. OpenTelemetry), minimizing operational laten

## Examples

### ID Token Validation
### Access Token Validation

```go
import (
Expand All @@ -24,7 +24,7 @@ func main() {
v := verifier.NewVerifier(issuer, clientId)

idToken := "..."
token, err := v.VerifyIdToken(ctx, idToken)
token, err := v.VerifyAccessToken(ctx, idToken)
}

```
Expand Down
83 changes: 55 additions & 28 deletions verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ func defaultOptions(issuer string) *Options {
return opts
}

// Option for the OktaMetadataProvider
// Option for the Verifier
type Option func(*Options)

// Verifier is the implementation of the Okta JWT verification logic.
type Verifier struct {
keyfuncProvider keyfunc.Provider
issuer string
Expand All @@ -49,25 +50,68 @@ func NewVerifier(issuer string, clientId string, options ...Option) *Verifier {

// VerifyIdToken verifies an Okta ID token.
func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*jwt.Token, error) {
jwt, err := v.parseToken(ctx, idToken)
token, err := v.parseToken(ctx, idToken)
if err != nil {
return nil, fmt.Errorf("verifying id token: %w", err)
}

if err = v.validateCommonClaims(ctx, token); err != nil {
return nil, fmt.Errorf("validating claims: %w", err)
}

claims := token.Claims.(jwt.MapClaims)

_, exists := claims["nonce"]
if !exists {
return nil, fmt.Errorf("verifying token nonce: no nonce found")
}

return token, nil
}

// VerifyAccessToken verifies an Okta access token.
func (v *Verifier) VerifyAccessToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
jwt, err := v.parseToken(ctx, accessToken)
if err != nil {
return nil, fmt.Errorf("verifying access token: %w", err)
}

if err = v.validateCommonClaims(ctx, jwt); err != nil {
return nil, fmt.Errorf("validating claims: %w", err)
}

return jwt, nil
}

func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Token, error) {
keyfunc, err := v.keyfuncProvider.GetKeyfunc(ctx)
if err != nil {
return nil, fmt.Errorf("getting key function: %w", err)
}

token, err := jwt.Parse(tokenString, keyfunc)
if err != nil {
return nil, fmt.Errorf("parsing token: %w", err)
}

return token, err
}

func (v *Verifier) validateCommonClaims(ctx context.Context, jwt *jwt.Token) error {
claims := jwt.Claims

jwtIssuer, err := claims.GetIssuer()
if err != nil {
return nil, fmt.Errorf("verifying id token issuer: %w", err)
return fmt.Errorf("verifying token issuer: %w", err)
}

if jwtIssuer != v.issuer {
return nil, fmt.Errorf("verifying id token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer)
return fmt.Errorf("verifying token issuer: issuer '%s' in token does not match '%s'", jwtIssuer, v.issuer)
}

jwtAuds, err := claims.GetAudience()
if err != nil {
return nil, fmt.Errorf("veriying id token audience: %w", err)
return fmt.Errorf("veriying token audience: %w", err)
}

matchFound := false
Expand All @@ -79,43 +123,26 @@ func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*jwt.Toke
}

if !matchFound {
return nil, fmt.Errorf("verifying id token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId)
return fmt.Errorf("verifying token audience: audience '%s' in token does not match '%s'", jwtAuds, v.clientId)
}

jwtIat, err := claims.GetIssuedAt()
if err != nil {
return nil, fmt.Errorf("verifying id token issued time: %w", err)
return fmt.Errorf("verifying id token issued time: %w", err)
}

if jwtIat == nil {
return nil, fmt.Errorf("verifying id token issued time: no issued time found")
return fmt.Errorf("verifying token issued time: no issued time found")
}

jwtExp, err := claims.GetExpirationTime()
if err != nil {
return nil, fmt.Errorf("verifying id token expriation time: %w", err)
return fmt.Errorf("verifying token expriation time: %w", err)
}

if jwtExp == nil {
return nil, fmt.Errorf("verifying id token expiration time: no expiration time found")
return fmt.Errorf("verifying token expiration time: no expiration time found")
}

// FIXME: add support for nonce

return jwt, nil
}

func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Token, error) {

keyfunc, err := v.keyfuncProvider.GetKeyfunc(ctx)
if err != nil {
return nil, fmt.Errorf("getting key function: %w", err)
}

token, err := jwt.Parse(tokenString, keyfunc)
if err != nil {
return nil, fmt.Errorf("parsing token: %w", err)
}

return token, err
return nil
}
148 changes: 131 additions & 17 deletions verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"time"
)

func TestVerifier(t *testing.T) {
func TestVerifierVerifyIdToken(t *testing.T) {
issuer := "https://test.okta.com"
clientId := "test"

Expand All @@ -35,6 +35,120 @@ func TestVerifier(t *testing.T) {
v := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp))

t.Run("verify valid id token", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.NoError(t, err)
})

t.Run("verify id token missing issuer", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'")
})

t.Run("verify id token missing audience", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"iat": time.Now().Unix(),
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'")
})

t.Run("verify id token missing issued time", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"exp": time.Now().Add(24 * time.Hour).Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issued time: no issued time found")
})

t.Run("verify id token missing expiration", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token expiration time: no expiration time found")
})

t.Run("verify id token expired", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
"iat": time.Now().Unix(),
"exp": time.Now().Unix(),
"nonce": 456,
})
token.Header["kid"] = oktatest.KID
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired")
})
}

func TestVerifierVerifyAccessToken(t *testing.T) {
issuer := "https://test.okta.com"
clientId := "test"

// Generate RSA key.
pk, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

ctx := context.Background()

uri, _ := oktatest.ServeJwks(t, ctx, pk)

mp := &oktatest.StaticMetadataProvider{
Md: metadata.Metadata{
JwksUri: uri,
},
}

kp := okta.NewKeyfuncProvider(mp)
v := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp))

t.Run("verify valid access token", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
Expand All @@ -45,11 +159,11 @@ func TestVerifier(t *testing.T) {
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
_, err = v.VerifyAccessToken(ctx, idToken)
require.NoError(t, err)
})

t.Run("verify id token missing issuer", func(t *testing.T) {
t.Run("verify access token missing issuer", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"aud": clientId,
"iat": time.Now().Unix(),
Expand All @@ -59,11 +173,11 @@ func TestVerifier(t *testing.T) {
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token issuer: issuer '' in token does not match 'https://test.okta.com'")
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issuer: issuer '' in token does not match 'https://test.okta.com'")
})

t.Run("verify id token missing audience", func(t *testing.T) {
t.Run("verify access token missing audience", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"iat": time.Now().Unix(),
Expand All @@ -73,11 +187,11 @@ func TestVerifier(t *testing.T) {
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token audience: audience '[]' in token does not match 'test'")
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token audience: audience '[]' in token does not match 'test'")
})

t.Run("verify id token missing issued time", func(t *testing.T) {
t.Run("verify access token missing issued time", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
Expand All @@ -87,11 +201,11 @@ func TestVerifier(t *testing.T) {
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token issued time: no issued time found")
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token issued time: no issued time found")
})

t.Run("verify id token missing expiration", func(t *testing.T) {
t.Run("verify access token missing expiration", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
Expand All @@ -101,11 +215,11 @@ func TestVerifier(t *testing.T) {
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token expiration time: no expiration time found")
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying token expiration time: no expiration time found")
})

t.Run("verify id token expired", func(t *testing.T) {
t.Run("verify access token expired", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": issuer,
"aud": clientId,
Expand All @@ -116,7 +230,7 @@ func TestVerifier(t *testing.T) {
idToken, err := token.SignedString(pk)
require.NoError(t, err)

_, err = v.VerifyIdToken(ctx, idToken)
require.ErrorContains(t, err, "verifying id token: parsing token: token has invalid claims: token is expired")
_, err = v.VerifyAccessToken(ctx, idToken)
require.ErrorContains(t, err, "verifying access token: parsing token: token has invalid claims: token is expired")
})
}
Loading