diff --git a/docs/util.rst b/docs/util.rst index 560bbf43..5f18a0cb 100644 --- a/docs/util.rst +++ b/docs/util.rst @@ -138,11 +138,15 @@ Example limiter = RateLimiter('MyRateLimiterName') # define limits, duplicate limits of the same algorithm will only be added once + # These lines all define the same limit so it'll result in only one limiter added limiter.add_limit(5, 60) # add limits explicitly limiter.parse_limits('5 per minute').parse_limits('5 in 60s', '5/60seconds') # add limits through text # add additional limit with leaky bucket algorithm - limiter.add_limit(10, 120, algorithm='leaky_bucket') + limiter.add_limit(10, 100, algorithm='leaky_bucket') + + # add additional limit with fixed window elastic expiry algorithm + limiter.add_limit(10, 100, algorithm='fixed_window_elastic_expiry') # Test the limit without increasing the hits for _ in range(100): diff --git a/src/HABApp/util/rate_limiter/limiter.py b/src/HABApp/util/rate_limiter/limiter.py index 487e0c1a..1d743b89 100644 --- a/src/HABApp/util/rate_limiter/limiter.py +++ b/src/HABApp/util/rate_limiter/limiter.py @@ -1,17 +1,10 @@ from dataclasses import dataclass -from typing import Final, List, Literal, Tuple, Union - -from HABApp.core.const.const import PYTHON_310, StrEnum - -from .limits import ( - BaseRateLimit, - FixedWindowElasticExpiryLimit, - FixedWindowElasticExpiryLimitInfo, - LeakyBucketLimit, - LeakyBucketLimitInfo, -) -from .parser import parse_limit +from typing import Final, List, Literal, Tuple, Union, get_args +from HABApp.core.const.const import PYTHON_310 +from HABApp.util.rate_limiter.limits import BaseRateLimit, FixedWindowElasticExpiryLimit, \ + FixedWindowElasticExpiryLimitInfo, LeakyBucketLimit, LeakyBucketLimitInfo +from HABApp.util.rate_limiter.parser import parse_limit if PYTHON_310: from typing import TypeAlias @@ -19,12 +12,16 @@ from typing_extensions import TypeAlias -class LimitTypeEnum(StrEnum): - LEAKY_BUCKET = 'leaky_bucket' - FIXED_WINDOW_ELASTIC_EXPIRY = 'fixed_window_elastic_expiry' +_LITERAL_LEAKY_BUCKET = Literal['leaky_bucket'] +_LITERAL_FIXED_WINDOW_ELASTIC_EXPIRY = Literal['fixed_window_elastic_expiry'] + +LIMITER_ALGORITHM_HINT: TypeAlias = Literal[_LITERAL_LEAKY_BUCKET, _LITERAL_FIXED_WINDOW_ELASTIC_EXPIRY] -LIMITER_ALGORITHM_HINT: TypeAlias = Literal[LimitTypeEnum.LEAKY_BUCKET, LimitTypeEnum.FIXED_WINDOW_ELASTIC_EXPIRY] +def _check_arg(name: str, value, allow_0=False): + if not isinstance(value, int) or ((value <= 0) if not allow_0 else (value < 0)): + msg = f'Parameter {name:s} must be an int >{"=" if allow_0 else ""} 0, is {value} ({type(value)})' + raise ValueError(msg) class Limiter: @@ -36,66 +33,70 @@ def __init__(self, name: str): def __repr__(self): return f'<{self.__class__.__name__} {self._name:s}>' - def add_limit(self, allowed: int, interval: int, - algorithm: LIMITER_ALGORITHM_HINT = LimitTypeEnum.FIXED_WINDOW_ELASTIC_EXPIRY) -> 'Limiter': + def add_limit(self, allowed: int, interval: int, *, + hits: int = 0, + algorithm: LIMITER_ALGORITHM_HINT = 'leaky_bucket') -> 'Limiter': """Add a new rate limit :param allowed: How many hits are allowed :param interval: Interval in seconds + :param hits: How many hits the limit already has when it gets initially created :param algorithm: Which algorithm should this limit use """ - if allowed <= 0 or not isinstance(allowed, int): - msg = f'Allowed must be an int >= 0, is {allowed} ({type(allowed)})' - raise ValueError(msg) - - if interval <= 0 or not isinstance(interval, int): - msg = f'Expire time must be an int >= 0, is {interval} ({type(interval)})' + _check_arg('allowed', allowed) + _check_arg('interval', interval) + _check_arg('hits', hits, allow_0=True) + if not hits <= allowed: + msg = f'Parameter hits must be <= parameter allowed! {hits:d} <= {allowed:d}!' raise ValueError(msg) - algo = LimitTypeEnum(algorithm) - if algo is LimitTypeEnum.FIXED_WINDOW_ELASTIC_EXPIRY: - cls = FixedWindowElasticExpiryLimit - elif algo is LimitTypeEnum.LEAKY_BUCKET: + if algorithm == get_args(_LITERAL_LEAKY_BUCKET)[0]: cls = LeakyBucketLimit + elif algorithm == get_args(_LITERAL_FIXED_WINDOW_ELASTIC_EXPIRY)[0]: + cls = FixedWindowElasticExpiryLimit else: - raise ValueError() + msg = f'Unknown algorithm "{algorithm}"' + raise ValueError(msg) # Check if we have already added an algorithm with these parameters for window in self._limits: if isinstance(window, cls) and window.allowed == allowed and window.interval == interval: return self - limit = cls(allowed, interval) + limit = cls(allowed, interval, hits=hits) self._limits = tuple(sorted([*self._limits, limit], key=lambda x: x.interval)) return self def parse_limits(self, *text: str, - algorithm: LIMITER_ALGORITHM_HINT = LimitTypeEnum.FIXED_WINDOW_ELASTIC_EXPIRY) -> 'Limiter': + hits: int = 0, + algorithm: LIMITER_ALGORITHM_HINT = 'leaky_bucket') -> 'Limiter': """Add one or more limits in textual form, e.g. ``5 in 60s``, ``10 per hour`` or ``10/15 mins``. If the limit does already exist it will not be added again. :param text: textual description of limit + :param hits: How many hits the limit already has when it gets initially created :param algorithm: Which algorithm should these limits use """ for limit in [parse_limit(t) for t in text]: - self.add_limit(*limit, algorithm=algorithm) + self.add_limit(*limit, hits=hits, algorithm=algorithm) return self def allow(self) -> bool: - """Test the limit. + """Test the limit(s). :return: ``True`` if allowed, ``False`` if forbidden """ - allow = True - clear_skipped = True - if not self._limits: msg = 'No limits defined!' raise ValueError(msg) + clear_skipped = True + for limit in self._limits: if not limit.allow(): - allow = False + self._skips += 1 + return False + # allow increments hits, if it's now 1 it was 0 before if limit.hits != 1: clear_skipped = False @@ -103,32 +104,30 @@ def allow(self) -> bool: if clear_skipped: self._skips = 0 - if not allow: - self._skips += 1 - - return allow + return True def test_allow(self) -> bool: - """Test the limit without hitting it. Calling this will not increase the hit counter. + """Test the limit(s) without hitting it. Calling this will not increase the hit counter. :return: ``True`` if allowed, ``False`` if forbidden """ - allow = True - clear_skipped = True if not self._limits: msg = 'No limits defined!' raise ValueError(msg) + clear_skipped = True + for limit in self._limits: if not limit.test_allow(): - allow = False + return False + if limit.hits != 0: clear_skipped = False if clear_skipped: self._skips = 0 - return allow + return True def info(self) -> 'LimiterInfo': """Get some info about the limiter and the defined windows diff --git a/src/HABApp/util/rate_limiter/limits/base.py b/src/HABApp/util/rate_limiter/limits/base.py index ba68c156..ef8ae87d 100644 --- a/src/HABApp/util/rate_limiter/limits/base.py +++ b/src/HABApp/util/rate_limiter/limits/base.py @@ -14,15 +14,16 @@ def hits_remaining(self) -> int: class BaseRateLimit: - def __init__(self, allowed: int, interval: int): + def __init__(self, allowed: int, interval: int, hits: int = 0): super().__init__() assert allowed > 0, allowed assert interval > 0, interval + assert 0 <= hits <= allowed self.interval: Final = interval self.allowed: Final = allowed - self.hits: int = 0 + self.hits: int = hits self.skips: int = 0 def repr_text(self) -> str: diff --git a/src/HABApp/util/rate_limiter/limits/fixed_window.py b/src/HABApp/util/rate_limiter/limits/fixed_window.py index 48b15ba7..95cc4328 100644 --- a/src/HABApp/util/rate_limiter/limits/fixed_window.py +++ b/src/HABApp/util/rate_limiter/limits/fixed_window.py @@ -10,8 +10,8 @@ class FixedWindowElasticExpiryLimitInfo(BaseRateLimitInfo): class FixedWindowElasticExpiryLimit(BaseRateLimit): - def __init__(self, allowed: int, interval: int): - super().__init__(allowed, interval) + def __init__(self, allowed: int, interval: int, hits: int = 0): + super().__init__(allowed, interval, hits) self.start: float = -1.0 self.stop: float = -1.0 diff --git a/src/HABApp/util/rate_limiter/limits/leaky_bucket.py b/src/HABApp/util/rate_limiter/limits/leaky_bucket.py index 9f61ed7b..ed91eeb9 100644 --- a/src/HABApp/util/rate_limiter/limits/leaky_bucket.py +++ b/src/HABApp/util/rate_limiter/limits/leaky_bucket.py @@ -11,8 +11,8 @@ class LeakyBucketLimitInfo(BaseRateLimitInfo): class LeakyBucketLimit(BaseRateLimit): - def __init__(self, allowed: int, interval: int): - super().__init__(allowed, interval) + def __init__(self, allowed: int, interval: int, hits: int = 0): + super().__init__(allowed, interval, hits) self.drop_interval: Final = interval / allowed self.next_drop: float = -1.0 diff --git a/tests/test_utils/test_rate_limiter.py b/tests/test_utils/test_rate_limiter.py index 87c68de0..b68aa401 100644 --- a/tests/test_utils/test_rate_limiter.py +++ b/tests/test_utils/test_rate_limiter.py @@ -59,7 +59,7 @@ def test_parse(unit: str, factor: int): assert str(e.value) == 'Invalid limit string: "asdf"' -def test_regex_all_units(): +def test_parse_regex_all_units(): m = re.search(r'\(([^)]+)\)s\?', LIMIT_REGEX.pattern) values = m.group(1) @@ -158,6 +158,22 @@ def test_limiter_add(time): limiter.add_limit(3, 5).add_limit(3, 5).parse_limits('3 in 5s') assert len(limiter._limits) == 1 + with pytest.raises(ValueError) as e: + limiter.add_limit(0, 5) + assert str(e.value) == "Parameter allowed must be an int > 0, is 0 ()" + + with pytest.raises(ValueError) as e: + limiter.add_limit(1, 0.5) + assert str(e.value) == "Parameter interval must be an int > 0, is 0.5 ()" + + with pytest.raises(ValueError) as e: + limiter.add_limit(3, 5, hits=-1) + assert str(e.value) == "Parameter hits must be an int >= 0, is -1 ()" + + with pytest.raises(ValueError) as e: + limiter.add_limit(3, 5, hits=5) + assert str(e.value) == "Parameter hits must be <= parameter allowed! 5 <= 3!" + def test_fixed_window_info(time): limit = FixedWindowElasticExpiryLimit(5, 3) @@ -208,7 +224,8 @@ def test_limiter(time): with pytest.raises(ValueError): limiter.allow() - limiter.add_limit(2, 1).add_limit(2, 2) + limiter.add_limit( + 2, 1, algorithm='fixed_window_elastic_expiry').add_limit(2, 2, algorithm='fixed_window_elastic_expiry') assert limiter.allow() assert limiter.allow()