diff --git a/backoff/_async.py b/backoff/_async.py index c9113cd..f7c1cc3 100644 --- a/backoff/_async.py +++ b/backoff/_async.py @@ -65,6 +65,9 @@ async def retry(*args, **kwargs): wait = _init_wait_gen(wait_gen, wait_gen_kwargs) while True: tries += 1 + + ret = await target(*args, **kwargs) + elapsed = timedelta.total_seconds(datetime.datetime.now() - start) details = { "target": target, @@ -73,8 +76,6 @@ async def retry(*args, **kwargs): "tries": tries, "elapsed": elapsed, } - - ret = await target(*args, **kwargs) if predicate(ret): max_tries_exceeded = (tries == max_tries) max_time_exceeded = (max_time is not None and @@ -140,18 +141,19 @@ async def retry(*args, **kwargs): wait = _init_wait_gen(wait_gen, wait_gen_kwargs) while True: tries += 1 - elapsed = timedelta.total_seconds(datetime.datetime.now() - start) details = { "target": target, "args": args, "kwargs": kwargs, "tries": tries, - "elapsed": elapsed, } try: ret = await target(*args, **kwargs) except exception as e: + elapsed = timedelta.total_seconds(datetime.datetime.now() - start) + details["elapsed"] = elapsed + giveup_result = await giveup(e) max_tries_exceeded = (tries == max_tries) max_time_exceeded = (max_time is not None and @@ -180,6 +182,9 @@ async def retry(*args, **kwargs): # await asyncio.sleep(seconds) else: + details["elapsed"] = timedelta.total_seconds( + datetime.datetime.now() - start + ) await _call_handlers(on_success, **details) return ret diff --git a/backoff/_sync.py b/backoff/_sync.py index 151924c..0b73c9b 100644 --- a/backoff/_sync.py +++ b/backoff/_sync.py @@ -39,6 +39,9 @@ def retry(*args, **kwargs): wait = _init_wait_gen(wait_gen, wait_gen_kwargs) while True: tries += 1 + + ret = target(*args, **kwargs) + elapsed = timedelta.total_seconds(datetime.datetime.now() - start) details = { "target": target, @@ -47,8 +50,6 @@ def retry(*args, **kwargs): "tries": tries, "elapsed": elapsed, } - - ret = target(*args, **kwargs) if predicate(ret): max_tries_exceeded = (tries == max_tries) max_time_exceeded = (max_time is not None and @@ -97,18 +98,19 @@ def retry(*args, **kwargs): wait = _init_wait_gen(wait_gen, wait_gen_kwargs) while True: tries += 1 - elapsed = timedelta.total_seconds(datetime.datetime.now() - start) details = { "target": target, "args": args, "kwargs": kwargs, "tries": tries, - "elapsed": elapsed, } try: ret = target(*args, **kwargs) except exception as e: + elapsed = timedelta.total_seconds(datetime.datetime.now() - start) + details["elapsed"] = elapsed + max_tries_exceeded = (tries == max_tries) max_time_exceeded = (max_time is not None and elapsed >= max_time) @@ -127,6 +129,9 @@ def retry(*args, **kwargs): time.sleep(seconds) else: + details["elapsed"] = timedelta.total_seconds( + datetime.datetime.now() - start + ) _call_handlers(on_success, **details) return ret diff --git a/tests/test_backoff.py b/tests/test_backoff.py index 0af6cfc..a0e2053 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -77,6 +77,51 @@ def return_true(log, n): assert len(log) == 3 +def test_max_time(monkeypatch): + + start = datetime.datetime.now() + elapsed = datetime.timedelta() + + def patch_sleep(n): + nonlocal elapsed + elapsed += datetime.timedelta(seconds=n) + + class Datetime: + @staticmethod + def now(): + nonlocal start + nonlocal elapsed + return start + elapsed + + monkeypatch.setattr("time.sleep", patch_sleep) + monkeypatch.setattr("datetime.datetime", Datetime) + + # A good place for property-based testing + for function_runtime, max_time in itertools.product(range(10), repeat=2): + elapsed = datetime.timedelta() + + @backoff.on_exception(backoff.constant, RuntimeError, max_time=max_time) + def on_exception(): + patch_sleep(function_runtime) + raise + + try: + on_exception() + except: + pass + + assert elapsed <= datetime.timedelta(seconds=max_time + function_runtime) + + elapsed = datetime.timedelta() + + @backoff.on_predicate(backoff.constant, lambda x: False, max_time=max_time) + def on_predicate(): + patch_sleep(function_runtime) + + on_predicate() + assert elapsed <= datetime.timedelta(seconds=max_time + function_runtime) + + def test_on_exception(monkeypatch): monkeypatch.setattr('time.sleep', lambda x: None) diff --git a/tests/test_backoff_async.py b/tests/test_backoff_async.py index ca62c4a..3dbcf38 100644 --- a/tests/test_backoff_async.py +++ b/tests/test_backoff_async.py @@ -2,6 +2,8 @@ import asyncio # Python 3.5 code and syntax is allowed in this file import backoff +import datetime +import itertools import pytest import random @@ -566,6 +568,52 @@ async def exceptor(): assert len(log) == 3 +@pytest.mark.asyncio +async def test_max_time(monkeypatch): + + start = datetime.datetime.now() + elapsed = datetime.timedelta() + + async def patch_sleep(n): + nonlocal elapsed + elapsed += datetime.timedelta(seconds=n) + + class Datetime: + @staticmethod + def now(): + nonlocal start + nonlocal elapsed + return start + elapsed + + monkeypatch.setattr('asyncio.sleep', patch_sleep) + monkeypatch.setattr("datetime.datetime", Datetime) + + # A good place for property-based testing + for function_runtime, max_time in itertools.product(range(10), repeat=2): + elapsed = datetime.timedelta() + + @backoff.on_exception(backoff.constant, RuntimeError, max_time=max_time) + async def on_exception(): + await patch_sleep(function_runtime) + raise + + try: + await on_exception() + except: + pass + + assert elapsed <= datetime.timedelta(seconds=max_time + function_runtime) + + elapsed = datetime.timedelta() + + @backoff.on_predicate(backoff.constant, lambda x: False, max_time=max_time) + async def on_predicate(): + await patch_sleep(function_runtime) + + await on_predicate() + assert elapsed <= datetime.timedelta(seconds=max_time + function_runtime) + + @pytest.mark.asyncio async def test_on_exception_callable_gen_kwargs():