From 81fe05ef506404bda8b02c54d2fd6cb5ffc26a7d Mon Sep 17 00:00:00 2001 From: johnabass Date: Mon, 12 Aug 2024 15:48:55 -0700 Subject: [PATCH] avoid embedding jwt.Token; unit tests --- basculejwt/token.go | 36 ++++++++- basculejwt/token_test.go | 153 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 4 deletions(-) create mode 100644 basculejwt/token_test.go diff --git a/basculejwt/token.go b/basculejwt/token.go index 3920516..1cc33d6 100644 --- a/basculejwt/token.go +++ b/basculejwt/token.go @@ -40,11 +40,39 @@ type Claims interface { // token is the internal implementation of the JWT Token interface. It fronts // a lestrrat-go Token. type token struct { - jwt.Token + jwt jwt.Token } -func (t *token) Principal() string { - return t.Token.Subject() +func (t token) Audience() []string { + return t.jwt.Audience() +} + +func (t token) Expiration() time.Time { + return t.jwt.Expiration() +} + +func (t token) IssuedAt() time.Time { + return t.jwt.IssuedAt() +} + +func (t token) Issuer() string { + return t.jwt.Issuer() +} + +func (t token) JwtID() string { + return t.jwt.JwtID() +} + +func (t token) NotBefore() time.Time { + return t.jwt.NotBefore() +} + +func (t token) Subject() string { + return t.jwt.Subject() +} + +func (t token) Principal() string { + return t.jwt.Subject() } // tokenParser is the canonical parser for bascule that deals with JWTs. @@ -72,6 +100,6 @@ func (tp *tokenParser) Parse(ctx context.Context, value string) (bascule.Token, } return &token{ - Token: jwtToken, + jwt: jwtToken, }, nil } diff --git a/basculejwt/token_test.go b/basculejwt/token_test.go new file mode 100644 index 0000000..388a380 --- /dev/null +++ b/basculejwt/token_test.go @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculejwt + +import ( + "context" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/suite" +) + +type TokenTestSuite struct { + suite.Suite + + audience []string + jwtID string + issuer string + + expiration time.Time + issuedAt time.Time + notBefore time.Time + subject string + + capabilities []string + allowedResources map[string]any + version string + + testKey jwk.Key + testKeySet jwk.Set + + testJWT jwt.Token + signedJWT []byte +} + +func (suite *TokenTestSuite) initializeKey() { + var err error + suite.testKey, err = jwk.ParseKey([]byte(`{ + "p": "7HMYtb-1dKyDp1OkdKc9WDdVMw3vtiiKDyuyRwnnwMOoYLPYxqE0CUMzw8_zXuzq7WJAmGiFd5q7oVzkbHzrtQ", + "kty": "RSA", + "q": "5253lCAgBLr8SR_VzzDtk_3XTHVmVIgniajMl7XM-ttrUONV86DoIm9VBx6ywEKpj5Xv3USBRNlpf8OXqWVhPw", + "d": "G7RLbBiCkiZuepbu46G0P8J7vn5l8G6U78gcMRdEhEsaXGZz_ZnbqjW6u8KI_3akrBT__GDPf8Hx8HBNKX5T9jNQW0WtJg1XnwHOK_OJefZl2fnx-85h3tfPD4zI3m54fydce_2kDVvqTOx_XXdNJD7v5TIAgvCymQv7qvzQ0VE", + "e": "AQAB", + "use": "sig", + "kid": "test", + "qi": "a_6YlMdA9b6piRodA0MR7DwjbALlMan19wj_VkgZ8Xoilq68sGaV2CQDoAdsTW9Mjt5PpCxvJawz0AMr6LIk9w", + "dp": "s55HgiGs_YHjzSOsBXXaEv6NuWf31l_7aMTf_DkZFYVMjpFwtotVFUg4taJuFYlSeZwux9h2s0IXEOCZIZTQFQ", + "alg": "RS256", + "dq": "M79xoX9laWleDAPATSnFlbfGsmP106T2IkPKK4oNIXJ6loWerHEoNrrqKkNk-LRvMZn3HmS4-uoaOuVDPi9bBQ", + "n": "1cHjMu7H10hKxnoq3-PJT9R25bkgVX1b39faqfecC82RMcD2DkgCiKGxkCmdUzuebpmXCZuxp-rVVbjrnrI5phAdjshZlkHwV0tyJOcerXsPgu4uk_VIJgtLdvgUAtVEd8-ZF4Y9YNOAKtf2AHAoRdP0ZVH7iVWbE6qU-IN2los" +}`)) + + suite.Require().NoError(err) + + suite.testKeySet = jwk.NewSet() + err = suite.testKeySet.AddKey(suite.testKey) + suite.Require().NoError(err) +} + +func (suite *TokenTestSuite) initializeClaims() { + suite.audience = []string{"test-audience"} + suite.jwtID = "test-jwt" + suite.issuer = "test-issuer" + + // time fields in the JOSE spec are in seconds + // generate an issuedAt in the recent past, so that validation can work + suite.issuedAt = time.Now().Add(-time.Second).Round(time.Second).UTC() + suite.expiration = suite.issuedAt.Add(time.Hour) + suite.notBefore = suite.issuedAt.Add(-time.Hour) + + suite.subject = "test-subject" + + suite.capabilities = []string{ + "x1:webpa:api:.*:all", + "x1:webpa:api:device/.*/config\\b:all", + } + + suite.allowedResources = make(map[string]any) + suite.allowedResources["allowedPartners"] = []string{"comcast"} + + suite.version = "2.0" +} + +func (suite *TokenTestSuite) createJWT() { + var err error + suite.testJWT, err = jwt.NewBuilder(). + Audience(suite.audience). + Subject(suite.subject). + IssuedAt(suite.issuedAt). + Expiration(suite.expiration). + NotBefore(suite.notBefore). + JwtID(suite.jwtID). + Issuer(suite.issuer). + Claim("capabilities", suite.capabilities). + Claim("allowedResources", suite.allowedResources). + Claim("version", suite.version). + Build() + + suite.Require().NoError(err) + + suite.signedJWT, err = jwt.Sign(suite.testJWT, jwt.WithKey(jwa.RS256, suite.testKey)) + suite.Require().NoError(err) +} + +func (suite *TokenTestSuite) SetupSuite() { + suite.initializeKey() + suite.initializeClaims() + suite.createJWT() + + suite.T().Log("using signed JWT", string(suite.signedJWT)) +} + +func (suite *TokenTestSuite) TestTokenParser() { + suite.Run("Success", func() { + tp, err := NewTokenParser(jwt.WithKeySet(suite.testKeySet)) + suite.Require().NoError(err) + suite.Require().NotNil(tp) + + token, err := tp.Parse(context.Background(), string(suite.signedJWT)) + suite.Require().NoError(err) + suite.Require().NotNil(token) + + suite.Equal(suite.subject, token.Principal()) + + suite.Require().Implements((*Claims)(nil), token) + claims := token.(Claims) + suite.Equal(suite.audience, claims.Audience()) + suite.Equal(suite.subject, claims.Subject()) + suite.Equal(suite.issuer, claims.Issuer()) + suite.Equal(suite.expiration, claims.Expiration()) + suite.Equal(suite.issuedAt, claims.IssuedAt()) + suite.Equal(suite.notBefore, claims.NotBefore()) + suite.Equal(suite.jwtID, claims.JwtID()) + }) + + suite.Run("NoOptions", func() { + tp, err := NewTokenParser() + suite.Require().NoError(err) + suite.Require().NotNil(tp) + + token, err := tp.Parse(context.Background(), string(suite.signedJWT)) + suite.Error(err) + suite.Nil(token) + }) +} + +func TestToken(t *testing.T) { + suite.Run(t, new(TokenTestSuite)) +}