Skip to content

Commit

Permalink
feat: added csrf, xss and caching middlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
hokamsingh committed Aug 26, 2024
1 parent 8ac168d commit d8aa3b2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
54 changes: 42 additions & 12 deletions internal/core/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ import (
)

type RateLimiter struct {
requests map[string]int
mu sync.Mutex
limit int
interval time.Duration
requests map[string][]time.Time
mu sync.Mutex
limit int
interval time.Duration
cleanupInterval time.Duration
}

func NewRateLimiter(limit int, interval time.Duration) *RateLimiter {
func NewRateLimiter(limit int, interval, cleanupInterval time.Duration) *RateLimiter {
rl := &RateLimiter{
requests: make(map[string]int),
limit: limit,
interval: interval,
requests: make(map[string][]time.Time),
limit: limit,
interval: interval,
cleanupInterval: cleanupInterval,
}
go rl.cleanup()
return rl
Expand All @@ -27,22 +29,50 @@ func (rl *RateLimiter) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rl.mu.Lock()
key := r.RemoteAddr
if rl.requests[key] >= rl.limit {
now := time.Now()

// Filter out expired timestamps
requests := rl.requests[key]
var newRequests []time.Time
for _, reqTime := range requests {
if now.Sub(reqTime) < rl.interval {
newRequests = append(newRequests, reqTime)
}
}
rl.requests[key] = newRequests

if len(newRequests) >= rl.limit {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
rl.mu.Unlock()
return
}
rl.requests[key]++

// Add current request timestamp
rl.requests[key] = append(rl.requests[key], now)
rl.mu.Unlock()

next.ServeHTTP(w, r)
})
}

func (rl *RateLimiter) cleanup() {
for {
time.Sleep(rl.interval)
time.Sleep(rl.cleanupInterval)
rl.mu.Lock()
rl.requests = make(map[string]int)
now := time.Now()
for key, timestamps := range rl.requests {
var validTimestamps []time.Time
for _, reqTime := range timestamps {
if now.Sub(reqTime) < rl.interval {
validTimestamps = append(validTimestamps, reqTime)
}
}
if len(validTimestamps) > 0 {
rl.requests[key] = validTimestamps
} else {
delete(rl.requests, key)
}
}
rl.mu.Unlock()
}
}
4 changes: 2 additions & 2 deletions internal/core/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ func WithCORS(options middleware.CORSOptions) Option {
// Example usage:
//
// r := router.NewRouter(router.WithRateLimiter(100, time.Minute))
func WithRateLimiter(limit int, interval time.Duration) Option {
func WithRateLimiter(limit int, interval, cleanupInterval time.Duration) Option {
return func(r *Router) {
rateLimiter := middleware.NewRateLimiter(limit, interval)
rateLimiter := middleware.NewRateLimiter(limit, interval, cleanupInterval)
r.Use(rateLimiter)
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/lessgo/less.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func WithCORS(options middleware.CORSOptions) router.Option {
return router.WithCORS(options)
}

func WithRateLimiter(limit int, interval time.Duration) router.Option {
return router.WithRateLimiter(limit, interval)
func WithRateLimiter(limit int, interval, cleanupInterval time.Duration) router.Option {
return router.WithRateLimiter(limit, interval, cleanupInterval)
}

type ParserOptions = middleware.ParserOptions
Expand Down

0 comments on commit d8aa3b2

Please sign in to comment.