Skip to content

Commit

Permalink
Resolve TypeError on .get from nested groups (celery#4432)
Browse files Browse the repository at this point in the history
* Accept and pass along the `on_interval` in ResultSet.get

Otherwise, calls to .get or .join on ResultSets fail on nested groups.
Fixes celery#4274

* Add a unit test that verifies the fixed behavior

Verified that the unit test fails on master, but passes on the patched version. The
nested structure of results was borrowed from celery#4274

* Wrap long lines

* Add integration test for celery#4274 use case

* Switch to a simpler, group-only-based integration test

* Flatten expected integration test result

* Added back testcase from celery#4274 and skip it if the backend under test does not support native joins.

* Fix lint.

* Enable only if chords are allowed.

* Fix access to message.
  • Loading branch information
myw authored and Omer Katz committed Dec 11, 2017
1 parent ebd98fa commit dd2cdd9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 3 deletions.
5 changes: 3 additions & 2 deletions celery/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def iterate(self, timeout=None, propagate=True, interval=0.5):

def get(self, timeout=None, propagate=True, interval=0.5,
callback=None, no_ack=True, on_message=None,
disable_sync_subtasks=True):
disable_sync_subtasks=True, on_interval=None):
"""See :meth:`join`.
This is here for API compatibility with :class:`AsyncResult`,
Expand All @@ -640,7 +640,8 @@ def get(self, timeout=None, propagate=True, interval=0.5,
return (self.join_native if self.supports_native_join else self.join)(
timeout=timeout, propagate=propagate,
interval=interval, callback=callback, no_ack=no_ack,
on_message=on_message, disable_sync_subtasks=disable_sync_subtasks
on_message=on_message, disable_sync_subtasks=disable_sync_subtasks,
on_interval=on_interval,
)

def join(self, timeout=None, propagate=True, interval=0.5,
Expand Down
44 changes: 44 additions & 0 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ def test_parent_ids(self, manager):
assert parent_id == expected_parent_id
assert value == i + 2

@flaky
def test_nested_group(self, manager):
assert manager.inspect().ping()

c = group(
add.si(1, 10),
group(
add.si(1, 100),
group(
add.si(1, 1000),
add.si(1, 2000),
),
),
)
res = c()

assert res.get(timeout=TIMEOUT) == [11, 101, 1001, 2001]


def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
root_id, parent_id, value = r.get(timeout=TIMEOUT)
Expand All @@ -164,6 +182,32 @@ def test_group_chain(self, manager):
res = c()
assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]

@flaky
def test_nested_group_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

if not manager.app.backend.supports_native_join:
raise pytest.skip('Requires native join support.')
c = chain(
add.si(1, 0),
group(
add.si(1, 100),
chain(
add.si(1, 200),
group(
add.si(1, 1000),
add.si(1, 2000),
),
),
),
add.si(1, 10),
)
res = c()
assert res.get(timeout=TIMEOUT) == 11

@flaky
def test_parent_ids(self, manager):
if not manager.app.conf.result_backend.startswith('redis'):
Expand Down
37 changes: 36 additions & 1 deletion t/unit/tasks/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,12 +519,16 @@ def get(self, propagate=True, **kwargs):
class MockAsyncResultSuccess(AsyncResult):
forgotten = False

def __init__(self, *args, **kwargs):
self._result = kwargs.pop('result', 42)
super(MockAsyncResultSuccess, self).__init__(*args, **kwargs)

def forget(self):
self.forgotten = True

@property
def result(self):
return 42
return self._result

@property
def state(self):
Expand Down Expand Up @@ -622,6 +626,37 @@ def test_forget(self):
for sub in subs:
assert sub.forgotten

def test_get_nested_without_native_join(self):
backend = SimpleBackend()
backend.supports_native_join = False
ts = self.app.GroupResult(uuid(), [
MockAsyncResultSuccess(uuid(), result='1.1',
app=self.app, backend=backend),
self.app.GroupResult(uuid(), [
MockAsyncResultSuccess(uuid(), result='2.1',
app=self.app, backend=backend),
self.app.GroupResult(uuid(), [
MockAsyncResultSuccess(uuid(), result='3.1',
app=self.app, backend=backend),
MockAsyncResultSuccess(uuid(), result='3.2',
app=self.app, backend=backend),
]),
]),
])
ts.app.backend = backend

vals = ts.get()
assert vals == [
'1.1',
[
'2.1',
[
'3.1',
'3.2',
]
],
]

def test_getitem(self):
subs = [MockAsyncResultSuccess(uuid(), app=self.app),
MockAsyncResultSuccess(uuid(), app=self.app)]
Expand Down

0 comments on commit dd2cdd9

Please sign in to comment.