diff --git a/common.go b/common.go index bbd2704..7eea3bc 100644 --- a/common.go +++ b/common.go @@ -8,9 +8,11 @@ import ( ) // SafeUnlock safely unlock mutex -func SafeUnlock(mutex *redsync.Mutex) { - if mutex != nil { - _, _ = mutex.Unlock() +func SafeUnlock(mutex ...*redsync.Mutex) { + for _, m := range mutex { + if m != nil { + _, _ = m.Unlock() + } } } diff --git a/keeper.go b/keeper.go index 7c67f8e..0af7285 100644 --- a/keeper.go +++ b/keeper.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "sync" "time" redigo "github.com/gomodule/redigo/redis" @@ -32,6 +33,9 @@ type ( Get(key string) (any, error) GetOrLock(key string) (any, *redsync.Mutex, error) GetOrSet(key string, fn GetterFn, opts ...func(Item)) ([]byte, error) + GetMultiple(keys []string) ([]any, error) + GetMultipleTX(keys []string) ([]any, error) + GetMultipleOrLock(keys []string) ([]any, []*redsync.Mutex, error) Store(*redsync.Mutex, Item) error StoreWithoutBlocking(Item) error StoreMultiWithoutBlocking([]Item) error @@ -155,6 +159,211 @@ func (k *keeper) Get(key string) (cachedItem any, err error) { return nil, nil } +// GetMultipleTX :nodoc: +func (k *keeper) GetMultipleTX(keys []string) (cachedItems []any, err error) { + if k.disableCaching { + return + } + c := k.connPool.Get() + defer func() { + _ = c.Close() + }() + + err = c.Send("MULTI") + if err != nil { + return nil, err + } + + for _, key := range keys { + err = c.Send("GET", key) + if err != nil { + return + } + } + + r, err := c.Do("EXEC") + if err != nil { + return nil, err + } + + return redigo.Values(r, err) +} + +// GetMultiple :nodoc: +func (k *keeper) GetMultiple(keys []string) (cachedItems []any, err error) { + if k.disableCaching { + return + } + c := k.connPool.Get() + defer func() { + _ = c.Close() + }() + + for _, key := range keys { + err = c.Send("GET", key) + if err != nil { + return + } + } + + err = c.Flush() + if err != nil { + return + } + + for range keys { + rep, err := redigo.Bytes(c.Receive()) + if err != nil && err != redigo.ErrNil { + return nil, err + } + cachedItems = append(cachedItems, rep) + } + + return +} + +// GetMultipleOrLock get multiple and apply locks for non-existing keys on redis. +// Returned cached items will be in order based on keys provided, if the value for some key is not exist then it will be marked as nil on +// returned cached items slice. +func (k *keeper) GetMultipleOrLock(keys []string) (cachedItems []any, mutexes []*redsync.Mutex, err error) { + if k.disableCaching { + return + } + + c := k.connPool.Get() + defer func() { + _ = c.Close() + }() + + for _, key := range keys { + err = c.Send("GET", key) + if err != nil { + return + } + } + + err = c.Flush() + if err != nil { + return + } + + var ( + keysToLock []string + cachedItemsBuf = make(map[string]any) + mutexesBuf = make(map[string]*redsync.Mutex) + ) + for _, k := range keys { + rep, err := redigo.Bytes(c.Receive()) + if err != nil && err != redigo.ErrNil { + return nil, nil, err + } + if rep == nil { + keysToLock = append(keysToLock, k) + continue + } + cachedItemsBuf[k] = rep + } + + type itemWithKey struct { + Key string + Item any + } + + type mutexWithKey struct { + Key string + Mutex *redsync.Mutex + } + + var ( + itemCh = make(chan *itemWithKey) + errCh = make(chan error) + mutexCh = make(chan *mutexWithKey) + ) + + for _, key := range keysToLock { + go func(key string) { + mutex, err := k.AcquireLock(key) + if err == nil { + mutexCh <- &mutexWithKey{Mutex: mutex, Key: key} + return + } + start := time.Now() + for { + b := &backoff.Backoff{ + Jitter: true, + Min: 20 * time.Millisecond, + Max: 200 * time.Millisecond, + } + + if !k.isLocked(key) { + cachedItem, err := get(k.connPool.Get(), key) + if err != nil { + if err == ErrKeyNotExist { + mutex, err = k.AcquireLock(key) + if err == nil { + mutexCh <- &mutexWithKey{Mutex: mutex, Key: key} + return + } + goto Wait + } + errCh <- err + return + } + itemCh <- &itemWithKey{Item: cachedItem, Key: key} + return + } + + Wait: + elapsed := time.Since(start) + if elapsed >= k.waitTime { + errCh <- ErrWaitTooLong + return + } + time.Sleep(b.Duration()) + } + }(key) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + counter := 0 + for { + select { + case i := <-itemCh: + cachedItemsBuf[i.Key] = i.Item + counter++ + case err = <-errCh: + return + case m := <-mutexCh: + mutexesBuf[m.Key] = m.Mutex + counter++ + default: + if counter == len(keysToLock) { + return + } + } + } + }() + + wg.Wait() + if err != nil { + return + } + + for _, k := range keys { + if v, ok := cachedItemsBuf[k]; ok { + cachedItems = append(cachedItems, v) + } else if m, ok := mutexesBuf[k]; ok { + mutexes = append(mutexes, m) + cachedItems = append(cachedItems, nil) + } + } + + return +} + // GetOrLock :nodoc: func (k *keeper) GetOrLock(key string) (cachedItem any, mutex *redsync.Mutex, err error) { if k.disableCaching { diff --git a/keeper_test.go b/keeper_test.go index 2897924..c787249 100644 --- a/keeper_test.go +++ b/keeper_test.go @@ -7,10 +7,9 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - redigo "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/alicebob/miniredis/v2" ) @@ -1159,3 +1158,156 @@ func TestHashScan_Empty(t *testing.T) { assert.Empty(t, result) assert.EqualValues(t, 0, cursor) } + +func TestGetMultiple(t *testing.T) { + t.Run("success", func(t *testing.T) { + k := NewKeeper() + m, err := miniredis.Run() + assert.NoError(t, err) + r := newRedisConn(m.Addr()) + k.SetConnectionPool(r) + k.SetLockConnectionPool(r) + k.SetWaitTime(1 * time.Second) // override wait time to 1 second + + keys := []string{"a", "b", "c"} + items := map[string]string{"a": "A", "b": "B", "c": "C"} + for key, val := range items { + _ = k.StoreWithoutBlocking(NewItem(key, val)) + } + res, err := k.GetMultiple(keys) + assert.NoError(t, err) + for i, key := range keys { + assert.EqualValues(t, items[key], res[i]) + } + }) + + t.Run("success with missing cache", func(t *testing.T) { + k := NewKeeper() + m, err := miniredis.Run() + assert.NoError(t, err) + r := newRedisConn(m.Addr()) + k.SetConnectionPool(r) + k.SetLockConnectionPool(r) + k.SetWaitTime(1 * time.Second) // override wait time to 1 second + + keys := []string{"d", "b", "a", "o", "c"} + items := map[string]string{"b": "B", "o": "O"} + for key, val := range items { + _ = k.StoreWithoutBlocking(NewItem(key, val)) + } + + res, err := k.GetMultiple(keys) + assert.NoError(t, err) + for i, key := range keys { + if _, ok := items[key]; !ok { + assert.Nil(t, res[i]) + continue + } + assert.EqualValues(t, items[key], res[i]) + } + }) +} + +func TestGetMultipleOrLock(t *testing.T) { + t.Run("success get all locks", func(t *testing.T) { + k := NewKeeper() + m, err := miniredis.Run() + assert.NoError(t, err) + r := newRedisConn(m.Addr()) + k.SetConnectionPool(r) + k.SetLockConnectionPool(r) + + keys := []string{"key1", "key2", "key3"} + items, mutexes, err := k.GetMultipleOrLock(keys) + + assert.NotNil(t, mutexes) + assert.Equal(t, len(keys), len(mutexes)) + assert.Equal(t, len(keys), len(items)) + }) + + t.Run("success get locks for non existing items", func(t *testing.T) { + k := NewKeeper() + m, err := miniredis.Run() + assert.NoError(t, err) + r := newRedisConn(m.Addr()) + k.SetConnectionPool(r) + k.SetLockConnectionPool(r) + k.SetDefaultTTL(time.Minute) + + keys := []string{"key1", "key2", "key3"} + + err = k.StoreMultiWithoutBlocking([]Item{NewItem("key2", "key2")}) + assert.NoError(t, err) + + items, mutexes, err := k.GetMultipleOrLock(keys) + assert.NoError(t, err) + assert.Equal(t, 2, len(mutexes)) + assert.Equal(t, len(keys), len(items)) + }) + + t.Run("success get all cached items", func(t *testing.T) { + k := NewKeeper() + m, err := miniredis.Run() + assert.NoError(t, err) + r := newRedisConn(m.Addr()) + k.SetConnectionPool(r) + k.SetLockConnectionPool(r) + k.SetDefaultTTL(time.Minute) + + keys := []string{"key1", "key2", "key3"} + items := []Item{ + NewItem("key1", "key1"), + NewItem("key2", "key2"), + NewItem("key3", "key3"), + } + + err = k.StoreMultiWithoutBlocking(items) + assert.NoError(t, err) + + resp, mutexes, err := k.GetMultipleOrLock(keys) + assert.NoError(t, err) + assert.Nil(t, mutexes) + assert.Equal(t, len(keys), len(resp)) + }) + + t.Run("success with wait for cache key to be exists", func(t *testing.T) { + k := NewKeeper() + m, err := miniredis.Run() + assert.NoError(t, err) + r := newRedisConn(m.Addr()) + k.SetConnectionPool(r) + k.SetLockConnectionPool(r) + k.SetDefaultTTL(time.Minute) + + keys := []string{"key1", "key2", "key3"} + + _, mutexes, err := k.GetMultipleOrLock(keys) + assert.NoError(t, err) + assert.Equal(t, len(keys), len(mutexes)) + + items := map[string]string{ + "key1": "val1", + "key2": "val2", + "key3": "val3", + } + + // store item asynchronously so next call to GetMultipleOrLock will get the result + go func() { + defer SafeUnlock(mutexes...) + time.Sleep(1 * time.Second) + var cacheItems []Item + for k, v := range items { + cacheItems = append(cacheItems, NewItem(k, v)) + } + err := k.StoreMultiWithoutBlocking(cacheItems) + assert.NoError(t, err) + }() + + resp2, mutexes2, err := k.GetMultipleOrLock(keys) + assert.NoError(t, err) + assert.Nil(t, mutexes2) + for i, k := range keys { + assert.EqualValues(t, items[k], resp2[i]) + } + }) +} diff --git a/keeper_with_failover_test.go b/keeper_with_failover_test.go index f469bfe..2f83ad3 100644 --- a/keeper_with_failover_test.go +++ b/keeper_with_failover_test.go @@ -352,5 +352,5 @@ func Test_keeperWithFailover_DeleteHashMember(t *testing.T) { assert.True(t, m.Exists(identifier) && mFO.Exists(identifier)) err = k.DeleteHashMember(identifier, "key") assert.NoError(t, err) - assert.False(t, m.Exists(identifier)) + assert.False(t, m.Exists(identifier) || mFO.Exists(identifier)) }