diff --git a/go.mod b/go.mod index 37ee9ea..895f237 100644 --- a/go.mod +++ b/go.mod @@ -7,12 +7,12 @@ toolchain go1.22.5 require ( github.com/alecthomas/kong v1.4.0 github.com/foxcpp/go-mockdns v1.1.0 - github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/goschtalt/goschtalt v0.25.0 github.com/goschtalt/properties-decoder v0.1.0 github.com/goschtalt/yaml-decoder v0.0.1 github.com/goschtalt/yaml-encoder v0.0.3 + github.com/lestrrat-go/jwx/v2 v2.1.2 github.com/stretchr/testify v1.9.0 github.com/ugorji/go/codec v1.2.12 github.com/xmidt-org/arrange v0.5.1 @@ -29,21 +29,30 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect + github.com/goccy/go-json v0.10.3 // indirect github.com/gorilla/websocket v1.5.1 // indirect github.com/goschtalt/approx v1.0.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.6 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/miekg/dns v1.1.59 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.2.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/xmidt-org/httpaux v0.4.1 // indirect go.uber.org/dig v1.18.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/crypto v0.28.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.25.0 // indirect golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.25.0 // indirect + golang.org/x/sys v0.26.0 // indirect golang.org/x/tools v0.21.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index a71d312..9c45b02 100644 --- a/go.sum +++ b/go.sum @@ -10,11 +10,13 @@ github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW5 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= github.com/gdamore/optopia v0.2.0/go.mod h1:YKYEwo5C1Pa617H7NlPcmQXl+vG6YnSSNB44n8dNL0Q= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -41,6 +43,18 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k= +github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.1.2 h1:6poete4MPsO8+LAEVhpdrNI4Xp2xdiafgl2RD89moBc= +github.com/lestrrat-go/jwx/v2 v2.1.2/go.mod h1:pO+Gz9whn7MPdbsqSJzG8TlEpMZCwQDXnFJ+zsUVh8Y= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk= @@ -54,12 +68,16 @@ github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef h1:NKxTG6GVGbfMXc2m github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef/go.mod h1:tcaRap0jS3eifrEEllL6ZMd9dg8IlDpi2S1oARrQ+NI= github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= @@ -94,6 +112,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -130,8 +150,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -164,5 +184,6 @@ gopkg.in/dealancer/validate.v2 v2.1.0 h1:XY95SZhVH1rBe8uwtnQEsOO79rv8GPwK+P3VWhQ gopkg.in/dealancer/validate.v2 v2.1.0/go.mod h1:EipWMj8hVO2/dPXVlYRe9yKcgVd5OttpQDiM1/wZ0DE= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/credentials/cmd/example/main.go b/internal/credentials/cmd/example/main.go index 45e931b..750a343 100644 --- a/internal/credentials/cmd/example/main.go +++ b/internal/credentials/cmd/example/main.go @@ -7,13 +7,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/base64" "fmt" "net/http" "os" "time" "github.com/alecthomas/kong" - "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/xmidt-org/wrp-go/v3" cred "github.com/xmidt-org/xmidt-agent/internal/credentials" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" @@ -120,41 +122,44 @@ func main() { defer cancel() credentials.WaitUntilFetched(ctx) - token, expires, err := credentials.Credentials() + tokenString, expires, err := credentials.Credentials() if err != nil { panic(err) } - fmt.Printf("JWT: %s\n", token) + fmt.Printf("JWT: %s\n", tokenString) fmt.Printf("Expires: %s\n", expires.Format(time.RFC3339)) - claims := jwt.RegisteredClaims{} - parser := jwt.NewParser() - _, parts, err := parser.ParseUnverified(token, &claims) + token, err := jwt.ParseString(tokenString) if err != nil { panic(err) } fmt.Println("Claims:") - fmt.Printf(" ID: %s\n", claims.ID) - fmt.Printf(" ExpirationTime: %s\n", claims.ExpiresAt) - fmt.Printf(" IssuedAt: %s\n", claims.IssuedAt) - fmt.Printf(" NotBefore: %s\n", claims.NotBefore) - fmt.Printf(" Issuer: %s\n", claims.Issuer) - fmt.Printf(" Subject: %s\n", claims.Subject) - fmt.Printf(" Audience: %s\n", claims.Audience) - - header, err := parser.DecodeSegment(parts[0]) + fmt.Printf(" ID: %s\n", token.JwtID()) + fmt.Printf(" Expiration: %s\n", token.Expiration()) + fmt.Printf(" IssuedAt: %s\n", token.IssuedAt()) + fmt.Printf(" NotBefore: %s\n", token.NotBefore()) + fmt.Printf(" Issuer: %s\n", token.Issuer()) + fmt.Printf(" Subject: %s\n", token.Subject()) + fmt.Printf(" Audience: %v\n", token.Audience()) + + header, body, _, err := jws.SplitCompactString(tokenString) if err != nil { panic(err) } - body, err := parser.DecodeSegment(parts[1]) + decodedHeader, err := base64.RawURLEncoding.DecodeString(string(header)) + if err != nil { + panic(err) + } + + decodedBody, err := base64.RawURLEncoding.DecodeString(string(body)) if err != nil { panic(err) } fmt.Println("Parts:") - fmt.Printf(" Header: %s\n", header) - fmt.Printf(" Body: %s\n", body) + fmt.Printf(" Header: %s\n", decodedHeader) + fmt.Printf(" Body: %s\n", decodedBody) } diff --git a/internal/jwtxt/options.go b/internal/jwtxt/options.go index adf4bbe..5fa8a3c 100644 --- a/internal/jwtxt/options.go +++ b/internal/jwtxt/options.go @@ -8,7 +8,9 @@ import ( "net/url" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/jwtxt/event" ) @@ -18,6 +20,20 @@ type Option interface { apply(*Instructions) error } +func errorOptionFn(err error) Option { + return errorOption{ + err: err, + } +} + +type errorOption struct { + err error +} + +func (e errorOption) apply(*Instructions) error { + return e.err +} + // WithFetchListener adds a listener for fetch events. func WithFetchListener(listener event.FetchListener) Option { return &fetchListener{ @@ -66,31 +82,45 @@ func (u useNowFunc) apply(ins *Instructions) error { return nil } +var allowedSigningAlgorithms = map[string]jwa.SignatureAlgorithm{ + "EdDSA": jwa.EdDSA, + "ES256": jwa.ES256, + "ES384": jwa.ES384, + "ES512": jwa.ES512, + "PS256": jwa.PS256, + "PS384": jwa.PS384, + "PS512": jwa.PS512, + "RS256": jwa.RS256, + "RS384": jwa.RS384, + "RS512": jwa.RS512, +} + // Algorithms sets the algorithms to use for verification. Valid algorithms // are "EdDSA", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", // "RS256", "RS384", and "RS512". func Algorithms(algs ...string) Option { + allowed := make([]jwa.SignatureAlgorithm, 0, len(algs)) + + for _, alg := range algs { + got, found := allowedSigningAlgorithms[alg] + if !found { + return errorOptionFn(fmt.Errorf("%w '%s'", ErrUnspportedAlg, alg)) + } + allowed = append(allowed, got) + } + return &algorithms{ - algs: algs, + algs: allowed, } } type algorithms struct { - algs []string + algs []jwa.SignatureAlgorithm } func (a algorithms) apply(ins *Instructions) error { for _, alg := range a.algs { - switch alg { - case "EdDSA", - "ES256", "ES384", "ES512", - "PS256", "PS384", "PS512", - "RS256", "RS384", "RS512": - default: - return fmt.Errorf("%w '%s'", ErrUnspportedAlg, alg) - } - ins.jwtOptions = append(ins.jwtOptions, jwt.WithValidMethods([]string{alg})) - ins.algorithms = append(ins.algorithms, alg) + ins.algorithms[alg] = struct{}{} } return nil } @@ -130,25 +160,24 @@ type pemOption struct { } func (p pemOption) apply(ins *Instructions) error { - for _, pem := range p.pems { - var key jwt.VerificationKey - var err error - - key, err = jwt.ParseECPublicKeyFromPEM(pem) - + for _, single := range p.pems { + key, err := jwk.ParseKey(single, jwk.WithPEM(true)) if err != nil { - key, err = jwt.ParseRSAPublicKeyFromPEM(pem) + return fmt.Errorf("%w: invalid pem", ErrInvalidInput) } + algs, err := jws.AlgorithmsForKey(key) if err != nil { - key, err = jwt.ParseEdPublicKeyFromPEM(pem) + return fmt.Errorf("%w: invalid key", ErrInvalidInput) } - if err != nil { - return fmt.Errorf("%w: invalid pem", ErrInvalidInput) + for _, a := range algs { + if _, ok := allowedSigningAlgorithms[a.String()]; !ok { + return fmt.Errorf("%w: algorithm not allowed %s", ErrInvalidInput, a) + } } - ins.publicKeys.Keys = append(ins.publicKeys.Keys, key) + ins.publicKeys = append(ins.publicKeys, key) } return nil @@ -194,3 +223,85 @@ func (i idOption) apply(ins *Instructions) error { ins.id = id.ID() return nil } + +// -- validation options ------------------------------------------------------- + +func validateAlgs() Option { + return &validateAlgorithms{} +} + +type validateAlgorithms struct{} + +func (validateAlgorithms) apply(ins *Instructions) error { + if len(ins.algorithms) == 0 { + return fmt.Errorf("%w: zero provided algorithms", ErrInvalidInput) + } + + if len(ins.publicKeys) == 0 { + return fmt.Errorf("%w: zero provided public keys", ErrInvalidInput) + } + + // Ensure each key passed in provides an allowed algorithm. + for _, k := range ins.publicKeys { + var allowed bool + + algs, _ := jws.AlgorithmsForKey(k) + for _, alg := range algs { + if _, ok := ins.algorithms[alg]; ok { + allowed = true + break + } + } + + if !allowed { + return fmt.Errorf("%w: provided pem does not support allowed algorithm", ErrInvalidInput) + } + } + + return nil +} + +func validateBase() Option { + return &validateBaseURL{} +} + +type validateBaseURL struct{} + +func (validateBaseURL) apply(ins *Instructions) error { + if ins.baseURL == "" { + return fmt.Errorf("%w: baseURL must be set", ErrInvalidInput) + } + + return nil +} + +func validateTheID() Option { + return &validateID{} +} + +type validateID struct{} + +func (validateID) apply(ins *Instructions) error { + if ins.id == "" { + return fmt.Errorf("%w: id must be set", ErrInvalidInput) + } + + return nil +} + +// -- internal options --------------------------------------------------------- + +func makeSet() Option { + return &makeSetOption{} +} + +type makeSetOption struct{} + +func (makeSetOption) apply(ins *Instructions) error { + ins.set = jwk.NewSet() + + for _, k := range ins.publicKeys { + ins.set.AddKey(k) + } + return nil +} diff --git a/internal/jwtxt/token.go b/internal/jwtxt/token.go index 033f9e7..c02ed2f 100644 --- a/internal/jwtxt/token.go +++ b/internal/jwtxt/token.go @@ -12,7 +12,10 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/xmidt-org/eventor" "github.com/xmidt-org/xmidt-agent/internal/jwtxt/event" ) @@ -50,16 +53,19 @@ type Instructions struct { fqdn string // jwtOptions allows for normal and test configurations. - jwtOptions []jwt.ParserOption + jwtOptions []jwt.ParseOption // timeout is the timeout for the DNS query. timeout time.Duration // algorithms is the list of algorithms allowed for JWT validation. - algorithms []string + algorithms map[jwa.SignatureAlgorithm]struct{} // publicKeys is the collection of keys split out by supported algorithm. - publicKeys jwt.VerificationKeySet + publicKeys []jwk.Key + + // The useable set of keys to use for validation. + set jwk.Set // now is used to supply the current time that is needed for expiration. // it's here just for testing support. @@ -93,10 +99,17 @@ func New(opts ...Option) (*Instructions, error) { now: time.Now, resolver: net.DefaultResolver, timeout: DefaultTimeout, - algorithms: []string{}, + algorithms: map[jwa.SignatureAlgorithm]struct{}{}, } - for _, opt := range opts { + full := append(opts, + validateAlgs(), + validateBase(), + validateTheID(), + makeSet(), + ) + + for _, opt := range full { if opt != nil { err := opt.apply(&ins) if err != nil { @@ -105,12 +118,7 @@ func New(opts ...Option) (*Instructions, error) { } } - if ins.baseURL == "" || ins.id == "" { - return nil, fmt.Errorf("%w: baseURL and id must be set", ErrInvalidInput) - } - ins.fqdn = ins.id + "." + ins.baseURL - ins.jwtOptions = []jwt.ParserOption{jwt.WithValidMethods(ins.algorithms)} return &ins, nil } @@ -240,39 +248,29 @@ func (ins *Instructions) reassemble(lines []string) string { // If it is valid, the information is saved in the Instruction object for // use along with when the information is no longer valid after. func (ins *Instructions) validate(input string) error { - parser := jwt.NewParser(ins.jwtOptions...) - - token, err := parser.ParseWithClaims(input, &customClaims{}, - func(t *jwt.Token) (any, error) { - return ins.publicKeys, nil - }) - if err != nil { - return err - } - - until, err := token.Claims.GetExpirationTime() + token, err := jwt.ParseString(input, + jwt.WithKeySet(ins.set, + jws.WithRequireKid(false), + jws.WithInferAlgorithmFromKey(true), + ), + jwt.WithClock(jwt.ClockFunc(ins.now)), + jwt.WithRequiredClaim("endpoint"), + jwt.WithValidate(true), + ) if err != nil { - return err + return errors.Join(err, ErrInvalidJWT) } - _, parts, err := parser.ParseUnverified(input, &customClaims{}) + msg, err := jws.ParseString(input) if err != nil { - return err + return errors.Join(err, ErrInvalidJWT) } - payload, err := parser.DecodeSegment(parts[1]) - if err != nil { - return err - } + ins.payload = msg.Payload() + ins.validUntil = token.Expiration().Local() - ins.endpoint = token.Claims.(*customClaims).Endpoint - ins.payload = payload - ins.validUntil = (*until).Time + ep, _ := token.Get("endpoint") + ins.endpoint = ep.(string) return nil } - -type customClaims struct { - Endpoint string `json:"endpoint"` - jwt.RegisteredClaims -} diff --git a/internal/jwtxt/token_test.go b/internal/jwtxt/token_test.go index 63eff3d..a530d80 100644 --- a/internal/jwtxt/token_test.go +++ b/internal/jwtxt/token_test.go @@ -11,12 +11,23 @@ import ( "time" "github.com/foxcpp/go-mockdns" - "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xmidt-org/xmidt-agent/internal/jwtxt/event" ) +// The orignal JWT: +// { +// "alg": "ES256", +// "typ": "JWT" +// } +// { +// "endpoint": "fabric.xmidt.example.org", +// "exp": 1690000000 +// } +// The JWT is signed with the private key and the public key is used to verify it. + func randomResolver() Option { return UseResolver(&mockdns.Resolver{ Zones: map[string]mockdns.Zone{ @@ -159,7 +170,7 @@ func TestInstructions_EndToEnd(t *testing.T) { }, { description: "missing segments", times: []int64{1680000000}, - expectedEndpointErr: jwt.ErrTokenMalformed, + expectedEndpointErr: ErrInvalidJWT, opts: []Option{ BaseURL("https://fabric.random.example.org"), DeviceID("mac:112233445566"), @@ -181,7 +192,7 @@ func TestInstructions_EndToEnd(t *testing.T) { }, { description: "expired token", times: []int64{1700000000}, - expectedEndpointErr: jwt.ErrTokenExpired, + expectedEndpointErr: jwt.ErrTokenExpired(), opts: []Option{ BaseURL("https://fabric.random.example.org"), DeviceID("mac:112233445566"), @@ -197,7 +208,7 @@ func TestInstructions_EndToEnd(t *testing.T) { assert.Equal(time.Time{}, fe.Expiration) assert.False(fe.TemporaryErr) assert.Equal("", fe.Endpoint) - assert.ErrorIs(fe.Err, jwt.ErrTokenExpired) + assert.ErrorIs(fe.Err, jwt.ErrTokenExpired()) }, }, { description: "times out with nice resolver", @@ -244,9 +255,9 @@ func TestInstructions_EndToEnd(t *testing.T) { assert.Error(fe.Err) }, }, { - description: "no algorithms", - times: []int64{1680000000}, - expectedEndpointErr: jwt.ErrTokenSignatureInvalid, + description: "no algorithms", + times: []int64{1680000000}, + expectedNewErr: ErrInvalidInput, opts: []Option{ BaseURL("https://fabric.random.example.org"), DeviceID("mac:112233445566"), @@ -254,9 +265,9 @@ func TestInstructions_EndToEnd(t *testing.T) { randomResolver(), }, }, { - description: "no keys", - times: []int64{1680000000}, - expectedEndpointErr: jwt.ErrTokenUnverifiable, + description: "no keys", + times: []int64{1680000000}, + expectedNewErr: ErrInvalidInput, opts: []Option{ BaseURL("https://fabric.random.example.org"), DeviceID("mac:112233445566"), @@ -340,7 +351,7 @@ func TestInstructions_EndToEnd(t *testing.T) { return } - when := jwt.WithTimeFunc(then) + when := jwt.WithClock(jwt.ClockFunc(then)) obj.jwtOptions = append(obj.jwtOptions, when) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)