diff --git a/handler/openid/strategy_jwt_test.go b/handler/openid/strategy_jwt_test.go index 889cb7d5d..af23b92d0 100644 --- a/handler/openid/strategy_jwt_test.go +++ b/handler/openid/strategy_jwt_test.go @@ -10,8 +10,10 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/fosite" + "github.com/ory/fosite/internal/gen" "github.com/ory/fosite/token/jwt" ) @@ -283,3 +285,86 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { }) } } + +func TestJWTStrategy_DecodeIDToken(t *testing.T) { + var j = &DefaultStrategy{ + Signer: &jwt.DefaultSigner{ + GetPrivateKey: func(_ context.Context) (interface{}, error) { + return key, nil + }}, + Config: &fosite.Config{ + MinParameterEntropy: fosite.MinParameterEntropy, + }, + } + + var anotherKey = gen.MustRSAKey() + + var genIDToken = func(c jwt.IDTokenClaims) string { + s, _, err := j.Generate(context.TODO(), c.ToMapClaims(), jwt.NewHeaders()) + require.NoError(t, err) + return s + } + + var token string + var decoder *DefaultStrategy + for k, c := range []struct { + description string + setup func() + expectErr bool + }{ + { + description: "should pass with valid token", + setup: func() { + token = genIDToken(jwt.IDTokenClaims{ + Subject: "peter", + RequestedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + }) + decoder = j + }, + expectErr: false, + }, + { + description: "should pass even though token is expired", + setup: func() { + token = genIDToken(jwt.IDTokenClaims{ + Subject: "peter", + RequestedAt: time.Now(), + ExpiresAt: time.Now().Add(-time.Hour), + }) + decoder = j + }, + expectErr: false, + }, + { + description: "should fail because token is decoded with wrong key", + setup: func() { + token = genIDToken(jwt.IDTokenClaims{ + Subject: "peter", + RequestedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + }) + decoder = &DefaultStrategy{ + Signer: &jwt.DefaultSigner{ + GetPrivateKey: func(_ context.Context) (interface{}, error) { + return anotherKey, nil + }}, + Config: &fosite.Config{ + MinParameterEntropy: fosite.MinParameterEntropy, + }, + } + }, + expectErr: true, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) { + c.setup() + req := fosite.NewAccessRequest(&DefaultSession{}) + idtoken, err := decoder.DecodeIDToken(context.Background(), req, token) + assert.Equal(t, c.expectErr, err != nil, "%d: %+v", k, err) + if !c.expectErr { + assert.NotNil(t, idtoken) + } + }) + } +}