Skip to content

Commit

Permalink
Add a mutex to read and delete keys to avoid concurrent map writes.
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Ortiz <[email protected]>
  • Loading branch information
taik0 committed Oct 23, 2023
1 parent f0bba6f commit bfdec20
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions key_cacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ type MemoryKeyCacher struct {
maxKeyAge time.Duration
maxCacheSize int
keyIDGetter KeyIDGetter
mu *sync.RWMutex
}

type keyCacherEntry struct {
Expand Down Expand Up @@ -160,24 +161,20 @@ func (gkc *GMemoryKeyCacher) Get(keyID string) (*jose.JSONWebKey, error) {

// NewMemoryKeyCacher creates a new Keycacher interface with option
// to set max age of cached keys and max size of the cache.
func NewMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) KeyCacher {
func NewMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) *MemoryKeyCacher {
return &MemoryKeyCacher{
entries: map[string]keyCacherEntry{},
maxKeyAge: maxKeyAge,
maxCacheSize: maxCacheSize,
keyIDGetter: KeyIDGetterFactory(keyIdentifyStrategy),
mu: new(sync.RWMutex),
}
}

func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) KeyCacher {
func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) *GMemoryKeyCacher {
kc := &GMemoryKeyCacher{
MemoryKeyCacher: &MemoryKeyCacher{
entries: map[string]keyCacherEntry{},
maxKeyAge: maxKeyAge,
maxCacheSize: maxCacheSize,
keyIDGetter: KeyIDGetterFactory(keyIdentifyStrategy),
},
Global: GlobalCacher{},
MemoryKeyCacher: NewMemoryKeyCacher(maxKeyAge, maxCacheSize, keyIdentifyStrategy),
Global: GlobalCacher{},
}
if keyIdentifyStrategy == "" {
keyIdentifyStrategy = defaultStrategy
Expand All @@ -191,7 +188,9 @@ func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIden

// Get obtains a key from the cache, and checks if the key is expired
func (mkc *MemoryKeyCacher) Get(keyID string) (*jose.JSONWebKey, error) {
mkc.mu.RLock()
searchKey, ok := mkc.entries[keyID]
mkc.mu.RUnlock()
if ok {
if mkc.maxKeyAge == MaxKeyAgeNoCheck || !mkc.keyIsExpired(keyID) {
return &searchKey.JSONWebKey, nil
Expand All @@ -212,18 +211,22 @@ func (mkc *MemoryKeyCacher) Add(keyID string, downloadedKeys []jose.JSONWebKey)
addingKeyID = cacheKey
}
if mkc.maxCacheSize == -1 {
mkc.mu.Lock()
mkc.entries[cacheKey] = keyCacherEntry{
addedAt: time.Now(),
JSONWebKey: downloadedKeys[i],
}
mkc.mu.Unlock()
}
}
if addingKey.Key != nil {
if mkc.maxCacheSize != -1 {
mkc.mu.Lock()
mkc.entries[addingKeyID] = keyCacherEntry{
addedAt: time.Now(),
JSONWebKey: addingKey,
}
mkc.mu.Unlock()
mkc.handleOverflow()
}
return &addingKey, nil
Expand All @@ -233,8 +236,14 @@ func (mkc *MemoryKeyCacher) Add(keyID string, downloadedKeys []jose.JSONWebKey)

// keyIsExpired deletes the key from cache if it is expired
func (mkc *MemoryKeyCacher) keyIsExpired(keyID string) bool {
if time.Now().After(mkc.entries[keyID].addedAt.Add(mkc.maxKeyAge)) {
mkc.mu.RLock()
entry := mkc.entries[keyID].addedAt.Add(mkc.maxKeyAge)
mkc.mu.RUnlock()

if time.Now().After(entry) {
mkc.mu.Lock()
delete(mkc.entries, keyID)
mkc.mu.Unlock()
return true
}
return false
Expand All @@ -243,6 +252,7 @@ func (mkc *MemoryKeyCacher) keyIsExpired(keyID string) bool {
// handleOverflow deletes the oldest key from the cache if overflowed
func (mkc *MemoryKeyCacher) handleOverflow() {
if mkc.maxCacheSize < len(mkc.entries) {
mkc.mu.RLock()
var oldestEntryKeyID string
latestAddedTime := time.Now()
for entryKeyID := range mkc.entries {
Expand All @@ -251,6 +261,10 @@ func (mkc *MemoryKeyCacher) handleOverflow() {
oldestEntryKeyID = entryKeyID
}
}
mkc.mu.RUnlock()

mkc.mu.Lock()
delete(mkc.entries, oldestEntryKeyID)
mkc.mu.Unlock()
}
}

0 comments on commit bfdec20

Please sign in to comment.