Skip to content

Commit

Permalink
Allow more flexible config of retry decorator (#289)
Browse files Browse the repository at this point in the history
Option to specify retry exceptions as a dictionary instead of a tuple.
Values should be a callable determining whether a specific exception
object should be retied or not.

__Example:__

``` python
@Retry(
   exceptions = {ValueError: lambda x: "Invalid" not in str(x)}
)
def func() -> None:
   value = some_function()

   if value is None:
       raise ValueError("Could not retrieve value")

   if not_valid(value):
       raise ValueError(f"Invalid value: {value}")
```

Add requests_exceptions template, so you could do e.g.

``` python
retry(exceptions = request_exceptions())
```

if you're using requests
  • Loading branch information
mathialo authored Feb 7, 2024
1 parent dd2cd9b commit 9bd35c5
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 20 deletions.
29 changes: 29 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,35 @@ Changes are grouped as follows
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.

## [6.4.0]

### Added

* Option to specify retry exceptions as a dictionary instead of a tuple. Values should be a callable determining whether a specific exception object should be retied or not. Example:
``` python
@retry(
exceptions = {ValueError: lambda x: "Invalid" not in str(x)}
)
def func() -> None:
value = some_function()

if value is None:
raise ValueError("Could not retrieve value")

if not_valid(value):
raise ValueError(f"Invalid value: {value}")
```

* Templates for common retry scenarios. For example, if you're using the `requests` library, you can do

``` python
retry(exceptions = request_exceptions())
```

### Changed

* Default parameters in `retry` has changed to be less agressive. Retries will apply backoff by default, and give up after 10 retries.

## [6.3.2]

### Added
Expand Down
2 changes: 1 addition & 1 deletion cognite/extractorutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
Cognite extractor utils is a Python package that simplifies the development of new extractors.
"""

__version__ = "6.3.2"
__version__ = "6.4.0"
from .base import Extractor
117 changes: 100 additions & 17 deletions cognite/extractorutils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from functools import partial, wraps
from threading import Event, Thread
from time import time
from typing import Any, Callable, Generator, Iterable, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Type, TypeVar, Union

from decorator import decorator

Expand Down Expand Up @@ -329,26 +329,41 @@ def throttled_loop(target_time: int, cancellation_token: Event) -> Generator[Non

def _retry_internal(
f: Callable[..., _T2],
cancellation_token: threading.Event = threading.Event(),
exceptions: Tuple[Type[Exception], ...] = (Exception,),
tries: int = -1,
delay: float = 0,
max_delay: Optional[float] = None,
backoff: float = 1,
jitter: Union[float, Tuple[float, float]] = 0,
cancellation_token: threading.Event,
exceptions: Union[Tuple[Type[Exception], ...], Dict[Type[Exception], Callable[[Exception], bool]]],
tries: int,
delay: float,
max_delay: Optional[float],
backoff: float,
jitter: Union[float, Tuple[float, float]],
) -> _T2:
logger = logging.getLogger(__name__)

while tries and not cancellation_token.is_set():
try:
return f()
except exceptions as e:

except Exception as e:
if isinstance(exceptions, tuple):
for ex_type in exceptions:
if isinstance(e, ex_type):
break
else:
raise e

else:
for ex_type in exceptions:
if isinstance(e, ex_type) and exceptions[ex_type](e):
break
else:
raise e

tries -= 1
if not tries:
raise e

if logger is not None:
logger.warning("%s, retrying in %s seconds...", str(e), delay)
logger.warning("%s, retrying in %.1f seconds...", str(e), delay)

cancellation_token.wait(delay)
delay *= backoff
Expand All @@ -366,12 +381,12 @@ def _retry_internal(

def retry(
cancellation_token: threading.Event = threading.Event(),
exceptions: Tuple[Type[Exception], ...] = (Exception,),
tries: int = -1,
delay: float = 0,
max_delay: Optional[float] = None,
backoff: float = 1,
jitter: Union[float, Tuple[float, float]] = 0,
exceptions: Union[Tuple[Type[Exception], ...], Dict[Type[Exception], Callable[[Any], bool]]] = (Exception,),
tries: int = 10,
delay: float = 1,
max_delay: Optional[float] = 60,
backoff: float = 2,
jitter: Union[float, Tuple[float, float]] = (0, 2),
) -> Callable[[Callable[..., _T2]], Callable[..., _T2]]:
"""
Returns a retry decorator.
Expand All @@ -380,7 +395,9 @@ def retry(
Args:
cancellation_token: a threading token that is waited on.
exceptions: an exception or a tuple of exceptions to catch. default: Exception.
exceptions: a tuple of exceptions to catch, or a dictionary from exception types to a callback determining
whether to retry the exception or not. The callback will be given the exception object as argument.
default: retry all exceptions.
tries: the maximum number of attempts. default: -1 (infinite).
delay: initial delay between attempts. default: 0.
max_delay: the maximum value of delay. default: None (no limit).
Expand Down Expand Up @@ -408,3 +425,69 @@ def retry_decorator(f: Callable[..., _T2], *fargs: Any, **fkwargs: Any) -> _T2:
)

return retry_decorator


def requests_exceptions(
status_codes: List[int] = [408, 425, 429, 500, 502, 503, 504],
) -> Dict[Type[Exception], Callable[[Any], bool]]:
"""
Retry exceptions from using the ``requests`` library. This will retry all connection and HTTP errors matching
the given status codes.
Example:
.. code-block:: python
@retry(exceptions = requests_exceptions())
def my_function() -> None:
...
"""
# types ignored, since they are not installed as we don't depend on the package
from requests.exceptions import HTTPError, RequestException # type: ignore

def handle_http_errors(exception: RequestException) -> bool:
if isinstance(exception, HTTPError):
response = exception.response
if response is None:
return True

return response.status_code in status_codes

else:
return True

return {RequestException: handle_http_errors}


def httpx_exceptions(
status_codes: List[int] = [408, 425, 429, 500, 502, 503, 504],
) -> Dict[Type[Exception], Callable[[Any], bool]]:
"""
Retry exceptions from using the ``httpx`` library. This will retry all connection and HTTP errors matching
the given status codes.
Example:
.. code-block:: python
@retry(exceptions = httpx_exceptions())
def my_function() -> None:
...
"""
# types ignored, since they are not installed as we don't depend on the package
from httpx import HTTPError, HTTPStatusError # type: ignore

def handle_http_errors(exception: HTTPError) -> bool:
if isinstance(exception, HTTPStatusError):
response = exception.response
if response is None:
return True

return response.status_code in status_codes

else:
return True

return {HTTPError: handle_http_errors}
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cognite-extractor-utils"
version = "6.3.2"
version = "6.4.0"
description = "Utilities for easier development of extractors for CDF"
authors = ["Mathias Lohne <[email protected]>"]
license = "Apache-2.0"
Expand Down Expand Up @@ -80,6 +80,9 @@ SecretStorage = "^3.1.2"
twine = "^4.0.0"
pytest-order = "^1.0.1"
parameterized = "*"
requests = "^2.31.0"
types-requests = "^2.31.0.20240125"
httpx = "^0.26.0"

[build-system]
requires = ["poetry>=0.12"]
Expand Down
157 changes: 156 additions & 1 deletion tests/tests_unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
import unittest
from unittest.mock import Mock, patch

import httpx
import requests

from cognite.client import CogniteClient
from cognite.client.data_classes import Asset, TimeSeries
from cognite.client.exceptions import CogniteNotFoundError
from cognite.extractorutils.util import EitherId, ensure_assets, ensure_time_series
from cognite.extractorutils.util import (
EitherId,
ensure_assets,
ensure_time_series,
httpx_exceptions,
requests_exceptions,
retry,
)


class TestEnsureTimeSeries(unittest.TestCase):
Expand Down Expand Up @@ -134,3 +144,148 @@ def test_hash(self):

def test_repr(self):
self.assertEqual(EitherId(externalId="extId").__repr__(), "externalId: extId")


class TestRetries(unittest.TestCase):
def test_simple_retry(self) -> None:
mock = Mock()

@retry(tries=3, delay=0, jitter=0)
def call_mock() -> None:
mock()
raise ValueError()

with self.assertRaises(ValueError):
call_mock()

self.assertEqual(len(mock.call_args_list), 3)

def test_simple_retry_specified(self) -> None:
mock = Mock()

@retry(tries=3, delay=0, jitter=0, exceptions=(ValueError,))
def call_mock() -> None:
mock()
raise ValueError()

with self.assertRaises(ValueError):
call_mock()

self.assertEqual(len(mock.call_args_list), 3)

def test_not_retry_unspecified(self) -> None:
mock = Mock()

@retry(tries=3, delay=0, jitter=0, exceptions=(TypeError,))
def call_mock() -> None:
mock()
raise ValueError()

with self.assertRaises(ValueError):
call_mock()

self.assertEqual(len(mock.call_args_list), 1)

def test_retry_conditional(self) -> None:
mock = Mock()

@retry(tries=3, delay=0, jitter=0, exceptions={ValueError: lambda x: "Invalid" not in str(x)})
def call_mock(is_none: bool) -> None:
mock()

if is_none:
raise ValueError("Could not retrieve value")
else:
raise ValueError("Invalid value: 1234")

with self.assertRaises(ValueError):
call_mock(True)

self.assertEqual(len(mock.call_args_list), 3)

mock.reset_mock()

with self.assertRaises(ValueError):
call_mock(False)

self.assertEqual(len(mock.call_args_list), 1)

def test_retry_requests(self) -> None:
mock = Mock()

@retry(tries=3, delay=0, jitter=0, exceptions=requests_exceptions())
def call_mock() -> None:
mock()
requests.get("http://localhost:1234/nope")

with self.assertRaises(requests.ConnectionError):
call_mock()

self.assertEqual(len(mock.call_args_list), 3)
mock.reset_mock()

# 404 should not be retried
@retry(tries=3, delay=0, jitter=0, exceptions=requests_exceptions())
def call_mock2() -> None:
mock()
res = requests.Response()
res.status_code = 404
res.raise_for_status()

with self.assertRaises(requests.HTTPError):
call_mock2()

self.assertEqual(len(mock.call_args_list), 1)
mock.reset_mock()

# 429 should be retried
@retry(tries=3, delay=0, jitter=0, exceptions=requests_exceptions())
def call_mock3() -> None:
mock()
res = requests.Response()
res.status_code = 429
res.raise_for_status()

with self.assertRaises(requests.HTTPError):
call_mock3()

self.assertEqual(len(mock.call_args_list), 3)

def test_httpx_requests(self) -> None:
mock = Mock()

@retry(tries=3, delay=0, jitter=0, exceptions=httpx_exceptions())
def call_mock() -> None:
mock()
httpx.get("http://localhost:1234/nope")

with self.assertRaises(httpx.ConnectError):
call_mock()

self.assertEqual(len(mock.call_args_list), 3)
mock.reset_mock()

# 404 should not be retried
@retry(tries=3, delay=0, jitter=0, exceptions=httpx_exceptions())
def call_mock2() -> None:
mock()
res = httpx.Response(404, request=httpx.Request("GET", "http://localhost/"))
res.raise_for_status()

with self.assertRaises(httpx.HTTPError):
call_mock2()

self.assertEqual(len(mock.call_args_list), 1)
mock.reset_mock()

# 429 should be retried
@retry(tries=3, delay=0, jitter=0, exceptions=httpx_exceptions())
def call_mock3() -> None:
mock()
res = httpx.Response(429, request=httpx.Request("GET", "http://localhost/"))
res.raise_for_status()

with self.assertRaises(httpx.HTTPError):
call_mock3()

self.assertEqual(len(mock.call_args_list), 3)

0 comments on commit 9bd35c5

Please sign in to comment.