From 52e22b32f79a35012b784fd49a3c8af36b1e4be4 Mon Sep 17 00:00:00 2001 From: evgenyk Date: Thu, 16 May 2024 12:10:06 +1000 Subject: [PATCH] support for providing RSA private key to parser --- jwt/jwt.go | 11 +++++--- .../authorization_code_test.go | 28 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/jwt/jwt.go b/jwt/jwt.go index f116db7..453c62e 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -1,6 +1,7 @@ package jwt import ( + "crypto/rsa" "fmt" "net/http" "slices" @@ -25,6 +26,7 @@ type Token struct { isValid bool } +// ParseFromAuthorizationHeader will parse the token from the Authorization header and validate it with the given options. func ParseFromAuthorizationHeader(r *http.Request, options ...func(*Token)) (*Token, error) { requestedToken := r.Header.Get("Authorization") splitToken := strings.Split(requestedToken, "Bearer") @@ -35,6 +37,7 @@ func ParseFromAuthorizationHeader(r *http.Request, options ...func(*Token)) (*To return ParseOAuth2Token(&oauth2.Token{AccessToken: requestedToken}, options...) } +// ParseFromString will parse the given token and validate it with the given options. func ParseFromString(rawToken string, options ...func(*Token)) (*Token, error) { return ParseOAuth2Token(&oauth2.Token{AccessToken: rawToken}, options...) } @@ -93,8 +96,8 @@ func (j *Token) IsValid() bool { return j.isValid } -// WillValidateKeys will validate the token with the given keyFunc. -func WillValidateKeys(keyFunc func(rawToken string) (interface{}, error)) func(*Token) { +// WillValidateKeys receives a token and needs to return a public rsa key to validate the token. +func WillValidateKeys(keyFunc func(rawToken string) (*rsa.PublicKey, error)) func(*Token) { return func(s *Token) { wrapped := func(token *golangjwt.Token) (interface{}, error) { return keyFunc(token.Raw) @@ -104,9 +107,9 @@ func WillValidateKeys(keyFunc func(rawToken string) (interface{}, error)) func(* } // WillValidateKeys will validate the token with the given keyFunc. -func WillValidateJWKSUrl(jwks string) func(*Token) { +func WillValidateJWKSUrl(url string) func(*Token) { return func(s *Token) { - jwks, err := keyfunc.NewDefault([]string{jwks}) + jwks, err := keyfunc.NewDefault([]string{url}) if err != nil { return } diff --git a/oauth2/authorization_code/authorization_code_test.go b/oauth2/authorization_code/authorization_code_test.go index 775f72a..4759583 100644 --- a/oauth2/authorization_code/authorization_code_test.go +++ b/oauth2/authorization_code/authorization_code_test.go @@ -2,6 +2,9 @@ package authorization_code import ( "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" "io" "net/http" @@ -46,6 +49,17 @@ func TestAutorizationCodeFlowOnline(t *testing.T) { assert.Nil(t, err, "error parsing token") assert.True(t, parsedToken.IsValid(), "token is not valid") + parsedToken, err = jwt.ParseFromAuthorizationHeader(r, + //testing verification with provided private key instead of JWKS + jwt.WillValidateKeys(func(rawToken string) (*rsa.PublicKey, error) { + return testPublicPEM(), nil + }), + jwt.WillValidateAudience("http://my.api.com/api"), + jwt.WillValidateAlgorythm(), + ) + assert.Nil(t, err, "error parsing token") + assert.True(t, parsedToken.IsValid(), "token is not valid") + w.Write([]byte(`hello world`)) })) defer testApiServer.Close() @@ -133,7 +147,21 @@ func testJWKSPublicKeys() []byte { ]}` return []byte(key) +} +func testPublicPEM() *rsa.PublicKey { + publicKey := `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuOaDKcdR8JR7PiVEHjRO +1dQVbLFoMRSiBio+rRlq+ljouBFJtehghnkIk0sSJlmoJY8329RdF9122IL0NYxO ++QTFJmAamSdUcmSgg4D3qI3Nc82H7L7ocad2OfhhXmBwz+O/8cxK+xYAnvKGmHf/ +tSmqVWJVbvBFG1r7sU3WBfLZPoivofFKjnhPG5jFbC2AziTFqKiQ7i2T2F0APIPT +J5Bf05zI2BpIYwyZyaP1F5EWmBEOvOP02Mr0L3Rj0lOJGQJ8gJh9uacGCt/RZAlx +0ZMiK93fk3vfszfKv0UhOpYKBcElR/5U1gJfXuDF6j10vG+8rwoorIPzCwu3wKZP +ewIDAQAB +-----END PUBLIC KEY-----` + block, _ := pem.Decode([]byte(publicKey)) + pemKey, _ := x509.ParsePKIXPublicKey(block.Bytes) + return pemKey.(*rsa.PublicKey) } func testJWKSPrivateKey() string {