From ae8ccaa6e2e8d165c7cd44b892c631bdfec9c9be Mon Sep 17 00:00:00 2001 From: Cody Littley <56973212+cody-littley@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:33:25 -0600 Subject: [PATCH] Add cache wrapper that handles parallel access. (#861) Signed-off-by: Cody Littley --- relay/cached_accessor.go | 126 +++++++++++++++++ relay/cached_accessor_test.go | 256 ++++++++++++++++++++++++++++++++++ 2 files changed, 382 insertions(+) create mode 100644 relay/cached_accessor.go create mode 100644 relay/cached_accessor_test.go diff --git a/relay/cached_accessor.go b/relay/cached_accessor.go new file mode 100644 index 0000000000..c42b48af0d --- /dev/null +++ b/relay/cached_accessor.go @@ -0,0 +1,126 @@ +package relay + +import ( + lru "github.com/hashicorp/golang-lru/v2" + "sync" +) + +// CachedAccessor is an interface for accessing a resource that is cached. It assumes that cache misses +// are expensive, and prevents multiple concurrent cache misses for the same key. +type CachedAccessor[K comparable, V any] interface { + // Get returns the value for the given key. If the value is not in the cache, it will be fetched using the Accessor. + Get(key K) (*V, error) +} + +// Accessor is function capable of fetching a value from a resource. Used by CachedAccessor when there is a cache miss. +type Accessor[K comparable, V any] func(key K) (V, error) + +// accessResult is a struct that holds the result of an Accessor call. +type accessResult[V any] struct { + // wg.Wait() will block until the value is fetched. + wg sync.WaitGroup + // value is the value fetched by the Accessor, or nil if there was an error. + value *V + // err is the error returned by the Accessor, or nil if the fetch was successful. + err error +} + +var _ CachedAccessor[string, string] = &cachedAccessor[string, string]{} + +// Future work: the cache used in this implementation is suboptimal when storing items that have a large +// variance in size. The current implementation uses a fixed size cache, which requires the cached to be +// sized to the largest item that will be stored. This cache should be replaced with an implementation +// whose size can be specified by memory footprint in bytes. + +// cachedAccessor is an implementation of CachedAccessor. +type cachedAccessor[K comparable, V any] struct { + + // lookupsInProgress has an entry for each key that is currently being looked up via the accessor. The value + // is written into the channel when it is eventually fetched. If a key is requested more than once while a + // lookup in progress, the second (and following) requests will wait for the result of the first lookup + // to be written into the channel. + lookupsInProgress map[K]*accessResult[V] + + // cache is the LRU cache used to store values fetched by the accessor. + cache *lru.Cache[K, *V] + + // lock is used to protect the cache and lookupsInProgress map. + cacheLock sync.Mutex + + // accessor is the function used to fetch values that are not in the cache. + accessor Accessor[K, *V] +} + +// NewCachedAccessor creates a new CachedAccessor. +func NewCachedAccessor[K comparable, V any](cacheSize int, accessor Accessor[K, *V]) (CachedAccessor[K, V], error) { + + cache, err := lru.New[K, *V](cacheSize) + if err != nil { + return nil, err + } + + lookupsInProgress := make(map[K]*accessResult[V]) + + return &cachedAccessor[K, V]{ + cache: cache, + accessor: accessor, + lookupsInProgress: lookupsInProgress, + }, nil +} + +func newAccessResult[V any]() *accessResult[V] { + result := &accessResult[V]{ + wg: sync.WaitGroup{}, + } + result.wg.Add(1) + return result +} + +func (c *cachedAccessor[K, V]) Get(key K) (*V, error) { + + c.cacheLock.Lock() + + // first, attempt to get the value from the cache + v, ok := c.cache.Get(key) + if ok { + c.cacheLock.Unlock() + return v, nil + } + + // if that fails, check if a lookup is already in progress. If not, start a new one. + result, alreadyLoading := c.lookupsInProgress[key] + if !alreadyLoading { + result = newAccessResult[V]() + c.lookupsInProgress[key] = result + } + + c.cacheLock.Unlock() + + if alreadyLoading { + // The result is being fetched on another goroutine. Wait for it to finish. + result.wg.Wait() + return result.value, result.err + } else { + // We are the first goroutine to request this key. + value, err := c.accessor(key) + + c.cacheLock.Lock() + + // Update the cache if the fetch was successful. + if err == nil { + c.cache.Add(key, value) + } + + // Provide the result to all other goroutines that may be waiting for it. + result.err = err + result.value = value + result.wg.Done() + + // Clean up the lookupInProgress map. + delete(c.lookupsInProgress, key) + + c.cacheLock.Unlock() + + return value, err + } +} diff --git a/relay/cached_accessor_test.go b/relay/cached_accessor_test.go new file mode 100644 index 0000000000..791214705e --- /dev/null +++ b/relay/cached_accessor_test.go @@ -0,0 +1,256 @@ +package relay + +import ( + "errors" + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/stretchr/testify/require" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestRandomOperationsSingleThread(t *testing.T) { + tu.InitializeRandom() + + dataSize := 1024 + + baseData := make(map[int]string) + for i := 0; i < dataSize; i++ { + baseData[i] = tu.RandomString(10) + } + + accessor := func(key int) (*string, error) { + // Return an error if the key is a multiple of 17 + if key%17 == 0 { + return nil, errors.New("intentional error") + } + + str := baseData[key] + return &str, nil + } + cacheSize := rand.Intn(dataSize) + 1 + + ca, err := NewCachedAccessor(cacheSize, accessor) + require.NoError(t, err) + + for i := 0; i < dataSize; i++ { + value, err := ca.Get(i) + + if i%17 == 0 { + require.Error(t, err) + require.Nil(t, value) + } else { + require.NoError(t, err) + require.Equal(t, baseData[i], *value) + } + } + + for k, v := range baseData { + value, err := ca.Get(k) + + if k%17 == 0 { + require.Error(t, err) + require.Nil(t, value) + } else { + require.NoError(t, err) + require.Equal(t, v, *value) + } + } +} + +func TestCacheMisses(t *testing.T) { + tu.InitializeRandom() + + cacheSize := rand.Intn(10) + 10 + keyCount := cacheSize + 1 + + baseData := make(map[int]string) + for i := 0; i < keyCount; i++ { + baseData[i] = tu.RandomString(10) + } + + cacheMissCount := atomic.Uint64{} + + accessor := func(key int) (*string, error) { + cacheMissCount.Add(1) + str := baseData[key] + return &str, nil + } + + ca, err := NewCachedAccessor(cacheSize, accessor) + require.NoError(t, err) + + // Get the first cacheSize keys. This should fill the cache. + expectedCacheMissCount := uint64(0) + for i := 0; i < cacheSize; i++ { + expectedCacheMissCount++ + value, err := ca.Get(i) + require.NoError(t, err) + require.Equal(t, baseData[i], *value) + require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) + } + + // Get the first cacheSize keys again. This should not increase the cache miss count. + for i := 0; i < cacheSize; i++ { + value, err := ca.Get(i) + require.NoError(t, err) + require.Equal(t, baseData[i], *value) + require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) + } + + // Read the last key. This should cause the first key to be evicted. + expectedCacheMissCount++ + value, err := ca.Get(cacheSize) + require.NoError(t, err) + require.Equal(t, baseData[cacheSize], *value) + + // Read the keys in order. Due to the order of evictions, each read should result in a cache miss. + for i := 0; i < cacheSize; i++ { + expectedCacheMissCount++ + value, err := ca.Get(i) + require.NoError(t, err) + require.Equal(t, baseData[i], *value) + require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) + } +} + +func ParallelAccessTest(t *testing.T, sleepEnabled bool) { + tu.InitializeRandom() + + dataSize := 1024 + + baseData := make(map[int]string) + for i := 0; i < dataSize; i++ { + baseData[i] = tu.RandomString(10) + } + + accessorLock := sync.RWMutex{} + cacheMissCount := atomic.Uint64{} + accessor := func(key int) (*string, error) { + + // Intentionally block if accessorLock is held by the outside scope. + // Used to provoke specific race conditions. + accessorLock.Lock() + defer accessorLock.Unlock() + + cacheMissCount.Add(1) + + str := baseData[key] + return &str, nil + } + cacheSize := rand.Intn(dataSize) + 1 + + ca, err := NewCachedAccessor(cacheSize, accessor) + require.NoError(t, err) + + // Lock the accessor. This will cause all cache misses to block. + accessorLock.Lock() + + // Start several goroutines that will attempt to access the same key. + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + value, err := ca.Get(0) + require.NoError(t, err) + require.Equal(t, baseData[0], *value) + }() + } + + if sleepEnabled { + // Wait for the goroutines to start. We want to give the goroutines a chance to do naughty things if they want. + // Eliminating this sleep will not cause the test to fail, but it may cause the test not to exercise the + // desired race condition. + time.Sleep(100 * time.Millisecond) + } + + // Unlock the accessor. This will allow the goroutines to proceed. + accessorLock.Unlock() + + // Wait for the goroutines to finish. + wg.Wait() + + // Only one of the goroutines should have called into the accessor. + require.Equal(t, uint64(1), cacheMissCount.Load()) + + // Fetching the key again should not result in a cache miss. + value, err := ca.Get(0) + require.NoError(t, err) + require.Equal(t, baseData[0], *value) + require.Equal(t, uint64(1), cacheMissCount.Load()) + + // The internal lookupsInProgress map should no longer contain the key. + require.Equal(t, 0, len(ca.(*cachedAccessor[int, string]).lookupsInProgress)) +} + +func TestParallelAccess(t *testing.T) { + // To show that the sleep is not necessary, we run the test twice: once with the sleep enabled and once without. + // The purpose of the sleep is to make a certain type of race condition more likely to occur. + + ParallelAccessTest(t, false) + ParallelAccessTest(t, true) +} + +func TestParallelAccessWithError(t *testing.T) { + tu.InitializeRandom() + + accessorLock := sync.RWMutex{} + cacheMissCount := atomic.Uint64{} + accessor := func(key int) (*string, error) { + // Intentionally block if accessorLock is held by the outside scope. + // Used to provoke specific race conditions. + accessorLock.Lock() + defer accessorLock.Unlock() + + cacheMissCount.Add(1) + + return nil, errors.New("intentional error") + } + cacheSize := 100 + + ca, err := NewCachedAccessor(cacheSize, accessor) + require.NoError(t, err) + + // Lock the accessor. This will cause all cache misses to block. + accessorLock.Lock() + + // Start several goroutines that will attempt to access the same key. + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + value, err := ca.Get(0) + require.Nil(t, value) + require.Equal(t, errors.New("intentional error"), err) + }() + } + + // Wait for the goroutines to start. We want to give the goroutines a chance to do naughty things if they want. + // Eliminating this sleep will not cause the test to fail, but it may cause the test not to exercise the + // desired race condition. + time.Sleep(100 * time.Millisecond) + + // Unlock the accessor. This will allow the goroutines to proceed. + accessorLock.Unlock() + + // Wait for the goroutines to finish. + wg.Wait() + + // At least one of the goroutines should have called into the accessor. In theory all of them could have, + // but most likely it will be exactly one. + count := cacheMissCount.Load() + require.True(t, count >= 1) + + // Fetching the key again should result in another cache miss since the previous fetch failed. + value, err := ca.Get(0) + require.Nil(t, value) + require.Equal(t, errors.New("intentional error"), err) + require.Equal(t, count+1, cacheMissCount.Load()) + + // The internal lookupsInProgress map should no longer contain the key. + require.Equal(t, 0, len(ca.(*cachedAccessor[int, string]).lookupsInProgress)) +}