diff --git a/backoff/_common.py b/backoff/_common.py index 2b2e54e..b37d4ab 100644 --- a/backoff/_common.py +++ b/backoff/_common.py @@ -31,6 +31,13 @@ def _init_wait_gen(wait_gen, wait_gen_kwargs): return initialized +def _get_func(func): + if isinstance(func, staticmethod): + return func.__func__ + else: + return func + + def _next_wait(wait, send_value, jitter, elapsed, max_time): value = wait.send(send_value) try: @@ -88,7 +95,7 @@ def _config_handlers( # append a single handler handlers.append(user_handlers) - return handlers + return [_get_func(h) for h in handlers] # Default backoff handler diff --git a/backoff/_decorator.py b/backoff/_decorator.py index 77ed8c2..71420bd 100644 --- a/backoff/_decorator.py +++ b/backoff/_decorator.py @@ -8,7 +8,8 @@ _prepare_logger, _config_handlers, _log_backoff, - _log_giveup + _log_giveup, + _get_func ) from backoff._jitter import full_jitter from backoff import _async, _sync @@ -83,6 +84,8 @@ def on_predicate(wait_gen: _WaitGenerator, def decorate(target): nonlocal logger, on_success, on_backoff, on_giveup + target_func = _get_func(target) + logger = _prepare_logger(logger) on_success = _config_handlers(on_success) on_backoff = _config_handlers( @@ -98,13 +101,13 @@ def decorate(target): log_level=giveup_log_level ) - if asyncio.iscoroutinefunction(target): + if asyncio.iscoroutinefunction(target_func): retry = _async.retry_predicate else: retry = _sync.retry_predicate return retry( - target, + target_func, wait_gen, predicate, max_tries=max_tries, @@ -198,13 +201,15 @@ def decorate(target): log_level=giveup_log_level, ) - if asyncio.iscoroutinefunction(target): + target_func = _get_func(target) + + if asyncio.iscoroutinefunction(target_func): retry = _async.retry_exception else: retry = _sync.retry_exception return retry( - target, + target_func, wait_gen, exception, max_tries=max_tries, diff --git a/tests/test_backoff.py b/tests/test_backoff.py index cd33b63..9bbe003 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -1,5 +1,6 @@ # coding:utf-8 import datetime +import inspect import itertools import logging import random @@ -29,6 +30,23 @@ def return_true(log, n): assert 3 == len(log) +def test_on_predicate_static(monkeypatch): + monkeypatch.setattr('time.sleep', lambda x: None) + + class A: + @backoff.on_predicate(backoff.expo) + @staticmethod + def return_true(log, n): + val = (len(log) == n - 1) + log.append(val) + return val + + log = [] + ret = A.return_true(log, 3) + assert ret is True + assert 3 == len(log) + + def test_on_predicate_max_tries(monkeypatch): monkeypatch.setattr('time.sleep', lambda x: None) @@ -129,6 +147,24 @@ def keyerror_then_true(log, n): assert 3 == len(log) +def test_on_exception_static(monkeypatch): + monkeypatch.setattr('time.sleep', lambda x: None) + + class A: + @backoff.on_exception(backoff.expo, KeyError) + @staticmethod + def keyerror_then_true(log, n): + if len(log) == n: + return True + e = KeyError() + log.append(e) + raise e + + log = [] + assert A.keyerror_then_true(log, 3) is True + assert 3 == len(log) + + def test_on_exception_tuple(monkeypatch): monkeypatch.setattr('time.sleep', lambda x: None) @@ -489,6 +525,60 @@ def emptiness(*args, **kwargs): 'value': None} +def test_on_static_predicate_iterable_handlers(): + + class Logger: + def __init__(self): + self.backoffs = [] + self.giveups = [] + self.successes = [] + + static_calls = [] + + loggers = [Logger() for _ in range(3)] + + class Tester: + @staticmethod + def static_log(details): + static_calls.append(details) + + @backoff.on_predicate( + backoff.constant, + on_backoff=[lg.backoffs.append for lg in loggers], + on_giveup=[lg.giveups.append for lg in loggers] + [static_log], + on_success=(lg.successes.append for lg in loggers), + max_tries=4, + jitter=None, + interval=0) + @staticmethod + def emptiness(*args, **kwargs): + pass + + Tester.emptiness(1, 2, 3, foo=1, bar=2) + + for logger in loggers: + + assert len(logger.successes) == 0 + assert len(logger.backoffs) == 3 + assert len(logger.giveups) == 1 + + details = dict(logger.giveups[0]) + print(details) + elapsed = details.pop('elapsed') + assert isinstance(elapsed, float) + # the staticmethod is instantiated individually for the different + # frames so _save_target won't save the same function. We will + # compare source code instead. + assert ( + inspect.getsource(details['target']) == inspect.getsource(Tester.emptiness) + ) + del details['target'] + assert details == {'args': (1, 2, 3), + 'kwargs': {'foo': 1, 'bar': 2}, + 'tries': 4, + 'value': None} + + # To maintain backward compatibility, # on_predicate should support 0-argument jitter function. def test_on_exception_success_0_arg_jitter(monkeypatch):