Skip to content

Commit

Permalink
Cache JWKS and keep them in an LRU
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Feb 1, 2025
1 parent 5340c99 commit 00b515e
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 25 deletions.
12 changes: 8 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module github.com/lehigh-university-libraries/scyllaridae

go 1.22.2
go 1.22.6

toolchain go1.23.4

require (
github.com/go-stomp/stomp/v3 v3.1.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/gorilla/mux v1.8.1
github.com/lestrrat-go/jwx v1.2.30
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/lestrrat-go/jwx/v2 v2.1.3
github.com/stretchr/testify v1.9.0
golang.org/x/text v0.21.0
gopkg.in/yaml.v3 v3.0.1
Expand All @@ -18,13 +21,14 @@ require (
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/sys v0.28.0 // indirect
)
17 changes: 10 additions & 7 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,35 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A=
github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx v1.2.30 h1:VKIFrmjYn0z2J51iLPadqoHIVLzvWNa1kCsTqNDHYPA=
github.com/lestrrat-go/jwx v1.2.30/go.mod h1:vMxrwFhunGZ3qddmfmEm2+uced8MSI6QFWGTKygjSzQ=
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/jwx/v2 v2.1.3 h1:Ud4lb2QuxRClYAmRleF50KrbKIoM1TddXgBrneT5/Jo=
github.com/lestrrat-go/jwx/v2 v2.1.3/go.mod h1:q6uFgbgZfEmQrfJfrCo90QcQOcXFMfbI/fO0NqRtvZo=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand All @@ -65,6 +66,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
Expand Down
4 changes: 3 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ func main() {
if len(config.QueueMiddlewares) > 0 {
runStompSubscribers(config)
} else {
server := &Server{Config: config}
server := &Server{
Config: config,
}
runHTTPServer(server)
}
}
41 changes: 30 additions & 11 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config"
"github.com/lehigh-university-libraries/scyllaridae/pkg/api"

"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)

type contextKey string
Expand Down Expand Up @@ -77,7 +77,7 @@ func (s *Server) LoggingMiddleware(next http.Handler) http.Handler {
}

// JWTAuthMiddleware validates a JWT token and adds claims to the context
func JWTAuthMiddleware(next http.Handler) http.Handler {
func (s *Server) JWTAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
a := r.Header.Get("Authorization")
if a == "" || len(a) <= 7 || !strings.EqualFold(a[:7], "bearer ") {
Expand All @@ -90,7 +90,7 @@ func JWTAuthMiddleware(next http.Handler) http.Handler {
if os.Getenv("SKIP_JWT_VERIFY") != "true" {
tokenString := a[7:]
message := r.Context().Value(msgKey).(api.Payload)
err := verifyJWT(tokenString, message)
err := s.verifyJWT(tokenString, message)
if err != nil {
slog.Error("JWT verification failed", "err", err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
Expand All @@ -102,19 +102,22 @@ func JWTAuthMiddleware(next http.Handler) http.Handler {
})
}

func verifyJWT(tokenString string, message api.Payload) error {
keySet, err := fetchJWKS(message)
func (s *Server) verifyJWT(tokenString string, message api.Payload) error {
keySet, err := s.fetchJWKS(message)
if err != nil {
return fmt.Errorf("unable to fetch JWKS: %v", err)
}
key, ok := keySet.Key(0)
if !ok {
return fmt.Errorf("no key in key set")
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

token, err := jwt.Parse([]byte(tokenString),
jwt.WithKeySet(keySet),
jwt.WithKey(jwa.RS256, key),
jwt.WithContext(ctx),
jwt.WithVerify(jwa.RS256, keySet),
)
if err != nil {
return fmt.Errorf("unable to parse token: %v", err)
Expand All @@ -129,7 +132,7 @@ func verifyJWT(tokenString string, message api.Payload) error {
}

// fetchJWKS fetches the JSON Web Key Set (JWKS) from the given URI
func fetchJWKS(message api.Payload) (jwk.Set, error) {
func (s *Server) fetchJWKS(message api.Payload) (jwk.Set, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

Expand All @@ -144,6 +147,22 @@ func fetchJWKS(message api.Payload) (jwk.Set, error) {

jwksURI = fmt.Sprintf("%s://%s/oauth/discovery/keys", parsedURL.Scheme, parsedURL.Host)
}
ks, ok := s.KeySets.Get(jwksURI)
if ok {
return ks, nil
}
c := jwk.NewCache(ctx)
c.Register(jwksURI, jwk.WithMinRefreshInterval(15*time.Minute))

Check failure on line 155 in middleware.go

View workflow job for this annotation

GitHub Actions / lint-test

Error return value of `c.Register` is not checked (errcheck)
_, err := c.Refresh(ctx, jwksURI)
if err != nil {
return nil, err
}

cached := jwk.NewCachedSet(c, jwksURI)
evicted := s.KeySets.Add(jwksURI, cached)
if evicted {
slog.Warn("server jwks cache is too small")
}

return jwk.Fetch(ctx, jwksURI)
return cached, nil
}
15 changes: 13 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,28 @@ import (
"os/exec"

"github.com/gorilla/mux"
lru "github.com/hashicorp/golang-lru/v2"
scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config"
"github.com/lehigh-university-libraries/scyllaridae/pkg/api"
"github.com/lestrrat-go/jwx/v2/jwk"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)

type Server struct {
Config *scyllaridae.ServerConfig
Config *scyllaridae.ServerConfig
KeySets *lru.Cache[string, jwk.Set]
}

func (server *Server) SetupRouter() *mux.Router {
var err error

server.KeySets, err = lru.New[string, jwk.Set](25)
if err != nil {
slog.Error("Unable to create LRU cache for JWKS sets", "err", err)
os.Exit(1)
}

r := mux.NewRouter()
r.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
Expand All @@ -28,7 +39,7 @@ func (server *Server) SetupRouter() *mux.Router {

// create the main route with logging and JWT auth middleware
authRouter := r.PathPrefix("/").Subrouter()
authRouter.Use(server.LoggingMiddleware, JWTAuthMiddleware)
authRouter.Use(server.LoggingMiddleware, server.JWTAuthMiddleware)
authRouter.HandleFunc("/", server.MessageHandler).Methods("GET", "POST")

// make sure 404s get logged
Expand Down

0 comments on commit 00b515e

Please sign in to comment.