Skip to content

Commit

Permalink
feat: improved caching and rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
hokamsingh committed Aug 28, 2024
1 parent f8383f6 commit 7a96e28
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 61 deletions.
47 changes: 30 additions & 17 deletions internal/core/middleware/cacher.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package middleware

import (
"bytes"
"context"
"io"
"log"
"net/http"
"time"
Expand All @@ -10,11 +12,12 @@ import (
)

type Caching struct {
client *redis.Client
ttl time.Duration
client *redis.Client
ttl time.Duration
cacheControl bool
}

func NewCaching(redisAddr string, ttl time.Duration) *Caching {
func NewCaching(redisAddr string, ttl time.Duration, cacheControl bool) *Caching {
ctx := context.Background()
client := redis.NewClient(&redis.Options{
Addr: redisAddr, // e.g., "localhost:6379"
Expand All @@ -24,36 +27,41 @@ func NewCaching(redisAddr string, ttl time.Duration) *Caching {
log.Fatalf("Could not connect to Redis: %v", err)
}
return &Caching{
client: client,
ttl: ttl,
client: client,
ttl: ttl,
cacheControl: cacheControl,
}
}

func (c *Caching) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()

// Respect Cache-Control: no-store
if c.cacheControl && r.Header.Get("Cache-Control") == "no-store" {
next.ServeHTTP(w, r)
return
}

if r.Method == http.MethodGet {
// Try to get the cached response from Redis
data, err := c.client.Get(ctx, r.RequestURI).Result()
if err == nil {
// If found in cache, write it directly to the response
// Cache hit
w.Header().Set("X-Cache-Hit", "true")
w.Write([]byte(data))
io.WriteString(w, data)
return
} else if err != redis.Nil {
// Log any errors retrieving from Redis
log.Printf("Error retrieving from cache: %v", err)
}
}

// Create a response writer to capture the response
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
// Capture response
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK, body: new(bytes.Buffer)}
next.ServeHTTP(rec, r)

if r.Method == http.MethodGet {
// Cache the response in Redis
err := c.client.Set(ctx, r.RequestURI, rec.body, c.ttl).Err()
// Cache only successful responses (status code 200)
if r.Method == http.MethodGet && rec.statusCode == http.StatusOK {
err := c.client.Set(ctx, r.RequestURI, rec.body.String(), c.ttl).Err()
if err != nil {
log.Printf("Error setting cache: %v", err)
}
Expand All @@ -64,12 +72,17 @@ func (c *Caching) Handle(next http.Handler) http.Handler {
type responseRecorder struct {
http.ResponseWriter
statusCode int
body []byte
body *bytes.Buffer
}

func (rec *responseRecorder) Write(p []byte) (int, error) {
rec.body = append(rec.body, p...)
return rec.ResponseWriter.Write(p)
rec.body.Write(p) // Write to the buffer
return rec.ResponseWriter.Write(p) // Stream response to client
}

func (rec *responseRecorder) WriteHeader(statusCode int) {
rec.statusCode = statusCode
rec.ResponseWriter.WriteHeader(statusCode)
}

// Implement the Flush method
Expand Down
Loading

0 comments on commit 7a96e28

Please sign in to comment.