diff --git a/distributed.go b/distributed.go index 8e297fe..bbe286a 100644 --- a/distributed.go +++ b/distributed.go @@ -189,7 +189,7 @@ func (h Handler) syncDistributedRead(ctx context.Context) error { // distributedRateLimiting enforces limiter (keyed by rlKey) in consideration of all other instances in the cluster. // If the limit is exceeded, the response is prepared and the relevant error is returned. Otherwise, a reservation // is made in the local limiter and no error is returned. -func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Replacer, limiter *ringBufferRateLimiter, rlKey, zoneName string) error { +func (h Handler) distributedRateLimiting(w http.ResponseWriter, r *http.Request, repl *caddy.Replacer, limiter *ringBufferRateLimiter, rlKey, zoneName string) error { maxAllowed := limiter.MaxEvents() window := limiter.Window() @@ -215,7 +215,7 @@ func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Repl // no point in counting more if we're already over if totalCount >= maxAllowed { - return h.rateLimitExceeded(w, repl, zoneName, oldestEvent.Add(window).Sub(now())) + return h.rateLimitExceeded(w, r, repl, zoneName, oldestEvent.Add(window).Sub(now())) } } } @@ -237,7 +237,7 @@ func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Repl limiter.mu.Unlock() // otherwise, it appears limit has been exceeded - return h.rateLimitExceeded(w, repl, zoneName, oldestEvent.Add(window).Sub(now())) + return h.rateLimitExceeded(w, r, repl, zoneName, oldestEvent.Add(window).Sub(now())) } type rlStateValue struct { diff --git a/handler.go b/handler.go index f60d014..a7ea0c9 100644 --- a/handler.go +++ b/handler.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" weakrand "math/rand" + "net" "net/http" "sort" "strconv" @@ -184,11 +185,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt if h.Distributed == nil { // internal rate limiter only if dur := limiter.When(); dur > 0 { - return h.rateLimitExceeded(w, repl, rl.zoneName, dur) + return h.rateLimitExceeded(w, r, repl, rl.zoneName, dur) } } else { // distributed rate limiting; add last known state of other instances - if err := h.distributedRateLimiting(w, repl, limiter, key, rl.zoneName); err != nil { + if err := h.distributedRateLimiting(w, r, repl, limiter, key, rl.zoneName); err != nil { return err } } @@ -197,7 +198,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt return next.ServeHTTP(w, r) } -func (h *Handler) rateLimitExceeded(w http.ResponseWriter, repl *caddy.Replacer, zoneName string, wait time.Duration) error { +func (h *Handler) rateLimitExceeded(w http.ResponseWriter, r *http.Request, repl *caddy.Replacer, zoneName string, wait time.Duration) error { // add jitter, if configured if h.random != nil { jitter := h.randomFloatInRange(0, float64(wait)*h.Jitter) @@ -207,6 +208,17 @@ func (h *Handler) rateLimitExceeded(w http.ResponseWriter, repl *caddy.Replacer, // add 0.5 to ceil() instead of round() which FormatFloat() does automatically w.Header().Set("Retry-After", strconv.FormatFloat(wait.Seconds()+0.5, 'f', 0, 64)) + // emit log about exceeding rate limit (see #37) + remoteIP, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + remoteIP = r.RemoteAddr // assume there was no port, I guess + } + h.logger.Info("rate limit exceeded", + zap.String("zone", zoneName), + zap.Duration("wait", wait), + zap.String("remote_ip", remoteIP), + ) + // make some information about this rate limit available repl.Set("http.rate_limit.exceeded.name", zoneName)