Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable use with staticmethods #201

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion backoff/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def _init_wait_gen(wait_gen, wait_gen_kwargs):
return initialized


def _get_func(func):
if isinstance(func, staticmethod):
return func.__func__
else:
return func


def _next_wait(wait, send_value, jitter, elapsed, max_time):
value = wait.send(send_value)
try:
Expand Down Expand Up @@ -88,7 +95,7 @@ def _config_handlers(
# append a single handler
handlers.append(user_handlers)

return handlers
return [_get_func(h) for h in handlers]


# Default backoff handler
Expand Down
15 changes: 10 additions & 5 deletions backoff/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
_prepare_logger,
_config_handlers,
_log_backoff,
_log_giveup
_log_giveup,
_get_func
)
from backoff._jitter import full_jitter
from backoff import _async, _sync
Expand Down Expand Up @@ -83,6 +84,8 @@ def on_predicate(wait_gen: _WaitGenerator,
def decorate(target):
nonlocal logger, on_success, on_backoff, on_giveup

target_func = _get_func(target)

logger = _prepare_logger(logger)
on_success = _config_handlers(on_success)
on_backoff = _config_handlers(
Expand All @@ -98,13 +101,13 @@ def decorate(target):
log_level=giveup_log_level
)

if asyncio.iscoroutinefunction(target):
if asyncio.iscoroutinefunction(target_func):
retry = _async.retry_predicate
else:
retry = _sync.retry_predicate

return retry(
target,
target_func,
wait_gen,
predicate,
max_tries=max_tries,
Expand Down Expand Up @@ -198,13 +201,15 @@ def decorate(target):
log_level=giveup_log_level,
)

if asyncio.iscoroutinefunction(target):
target_func = _get_func(target)

if asyncio.iscoroutinefunction(target_func):
retry = _async.retry_exception
else:
retry = _sync.retry_exception

return retry(
target,
target_func,
wait_gen,
exception,
max_tries=max_tries,
Expand Down
90 changes: 90 additions & 0 deletions tests/test_backoff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding:utf-8
import datetime
import inspect
import itertools
import logging
import random
Expand Down Expand Up @@ -29,6 +30,23 @@ def return_true(log, n):
assert 3 == len(log)


def test_on_predicate_static(monkeypatch):
monkeypatch.setattr('time.sleep', lambda x: None)

class A:
@backoff.on_predicate(backoff.expo)
@staticmethod
def return_true(log, n):
val = (len(log) == n - 1)
log.append(val)
return val

log = []
ret = A.return_true(log, 3)
assert ret is True
assert 3 == len(log)


def test_on_predicate_max_tries(monkeypatch):
monkeypatch.setattr('time.sleep', lambda x: None)

Expand Down Expand Up @@ -129,6 +147,24 @@ def keyerror_then_true(log, n):
assert 3 == len(log)


def test_on_exception_static(monkeypatch):
monkeypatch.setattr('time.sleep', lambda x: None)

class A:
@backoff.on_exception(backoff.expo, KeyError)
@staticmethod
def keyerror_then_true(log, n):
if len(log) == n:
return True
e = KeyError()
log.append(e)
raise e

log = []
assert A.keyerror_then_true(log, 3) is True
assert 3 == len(log)


def test_on_exception_tuple(monkeypatch):
monkeypatch.setattr('time.sleep', lambda x: None)

Expand Down Expand Up @@ -489,6 +525,60 @@ def emptiness(*args, **kwargs):
'value': None}


def test_on_static_predicate_iterable_handlers():

class Logger:
def __init__(self):
self.backoffs = []
self.giveups = []
self.successes = []

static_calls = []

loggers = [Logger() for _ in range(3)]

class Tester:
@staticmethod
def static_log(details):
static_calls.append(details)

@backoff.on_predicate(
backoff.constant,
on_backoff=[lg.backoffs.append for lg in loggers],
on_giveup=[lg.giveups.append for lg in loggers] + [static_log],
on_success=(lg.successes.append for lg in loggers),
max_tries=4,
jitter=None,
interval=0)
@staticmethod
def emptiness(*args, **kwargs):
pass

Tester.emptiness(1, 2, 3, foo=1, bar=2)

for logger in loggers:

assert len(logger.successes) == 0
assert len(logger.backoffs) == 3
assert len(logger.giveups) == 1

details = dict(logger.giveups[0])
print(details)
elapsed = details.pop('elapsed')
assert isinstance(elapsed, float)
# the staticmethod is instantiated individually for the different
# frames so _save_target won't save the same function. We will
# compare source code instead.
assert (
inspect.getsource(details['target']) == inspect.getsource(Tester.emptiness)
)
del details['target']
assert details == {'args': (1, 2, 3),
'kwargs': {'foo': 1, 'bar': 2},
'tries': 4,
'value': None}


# To maintain backward compatibility,
# on_predicate should support 0-argument jitter function.
def test_on_exception_success_0_arg_jitter(monkeypatch):
Expand Down