diff --git a/CHANGELOG.md b/CHANGELOG.md index 826c66c..641d48e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## Next version + +### ✨ Improved + +* Add option to `Retrier` to immediately raise an exception if the exception class matches a given list of exceptions. + + ## 0.4.1 - November 27, 2024 ### 🚀 New diff --git a/src/lvmopstools/retrier.py b/src/lvmopstools/retrier.py index 27108ce..2a22815 100644 --- a/src/lvmopstools/retrier.py +++ b/src/lvmopstools/retrier.py @@ -11,7 +11,7 @@ import asyncio import inspect import time -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import wraps from typing import ( @@ -75,6 +75,9 @@ async def test_function(x, y): on_retry A function that will be called when a retry is attempted. The function should accept an exception as its only argument. + raise_on_exception_class + A list of exception classes that will cause an exception to be raised + without retrying. """ @@ -84,6 +87,7 @@ async def test_function(x, y): exponential_backoff_base: float = 2 max_delay: float = 32.0 on_retry: Callable[[Exception], None] | None = None + raise_on_exception_class: list[type[Exception]] = field(default_factory=list) def calculate_delay(self, attempt: int) -> float: """Calculates the delay for a given attempt.""" @@ -129,6 +133,8 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs): attempt += 1 if attempt >= self.max_attempts: raise ee + elif isinstance(ee, tuple(self.raise_on_exception_class)): + raise ee else: if self.on_retry: self.on_retry(ee) @@ -148,6 +154,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs): attempt += 1 if attempt >= self.max_attempts: raise ee + elif isinstance(ee, tuple(self.raise_on_exception_class)): + raise ee else: if self.on_retry: self.on_retry(ee) diff --git a/tests/test_retrier.py b/tests/test_retrier.py index a0ca262..a75eeaf 100644 --- a/tests/test_retrier.py +++ b/tests/test_retrier.py @@ -126,3 +126,31 @@ async def test_retrier_async( assert on_retry_mock.call_count == 1 else: assert on_retry_mock.call_count == 0 + + +def test_retier_raise_on_exception_class(): + def raise_runtime_error(): + raise RuntimeError() + + on_retry_mock.reset_mock() + retrier = Retrier(raise_on_exception_class=[RuntimeError], on_retry=on_retry_mock) + test_function = retrier(raise_runtime_error) + + with pytest.raises(RuntimeError): + test_function() + + on_retry_mock.assert_not_called() + + +async def test_retier_raise_on_exception_class_async(): + async def raise_runtime_error(): + raise RuntimeError() + + on_retry_mock.reset_mock() + retrier = Retrier(raise_on_exception_class=[RuntimeError], on_retry=on_retry_mock) + test_function = retrier(raise_runtime_error) + + with pytest.raises(RuntimeError): + await test_function() + + on_retry_mock.assert_not_called()