Skip to content

Commit

Permalink
Add check task-synchronous-subtasks for task_always_eager mode (celer…
Browse files Browse the repository at this point in the history
  • Loading branch information
haos616 authored and auvipy committed Dec 5, 2017
1 parent 1207709 commit d02d260
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
7 changes: 4 additions & 3 deletions celery/app/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
from celery.five import items, python_2_unicode_compatible
from celery.local import class_property
from celery.result import EagerResult
from celery.result import EagerResult, denied_join_result
from celery.utils import abstract
from celery.utils.functional import mattrgetter, maybe_list
from celery.utils.imports import instantiate
Expand Down Expand Up @@ -521,8 +521,9 @@ def apply_async(self, args=None, kwargs=None, task_id=None, producer=None,

app = self._get_app()
if app.conf.task_always_eager:
return self.apply(args, kwargs, task_id=task_id or uuid(),
link=link, link_error=link_error, **options)
with denied_join_result():
return self.apply(args, kwargs, task_id=task_id or uuid(),
link=link, link_error=link_error, **options)
# add 'self' if this is a "task_method".
if self.__self__ is not None:
args = args if isinstance(args, tuple) else tuple(args or ())
Expand Down
36 changes: 29 additions & 7 deletions celery/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def allow_join_result():
_set_task_join_will_block(reset_value)


@contextmanager
def denied_join_result():
reset_value = task_join_will_block()
_set_task_join_will_block(True)
try:
yield
finally:
_set_task_join_will_block(reset_value)


class ResultBase(object):
"""Base class for results."""

Expand Down Expand Up @@ -617,7 +627,8 @@ def iterate(self, timeout=None, propagate=True, interval=0.5):
raise TimeoutError('The operation timed out')

def get(self, timeout=None, propagate=True, interval=0.5,
callback=None, no_ack=True, on_message=None):
callback=None, no_ack=True, on_message=None,
disable_sync_subtasks=True):
"""See :meth:`join`.
This is here for API compatibility with :class:`AsyncResult`,
Expand All @@ -629,11 +640,12 @@ def get(self, timeout=None, propagate=True, interval=0.5,
return (self.join_native if self.supports_native_join else self.join)(
timeout=timeout, propagate=propagate,
interval=interval, callback=callback, no_ack=no_ack,
on_message=on_message,
on_message=on_message, disable_sync_subtasks=disable_sync_subtasks
)

def join(self, timeout=None, propagate=True, interval=0.5,
callback=None, no_ack=True, on_message=None, on_interval=None):
callback=None, no_ack=True, on_message=None,
disable_sync_subtasks=True, on_interval=None):
"""Gather the results of all tasks as a list in order.
Note:
Expand Down Expand Up @@ -669,13 +681,17 @@ def join(self, timeout=None, propagate=True, interval=0.5,
no_ack (bool): Automatic message acknowledgment (Note that if this
is set to :const:`False` then the messages
*will not be acknowledged*).
disable_sync_subtasks (bool): Disable tasks to wait for sub tasks
this is the default configuration. CAUTION do not enable this
unless you must.
Raises:
celery.exceptions.TimeoutError: if ``timeout`` isn't
:const:`None` and the operation takes longer than ``timeout``
seconds.
"""
assert_will_not_block()
if disable_sync_subtasks:
assert_will_not_block()
time_start = monotonic()
remaining = None

Expand Down Expand Up @@ -723,7 +739,8 @@ def iter_native(self, timeout=None, interval=0.5, no_ack=True,

def join_native(self, timeout=None, propagate=True,
interval=0.5, callback=None, no_ack=True,
on_message=None, on_interval=None):
on_message=None, on_interval=None,
disable_sync_subtasks=True):
"""Backend optimized version of :meth:`join`.
.. versionadded:: 2.2
Expand All @@ -734,7 +751,8 @@ def join_native(self, timeout=None, propagate=True,
This is currently only supported by the amqp, Redis and cache
result backends.
"""
assert_will_not_block()
if disable_sync_subtasks:
assert_will_not_block()
order_index = None if callback else {
result.id: i for i, result in enumerate(self.results)
}
Expand Down Expand Up @@ -916,7 +934,11 @@ def __copy__(self):
def ready(self):
return True

def get(self, timeout=None, propagate=True, **kwargs):
def get(self, timeout=None, propagate=True,
disable_sync_subtasks=True, **kwargs):
if disable_sync_subtasks:
assert_will_not_block()

if self.successful():
return self.result
elif self.state in states.PROPAGATE_STATES:
Expand Down
9 changes: 9 additions & 0 deletions t/unit/tasks/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,15 @@ def test_revoke(self):
res = self.raising.apply(args=[3, 3])
assert not res.revoke()

@patch('celery.result.task_join_will_block')
def test_get_sync_subtask_option(self, task_join_will_block):
task_join_will_block.return_value = True
tid = uuid()
res_subtask_async = EagerResult(tid, 'x', 'x', states.SUCCESS)
with pytest.raises(RuntimeError):
res_subtask_async.get()
res_subtask_async.get(disable_sync_subtasks=False)


class test_tuples:

Expand Down

0 comments on commit d02d260

Please sign in to comment.