diff --git a/tasktiger/worker.py b/tasktiger/worker.py index e6d1252..b99482d 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -771,8 +771,12 @@ def _mark_done() -> None: has_job_timeout = True if execution and execution.get("retry"): + # Prefer retry method from the execution, then the task, then + # default. if "retry_method" in execution: retry_func, retry_args = execution["retry_method"] + elif task.retry_method: + retry_func, retry_args = task.retry_method else: # We expect the serialized method here. retry_func, retry_args = serialize_retry_method( diff --git a/tests/tasks.py b/tests/tasks.py index 6923c14..9c46599 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -135,6 +135,11 @@ def retry_task_2(): raise RetryException(method=fixed(DELAY, 1), log_error=False) +@tiger.task(retry_method=fixed(DELAY, 1)) +def retry_task_3(): + raise RetryException(log_error=False) + + def verify_current_task(): with redis.Redis( host=REDIS_HOST, db=TEST_DB, decode_responses=True diff --git a/tests/test_base.py b/tests/test_base.py index d9ba466..29a7bd1 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -39,6 +39,7 @@ non_batch_task, retry_task, retry_task_2, + retry_task_3, simple_task, sleep_task, task_on_other_queue, @@ -617,6 +618,23 @@ def test_retry_exception_2(self): pytest.raises(TaskNotFound, task.n_executions) + def test_retry_exception_3(self): + task = self.tiger.delay(retry_task_3) + self._ensure_queues(queued={"default": 1}) + assert task.n_executions() == 0 + + Worker(self.tiger).run(once=True) + self._ensure_queues(scheduled={"default": 1}) + assert task.n_executions() == 1 + + time.sleep(DELAY) + + Worker(self.tiger).run(once=True) + Worker(self.tiger).run(once=True) + self._ensure_queues() + + pytest.raises(TaskNotFound, task.n_executions) + @pytest.mark.parametrize("count", [1, 3, 7]) def test_retry_executions_count(self, count): task = self.tiger.delay(exception_task, retry_method=fixed(DELAY, 20))