Skip to content

Commit

Permalink
Add cache wrapper that handles parallel access. (#861)
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Littley <[email protected]>
  • Loading branch information
cody-littley authored Nov 6, 2024
1 parent 14c53d8 commit ae8ccaa
Show file tree
Hide file tree
Showing 2 changed files with 382 additions and 0 deletions.
126 changes: 126 additions & 0 deletions relay/cached_accessor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package relay

import (
lru "github.com/hashicorp/golang-lru/v2"
"sync"
)

// CachedAccessor is an interface for accessing a resource that is cached. It assumes that cache misses
// are expensive, and prevents multiple concurrent cache misses for the same key.
type CachedAccessor[K comparable, V any] interface {
// Get returns the value for the given key. If the value is not in the cache, it will be fetched using the Accessor.
Get(key K) (*V, error)
}

// Accessor is function capable of fetching a value from a resource. Used by CachedAccessor when there is a cache miss.
type Accessor[K comparable, V any] func(key K) (V, error)

// accessResult is a struct that holds the result of an Accessor call.
type accessResult[V any] struct {
// wg.Wait() will block until the value is fetched.
wg sync.WaitGroup
// value is the value fetched by the Accessor, or nil if there was an error.
value *V
// err is the error returned by the Accessor, or nil if the fetch was successful.
err error
}

var _ CachedAccessor[string, string] = &cachedAccessor[string, string]{}

// Future work: the cache used in this implementation is suboptimal when storing items that have a large
// variance in size. The current implementation uses a fixed size cache, which requires the cached to be
// sized to the largest item that will be stored. This cache should be replaced with an implementation
// whose size can be specified by memory footprint in bytes.

// cachedAccessor is an implementation of CachedAccessor.
type cachedAccessor[K comparable, V any] struct {

// lookupsInProgress has an entry for each key that is currently being looked up via the accessor. The value
// is written into the channel when it is eventually fetched. If a key is requested more than once while a
// lookup in progress, the second (and following) requests will wait for the result of the first lookup
// to be written into the channel.
lookupsInProgress map[K]*accessResult[V]

// cache is the LRU cache used to store values fetched by the accessor.
cache *lru.Cache[K, *V]

// lock is used to protect the cache and lookupsInProgress map.
cacheLock sync.Mutex

// accessor is the function used to fetch values that are not in the cache.
accessor Accessor[K, *V]
}

// NewCachedAccessor creates a new CachedAccessor.
func NewCachedAccessor[K comparable, V any](cacheSize int, accessor Accessor[K, *V]) (CachedAccessor[K, V], error) {

cache, err := lru.New[K, *V](cacheSize)
if err != nil {
return nil, err
}

lookupsInProgress := make(map[K]*accessResult[V])

return &cachedAccessor[K, V]{
cache: cache,
accessor: accessor,
lookupsInProgress: lookupsInProgress,
}, nil
}

func newAccessResult[V any]() *accessResult[V] {
result := &accessResult[V]{
wg: sync.WaitGroup{},
}
result.wg.Add(1)
return result
}

func (c *cachedAccessor[K, V]) Get(key K) (*V, error) {

c.cacheLock.Lock()

// first, attempt to get the value from the cache
v, ok := c.cache.Get(key)
if ok {
c.cacheLock.Unlock()
return v, nil
}

// if that fails, check if a lookup is already in progress. If not, start a new one.
result, alreadyLoading := c.lookupsInProgress[key]
if !alreadyLoading {
result = newAccessResult[V]()
c.lookupsInProgress[key] = result
}

c.cacheLock.Unlock()

if alreadyLoading {
// The result is being fetched on another goroutine. Wait for it to finish.
result.wg.Wait()
return result.value, result.err
} else {
// We are the first goroutine to request this key.
value, err := c.accessor(key)

c.cacheLock.Lock()

// Update the cache if the fetch was successful.
if err == nil {
c.cache.Add(key, value)
}

// Provide the result to all other goroutines that may be waiting for it.
result.err = err
result.value = value
result.wg.Done()

// Clean up the lookupInProgress map.
delete(c.lookupsInProgress, key)

c.cacheLock.Unlock()

return value, err
}
}
256 changes: 256 additions & 0 deletions relay/cached_accessor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
package relay

import (
"errors"
tu "github.com/Layr-Labs/eigenda/common/testutils"
"github.com/stretchr/testify/require"
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"
)

func TestRandomOperationsSingleThread(t *testing.T) {
tu.InitializeRandom()

dataSize := 1024

baseData := make(map[int]string)
for i := 0; i < dataSize; i++ {
baseData[i] = tu.RandomString(10)
}

accessor := func(key int) (*string, error) {
// Return an error if the key is a multiple of 17
if key%17 == 0 {
return nil, errors.New("intentional error")
}

str := baseData[key]
return &str, nil
}
cacheSize := rand.Intn(dataSize) + 1

ca, err := NewCachedAccessor(cacheSize, accessor)
require.NoError(t, err)

for i := 0; i < dataSize; i++ {
value, err := ca.Get(i)

if i%17 == 0 {
require.Error(t, err)
require.Nil(t, value)
} else {
require.NoError(t, err)
require.Equal(t, baseData[i], *value)
}
}

for k, v := range baseData {
value, err := ca.Get(k)

if k%17 == 0 {
require.Error(t, err)
require.Nil(t, value)
} else {
require.NoError(t, err)
require.Equal(t, v, *value)
}
}
}

func TestCacheMisses(t *testing.T) {
tu.InitializeRandom()

cacheSize := rand.Intn(10) + 10
keyCount := cacheSize + 1

baseData := make(map[int]string)
for i := 0; i < keyCount; i++ {
baseData[i] = tu.RandomString(10)
}

cacheMissCount := atomic.Uint64{}

accessor := func(key int) (*string, error) {
cacheMissCount.Add(1)
str := baseData[key]
return &str, nil
}

ca, err := NewCachedAccessor(cacheSize, accessor)
require.NoError(t, err)

// Get the first cacheSize keys. This should fill the cache.
expectedCacheMissCount := uint64(0)
for i := 0; i < cacheSize; i++ {
expectedCacheMissCount++
value, err := ca.Get(i)
require.NoError(t, err)
require.Equal(t, baseData[i], *value)
require.Equal(t, expectedCacheMissCount, cacheMissCount.Load())
}

// Get the first cacheSize keys again. This should not increase the cache miss count.
for i := 0; i < cacheSize; i++ {
value, err := ca.Get(i)
require.NoError(t, err)
require.Equal(t, baseData[i], *value)
require.Equal(t, expectedCacheMissCount, cacheMissCount.Load())
}

// Read the last key. This should cause the first key to be evicted.
expectedCacheMissCount++
value, err := ca.Get(cacheSize)
require.NoError(t, err)
require.Equal(t, baseData[cacheSize], *value)

// Read the keys in order. Due to the order of evictions, each read should result in a cache miss.
for i := 0; i < cacheSize; i++ {
expectedCacheMissCount++
value, err := ca.Get(i)
require.NoError(t, err)
require.Equal(t, baseData[i], *value)
require.Equal(t, expectedCacheMissCount, cacheMissCount.Load())
}
}

func ParallelAccessTest(t *testing.T, sleepEnabled bool) {
tu.InitializeRandom()

dataSize := 1024

baseData := make(map[int]string)
for i := 0; i < dataSize; i++ {
baseData[i] = tu.RandomString(10)
}

accessorLock := sync.RWMutex{}
cacheMissCount := atomic.Uint64{}
accessor := func(key int) (*string, error) {

// Intentionally block if accessorLock is held by the outside scope.
// Used to provoke specific race conditions.
accessorLock.Lock()
defer accessorLock.Unlock()

cacheMissCount.Add(1)

str := baseData[key]
return &str, nil
}
cacheSize := rand.Intn(dataSize) + 1

ca, err := NewCachedAccessor(cacheSize, accessor)
require.NoError(t, err)

// Lock the accessor. This will cause all cache misses to block.
accessorLock.Lock()

// Start several goroutines that will attempt to access the same key.
wg := sync.WaitGroup{}
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
value, err := ca.Get(0)
require.NoError(t, err)
require.Equal(t, baseData[0], *value)
}()
}

if sleepEnabled {
// Wait for the goroutines to start. We want to give the goroutines a chance to do naughty things if they want.
// Eliminating this sleep will not cause the test to fail, but it may cause the test not to exercise the
// desired race condition.
time.Sleep(100 * time.Millisecond)
}

// Unlock the accessor. This will allow the goroutines to proceed.
accessorLock.Unlock()

// Wait for the goroutines to finish.
wg.Wait()

// Only one of the goroutines should have called into the accessor.
require.Equal(t, uint64(1), cacheMissCount.Load())

// Fetching the key again should not result in a cache miss.
value, err := ca.Get(0)
require.NoError(t, err)
require.Equal(t, baseData[0], *value)
require.Equal(t, uint64(1), cacheMissCount.Load())

// The internal lookupsInProgress map should no longer contain the key.
require.Equal(t, 0, len(ca.(*cachedAccessor[int, string]).lookupsInProgress))
}

func TestParallelAccess(t *testing.T) {
// To show that the sleep is not necessary, we run the test twice: once with the sleep enabled and once without.
// The purpose of the sleep is to make a certain type of race condition more likely to occur.

ParallelAccessTest(t, false)
ParallelAccessTest(t, true)
}

func TestParallelAccessWithError(t *testing.T) {
tu.InitializeRandom()

accessorLock := sync.RWMutex{}
cacheMissCount := atomic.Uint64{}
accessor := func(key int) (*string, error) {
// Intentionally block if accessorLock is held by the outside scope.
// Used to provoke specific race conditions.
accessorLock.Lock()
defer accessorLock.Unlock()

cacheMissCount.Add(1)

return nil, errors.New("intentional error")
}
cacheSize := 100

ca, err := NewCachedAccessor(cacheSize, accessor)
require.NoError(t, err)

// Lock the accessor. This will cause all cache misses to block.
accessorLock.Lock()

// Start several goroutines that will attempt to access the same key.
wg := sync.WaitGroup{}
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
value, err := ca.Get(0)
require.Nil(t, value)
require.Equal(t, errors.New("intentional error"), err)
}()
}

// Wait for the goroutines to start. We want to give the goroutines a chance to do naughty things if they want.
// Eliminating this sleep will not cause the test to fail, but it may cause the test not to exercise the
// desired race condition.
time.Sleep(100 * time.Millisecond)

// Unlock the accessor. This will allow the goroutines to proceed.
accessorLock.Unlock()

// Wait for the goroutines to finish.
wg.Wait()

// At least one of the goroutines should have called into the accessor. In theory all of them could have,
// but most likely it will be exactly one.
count := cacheMissCount.Load()
require.True(t, count >= 1)

// Fetching the key again should result in another cache miss since the previous fetch failed.
value, err := ca.Get(0)
require.Nil(t, value)
require.Equal(t, errors.New("intentional error"), err)
require.Equal(t, count+1, cacheMissCount.Load())

// The internal lookupsInProgress map should no longer contain the key.
require.Equal(t, 0, len(ca.(*cachedAccessor[int, string]).lookupsInProgress))
}

0 comments on commit ae8ccaa

Please sign in to comment.