Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement threadsafe version of exec_within_threshold #1

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions lib/ratelimit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def initialize(key, options = {})
# @return [Integer] The counter value
def add(subject, count = 1)
bucket = get_bucket
subject = "#{@key}:#{subject}"
subject = get_key_for_subject(subject)
redis.multi do
redis.hincrby(subject, bucket, count)
redis.hdel(subject, (bucket + 1) % @bucket_count)
Expand All @@ -55,14 +55,8 @@ def add(subject, count = 1)
# @param [Integer] interval How far back (in seconds) to retrieve activity.
def count(subject, interval)
bucket = get_bucket
interval = [[interval, @bucket_interval].max, @bucket_span].min
count = (interval / @bucket_interval).floor
subject = "#{@key}:#{subject}"

keys = (0..count - 1).map do |i|
(bucket - i) % @bucket_count
end
return redis.hmget(subject, *keys).inject(0) {|a, i| a + i.to_i}
keys = get_bucket_keys_for_interval(bucket, interval)
return redis.hmget(get_key_for_subject(subject), *keys).inject(0) {|a, i| a + i.to_i}
end

# Check if the rate limit has been exceeded.
Expand Down Expand Up @@ -108,12 +102,60 @@ def exec_within_threshold(subject, options = {}, &block)
yield(self)
end

# Execute a block and increment the count once the rate limit is within bounds.
# This fixes the concurrency issue found in exec_within_threshold
# *WARNING* This will block the current thread until the rate limit is within bounds.
#
# @param [String] subject Subject for this rate limit
# @param [Hash] options Options hash
# @option options [Integer] :interval How far back to retrieve activity.
# @option options [Integer] :threshold Maximum number of actions
# @option options [Integer] :increment
# @yield The block to be run
#
# @example Send an email as long as we haven't send 5 in the last 10 minutes
# ratelimit.exec_with_threshold(email, [:threshold => 5, :interval => 600, :increment => 1]) do
# send_another_email
# end
def exec_and_increment_within_threshold(subject, options = {}, &block)
options[:threshold] ||= 30
options[:interval] ||= 30
options[:increment] ||= 1
until count_incremented_within_threshold(subject, options)
sleep @bucket_interval
end
yield(self)
end

private

def get_bucket(time = Time.now.to_i)
((time % @bucket_span) / @bucket_interval).floor
end

def get_bucket_keys_for_interval(bucket, interval)
return [] if interval.nil?
interval = [[interval, @bucket_interval].max, @bucket_span].min
count = (interval / @bucket_interval).floor
(0..count - 1).map do |i|
(bucket - i) % @bucket_count
end
end

def get_key_for_subject(subject)
"#{@key}:#{subject}"
end

def count_incremented_within_threshold(subject, options)
bucket = get_bucket
keys = get_bucket_keys_for_interval(bucket, options[:interval])
burstKeys = get_bucket_keys_for_interval(bucket, options[:burst_interval])
evalScript = 'local a=KEYS[1]local b=tonumber(ARGV[1])local c=tonumber(ARGV[b+2])local d=b+c;local e=tonumber(ARGV[d+3])local f=tonumber(ARGV[d+4])local g=tonumber(ARGV[d+5])local h=tonumber(ARGV[d+6])local i=tonumber(ARGV[d+7])local j=tonumber(ARGV[d+8])or 0;local k=false;local l=false;local m=false;local n=0;if c>0 then local o=redis.call("HMGET",a,unpack(ARGV,b+3,d+2))for p,q in ipairs(o)do n=n+(tonumber(q)or 0)end;if n<j then l=true end end;local r=0;local s=redis.call("HMGET",a,unpack(ARGV,2,b+1))for p,q in ipairs(s)do r=r+(tonumber(q)or 0)end;if r<h then m=true end;if m or l then redis.call("HINCRBY",a,e,i)redis.call("HDEL",a,(e+1)%f)redis.call("HDEL",a,(e+2)%f)redis.call("EXPIRE",a,g)k=true end;return k'
evalKeys = [get_key_for_subject(subject)]
evalArgs = [keys.length, *keys, burstKeys.length, *burstKeys, bucket, @bucket_count, @bucket_expiry, options[:threshold], options[:increment], options[:burst_threshold]]
redis.eval(evalScript, evalKeys, evalArgs)
end

def redis
@redis ||= Redis::Namespace.new(:ratelimit, redis: @raw_redis || Redis.new)
end
Expand Down
45 changes: 45 additions & 0 deletions scripts/count_and_increment_within_threshold.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
local subject = KEYS[1];
local numKeys = tonumber(ARGV[1]);
local numBurstKeys = tonumber(ARGV[numKeys + 2]);
local totalKeys = numKeys + numBurstKeys;
local bucket = tonumber(ARGV[totalKeys + 3]);
local bucketCount = tonumber(ARGV[totalKeys + 4]);
local bucketExpiry = tonumber(ARGV[totalKeys + 5]);
local threshold = tonumber(ARGV[totalKeys + 6]);
local increment = tonumber(ARGV[totalKeys + 7]);
local burstThreshold = tonumber(ARGV[totalKeys + 8]) or 0;
local success = false;
local withinBurstThreshold = false;
local withinRegularThreshold = false;
local burstCount = 0;

if numBurstKeys > 0 then
local burstCounts = redis.call("HMGET", subject, unpack(ARGV, numKeys + 3, totalKeys + 2 ));
for key, value in ipairs(burstCounts) do
burstCount = burstCount + (tonumber(value) or 0)
end;

if burstCount < burstThreshold then
withinBurstThreshold = true;
end
end

local count = 0;
local counts = redis.call("HMGET", subject, unpack(ARGV, 2, numKeys + 1));
for key, value in ipairs(counts) do
count = count + (tonumber(value) or 0)
end;

if count < threshold then
withinRegularThreshold = true;
end

if withinRegularThreshold or withinBurstThreshold then
redis.call("HINCRBY", subject, bucket, increment);
redis.call("HDEL", subject, (bucket + 1) % bucketCount);
redis.call("HDEL", subject, (bucket + 2) % bucketCount);
redis.call("EXPIRE", subject, bucketExpiry);
success = true;
end

return success;
1 change: 1 addition & 0 deletions scripts/count_and_increment_within_threshold.min.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local a=KEYS[1]local b=tonumber(ARGV[1])local c=tonumber(ARGV[b+2])local d=b+c;local e=tonumber(ARGV[d+3])local f=tonumber(ARGV[d+4])local g=tonumber(ARGV[d+5])local h=tonumber(ARGV[d+6])local i=tonumber(ARGV[d+7])local j=tonumber(ARGV[d+8])or 0;local k=false;local l=false;local m=false;local n=0;if c>0 then local o=redis.call("HMGET",a,unpack(ARGV,b+3,d+2))for p,q in ipairs(o)do n=n+(tonumber(q)or 0)end;if n<j then l=true end end;local r=0;local s=redis.call("HMGET",a,unpack(ARGV,2,b+1))for p,q in ipairs(s)do r=r+(tonumber(q)or 0)end;if r<h then m=true end;if m or l then redis.call("HINCRBY",a,e,i)redis.call("HDEL",a,(e+1)%f)redis.call("HDEL",a,(e+2)%f)redis.call("EXPIRE",a,g)k=true end;return k