Skip to content

Commit

Permalink
Add argument raise_on_exception_class to Retrier
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Nov 27, 2024
1 parent 795054c commit 20049c2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## Next version

### ✨ Improved

* Add option to `Retrier` to immediately raise an exception if the exception class matches a given list of exceptions.


## 0.4.1 - November 27, 2024

### 🚀 New
Expand Down
10 changes: 9 additions & 1 deletion src/lvmopstools/retrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import asyncio
import inspect
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import wraps

from typing import (
Expand Down Expand Up @@ -75,6 +75,9 @@ async def test_function(x, y):
on_retry
A function that will be called when a retry is attempted. The function
should accept an exception as its only argument.
raise_on_exception_class
A list of exception classes that will cause an exception to be raised
without retrying.
"""

Expand All @@ -84,6 +87,7 @@ async def test_function(x, y):
exponential_backoff_base: float = 2
max_delay: float = 32.0
on_retry: Callable[[Exception], None] | None = None
raise_on_exception_class: list[type[Exception]] = field(default_factory=list)

def calculate_delay(self, attempt: int) -> float:
"""Calculates the delay for a given attempt."""
Expand Down Expand Up @@ -129,6 +133,8 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
attempt += 1
if attempt >= self.max_attempts:
raise ee
elif isinstance(ee, tuple(self.raise_on_exception_class)):
raise ee
else:
if self.on_retry:
self.on_retry(ee)
Expand All @@ -148,6 +154,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs):
attempt += 1
if attempt >= self.max_attempts:
raise ee
elif isinstance(ee, tuple(self.raise_on_exception_class)):
raise ee
else:
if self.on_retry:
self.on_retry(ee)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_retrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,31 @@ async def test_retrier_async(
assert on_retry_mock.call_count == 1
else:
assert on_retry_mock.call_count == 0


def test_retier_raise_on_exception_class():
def raise_runtime_error():
raise RuntimeError()

on_retry_mock.reset_mock()
retrier = Retrier(raise_on_exception_class=[RuntimeError], on_retry=on_retry_mock)
test_function = retrier(raise_runtime_error)

with pytest.raises(RuntimeError):
test_function()

on_retry_mock.assert_not_called()


async def test_retier_raise_on_exception_class_async():
async def raise_runtime_error():
raise RuntimeError()

on_retry_mock.reset_mock()
retrier = Retrier(raise_on_exception_class=[RuntimeError], on_retry=on_retry_mock)
test_function = retrier(raise_runtime_error)

with pytest.raises(RuntimeError):
await test_function()

on_retry_mock.assert_not_called()

0 comments on commit 20049c2

Please sign in to comment.