-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add cache wrapper that handles parallel access. (#861)
Signed-off-by: Cody Littley <[email protected]>
- Loading branch information
1 parent
14c53d8
commit ae8ccaa
Showing
2 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package relay | ||
|
||
import ( | ||
lru "github.com/hashicorp/golang-lru/v2" | ||
"sync" | ||
) | ||
|
||
// CachedAccessor is an interface for accessing a resource that is cached. It assumes that cache misses | ||
// are expensive, and prevents multiple concurrent cache misses for the same key. | ||
type CachedAccessor[K comparable, V any] interface { | ||
// Get returns the value for the given key. If the value is not in the cache, it will be fetched using the Accessor. | ||
Get(key K) (*V, error) | ||
} | ||
|
||
// Accessor is function capable of fetching a value from a resource. Used by CachedAccessor when there is a cache miss. | ||
type Accessor[K comparable, V any] func(key K) (V, error) | ||
|
||
// accessResult is a struct that holds the result of an Accessor call. | ||
type accessResult[V any] struct { | ||
// wg.Wait() will block until the value is fetched. | ||
wg sync.WaitGroup | ||
// value is the value fetched by the Accessor, or nil if there was an error. | ||
value *V | ||
// err is the error returned by the Accessor, or nil if the fetch was successful. | ||
err error | ||
} | ||
|
||
var _ CachedAccessor[string, string] = &cachedAccessor[string, string]{} | ||
|
||
// Future work: the cache used in this implementation is suboptimal when storing items that have a large | ||
// variance in size. The current implementation uses a fixed size cache, which requires the cached to be | ||
// sized to the largest item that will be stored. This cache should be replaced with an implementation | ||
// whose size can be specified by memory footprint in bytes. | ||
|
||
// cachedAccessor is an implementation of CachedAccessor. | ||
type cachedAccessor[K comparable, V any] struct { | ||
|
||
// lookupsInProgress has an entry for each key that is currently being looked up via the accessor. The value | ||
// is written into the channel when it is eventually fetched. If a key is requested more than once while a | ||
// lookup in progress, the second (and following) requests will wait for the result of the first lookup | ||
// to be written into the channel. | ||
lookupsInProgress map[K]*accessResult[V] | ||
|
||
// cache is the LRU cache used to store values fetched by the accessor. | ||
cache *lru.Cache[K, *V] | ||
|
||
// lock is used to protect the cache and lookupsInProgress map. | ||
cacheLock sync.Mutex | ||
|
||
// accessor is the function used to fetch values that are not in the cache. | ||
accessor Accessor[K, *V] | ||
} | ||
|
||
// NewCachedAccessor creates a new CachedAccessor. | ||
func NewCachedAccessor[K comparable, V any](cacheSize int, accessor Accessor[K, *V]) (CachedAccessor[K, V], error) { | ||
|
||
cache, err := lru.New[K, *V](cacheSize) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
lookupsInProgress := make(map[K]*accessResult[V]) | ||
|
||
return &cachedAccessor[K, V]{ | ||
cache: cache, | ||
accessor: accessor, | ||
lookupsInProgress: lookupsInProgress, | ||
}, nil | ||
} | ||
|
||
func newAccessResult[V any]() *accessResult[V] { | ||
result := &accessResult[V]{ | ||
wg: sync.WaitGroup{}, | ||
} | ||
result.wg.Add(1) | ||
return result | ||
} | ||
|
||
func (c *cachedAccessor[K, V]) Get(key K) (*V, error) { | ||
|
||
c.cacheLock.Lock() | ||
|
||
// first, attempt to get the value from the cache | ||
v, ok := c.cache.Get(key) | ||
if ok { | ||
c.cacheLock.Unlock() | ||
return v, nil | ||
} | ||
|
||
// if that fails, check if a lookup is already in progress. If not, start a new one. | ||
result, alreadyLoading := c.lookupsInProgress[key] | ||
if !alreadyLoading { | ||
result = newAccessResult[V]() | ||
c.lookupsInProgress[key] = result | ||
} | ||
|
||
c.cacheLock.Unlock() | ||
|
||
if alreadyLoading { | ||
// The result is being fetched on another goroutine. Wait for it to finish. | ||
result.wg.Wait() | ||
return result.value, result.err | ||
} else { | ||
// We are the first goroutine to request this key. | ||
value, err := c.accessor(key) | ||
|
||
c.cacheLock.Lock() | ||
|
||
// Update the cache if the fetch was successful. | ||
if err == nil { | ||
c.cache.Add(key, value) | ||
} | ||
|
||
// Provide the result to all other goroutines that may be waiting for it. | ||
result.err = err | ||
result.value = value | ||
result.wg.Done() | ||
|
||
// Clean up the lookupInProgress map. | ||
delete(c.lookupsInProgress, key) | ||
|
||
c.cacheLock.Unlock() | ||
|
||
return value, err | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
package relay | ||
|
||
import ( | ||
"errors" | ||
tu "github.com/Layr-Labs/eigenda/common/testutils" | ||
"github.com/stretchr/testify/require" | ||
"math/rand" | ||
"sync" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestRandomOperationsSingleThread(t *testing.T) { | ||
tu.InitializeRandom() | ||
|
||
dataSize := 1024 | ||
|
||
baseData := make(map[int]string) | ||
for i := 0; i < dataSize; i++ { | ||
baseData[i] = tu.RandomString(10) | ||
} | ||
|
||
accessor := func(key int) (*string, error) { | ||
// Return an error if the key is a multiple of 17 | ||
if key%17 == 0 { | ||
return nil, errors.New("intentional error") | ||
} | ||
|
||
str := baseData[key] | ||
return &str, nil | ||
} | ||
cacheSize := rand.Intn(dataSize) + 1 | ||
|
||
ca, err := NewCachedAccessor(cacheSize, accessor) | ||
require.NoError(t, err) | ||
|
||
for i := 0; i < dataSize; i++ { | ||
value, err := ca.Get(i) | ||
|
||
if i%17 == 0 { | ||
require.Error(t, err) | ||
require.Nil(t, value) | ||
} else { | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[i], *value) | ||
} | ||
} | ||
|
||
for k, v := range baseData { | ||
value, err := ca.Get(k) | ||
|
||
if k%17 == 0 { | ||
require.Error(t, err) | ||
require.Nil(t, value) | ||
} else { | ||
require.NoError(t, err) | ||
require.Equal(t, v, *value) | ||
} | ||
} | ||
} | ||
|
||
func TestCacheMisses(t *testing.T) { | ||
tu.InitializeRandom() | ||
|
||
cacheSize := rand.Intn(10) + 10 | ||
keyCount := cacheSize + 1 | ||
|
||
baseData := make(map[int]string) | ||
for i := 0; i < keyCount; i++ { | ||
baseData[i] = tu.RandomString(10) | ||
} | ||
|
||
cacheMissCount := atomic.Uint64{} | ||
|
||
accessor := func(key int) (*string, error) { | ||
cacheMissCount.Add(1) | ||
str := baseData[key] | ||
return &str, nil | ||
} | ||
|
||
ca, err := NewCachedAccessor(cacheSize, accessor) | ||
require.NoError(t, err) | ||
|
||
// Get the first cacheSize keys. This should fill the cache. | ||
expectedCacheMissCount := uint64(0) | ||
for i := 0; i < cacheSize; i++ { | ||
expectedCacheMissCount++ | ||
value, err := ca.Get(i) | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[i], *value) | ||
require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) | ||
} | ||
|
||
// Get the first cacheSize keys again. This should not increase the cache miss count. | ||
for i := 0; i < cacheSize; i++ { | ||
value, err := ca.Get(i) | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[i], *value) | ||
require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) | ||
} | ||
|
||
// Read the last key. This should cause the first key to be evicted. | ||
expectedCacheMissCount++ | ||
value, err := ca.Get(cacheSize) | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[cacheSize], *value) | ||
|
||
// Read the keys in order. Due to the order of evictions, each read should result in a cache miss. | ||
for i := 0; i < cacheSize; i++ { | ||
expectedCacheMissCount++ | ||
value, err := ca.Get(i) | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[i], *value) | ||
require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) | ||
} | ||
} | ||
|
||
func ParallelAccessTest(t *testing.T, sleepEnabled bool) { | ||
tu.InitializeRandom() | ||
|
||
dataSize := 1024 | ||
|
||
baseData := make(map[int]string) | ||
for i := 0; i < dataSize; i++ { | ||
baseData[i] = tu.RandomString(10) | ||
} | ||
|
||
accessorLock := sync.RWMutex{} | ||
cacheMissCount := atomic.Uint64{} | ||
accessor := func(key int) (*string, error) { | ||
|
||
// Intentionally block if accessorLock is held by the outside scope. | ||
// Used to provoke specific race conditions. | ||
accessorLock.Lock() | ||
defer accessorLock.Unlock() | ||
|
||
cacheMissCount.Add(1) | ||
|
||
str := baseData[key] | ||
return &str, nil | ||
} | ||
cacheSize := rand.Intn(dataSize) + 1 | ||
|
||
ca, err := NewCachedAccessor(cacheSize, accessor) | ||
require.NoError(t, err) | ||
|
||
// Lock the accessor. This will cause all cache misses to block. | ||
accessorLock.Lock() | ||
|
||
// Start several goroutines that will attempt to access the same key. | ||
wg := sync.WaitGroup{} | ||
wg.Add(10) | ||
for i := 0; i < 10; i++ { | ||
go func() { | ||
defer wg.Done() | ||
value, err := ca.Get(0) | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[0], *value) | ||
}() | ||
} | ||
|
||
if sleepEnabled { | ||
// Wait for the goroutines to start. We want to give the goroutines a chance to do naughty things if they want. | ||
// Eliminating this sleep will not cause the test to fail, but it may cause the test not to exercise the | ||
// desired race condition. | ||
time.Sleep(100 * time.Millisecond) | ||
} | ||
|
||
// Unlock the accessor. This will allow the goroutines to proceed. | ||
accessorLock.Unlock() | ||
|
||
// Wait for the goroutines to finish. | ||
wg.Wait() | ||
|
||
// Only one of the goroutines should have called into the accessor. | ||
require.Equal(t, uint64(1), cacheMissCount.Load()) | ||
|
||
// Fetching the key again should not result in a cache miss. | ||
value, err := ca.Get(0) | ||
require.NoError(t, err) | ||
require.Equal(t, baseData[0], *value) | ||
require.Equal(t, uint64(1), cacheMissCount.Load()) | ||
|
||
// The internal lookupsInProgress map should no longer contain the key. | ||
require.Equal(t, 0, len(ca.(*cachedAccessor[int, string]).lookupsInProgress)) | ||
} | ||
|
||
func TestParallelAccess(t *testing.T) { | ||
// To show that the sleep is not necessary, we run the test twice: once with the sleep enabled and once without. | ||
// The purpose of the sleep is to make a certain type of race condition more likely to occur. | ||
|
||
ParallelAccessTest(t, false) | ||
ParallelAccessTest(t, true) | ||
} | ||
|
||
func TestParallelAccessWithError(t *testing.T) { | ||
tu.InitializeRandom() | ||
|
||
accessorLock := sync.RWMutex{} | ||
cacheMissCount := atomic.Uint64{} | ||
accessor := func(key int) (*string, error) { | ||
// Intentionally block if accessorLock is held by the outside scope. | ||
// Used to provoke specific race conditions. | ||
accessorLock.Lock() | ||
defer accessorLock.Unlock() | ||
|
||
cacheMissCount.Add(1) | ||
|
||
return nil, errors.New("intentional error") | ||
} | ||
cacheSize := 100 | ||
|
||
ca, err := NewCachedAccessor(cacheSize, accessor) | ||
require.NoError(t, err) | ||
|
||
// Lock the accessor. This will cause all cache misses to block. | ||
accessorLock.Lock() | ||
|
||
// Start several goroutines that will attempt to access the same key. | ||
wg := sync.WaitGroup{} | ||
wg.Add(10) | ||
for i := 0; i < 10; i++ { | ||
go func() { | ||
defer wg.Done() | ||
value, err := ca.Get(0) | ||
require.Nil(t, value) | ||
require.Equal(t, errors.New("intentional error"), err) | ||
}() | ||
} | ||
|
||
// Wait for the goroutines to start. We want to give the goroutines a chance to do naughty things if they want. | ||
// Eliminating this sleep will not cause the test to fail, but it may cause the test not to exercise the | ||
// desired race condition. | ||
time.Sleep(100 * time.Millisecond) | ||
|
||
// Unlock the accessor. This will allow the goroutines to proceed. | ||
accessorLock.Unlock() | ||
|
||
// Wait for the goroutines to finish. | ||
wg.Wait() | ||
|
||
// At least one of the goroutines should have called into the accessor. In theory all of them could have, | ||
// but most likely it will be exactly one. | ||
count := cacheMissCount.Load() | ||
require.True(t, count >= 1) | ||
|
||
// Fetching the key again should result in another cache miss since the previous fetch failed. | ||
value, err := ca.Get(0) | ||
require.Nil(t, value) | ||
require.Equal(t, errors.New("intentional error"), err) | ||
require.Equal(t, count+1, cacheMissCount.Load()) | ||
|
||
// The internal lookupsInProgress map should no longer contain the key. | ||
require.Equal(t, 0, len(ca.(*cachedAccessor[int, string]).lookupsInProgress)) | ||
} |