Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(cache): improve caching system and add expiring cache #408

Merged
merged 10 commits into from
Jan 13, 2025
185 changes: 173 additions & 12 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,124 @@ import (
"io"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"

openai "github.com/sashabaranov/go-openai"
)

// CacheType represents the type of cache being used.
type CacheType string

// Cache types for different purposes.
const (
ConversationCache CacheType = "conversations"
TemporaryCache CacheType = "temp"
)

const cacheExt = ".gob"

var errInvalidID = errors.New("invalid id")

type convoCache struct {
dir string
// Cache is a generic cache implementation that stores data in files.
type Cache[T any] struct {
baseDir string
cType CacheType
}

func newCache(dir string) *convoCache {
return &convoCache{dir}
// NewCache creates a new cache instance with the specified base directory and cache type.
func NewCache[T any](baseDir string, cacheType CacheType) (*Cache[T], error) {
cacheDir := filepath.Join(baseDir, string(cacheType))
if err := os.MkdirAll(cacheDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("create cache directory: %w", err)
}
return &Cache[T]{
baseDir: baseDir,
cType: cacheType,
}, nil
}

func (c *convoCache) read(id string, messages *[]openai.ChatCompletionMessage) error {
func (c *Cache[T]) cacheDir() string {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func (c *Cache[T]) cacheDir() string {
func (c *Cache[T]) dir() string {

no nead to repeat cache I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done on main

return filepath.Join(c.baseDir, string(c.cType))
}

func (c *Cache[T]) Read(id string, readFn func(io.Reader) error) error {
if id == "" {
return fmt.Errorf("read: %w", errInvalidID)
}
file, err := os.Open(filepath.Join(c.dir, id+cacheExt))
file, err := os.Open(filepath.Join(c.cacheDir(), id+cacheExt))
if err != nil {
return fmt.Errorf("read: %w", err)
}
defer file.Close() //nolint:errcheck

if err := decode(file, messages); err != nil {
if err := readFn(file); err != nil {
return fmt.Errorf("read: %w", err)
}
return nil
}

func (c *convoCache) write(id string, messages *[]openai.ChatCompletionMessage) error {
func (c *Cache[T]) Write(id string, writeFn func(io.Writer) error) error {
if id == "" {
return fmt.Errorf("write: %w", errInvalidID)
}

file, err := os.Create(filepath.Join(c.dir, id+cacheExt))
file, err := os.Create(filepath.Join(c.cacheDir(), id+cacheExt))
if err != nil {
return fmt.Errorf("write: %w", err)
}
defer file.Close() //nolint:errcheck

if err := encode(file, messages); err != nil {
if err := writeFn(file); err != nil {
return fmt.Errorf("write: %w", err)
}

return nil
}

func (c *convoCache) delete(id string) error {
// Delete removes a cached item by its ID.
func (c *Cache[T]) Delete(id string) error {
if id == "" {
return fmt.Errorf("delete: %w", errInvalidID)
}
if err := os.Remove(filepath.Join(c.dir, id+cacheExt)); err != nil {
if err := os.Remove(filepath.Join(c.cacheDir(), id+cacheExt)); err != nil {
return fmt.Errorf("delete: %w", err)
}
return nil
}

type convoCache struct {
cache *Cache[[]openai.ChatCompletionMessage]
}

func newCache(dir string) *convoCache {
cache, err := NewCache[[]openai.ChatCompletionMessage](dir, ConversationCache)
if err != nil {
return nil
}
return &convoCache{
cache: cache,
}
}

func (c *convoCache) read(id string, messages *[]openai.ChatCompletionMessage) error {
return c.cache.Read(id, func(r io.Reader) error {
return decode(r, messages)
})
}

func (c *convoCache) write(id string, messages *[]openai.ChatCompletionMessage) error {
return c.cache.Write(id, func(w io.Writer) error {
return encode(w, messages)
})
}

func (c *convoCache) delete(id string) error {
return c.cache.Delete(id)
}

var _ chatCompletionReceiver = &cachedCompletionStream{}

type cachedCompletionStream struct {
Expand All @@ -76,6 +133,7 @@ type cachedCompletionStream struct {
}

func (c *cachedCompletionStream) Close() error { return nil }

func (c *cachedCompletionStream) Recv() (openai.ChatCompletionStreamResponse, error) {
c.m.Lock()
defer c.m.Unlock()
Expand All @@ -101,6 +159,7 @@ func (c *cachedCompletionStream) Recv() (openai.ChatCompletionStreamResponse, er
}

c.read++

return openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
{
Expand All @@ -112,3 +171,105 @@ func (c *cachedCompletionStream) Recv() (openai.ChatCompletionStreamResponse, er
},
}, nil
}

// ExpiringCache is a cache implementation that supports expiration of cached items.
type ExpiringCache[T any] struct {
cache *Cache[T]
}

// NewExpiringCache creates a new cache instance that supports item expiration.
func NewExpiringCache[T any]() (*ExpiringCache[T], error) {
cache, err := NewCache[T](config.CachePath, TemporaryCache)
if err != nil {
return nil, fmt.Errorf("create expiring cache: %w", err)
}
return &ExpiringCache[T]{cache: cache}, nil
}

func (c *ExpiringCache[T]) getCacheFilename(id string, expiresAt int64) string {
return fmt.Sprintf("%s.%d", id, expiresAt)
}

func (c *ExpiringCache[T]) Read(id string, readFn func(io.Reader) error) error {
pattern := fmt.Sprintf("%s.*", id)
matches, err := filepath.Glob(filepath.Join(c.cache.cacheDir(), pattern))
if err != nil {
return fmt.Errorf("failed to read read expiring cache: %w", err)
}

if len(matches) == 0 {
return fmt.Errorf("item not found")
}

filename := filepath.Base(matches[0])
parts := strings.Split(filename, ".")
expectedFilenameParts := 2 // name and expiration timestamp

if len(parts) != expectedFilenameParts {
return fmt.Errorf("invalid cache filename")
}

expiresAt, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("invalid expiration timestamp")
}

if expiresAt < time.Now().Unix() {
if err := os.Remove(matches[0]); err != nil {
return fmt.Errorf("failed to remove expired cache file: %w", err)
}
return os.ErrNotExist
}

file, err := os.Open(matches[0])
if err != nil {
return fmt.Errorf("failed to open expiring cache file: %w", err)
}
defer func() {
if cerr := file.Close(); cerr != nil {
err = cerr
}
}()

return readFn(file)
}

func (c *ExpiringCache[T]) Write(id string, expiresAt int64, writeFn func(io.Writer) error) error {
pattern := fmt.Sprintf("%s.*", id)
oldFiles, _ := filepath.Glob(filepath.Join(c.cache.cacheDir(), pattern))
for _, file := range oldFiles {
if err := os.Remove(file); err != nil {
return fmt.Errorf("failed to remove old cache file: %w", err)
}
}

filename := c.getCacheFilename(id, expiresAt)
file, err := os.Create(filepath.Join(c.cache.cacheDir(), filename))
if err != nil {
return fmt.Errorf("failed to create expiring cache file: %w", err)
}
defer func() {
if cerr := file.Close(); cerr != nil {
err = cerr
}
}()

return writeFn(file)
}

// Delete removes an expired cached item by its ID.
func (c *ExpiringCache[T]) Delete(id string) error {
pattern := fmt.Sprintf("%s.*", id)
matches, err := filepath.Glob(filepath.Join(c.cache.cacheDir(), pattern))
if err != nil {
return fmt.Errorf("failed to delete expiring cache: %w", err)
}

for _, match := range matches {
if err := os.Remove(match); err != nil {
return fmt.Errorf("failed to delete expiring cache file: %w", err)
}
}

return nil
}
90 changes: 90 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -124,3 +125,92 @@ func TestCachedCompletionStream(t *testing.T) {

require.Equal(t, string(bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))), content)
}

func TestExpiringCache(t *testing.T) {
t.Run("write and read", func(t *testing.T) {
config.CachePath = t.TempDir()
cache, err := NewExpiringCache[string]()
require.NoError(t, err)

// Write a value with expiry
data := "test data"
expiresAt := time.Now().Add(time.Hour).Unix()
err = cache.Write("test", expiresAt, func(w io.Writer) error {
_, err := w.Write([]byte(data))
return err
})
require.NoError(t, err)

// Read it back
var result string
err = cache.Read("test", func(r io.Reader) error {
b, err := io.ReadAll(r)
if err != nil {
return err
}
result = string(b)
return nil
})
require.NoError(t, err)
require.Equal(t, data, result)
})

t.Run("expired token", func(t *testing.T) {
config.CachePath = t.TempDir()
cache, err := NewExpiringCache[string]()
require.NoError(t, err)

// Write a value that's already expired
data := "test data"
expiresAt := time.Now().Add(-time.Hour).Unix() // expired 1 hour ago
err = cache.Write("test", expiresAt, func(w io.Writer) error {
_, err := w.Write([]byte(data))
return err
})
require.NoError(t, err)

// Try to read it
err = cache.Read("test", func(r io.Reader) error {
return nil
})
require.Error(t, err)
require.True(t, os.IsNotExist(err))
})

t.Run("overwrite token", func(t *testing.T) {
config.CachePath = t.TempDir()
cache, err := NewExpiringCache[string]()
require.NoError(t, err)

// Write initial value
data1 := "test data 1"
expiresAt1 := time.Now().Add(time.Hour).Unix()
err = cache.Write("test", expiresAt1, func(w io.Writer) error {
_, err := w.Write([]byte(data1))
return err
})
require.NoError(t, err)

// Write new value
data2 := "test data 2"
expiresAt2 := time.Now().Add(2 * time.Hour).Unix()
err = cache.Write("test", expiresAt2, func(w io.Writer) error {
_, err := w.Write([]byte(data2))
return err
})
require.NoError(t, err)

// Read it back - should get the new value
var result string
err = cache.Read("test", func(r io.Reader) error {
b, err := io.ReadAll(r)
if err != nil {
return err
}
result = string(b)
return nil
})
require.NoError(t, err)
require.Equal(t, data2, result)
})
}
2 changes: 1 addition & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func ensureConfig() (Config, error) {
}

if c.CachePath == "" {
c.CachePath = filepath.Join(xdg.DataHome, "mods", "conversations")
c.CachePath = filepath.Join(xdg.DataHome, "mods")
}

if err := os.MkdirAll(c.CachePath, 0o700); err != nil { //nolint:mnd
Expand Down
Loading
Loading