From f5cc55a856c51a87e515c50088dc3f9c17ac0f97 Mon Sep 17 00:00:00 2001 From: Kraig Amador <508403+bigkraig@users.noreply.github.com> Date: Tue, 3 May 2022 07:50:34 -0700 Subject: [PATCH] Pass context to client requests (#9) * pass context to client requests * fix up jose imports --- client.go | 31 ++++++++++++++++--------------- client_mock.go | 9 +++++---- client_test.go | 9 ++++++--- go.mod | 3 +-- go.sum | 6 ++---- source.go | 18 +++++++++++++----- source_dummy.go | 5 +++-- 7 files changed, 46 insertions(+), 35 deletions(-) diff --git a/client.go b/client.go index 7c6ac60..278d5fe 100644 --- a/client.go +++ b/client.go @@ -1,17 +1,18 @@ package jwks import ( + "context" "fmt" "time" - "github.com/square/go-jose" "golang.org/x/sync/semaphore" + "gopkg.in/square/go-jose.v2" ) type JWKSClient interface { - GetKey(keyId string, use string) (*jose.JSONWebKey, error) - GetEncryptionKey(keyId string) (*jose.JSONWebKey, error) - GetSignatureKey(keyId string) (*jose.JSONWebKey, error) + GetKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) + GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) + GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) } type jWKSClient struct { @@ -46,34 +47,34 @@ func NewClient(source JWKSSource, cache Cache, refresh time.Duration) JWKSClient } } -func (c *jWKSClient) GetSignatureKey(keyId string) (*jose.JSONWebKey, error) { - return c.GetKey(keyId, "sig") +func (c *jWKSClient) GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) { + return c.GetKey(ctx, keyId, "sig") } -func (c *jWKSClient) GetEncryptionKey(keyId string) (*jose.JSONWebKey, error) { - return c.GetKey(keyId, "enc") +func (c *jWKSClient) GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) { + return c.GetKey(ctx, keyId, "enc") } -func (c *jWKSClient) GetKey(keyId string, use string) (jwk *jose.JSONWebKey, err error) { +func (c *jWKSClient) GetKey(ctx context.Context, keyId string, use string) (jwk *jose.JSONWebKey, err error) { val, found := c.cache.Get(keyId) if found { entry := val.(*cacheEntry) if time.Now().After(time.Unix(entry.refresh, 0)) && c.sem.TryAcquire(1) { go func() { defer c.sem.Release(1) - if _, err := c.refreshKey(keyId, use); err != nil { + if _, err := c.refreshKey(ctx, keyId, use); err != nil { logger.Printf("unable to refresh key: %v", err) } }() } return entry.jwk, nil } else { - return c.refreshKey(keyId, use) + return c.refreshKey(ctx, keyId, use) } } -func (c *jWKSClient) refreshKey(keyId string, use string) (*jose.JSONWebKey, error) { - jwk, err := c.fetchJSONWebKey(keyId, use) +func (c *jWKSClient) refreshKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) { + jwk, err := c.fetchJSONWebKey(ctx, keyId, use) if err != nil { return nil, err } @@ -89,8 +90,8 @@ func (c *jWKSClient) save(keyId string, jwk *jose.JSONWebKey) { }) } -func (c *jWKSClient) fetchJSONWebKey(keyId string, use string) (*jose.JSONWebKey, error) { - jsonWebKeySet, err := c.source.JSONWebKeySet() +func (c *jWKSClient) fetchJSONWebKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) { + jsonWebKeySet, err := c.source.JSONWebKeySet(ctx) if err != nil { return nil, err } diff --git a/client_mock.go b/client_mock.go index 1a0014f..c0c5ea3 100644 --- a/client_mock.go +++ b/client_mock.go @@ -1,7 +1,8 @@ package jwks import ( - "github.com/square/go-jose" + "context" + "gopkg.in/square/go-jose.v2" ) type jWKSClientMock struct { @@ -14,15 +15,15 @@ func NewMockClient(secret string) JWKSClient { } } -func (c *jWKSClientMock) GetSignatureKey(keyId string) (*jose.JSONWebKey, error) { +func (c *jWKSClientMock) GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) { return mockKey(c.secret), nil } -func (c *jWKSClientMock) GetEncryptionKey(keyId string) (*jose.JSONWebKey, error) { +func (c *jWKSClientMock) GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) { return mockKey(c.secret), nil } -func (c *jWKSClientMock) GetKey(keyId string, use string) (*jose.JSONWebKey, error) { +func (c *jWKSClientMock) GetKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) { return mockKey(c.secret), nil } diff --git a/client_test.go b/client_test.go index 22d2418..60c76ed 100644 --- a/client_test.go +++ b/client_test.go @@ -1,10 +1,11 @@ package jwks import ( + "context" "testing" "time" - "github.com/square/go-jose" + "gopkg.in/square/go-jose.v2" ) func TestJWKSClient_GetKey(t *testing.T) { @@ -13,10 +14,11 @@ func TestJWKSClient_GetKey(t *testing.T) { KeyID: keyId, }}}) cacheMock := NewMockCache() + ctx := context.TODO() client := NewClient(sourceMock, cacheMock, time.Minute) - jwk, err := client.GetKey(keyId, "sig") + jwk, err := client.GetKey(ctx, keyId, "sig") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -45,10 +47,11 @@ func TestJWKSClient_GetKeyWithPrefetch(t *testing.T) { }, time.Unix(0, 0), ) + ctx := context.TODO() client := NewClient(sourceMock, cacheMock, time.Minute) - key1, err := client.GetKey(keyId, "sig") + key1, err := client.GetKey(ctx, keyId, "sig") time.Sleep(time.Millisecond * 5) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/go.mod b/go.mod index e8c59d7..e7cff7a 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,9 @@ go 1.12 require ( github.com/google/go-cmp v0.5.4 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/square/go-jose v2.6.0+incompatible github.com/stretchr/testify v1.6.1 // indirect golang.org/x/crypto v0.0.0-20180621125126-a49355c7e3f8 // indirect golang.org/x/net v0.0.0-20180729183719-c4299a1a0d85 // indirect golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f - gopkg.in/square/go-jose.v2 v2.3.1 // indirect + gopkg.in/square/go-jose.v2 v2.6.0 ) diff --git a/go.sum b/go.sum index b52c28f..d4cd164 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/square/go-jose v2.6.0+incompatible h1:X2EdV1z4PViZnxy2B+nY+HizcJKWuHZ0vN8Ju3R2No8= -github.com/square/go-jose v2.6.0+incompatible/go.mod h1:7MxpAF/1WTVUu8Am+T5kNy+t0902CaLWM4Z745MkOa8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -21,7 +19,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= -gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/source.go b/source.go index ec232af..50f4b34 100644 --- a/source.go +++ b/source.go @@ -1,15 +1,15 @@ package jwks import ( + "context" "encoding/json" "fmt" + "gopkg.in/square/go-jose.v2" "net/http" - - "github.com/square/go-jose" ) type JWKSSource interface { - JSONWebKeySet() (*jose.JSONWebKeySet, error) + JSONWebKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) } type WebSource struct { @@ -28,9 +28,17 @@ func NewWebSource(jwksUri string, client *http.Client) *WebSource { } } -func (s *WebSource) JSONWebKeySet() (*jose.JSONWebKeySet, error) { +func (s *WebSource) JSONWebKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { logger.Printf("Fetching JWKS from %s", s.jwksUri) - resp, err := s.client.Get(s.jwksUri) + + req, err := http.NewRequest("GET", s.jwksUri, nil) + if err != nil { + return nil, err + } + + req = req.WithContext(ctx) + + resp, err := s.client.Do(req) if err != nil { return nil, err } diff --git a/source_dummy.go b/source_dummy.go index 4cf01e2..4ed9ce4 100644 --- a/source_dummy.go +++ b/source_dummy.go @@ -1,7 +1,8 @@ package jwks import ( - "github.com/square/go-jose" + "context" + "gopkg.in/square/go-jose.v2" ) type DummySource struct { @@ -12,6 +13,6 @@ func NewDummySource(jwks *jose.JSONWebKeySet) *DummySource { return &DummySource{Jwks: jwks} } -func (s *DummySource) JSONWebKeySet() (*jose.JSONWebKeySet, error) { +func (s *DummySource) JSONWebKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { return s.Jwks, nil }