diff --git a/cache_test.go b/cache_test.go index 161f0a4..cea110b 100644 --- a/cache_test.go +++ b/cache_test.go @@ -4,16 +4,45 @@ import ( "bytes" "crypto/rand" "encoding/binary" + "errors" "fmt" "log" mrand "math/rand" "strconv" "strings" "sync" + "sync/atomic" "testing" "time" ) +// mockTimer is a mock for Timer contract. +type mockTimer struct { + nowCallsCnt uint32 // stores the number of times Now() was called + nowCallback func() uint32 // callback to be executed inside Now() +} + +// Now mock logic. +func (mock *mockTimer) Now() uint32 { + atomic.AddUint32(&mock.nowCallsCnt, 1) + if mock.nowCallback != nil { + return mock.nowCallback() + } + + return uint32(time.Now().Unix()) +} + +// SetNowCallback sets the callback to be executed inside Now(). +// You can control the return value this way. +func (mock *mockTimer) SetNowCallback(callback func() uint32) { + mock.nowCallback = callback +} + +// nowCallsCount returns the number of times Now() was called. +func (mock *mockTimer) NowCallsCount() int { + return int(atomic.LoadUint32(&mock.nowCallsCnt)) +} + func TestFreeCache(t *testing.T) { cache := NewCache(1024) if cache.HitRate() != 0 { @@ -229,20 +258,135 @@ func TestExpire(t *testing.T) { } func TestTTL(t *testing.T) { - cache := NewCache(1024) - key := []byte("abcd") - val := []byte("efgh") - err := cache.Set(key, val, 2) + t.Run("with no expire key", testTTLWithNoExpireKey) + t.Run("with expire key, not yet expired", testTTLWithNotYetExpiredKey) + t.Run("with expire key, expired", testTTLWithExpiredKey) + t.Run("with not found key", testTTLWithNotFoundKey) +} + +func testTTLWithNoExpireKey(t *testing.T) { + t.Parallel() + + // arrange + var now uint32 = 1659954367 + timer := new(mockTimer) + timer.SetNowCallback(func() uint32 { + return now + }) + cache := NewCacheCustomTimer(512*1024, timer) + key := []byte("test-key") + value := []byte("this key does not expire") + expireSeconds := 0 + if err := cache.Set(key, value, expireSeconds); err != nil { + t.Fatalf("prerequisite failed: could not set the key to query ttl for: %v", err) + } + + // act + ttl, err := cache.TTL(key) + + // assert if err != nil { - t.Error("err should be nil", err.Error()) + t.Errorf("expected nil, but got %v", err) } - time.Sleep(time.Second) + if ttl != uint32(expireSeconds) { + t.Errorf("expected %d, but got %d ", expireSeconds, ttl) + } + if timer.NowCallsCount() != 1 { + t.Errorf("expected %d, but got %d ", 1, timer.NowCallsCount()) + } +} + +func testTTLWithNotYetExpiredKey(t *testing.T) { + t.Parallel() + + // arrange + var now uint32 = 1659954368 + timer := new(mockTimer) + timer.SetNowCallback(func() uint32 { + return now + }) + cache := NewCacheCustomTimer(512*1024, timer) + key := []byte("test-key") + value := []byte("this key expires, but is not expired") + expireSeconds := 300 + if err := cache.Set(key, value, expireSeconds); err != nil { + t.Fatalf("prerequisite failed: could not set the key to query ttl for: %v", err) + } + + // act ttl, err := cache.TTL(key) + + // assert if err != nil { - t.Error("err should be nil", err.Error()) + t.Errorf("expected nil, but got %v", err) } - if ttl != 1 { - t.Fatalf("ttl should be 1, but %d returned", ttl) + if ttl != uint32(expireSeconds) { + t.Errorf("expected %d, but got %d ", expireSeconds, ttl) + } + if timer.NowCallsCount() != 2 { // one call from set, one from ttl + t.Errorf("expected %d, but got %d ", 2, timer.NowCallsCount()) + } +} + +func testTTLWithExpiredKey(t *testing.T) { + t.Parallel() + + // arrange + var now uint32 = 1659954369 + expireSeconds := 600 + timer := new(mockTimer) + timer.SetNowCallback(func() uint32 { + switch timer.NowCallsCount() { + case 1: + return now + case 2: + return now + uint32(expireSeconds) + } + + return now + }) + cache := NewCacheCustomTimer(512*1024, timer) + key := []byte("test-key") + value := []byte("this key is expired") + if err := cache.Set(key, value, expireSeconds); err != nil { + t.Fatalf("prerequisite failed: could not set the key to query ttl for: %v", err) + } + + // act + ttl, err := cache.TTL(key) + + // assert + if !errors.Is(err, ErrNotFound) { + t.Errorf("expected %v, but got %v", ErrNotFound, err) + } + if ttl != 0 { + t.Errorf("expected %d, but got %d ", 0, ttl) + } + if timer.NowCallsCount() != 2 { // one call from set, one from ttl + t.Errorf("expected %d, but got %d ", 2, timer.NowCallsCount()) + } +} + +func testTTLWithNotFoundKey(t *testing.T) { + t.Parallel() + + // arrange + timer := new(mockTimer) + cache := NewCacheCustomTimer(512*1024, timer) + key := []byte("test-not-found-key") + + // act + ttl, err := cache.TTL(key) + + // assert + if !errors.Is(err, ErrNotFound) { + t.Errorf("expected %v, but got %v", ErrNotFound, err) + } + if ttl != 0 { + t.Errorf("expected %d, but got %d ", 0, ttl) + } + if timer.NowCallsCount() != 0 { + t.Errorf("expected %d, but got %d ", 0, timer.NowCallsCount()) } } @@ -804,6 +948,35 @@ func BenchmarkHashFunc(b *testing.B) { } } +func benchmarkTTL(expireSeconds int) func(b *testing.B) { + return func(b *testing.B) { + cache := NewCache(512 * 1024) + key := []byte("bench-ttl-key") + value := []byte("bench-ttl-value") + if err := cache.Set(key, value, expireSeconds); err != nil { + b.Fatalf("prerequisite failed: could not set the key to query TTL for: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := cache.TTL(key) + if err != nil { + b.Error(err) + } + } + } +} + +func BenchmarkTTL_withKeyThatDoesNotExpire(b *testing.B) { + benchmarkTTL(0)(b) +} + +func BenchmarkTTL_withKeyThatDoesExpire(b *testing.B) { + benchmarkTTL(30)(b) +} + func TestConcurrentGetTTL(t *testing.T) { cache := NewCache(256 * 1024 * 1024) primaryKey := []byte("hello") diff --git a/segment.go b/segment.go index 9ebea9b..bcc90a9 100644 --- a/segment.go +++ b/segment.go @@ -170,7 +170,7 @@ func (seg *segment) touch(key []byte, hashVal uint64, expireSeconds int) (err er hdr := (*entryHdr)(unsafe.Pointer(&hdrBuf[0])) now := seg.timer.Now() - if hdr.expireAt != 0 && hdr.expireAt <= now { + if isExpired(hdr.expireAt, now) { seg.delEntryPtr(slotId, slot, idx) atomic.AddInt64(&seg.totalExpired, 1) err = ErrNotFound @@ -208,7 +208,7 @@ func (seg *segment) evacuate(entryLen int64, slotId uint8, now uint32) (slotModi seg.vacuumLen += oldEntryLen continue } - expired := oldHdr.expireAt != 0 && oldHdr.expireAt < now + expired := isExpired(oldHdr.expireAt, now) leastRecentUsed := int64(oldHdr.accessTime)*atomic.LoadInt64(&seg.totalCount) <= atomic.LoadInt64(&seg.totalTime) if expired || leastRecentUsed || consecutiveEvacuate > 5 { seg.delEntryPtrByOffset(oldHdr.slotId, oldHdr.hash16, oldOff) @@ -292,7 +292,7 @@ func (seg *segment) locate(key []byte, hashVal uint64, peek bool) (hdr *entryHdr hdr = (*entryHdr)(unsafe.Pointer(&hdrBuf[0])) if !peek { now := seg.timer.Now() - if hdr.expireAt != 0 && hdr.expireAt <= now { + if isExpired(hdr.expireAt, now) { seg.delEntryPtr(slotId, slot, idx) atomic.AddInt64(&seg.totalExpired, 1) err = ErrNotFound @@ -328,18 +328,19 @@ func (seg *segment) ttl(key []byte, hashVal uint64) (timeLeft uint32, err error) return } ptr := &slot[idx] - now := seg.timer.Now() var hdrBuf [ENTRY_HDR_SIZE]byte seg.rb.ReadAt(hdrBuf[:], ptr.offset) hdr := (*entryHdr)(unsafe.Pointer(&hdrBuf[0])) if hdr.expireAt == 0 { - timeLeft = 0 - return - } else if hdr.expireAt != 0 && hdr.expireAt >= now { - timeLeft = hdr.expireAt - now return + } else { + now := seg.timer.Now() + if !isExpired(hdr.expireAt, now) { + timeLeft = hdr.expireAt - now + return + } } err = ErrNotFound return @@ -477,3 +478,8 @@ func (seg *segment) getSlot(slotId uint8) []entryPtr { slotOff := int32(slotId) * seg.slotCap return seg.slotsData[slotOff : slotOff+seg.slotLens[slotId] : slotOff+seg.slotCap] } + +// isExpired checks if a key is expired. +func isExpired(keyExpireAt, now uint32) bool { + return keyExpireAt != 0 && keyExpireAt <= now +}