From 1f8f5e15d99259e3e6a514d5baebaa4fdea9f7d3 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Tue, 2 Jan 2024 12:03:10 -0800 Subject: [PATCH] Return implementation agnostic Jwt struct when verifying --- verifier.go | 23 +++++++++++++++++------ verifier_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/verifier.go b/verifier.go index 8a502fc..e428975 100644 --- a/verifier.go +++ b/verifier.go @@ -31,6 +31,17 @@ func defaultOptions(issuer string) *Options { // Option for the Verifier type Option func(*Options) +// Jwt is an implementation independent representation of a JWT that is returned to consumers of our APIs. +type Jwt struct { + Claims map[string]any +} + +// newJwtFromToken creates our Jwt struct from a jwt.Token. +func newJwtFromToken(token *jwt.Token) *Jwt { + claims := token.Claims.(jwt.MapClaims) + return &Jwt{Claims: claims} +} + // Verifier is the implementation of the Okta JWT verification logic. type Verifier struct { keyfuncProvider keyfunc.Provider @@ -49,7 +60,7 @@ func NewVerifier(issuer string, clientId string, options ...Option) *Verifier { } // VerifyIdToken verifies an Okta ID token. -func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*jwt.Token, error) { +func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*Jwt, error) { token, err := v.parseToken(ctx, idToken) if err != nil { return nil, fmt.Errorf("verifying id token: %w", err) @@ -66,21 +77,21 @@ func (v *Verifier) VerifyIdToken(ctx context.Context, idToken string) (*jwt.Toke return nil, fmt.Errorf("verifying token nonce: no nonce found") } - return token, nil + return newJwtFromToken(token), nil } // VerifyAccessToken verifies an Okta access token. -func (v *Verifier) VerifyAccessToken(ctx context.Context, accessToken string) (*jwt.Token, error) { - jwt, err := v.parseToken(ctx, accessToken) +func (v *Verifier) VerifyAccessToken(ctx context.Context, accessToken string) (*Jwt, error) { + token, err := v.parseToken(ctx, accessToken) if err != nil { return nil, fmt.Errorf("verifying access token: %w", err) } - if err = v.validateCommonClaims(ctx, jwt); err != nil { + if err = v.validateCommonClaims(ctx, token); err != nil { return nil, fmt.Errorf("validating claims: %w", err) } - return jwt, nil + return newJwtFromToken(token), nil } func (v *Verifier) parseToken(ctx context.Context, tokenString string) (*jwt.Token, error) { diff --git a/verifier_test.go b/verifier_test.go index bbbdd10..c57be79 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -163,6 +163,33 @@ func TestVerifierVerifyAccessToken(t *testing.T) { require.NoError(t, err) }) + t.Run("verify valid access with groups claims", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuer, + "aud": clientId, + "iat": time.Now().Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), + "groups": []string{"test1", "test2"}, + }) + token.Header["kid"] = oktatest.KID + idToken, err := token.SignedString(pk) + require.NoError(t, err) + + result, err := v.VerifyAccessToken(ctx, idToken) + require.NoError(t, err) + + var groups []string + groupIntfSlice, ok := result.Claims["groups"].([]interface{}) + require.True(t, ok) + for _, groupIntf := range groupIntfSlice { + group, ok := groupIntf.(string) + require.True(t, ok) + groups = append(groups, group) + } + + require.Equal(t, []string{"test1", "test2"}, groups) + }) + t.Run("verify access token missing issuer", func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "aud": clientId,