diff --git a/key_cacher.go b/key_cacher.go index cc088f9..2a5e00e 100644 --- a/key_cacher.go +++ b/key_cacher.go @@ -120,6 +120,7 @@ type MemoryKeyCacher struct { maxKeyAge time.Duration maxCacheSize int keyIDGetter KeyIDGetter + mu *sync.RWMutex } type keyCacherEntry struct { @@ -160,24 +161,20 @@ func (gkc *GMemoryKeyCacher) Get(keyID string) (*jose.JSONWebKey, error) { // NewMemoryKeyCacher creates a new Keycacher interface with option // to set max age of cached keys and max size of the cache. -func NewMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) KeyCacher { +func NewMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) *MemoryKeyCacher { return &MemoryKeyCacher{ entries: map[string]keyCacherEntry{}, maxKeyAge: maxKeyAge, maxCacheSize: maxCacheSize, keyIDGetter: KeyIDGetterFactory(keyIdentifyStrategy), + mu: new(sync.RWMutex), } } -func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) KeyCacher { +func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) *GMemoryKeyCacher { kc := &GMemoryKeyCacher{ - MemoryKeyCacher: &MemoryKeyCacher{ - entries: map[string]keyCacherEntry{}, - maxKeyAge: maxKeyAge, - maxCacheSize: maxCacheSize, - keyIDGetter: KeyIDGetterFactory(keyIdentifyStrategy), - }, - Global: GlobalCacher{}, + MemoryKeyCacher: NewMemoryKeyCacher(maxKeyAge, maxCacheSize, keyIdentifyStrategy), + Global: GlobalCacher{}, } if keyIdentifyStrategy == "" { keyIdentifyStrategy = defaultStrategy @@ -191,7 +188,9 @@ func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIden // Get obtains a key from the cache, and checks if the key is expired func (mkc *MemoryKeyCacher) Get(keyID string) (*jose.JSONWebKey, error) { + mkc.mu.RLock() searchKey, ok := mkc.entries[keyID] + mkc.mu.RUnlock() if ok { if mkc.maxKeyAge == MaxKeyAgeNoCheck || !mkc.keyIsExpired(keyID) { return &searchKey.JSONWebKey, nil @@ -212,18 +211,22 @@ func (mkc *MemoryKeyCacher) Add(keyID string, downloadedKeys []jose.JSONWebKey) addingKeyID = cacheKey } if mkc.maxCacheSize == -1 { + mkc.mu.Lock() mkc.entries[cacheKey] = keyCacherEntry{ addedAt: time.Now(), JSONWebKey: downloadedKeys[i], } + mkc.mu.Unlock() } } if addingKey.Key != nil { if mkc.maxCacheSize != -1 { + mkc.mu.Lock() mkc.entries[addingKeyID] = keyCacherEntry{ addedAt: time.Now(), JSONWebKey: addingKey, } + mkc.mu.Unlock() mkc.handleOverflow() } return &addingKey, nil @@ -233,8 +236,14 @@ func (mkc *MemoryKeyCacher) Add(keyID string, downloadedKeys []jose.JSONWebKey) // keyIsExpired deletes the key from cache if it is expired func (mkc *MemoryKeyCacher) keyIsExpired(keyID string) bool { - if time.Now().After(mkc.entries[keyID].addedAt.Add(mkc.maxKeyAge)) { + mkc.mu.RLock() + entry := mkc.entries[keyID].addedAt.Add(mkc.maxKeyAge) + mkc.mu.RUnlock() + + if time.Now().After(entry) { + mkc.mu.Lock() delete(mkc.entries, keyID) + mkc.mu.Unlock() return true } return false @@ -243,6 +252,7 @@ func (mkc *MemoryKeyCacher) keyIsExpired(keyID string) bool { // handleOverflow deletes the oldest key from the cache if overflowed func (mkc *MemoryKeyCacher) handleOverflow() { if mkc.maxCacheSize < len(mkc.entries) { + mkc.mu.RLock() var oldestEntryKeyID string latestAddedTime := time.Now() for entryKeyID := range mkc.entries { @@ -251,6 +261,10 @@ func (mkc *MemoryKeyCacher) handleOverflow() { oldestEntryKeyID = entryKeyID } } + mkc.mu.RUnlock() + + mkc.mu.Lock() delete(mkc.entries, oldestEntryKeyID) + mkc.mu.Unlock() } }