From d02d260a192f8923dfa331ae2b503af57534654a Mon Sep 17 00:00:00 2001 From: Denis Podlesniy Date: Tue, 5 Dec 2017 07:53:01 +0200 Subject: [PATCH] Add check task-synchronous-subtasks for task_always_eager mode (#4322) --- celery/app/task.py | 7 ++++--- celery/result.py | 36 +++++++++++++++++++++++++++++------- t/unit/tasks/test_result.py | 9 +++++++++ 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/celery/app/task.py b/celery/app/task.py index 6e3b8ad1732..691c668530d 100644 --- a/celery/app/task.py +++ b/celery/app/task.py @@ -14,7 +14,7 @@ from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry from celery.five import items, python_2_unicode_compatible from celery.local import class_property -from celery.result import EagerResult +from celery.result import EagerResult, denied_join_result from celery.utils import abstract from celery.utils.functional import mattrgetter, maybe_list from celery.utils.imports import instantiate @@ -521,8 +521,9 @@ def apply_async(self, args=None, kwargs=None, task_id=None, producer=None, app = self._get_app() if app.conf.task_always_eager: - return self.apply(args, kwargs, task_id=task_id or uuid(), - link=link, link_error=link_error, **options) + with denied_join_result(): + return self.apply(args, kwargs, task_id=task_id or uuid(), + link=link, link_error=link_error, **options) # add 'self' if this is a "task_method". if self.__self__ is not None: args = args if isinstance(args, tuple) else tuple(args or ()) diff --git a/celery/result.py b/celery/result.py index 39c13aea7ea..bfb6af82c21 100644 --- a/celery/result.py +++ b/celery/result.py @@ -51,6 +51,16 @@ def allow_join_result(): _set_task_join_will_block(reset_value) +@contextmanager +def denied_join_result(): + reset_value = task_join_will_block() + _set_task_join_will_block(True) + try: + yield + finally: + _set_task_join_will_block(reset_value) + + class ResultBase(object): """Base class for results.""" @@ -617,7 +627,8 @@ def iterate(self, timeout=None, propagate=True, interval=0.5): raise TimeoutError('The operation timed out') def get(self, timeout=None, propagate=True, interval=0.5, - callback=None, no_ack=True, on_message=None): + callback=None, no_ack=True, on_message=None, + disable_sync_subtasks=True): """See :meth:`join`. This is here for API compatibility with :class:`AsyncResult`, @@ -629,11 +640,12 @@ def get(self, timeout=None, propagate=True, interval=0.5, return (self.join_native if self.supports_native_join else self.join)( timeout=timeout, propagate=propagate, interval=interval, callback=callback, no_ack=no_ack, - on_message=on_message, + on_message=on_message, disable_sync_subtasks=disable_sync_subtasks ) def join(self, timeout=None, propagate=True, interval=0.5, - callback=None, no_ack=True, on_message=None, on_interval=None): + callback=None, no_ack=True, on_message=None, + disable_sync_subtasks=True, on_interval=None): """Gather the results of all tasks as a list in order. Note: @@ -669,13 +681,17 @@ def join(self, timeout=None, propagate=True, interval=0.5, no_ack (bool): Automatic message acknowledgment (Note that if this is set to :const:`False` then the messages *will not be acknowledged*). + disable_sync_subtasks (bool): Disable tasks to wait for sub tasks + this is the default configuration. CAUTION do not enable this + unless you must. Raises: celery.exceptions.TimeoutError: if ``timeout`` isn't :const:`None` and the operation takes longer than ``timeout`` seconds. """ - assert_will_not_block() + if disable_sync_subtasks: + assert_will_not_block() time_start = monotonic() remaining = None @@ -723,7 +739,8 @@ def iter_native(self, timeout=None, interval=0.5, no_ack=True, def join_native(self, timeout=None, propagate=True, interval=0.5, callback=None, no_ack=True, - on_message=None, on_interval=None): + on_message=None, on_interval=None, + disable_sync_subtasks=True): """Backend optimized version of :meth:`join`. .. versionadded:: 2.2 @@ -734,7 +751,8 @@ def join_native(self, timeout=None, propagate=True, This is currently only supported by the amqp, Redis and cache result backends. """ - assert_will_not_block() + if disable_sync_subtasks: + assert_will_not_block() order_index = None if callback else { result.id: i for i, result in enumerate(self.results) } @@ -916,7 +934,11 @@ def __copy__(self): def ready(self): return True - def get(self, timeout=None, propagate=True, **kwargs): + def get(self, timeout=None, propagate=True, + disable_sync_subtasks=True, **kwargs): + if disable_sync_subtasks: + assert_will_not_block() + if self.successful(): return self.result elif self.state in states.PROPAGATE_STATES: diff --git a/t/unit/tasks/test_result.py b/t/unit/tasks/test_result.py index 02a23a5bc3e..8499819ed4a 100644 --- a/t/unit/tasks/test_result.py +++ b/t/unit/tasks/test_result.py @@ -872,6 +872,15 @@ def test_revoke(self): res = self.raising.apply(args=[3, 3]) assert not res.revoke() + @patch('celery.result.task_join_will_block') + def test_get_sync_subtask_option(self, task_join_will_block): + task_join_will_block.return_value = True + tid = uuid() + res_subtask_async = EagerResult(tid, 'x', 'x', states.SUCCESS) + with pytest.raises(RuntimeError): + res_subtask_async.get() + res_subtask_async.get(disable_sync_subtasks=False) + class test_tuples: