From d8aa3b26b28835082522ab7365588cd6a4ff1af9 Mon Sep 17 00:00:00 2001 From: hokamsingh Date: Mon, 26 Aug 2024 18:43:26 +0530 Subject: [PATCH] feat: added csrf, xss and caching middlewares --- internal/core/middleware/ratelimiter.go | 54 +++++++++++++++++++------ internal/core/router/router.go | 4 +- pkg/lessgo/less.go | 4 +- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/internal/core/middleware/ratelimiter.go b/internal/core/middleware/ratelimiter.go index 5688e07..bf90361 100644 --- a/internal/core/middleware/ratelimiter.go +++ b/internal/core/middleware/ratelimiter.go @@ -7,17 +7,19 @@ import ( ) type RateLimiter struct { - requests map[string]int - mu sync.Mutex - limit int - interval time.Duration + requests map[string][]time.Time + mu sync.Mutex + limit int + interval time.Duration + cleanupInterval time.Duration } -func NewRateLimiter(limit int, interval time.Duration) *RateLimiter { +func NewRateLimiter(limit int, interval, cleanupInterval time.Duration) *RateLimiter { rl := &RateLimiter{ - requests: make(map[string]int), - limit: limit, - interval: interval, + requests: make(map[string][]time.Time), + limit: limit, + interval: interval, + cleanupInterval: cleanupInterval, } go rl.cleanup() return rl @@ -27,22 +29,50 @@ func (rl *RateLimiter) Handle(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rl.mu.Lock() key := r.RemoteAddr - if rl.requests[key] >= rl.limit { + now := time.Now() + + // Filter out expired timestamps + requests := rl.requests[key] + var newRequests []time.Time + for _, reqTime := range requests { + if now.Sub(reqTime) < rl.interval { + newRequests = append(newRequests, reqTime) + } + } + rl.requests[key] = newRequests + + if len(newRequests) >= rl.limit { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) rl.mu.Unlock() return } - rl.requests[key]++ + + // Add current request timestamp + rl.requests[key] = append(rl.requests[key], now) rl.mu.Unlock() + next.ServeHTTP(w, r) }) } func (rl *RateLimiter) cleanup() { for { - time.Sleep(rl.interval) + time.Sleep(rl.cleanupInterval) rl.mu.Lock() - rl.requests = make(map[string]int) + now := time.Now() + for key, timestamps := range rl.requests { + var validTimestamps []time.Time + for _, reqTime := range timestamps { + if now.Sub(reqTime) < rl.interval { + validTimestamps = append(validTimestamps, reqTime) + } + } + if len(validTimestamps) > 0 { + rl.requests[key] = validTimestamps + } else { + delete(rl.requests, key) + } + } rl.mu.Unlock() } } diff --git a/internal/core/router/router.go b/internal/core/router/router.go index 6378b85..f69bc8f 100644 --- a/internal/core/router/router.go +++ b/internal/core/router/router.go @@ -113,9 +113,9 @@ func WithCORS(options middleware.CORSOptions) Option { // Example usage: // // r := router.NewRouter(router.WithRateLimiter(100, time.Minute)) -func WithRateLimiter(limit int, interval time.Duration) Option { +func WithRateLimiter(limit int, interval, cleanupInterval time.Duration) Option { return func(r *Router) { - rateLimiter := middleware.NewRateLimiter(limit, interval) + rateLimiter := middleware.NewRateLimiter(limit, interval, cleanupInterval) r.Use(rateLimiter) } } diff --git a/pkg/lessgo/less.go b/pkg/lessgo/less.go index 60dcb62..e0327f7 100644 --- a/pkg/lessgo/less.go +++ b/pkg/lessgo/less.go @@ -144,8 +144,8 @@ func WithCORS(options middleware.CORSOptions) router.Option { return router.WithCORS(options) } -func WithRateLimiter(limit int, interval time.Duration) router.Option { - return router.WithRateLimiter(limit, interval) +func WithRateLimiter(limit int, interval, cleanupInterval time.Duration) router.Option { + return router.WithRateLimiter(limit, interval, cleanupInterval) } type ParserOptions = middleware.ParserOptions