Skip to content

Commit

Permalink
Pass context to client requests (#9)
Browse files Browse the repository at this point in the history
* pass context to client requests
* fix up jose imports
  • Loading branch information
bigkraig authored May 3, 2022
1 parent 47b1356 commit f5cc55a
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 35 deletions.
31 changes: 16 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package jwks

import (
"context"
"fmt"
"time"

"github.com/square/go-jose"
"golang.org/x/sync/semaphore"
"gopkg.in/square/go-jose.v2"
)

type JWKSClient interface {
GetKey(keyId string, use string) (*jose.JSONWebKey, error)
GetEncryptionKey(keyId string) (*jose.JSONWebKey, error)
GetSignatureKey(keyId string) (*jose.JSONWebKey, error)
GetKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error)
GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error)
GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error)
}

type jWKSClient struct {
Expand Down Expand Up @@ -46,34 +47,34 @@ func NewClient(source JWKSSource, cache Cache, refresh time.Duration) JWKSClient
}
}

func (c *jWKSClient) GetSignatureKey(keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(keyId, "sig")
func (c *jWKSClient) GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(ctx, keyId, "sig")
}

func (c *jWKSClient) GetEncryptionKey(keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(keyId, "enc")
func (c *jWKSClient) GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(ctx, keyId, "enc")
}

func (c *jWKSClient) GetKey(keyId string, use string) (jwk *jose.JSONWebKey, err error) {
func (c *jWKSClient) GetKey(ctx context.Context, keyId string, use string) (jwk *jose.JSONWebKey, err error) {
val, found := c.cache.Get(keyId)
if found {
entry := val.(*cacheEntry)
if time.Now().After(time.Unix(entry.refresh, 0)) && c.sem.TryAcquire(1) {
go func() {
defer c.sem.Release(1)
if _, err := c.refreshKey(keyId, use); err != nil {
if _, err := c.refreshKey(ctx, keyId, use); err != nil {
logger.Printf("unable to refresh key: %v", err)
}
}()
}
return entry.jwk, nil
} else {
return c.refreshKey(keyId, use)
return c.refreshKey(ctx, keyId, use)
}
}

func (c *jWKSClient) refreshKey(keyId string, use string) (*jose.JSONWebKey, error) {
jwk, err := c.fetchJSONWebKey(keyId, use)
func (c *jWKSClient) refreshKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) {
jwk, err := c.fetchJSONWebKey(ctx, keyId, use)
if err != nil {
return nil, err
}
Expand All @@ -89,8 +90,8 @@ func (c *jWKSClient) save(keyId string, jwk *jose.JSONWebKey) {
})
}

func (c *jWKSClient) fetchJSONWebKey(keyId string, use string) (*jose.JSONWebKey, error) {
jsonWebKeySet, err := c.source.JSONWebKeySet()
func (c *jWKSClient) fetchJSONWebKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) {
jsonWebKeySet, err := c.source.JSONWebKeySet(ctx)
if err != nil {
return nil, err
}
Expand Down
9 changes: 5 additions & 4 deletions client_mock.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package jwks

import (
"github.com/square/go-jose"
"context"
"gopkg.in/square/go-jose.v2"
)

type jWKSClientMock struct {
Expand All @@ -14,15 +15,15 @@ func NewMockClient(secret string) JWKSClient {
}
}

func (c *jWKSClientMock) GetSignatureKey(keyId string) (*jose.JSONWebKey, error) {
func (c *jWKSClientMock) GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) {
return mockKey(c.secret), nil
}

func (c *jWKSClientMock) GetEncryptionKey(keyId string) (*jose.JSONWebKey, error) {
func (c *jWKSClientMock) GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) {
return mockKey(c.secret), nil
}

func (c *jWKSClientMock) GetKey(keyId string, use string) (*jose.JSONWebKey, error) {
func (c *jWKSClientMock) GetKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) {
return mockKey(c.secret), nil
}

Expand Down
9 changes: 6 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package jwks

import (
"context"
"testing"
"time"

"github.com/square/go-jose"
"gopkg.in/square/go-jose.v2"
)

func TestJWKSClient_GetKey(t *testing.T) {
Expand All @@ -13,10 +14,11 @@ func TestJWKSClient_GetKey(t *testing.T) {
KeyID: keyId,
}}})
cacheMock := NewMockCache()
ctx := context.TODO()

client := NewClient(sourceMock, cacheMock, time.Minute)

jwk, err := client.GetKey(keyId, "sig")
jwk, err := client.GetKey(ctx, keyId, "sig")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -45,10 +47,11 @@ func TestJWKSClient_GetKeyWithPrefetch(t *testing.T) {
},
time.Unix(0, 0),
)
ctx := context.TODO()

client := NewClient(sourceMock, cacheMock, time.Minute)

key1, err := client.GetKey(keyId, "sig")
key1, err := client.GetKey(ctx, keyId, "sig")
time.Sleep(time.Millisecond * 5)
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand Down
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ go 1.12
require (
github.com/google/go-cmp v0.5.4 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/square/go-jose v2.6.0+incompatible
github.com/stretchr/testify v1.6.1 // indirect
golang.org/x/crypto v0.0.0-20180621125126-a49355c7e3f8 // indirect
golang.org/x/net v0.0.0-20180729183719-c4299a1a0d85 // indirect
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
gopkg.in/square/go-jose.v2 v2.6.0
)
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/square/go-jose v2.6.0+incompatible h1:X2EdV1z4PViZnxy2B+nY+HizcJKWuHZ0vN8Ju3R2No8=
github.com/square/go-jose v2.6.0+incompatible/go.mod h1:7MxpAF/1WTVUu8Am+T5kNy+t0902CaLWM4Z745MkOa8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand All @@ -21,7 +19,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
18 changes: 13 additions & 5 deletions source.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package jwks

import (
"context"
"encoding/json"
"fmt"
"gopkg.in/square/go-jose.v2"
"net/http"

"github.com/square/go-jose"
)

type JWKSSource interface {
JSONWebKeySet() (*jose.JSONWebKeySet, error)
JSONWebKeySet(ctx context.Context) (*jose.JSONWebKeySet, error)
}

type WebSource struct {
Expand All @@ -28,9 +28,17 @@ func NewWebSource(jwksUri string, client *http.Client) *WebSource {
}
}

func (s *WebSource) JSONWebKeySet() (*jose.JSONWebKeySet, error) {
func (s *WebSource) JSONWebKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) {
logger.Printf("Fetching JWKS from %s", s.jwksUri)
resp, err := s.client.Get(s.jwksUri)

req, err := http.NewRequest("GET", s.jwksUri, nil)
if err != nil {
return nil, err
}

req = req.WithContext(ctx)

resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions source_dummy.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package jwks

import (
"github.com/square/go-jose"
"context"
"gopkg.in/square/go-jose.v2"
)

type DummySource struct {
Expand All @@ -12,6 +13,6 @@ func NewDummySource(jwks *jose.JSONWebKeySet) *DummySource {
return &DummySource{Jwks: jwks}
}

func (s *DummySource) JSONWebKeySet() (*jose.JSONWebKeySet, error) {
func (s *DummySource) JSONWebKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) {
return s.Jwks, nil
}

0 comments on commit f5cc55a

Please sign in to comment.