diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 0039a60093d..0654787b6ec 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -277,31 +277,28 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision batchDecision := allowedDecision newBuckets := make(map[string]time.Time) incrBuckets := make(map[string]increment) + staleBuckets := make(map[string]time.Time) txnOutcomes := make(map[Transaction]string) for _, txn := range batch { - tat, bucketExists := tats[txn.bucketKey] - if !bucketExists { - // First request from this client. - tat = l.clk.Now() - } - - d := maybeSpend(l.clk, txn, tat) + storedTAT, bucketExists := tats[txn.bucketKey] + d := maybeSpend(l.clk, txn, storedTAT) if txn.limit.isOverride() { utilization := float64(txn.limit.Burst-d.remaining) / float64(txn.limit.Burst) l.overrideUsageGauge.WithLabelValues(txn.limit.name.String(), txn.limit.overrideKey).Set(utilization) } - if d.allowed && (tat != d.newTAT) && txn.spend { - // New bucket state should be persisted. - if bucketExists { + if d.allowed && (storedTAT != d.newTAT) && txn.spend { + if !bucketExists { + newBuckets[txn.bucketKey] = d.newTAT + } else if storedTAT.After(l.clk.Now()) { incrBuckets[txn.bucketKey] = increment{ cost: time.Duration(txn.cost * txn.limit.emissionInterval), ttl: time.Duration(txn.limit.burstOffset), } } else { - newBuckets[txn.bucketKey] = d.newTAT + staleBuckets[txn.bucketKey] = d.newTAT } } @@ -319,10 +316,24 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision if batchDecision.allowed { if len(newBuckets) > 0 { - err = l.source.BatchSet(ctx, newBuckets) + // Use BatchSetNotExisting to create new buckets so that we detect + // if concurrent requests have created this bucket at the same time, + // which would result in overwriting if we used a plain "SET" + // command. If that happens, fall back to incrementing. + alreadyExists, err := l.source.BatchSetNotExisting(ctx, newBuckets) if err != nil { return nil, fmt.Errorf("batch set for %d keys: %w", len(newBuckets), err) } + // Find the original transaction in order to compute the increment + // and set the TTL. + for _, txn := range batch { + if alreadyExists[txn.bucketKey] { + incrBuckets[txn.bucketKey] = increment{ + cost: time.Duration(txn.cost * txn.limit.emissionInterval), + ttl: time.Duration(txn.limit.burstOffset), + } + } + } } if len(incrBuckets) > 0 { @@ -331,6 +342,17 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision return nil, fmt.Errorf("batch increment for %d keys: %w", len(incrBuckets), err) } } + + if len(staleBuckets) > 0 { + // Incrementing a TAT in the past grants unintended burst capacity. + // So instead we overwrite it with a TAT of now + increment. This + // approach may cause a race condition where only the last spend is + // saved, but it's preferable to the alternative. + err = l.source.BatchSet(ctx, staleBuckets) + if err != nil { + return nil, fmt.Errorf("batch set for %d keys: %w", len(staleBuckets), err) + } + } } // Observe latency equally across all transactions in the batch. diff --git a/ratelimits/source.go b/ratelimits/source.go index c69484c95af..74f3ae6b2f4 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -20,6 +20,11 @@ type Source interface { // the underlying storage client implementation). BatchSet(ctx context.Context, bucketKeys map[string]time.Time) error + // BatchSetNotExisting attempts to set TATs for the specified bucketKeys if + // they do not already exist. Returns a map indicating which keys already + // exist. + BatchSetNotExisting(ctx context.Context, buckets map[string]time.Time) (map[string]bool, error) + // BatchIncrement updates the TATs for the specified bucketKeys, similar to // BatchSet. Implementations MUST ensure non-blocking operations by either: // a) applying a deadline or timeout to the context WITHIN the method, or @@ -79,6 +84,21 @@ func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) er return nil } +func (in *inmem) BatchSetNotExisting(_ context.Context, bucketKeys map[string]time.Time) (map[string]bool, error) { + in.Lock() + defer in.Unlock() + alreadyExists := make(map[string]bool, len(bucketKeys)) + for k, v := range bucketKeys { + _, ok := in.m[k] + if ok { + alreadyExists[k] = true + } else { + in.m[k] = v + } + } + return alreadyExists, nil +} + func (in *inmem) BatchIncrement(_ context.Context, bucketKeys map[string]increment) error { in.Lock() defer in.Unlock() diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index b55ee9739eb..4d32f7c2a6d 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -108,6 +108,43 @@ func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time return nil } +// BatchSetNotExisting attempts to set TATs for the specified bucketKeys if they +// do not already exist. Returns a map indicating which keys already existed. +func (r *RedisSource) BatchSetNotExisting(ctx context.Context, buckets map[string]time.Time) (map[string]bool, error) { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + cmds := make(map[string]*redis.BoolCmd, len(buckets)) + for bucketKey, tat := range buckets { + // Set a TTL of TAT + 10 minutes to account for clock skew. + ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute + cmds[bucketKey] = pipeline.SetNX(ctx, bucketKey, tat.UTC().UnixNano(), ttl) + } + _, err := pipeline.Exec(ctx) + if err != nil { + r.observeLatency("batchsetnotexisting", r.clk.Since(start), err) + return nil, err + } + + alreadyExists := make(map[string]bool, len(buckets)) + totalLatency := r.clk.Since(start) + perSetLatency := totalLatency / time.Duration(len(buckets)) + for bucketKey, cmd := range cmds { + success, err := cmd.Result() + if err != nil { + r.observeLatency("batchsetnotexisting_entry", perSetLatency, err) + return nil, err + } + if !success { + alreadyExists[bucketKey] = true + } + r.observeLatency("batchsetnotexisting_entry", perSetLatency, nil) + } + + r.observeLatency("batchsetnotexisting", totalLatency, nil) + return alreadyExists, nil +} + // BatchIncrement updates TATs for the specified bucketKeys using a pipelined // Redis Transaction in order to reduce the number of round-trips to each Redis // shard.