Skip to content

Commit

Permalink
Fix retry-after value in distributed case (#34)
Browse files Browse the repository at this point in the history
* Fix `retry-after` value in distributed case

The basic approach in #30 -- to sync around the timestamp of the oldest
event -- is sound, but the implementation was flawed, because the spot
before the cursor is not always the oldest event in the ring buffer. We
now correctly compute that value while counting events in the window
(which we had to do in order to sync event counts anyway). Additionally,
this commit adds tests for the distributed rate limiter which simulate a
peer by constructing a `ringBufferRateLimiter`, writing it out to
storage, and then starting up a `caddytest.Tester` on that storage.

* Add test for ringbuffer

* create AppDataDir

* debugging: print storage path to see what's up on windows

* escape storage path on windows

* strip out debug print

---------

Co-authored-by: Matt Holt <[email protected]>
  • Loading branch information
tgeoghegan and mholt authored Dec 19, 2023
1 parent 81d4916 commit 8aeaea3
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 34 deletions.
54 changes: 36 additions & 18 deletions distributed.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/certmagic"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -91,26 +92,38 @@ func (h Handler) syncDistributedWrite(ctx context.Context) error {
zoneNameStr := zoneName.(string)
zoneLimiters := value.(*sync.Map)

state.Zones[zoneNameStr] = make(map[string]rlStateValue)
state.Zones[zoneNameStr] = rlStateForZone(zoneLimiters, state.Timestamp)

// iterate all limiters within zone
zoneLimiters.Range(func(key, value interface{}) bool {
if value == nil {
return true
}
rl := value.(*ringBufferRateLimiter)
return true
})

state.Zones[zoneNameStr][key.(string)] = rlStateValue{
Count: rl.Count(state.Timestamp),
OldestEvent: rl.OldestEvent(),
}
return writeRateLimitState(ctx, state, h.Distributed.instanceID, h.storage)
}

func rlStateForZone(zoneLimiters *sync.Map, timestamp time.Time) map[string]rlStateValue {
state := make(map[string]rlStateValue)

// iterate all limiters within zone
zoneLimiters.Range(func(key, value interface{}) bool {
if value == nil {
return true
})
}
rl := value.(*ringBufferRateLimiter)

count, oldestEvent := rl.Count(timestamp)

state[key.(string)] = rlStateValue{
Count: count,
OldestEvent: oldestEvent,
}

return true
})

return state
}

func writeRateLimitState(ctx context.Context, state rlState, instanceID string, storage certmagic.Storage) error {
buf := gobBufPool.Get().(*bytes.Buffer)
buf.Reset()
defer gobBufPool.Put(buf)
Expand All @@ -119,7 +132,8 @@ func (h Handler) syncDistributedWrite(ctx context.Context) error {
if err != nil {
return err
}
err = h.storage.Store(ctx, path.Join(storagePrefix, h.Distributed.instanceID+".rlstate"), buf.Bytes())

err = storage.Store(ctx, path.Join(storagePrefix, instanceID+".rlstate"), buf.Bytes())
if err != nil {
return err
}
Expand Down Expand Up @@ -180,7 +194,7 @@ func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Repl
window := limiter.Window()

var totalCount int
oldestEvent := limiter.OldestEvent()
oldestEvent := now()

h.Distributed.otherStatesMu.RLock()
defer h.Distributed.otherStatesMu.RUnlock()
Expand All @@ -195,13 +209,13 @@ func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Repl
if zone, ok := otherInstanceState.Zones[zoneName]; ok {
// TODO: could probably skew the numbers here based on timestamp and window... perhaps try to predict a better updated count
totalCount += zone[rlKey].Count
if zone[rlKey].OldestEvent.Before(oldestEvent) {
if zone[rlKey].OldestEvent.Before(oldestEvent) && zone[rlKey].OldestEvent.After(now().Add(-window)) {
oldestEvent = zone[rlKey].OldestEvent
}

// no point in counting more if we're already over
if totalCount >= maxAllowed {
return h.rateLimitExceeded(w, repl, zoneName, time.Until(oldestEvent.Add(window)))
return h.rateLimitExceeded(w, repl, zoneName, oldestEvent.Add(window).Sub(now()))
}
}
}
Expand All @@ -210,7 +224,11 @@ func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Repl
// so the critical section over this limiter's lock is smaller), and make the
// reservation if we're within the limit
limiter.mu.Lock()
totalCount += limiter.countUnsynced(now())
count, oldestLocalEvent := limiter.countUnsynced(now())
totalCount += count
if oldestLocalEvent.Before(oldestEvent) && oldestLocalEvent.After(now().Add(-window)) {
oldestEvent = oldestLocalEvent
}
if totalCount < maxAllowed {
limiter.reserve()
limiter.mu.Unlock()
Expand All @@ -219,7 +237,7 @@ func (h Handler) distributedRateLimiting(w http.ResponseWriter, repl *caddy.Repl
limiter.mu.Unlock()

// otherwise, it appears limit has been exceeded
return h.rateLimitExceeded(w, repl, zoneName, time.Until(oldestEvent.Add(window)))
return h.rateLimitExceeded(w, repl, zoneName, oldestEvent.Add(window).Sub(now()))
}

type rlStateValue struct {
Expand Down
186 changes: 186 additions & 0 deletions distributed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright 2023 Matthew Holt

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package caddyrl

import (
"context"
"fmt"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddytest"
"github.com/caddyserver/certmagic"
"github.com/google/uuid"
)

func TestDistributed(t *testing.T) {
initTime()
window := 60
maxEvents := 10

// Make sure AppDataDir exists, because otherwise the caddytest.Tester won't
// be able to generate an instance ID
if err := os.MkdirAll(caddy.AppDataDir(), 0700); err != nil {
t.Fatalf("failed to create app data dir %s: %s", caddy.AppDataDir(), err)
}

testCases := []struct {
name string
peerRequests int
peerStateTimeStamp time.Time
localRequests int
rateLimited bool
}{
// Request should be refused because a peer used up the rate limit
{
name: "peer-usage-in-window",
peerRequests: maxEvents,
peerStateTimeStamp: now(),
localRequests: 0,
rateLimited: true,
},
// Request should be allowed because while lots of requests are in the
// peer state, the timestamp is outside the window
{
name: "peer-usage-before-window",
peerStateTimeStamp: now().Add(-time.Duration(window + 1)),
localRequests: 0,
rateLimited: false,
},
// Request should be refused because local usage exceeds rate limit
{
name: "local-usage",
peerRequests: 0,
peerStateTimeStamp: now(),
localRequests: maxEvents,
rateLimited: true,
},
// Request should be refused because usage in peer and locally sum up to
// exceed rate limit
{
name: "both-usage",
peerRequests: maxEvents / 2,
peerStateTimeStamp: now(),
localRequests: maxEvents / 2,
rateLimited: true,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
storageDir := t.TempDir()
// Use a random UUID as the zone so that rate limits from multiple test runs
// collide with each other
zone := uuid.New().String()

// To simulate a peer in a rate limiting cluster, constuct a
// ringBufferRateLimiter, record a bunch of events in it, and then sync that
// state to storage.
parsedDuration, err := time.ParseDuration(fmt.Sprintf("%ds", window))
if err != nil {
t.Fatal("failed to parse duration")
}
var simulatedPeer ringBufferRateLimiter
simulatedPeer.initialize(maxEvents, parsedDuration)

for i := 0; i < testCase.peerRequests; i++ {
if when := simulatedPeer.When(); when != 0 {
t.Fatalf("event should be allowed")
}
}

zoneLimiters := new(sync.Map)
zoneLimiters.Store("static", &simulatedPeer)

rlState := rlState{
Timestamp: testCase.peerStateTimeStamp,
Zones: map[string]map[string]rlStateValue{
zone: rlStateForZone(zoneLimiters, now()),
},
}

storage := certmagic.FileStorage{
Path: storageDir,
}

if err := writeRateLimitState(context.Background(), rlState, "f92a00f1-050c-4353-83b1-8ccc2337c25b", &storage); err != nil {
t.Fatalf("failed to write state to storage: %s", err)
}

// For Windows, escape \ in storage path.
storageDir = strings.ReplaceAll(storageDir, `\`, `\\`)

// Run a caddytest.Tester that uses the same storage we just wrote to, so it
// will treat the generated state as a peer to sync from.
configString := `{
"admin": {"listen": "localhost:2999"},
"storage": {
"module": "file_system",
"root": "%s"
},
"apps": {
"http": {
"servers": {
"one": {
"listen": [":8080"],
"routes": [{
"handle": [
{
"handler": "rate_limit",
"rate_limits": {
"%s": {
"match": [{"method": ["GET"]}],
"key": "static",
"window": "%ds",
"max_events": %d
}
},
"distributed": {
"write_interval": "3600s",
"read_interval": "3600s"
}
},
{
"handler": "static_response",
"status_code": 200
}
]
}]
}
}
}
}
}`

testerConfig := fmt.Sprintf(configString, storageDir, zone, window, maxEvents)
tester := caddytest.NewTester(t)
tester.InitServer(testerConfig, "json")

for i := 0; i < testCase.localRequests; i++ {
tester.AssertGetResponse("http://localhost:8080", 200, "")
}

if testCase.rateLimited {
assert429Response(t, tester, int64(window))
} else {
tester.AssertGetResponse("http://localhost:8080", 200, "")
}
})
}
}
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4=
github.com/golang/glog v1.1.2 h1:DVjP2PbBOzHyzA+dn3WhHIq4NdVu3Q+pvivFICf/7fo=
github.com/golang/glog v1.1.2/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down
6 changes: 3 additions & 3 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func initTime() {
}
}

func setTime(seconds int) {
func advanceTime(seconds int) {
now = func() time.Time {
return time.Unix(referenceTime+int64(seconds), 0)
}
Expand Down Expand Up @@ -104,13 +104,13 @@ func TestRateLimits(t *testing.T) {

// After advancing time by half the window, the retry-after value should
// change accordingly
setTime(window / 2)
advanceTime(window / 2)

assert429Response(t, tester, int64(window/2))

// Advance time beyond the window where the events occurred. We should now
// be able to make requests again.
setTime(window)
advanceTime(window)

tester.AssertGetResponse("http://localhost:8080", 200, "")
}
Expand Down
25 changes: 12 additions & 13 deletions ringbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ func (r *ringBufferRateLimiter) When() time.Duration {
return r.ring[r.cursor].Add(r.window).Sub(now())
}

// OldestEvent returns the time at which the oldest recorded event current in
// the ring buffer occurred.
func (r *ringBufferRateLimiter) OldestEvent() time.Time {
r.mu.Lock()
defer r.mu.Unlock()
return r.ring[r.cursor]
}

// allowed returns true if the event is allowed to happen right now.
// It does not wait. If the event is allowed, a reservation is made.
// It is NOT safe for concurrent use, so it must be called inside a
Expand Down Expand Up @@ -179,8 +171,10 @@ func (r *ringBufferRateLimiter) SetWindow(window time.Duration) {
r.mu.Unlock()
}

// Count counts how many events are in the window from the reference time.
func (r *ringBufferRateLimiter) Count(ref time.Time) int {
// Count counts how many events are in the window from the reference time and
// returns that value and the oldest event in the buffer (the zero value of
// time.Time if there are no events in the window).
func (r *ringBufferRateLimiter) Count(ref time.Time) (int, time.Time) {
r.mu.Lock()
defer r.mu.Unlock()
return r.countUnsynced(ref)
Expand All @@ -189,7 +183,8 @@ func (r *ringBufferRateLimiter) Count(ref time.Time) int {
// countUnsycned counts how many events are in the window from the reference time.
// It is NOT safe to use without a lock on r.mu.
// TODO: this is currently O(n) but could probably become O(log n) if we switch to some weird, custom binary search modulo ring length around the cursor.
func (r *ringBufferRateLimiter) countUnsynced(ref time.Time) int {
func (r *ringBufferRateLimiter) countUnsynced(ref time.Time) (int, time.Time) {
var zeroTime time.Time
beginningOfWindow := ref.Add(-r.window)

// This loop is a little gnarly, I know. We start at one before the cursor because that's
Expand All @@ -204,12 +199,16 @@ func (r *ringBufferRateLimiter) countUnsynced(ref time.Time) int {
// modulus the ring length to wrap around if necessary
i := (r.cursor + (len(r.ring) - eventsInWindow - 1)) % len(r.ring)
if r.ring[i].Before(beginningOfWindow) {
return eventsInWindow
if eventsInWindow == 0 {
return eventsInWindow, zeroTime
} else {
return eventsInWindow, r.ring[(i+1)%len(r.ring)]
}
}
}

// if we looped the entire ring, all events are within the window
return len(r.ring)
return len(r.ring), r.ring[r.cursor]
}

// Current time function, to be substituted by tests
Expand Down
Loading

0 comments on commit 8aeaea3

Please sign in to comment.