Skip to content

Commit

Permalink
Refresh keys
Browse files Browse the repository at this point in the history
  • Loading branch information
s12v committed Jul 30, 2018
1 parent aebe6a7 commit babccf4
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 86 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.iml
vendor/*
coverage.txt
example
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A Go library to retrieve RSA public keys from a JWKS (JSON Web Key Set) endpoint
## Installation

```bash
dep ensure --add "github.com/s12v/go-jwks"
dep ensure --add "github.com/s12v/go-jwks@v0.0.1"
```

## Dependencies
Expand All @@ -30,11 +30,11 @@ import (
)

func main() {
jwksSource := jwks.NewWebSource("https://www.googleapis.com/oauth2/v3/certs")
jwksSource := jwks.NewWebSource("https://www.googleapis.com/oauth2/v3/certs")
jwksClient := jwks.NewDefaultClient(
jwksSource,
time.Hour, // Cache keys for 1 hour
10 * time.Minute, // Prefetch key 10 minutes before expiration
time.Hour, // Refresh keys every 1 hour
12*time.Hour, // Expire keys after 12 hours
)

var jwk *jose.JSONWebKey
Expand Down
8 changes: 4 additions & 4 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
)

type Cache interface {
// Get an item from the cache and itsexpiration time.
// Get an item from the cache
// Returns the item or nil, and a bool indicating whether the key was found
GetWithExpiration(k string) (interface{}, time.Time, bool)
Get(k string) (interface{}, bool)
// Add an item to the cache, replacing any existing item.
Set(k string, x interface{})
}
Expand All @@ -22,8 +22,8 @@ func (c *defaultCache) Set(k string, x interface{}) {
c.cache.Set(k, x, cache.DefaultExpiration)
}

func (c *defaultCache) GetWithExpiration(k string) (interface{}, time.Time, bool) {
return c.cache.GetWithExpiration(k)
func (c *defaultCache) Get(k string) (interface{}, bool) {
return c.cache.Get(k)
}

func DefaultCache(ttl time.Duration) Cache {
Expand Down
18 changes: 5 additions & 13 deletions cache_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,22 @@ package jwks
import "time"

type mockCache struct {
m map[string]pair
}

type pair struct {
val interface{}
exp time.Time
m map[string]interface{}
}

func (c *mockCache) Set(k string, x interface{}) {
c.SetWithExpiration(k, x, time.Now())
}

func (c *mockCache) SetWithExpiration(k string, x interface{}, exp time.Time) {
c.m[k] = pair {
val: x,
exp: exp,
}
c.m[k] = x
}

func (c *mockCache) GetWithExpiration(k string) (interface{}, time.Time, bool) {
func (c *mockCache) Get(k string) (interface{}, bool) {
v, exists := c.m[k]
return v.val, v.exp, exists
return v, exists
}

func NewMockCache() *mockCache {
return &mockCache{make(map[string]pair)}
return &mockCache{make(map[string]interface{})}
}
32 changes: 2 additions & 30 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ import (
"time"
)

func TestDefaultCache_GetWithExpiration(t *testing.T) {
func TestDefaultCache_Get(t *testing.T) {
c := &defaultCache{
cache.New(time.Minute, 0),
}

c.Set("key", "val")
val, expTime, found := c.GetWithExpiration("key")

if expTime.Before(time.Now()) {
t.Fatalf("expTime should be after now: %v", expTime)
}
val, found := c.Get("key")

if !found {
t.Fatal("should be found")
Expand All @@ -27,28 +23,6 @@ func TestDefaultCache_GetWithExpiration(t *testing.T) {
}
}

func TestDefaultCache_GetWithExpiration_Expired(t *testing.T) {
c := &defaultCache{
cache.New(time.Nanosecond, 0),
}

c.Set("key", "val")
time.Sleep(10 * time.Millisecond)
val, expTime, found := c.GetWithExpiration("key")

if expTime.After(time.Now()) {
t.Fatalf("expTime should be before now: %v", expTime)
}

if found {
t.Fatal("should be not found")
}

if val != nil {
t.Fatalf("val should be nil, got %v instead", val)
}
}

func TestDefaultCache(t *testing.T) {
DefaultCache(time.Hour)
DefaultCache(0)
Expand All @@ -63,5 +37,3 @@ func TestDefaultCache_InvalidTtl(t *testing.T) {
}()
DefaultCache(-2)
}


74 changes: 43 additions & 31 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,34 @@ type JWKSClient interface {
}

type jWKSClient struct {
source JWKSSource
cache Cache
prefetch time.Duration
sem *semaphore.Weighted
source JWKSSource
cache Cache
refresh time.Duration
sem *semaphore.Weighted
}

type cacheEntry struct {
jwk *jose.JSONWebKey
refresh int64
}

// Creates a new client with default cache implementation
func NewDefaultClient(source JWKSSource, ttl time.Duration, prefetch time.Duration) JWKSClient {
if prefetch >= ttl {
panic(fmt.Sprintf("invalid prefetch: %v greater or eaquals to ttl: %v", prefetch, ttl))
func NewDefaultClient(source JWKSSource, refresh time.Duration, ttl time.Duration) JWKSClient {
if refresh >= ttl {
panic(fmt.Sprintf("invalid refresh: %v greater or eaquals to ttl: %v", refresh, ttl))
}
if refresh < 0 {
panic(fmt.Sprintf("invalid refresh: %v", refresh))
}
return NewClient(source, DefaultCache(ttl), prefetch)
return NewClient(source, DefaultCache(ttl), refresh)
}

func NewClient(source JWKSSource, cache Cache, prefetch time.Duration) JWKSClient {
func NewClient(source JWKSSource, cache Cache, refresh time.Duration) JWKSClient {
return &jWKSClient{
source: source,
cache: cache,
prefetch: prefetch,
sem: semaphore.NewWeighted(1),
source: source,
cache: cache,
refresh: refresh,
sem: semaphore.NewWeighted(1),
}
}

Expand All @@ -45,25 +53,22 @@ func (c *jWKSClient) GetEncryptionKey(keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(keyId, "enc")
}

func (c *jWKSClient) GetKey(keyId string, use string) (*jose.JSONWebKey, error) {
jwk, expiration, found := c.cache.GetWithExpiration(keyId)
if ! found {
var err error
if jwk, err = c.refreshKey(keyId, use); err != nil {
return nil, err
func (c *jWKSClient) GetKey(keyId string, use string) (jwk *jose.JSONWebKey, err error) {
val, found := c.cache.Get(keyId)
if found {
entry := val.(*cacheEntry)
if time.Now().After(time.Unix(entry.refresh, 0)) {
if c.sem.TryAcquire(1) {
go func() {
defer c.sem.Release(1)
c.refreshKey(keyId, use)
}()
}
}
return entry.jwk, nil
} else {
return c.refreshKey(keyId, use)
}

if time.Until(expiration) <= c.prefetch {
if c.sem.TryAcquire(1) {
go func () {
defer c.sem.Release(1)
c.refreshKey(keyId, use)
}()
}
}

return jwk.(*jose.JSONWebKey), nil
}

func (c *jWKSClient) refreshKey(keyId string, use string) (*jose.JSONWebKey, error) {
Expand All @@ -72,10 +77,17 @@ func (c *jWKSClient) refreshKey(keyId string, use string) (*jose.JSONWebKey, err
return nil, err
}

c.cache.Set(keyId, jwk)
c.save(keyId, jwk)
return jwk, nil
}

func (c *jWKSClient) save(keyId string, jwk *jose.JSONWebKey) {
c.cache.Set(keyId, &cacheEntry{
jwk: jwk,
refresh: time.Now().Add(c.refresh).Unix(),
})
}

func (c *jWKSClient) fetchJSONWebKey(keyId string, use string) (*jose.JSONWebKey, error) {
jsonWebKeySet, err := c.source.JSONWebKeySet()
if err != nil {
Expand Down
35 changes: 31 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,48 @@ func TestJWKSClient_GetKeyWithPrefetch(t *testing.T) {
Use: "enc",
}}})
cacheMock := NewMockCache()
cacheMock.SetWithExpiration(keyId, &mockJwk, time.Unix(0, 0))
cacheMock.SetWithExpiration(
keyId,
&cacheEntry{
refresh: 0,
jwk: &mockJwk,
},
time.Unix(0, 0),
)

client := NewClient(sourceMock, cacheMock, time.Minute)

key1, err := client.GetKey(keyId, "sig")
time.Sleep(time.Millisecond * 20)
time.Sleep(time.Millisecond * 5)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if key1.Use != "sig" {
t.Fatalf("unexpected Use: %v", key1.Use)
}

key2, _, _ := cacheMock.GetWithExpiration(keyId)
if key2.(*jose.JSONWebKey).Use != "enc" {
key2, _ := cacheMock.Get(keyId)
if key2.(*cacheEntry).jwk.Use != "enc" {
t.Fatal("key should be updated in cache")
}
}

func TestNewDefaultClient_InvalidNegativeRefresh(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected a panic")
}
}()
sourceMock := NewDummySource(&jose.JSONWebKeySet{})
NewDefaultClient(sourceMock, time.Second, -1)
}

func TestNewDefaultClient_InvalidRefreshBiggerThanTtl(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected a panic")
}
}()
sourceMock := NewDummySource(&jose.JSONWebKeySet{})
NewDefaultClient(sourceMock, time.Minute, time.Second)
}

0 comments on commit babccf4

Please sign in to comment.