From dd2cdd9c4f8688f965d7b5658fa4956d083a7b8b Mon Sep 17 00:00:00 2001 From: Misha Wolfson Date: Mon, 11 Dec 2017 07:08:31 -0500 Subject: [PATCH] Resolve TypeError on `.get` from nested groups (#4432) * Accept and pass along the `on_interval` in ResultSet.get Otherwise, calls to .get or .join on ResultSets fail on nested groups. Fixes #4274 * Add a unit test that verifies the fixed behavior Verified that the unit test fails on master, but passes on the patched version. The nested structure of results was borrowed from #4274 * Wrap long lines * Add integration test for #4274 use case * Switch to a simpler, group-only-based integration test * Flatten expected integration test result * Added back testcase from #4274 and skip it if the backend under test does not support native joins. * Fix lint. * Enable only if chords are allowed. * Fix access to message. --- celery/result.py | 5 ++-- t/integration/test_canvas.py | 44 ++++++++++++++++++++++++++++++++++++ t/unit/tasks/test_result.py | 37 +++++++++++++++++++++++++++++- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/celery/result.py b/celery/result.py index d7b3d288f68..e0443006c29 100644 --- a/celery/result.py +++ b/celery/result.py @@ -628,7 +628,7 @@ def iterate(self, timeout=None, propagate=True, interval=0.5): def get(self, timeout=None, propagate=True, interval=0.5, callback=None, no_ack=True, on_message=None, - disable_sync_subtasks=True): + disable_sync_subtasks=True, on_interval=None): """See :meth:`join`. This is here for API compatibility with :class:`AsyncResult`, @@ -640,7 +640,8 @@ def get(self, timeout=None, propagate=True, interval=0.5, return (self.join_native if self.supports_native_join else self.join)( timeout=timeout, propagate=propagate, interval=interval, callback=callback, no_ack=no_ack, - on_message=on_message, disable_sync_subtasks=disable_sync_subtasks + on_message=on_message, disable_sync_subtasks=disable_sync_subtasks, + on_interval=on_interval, ) def join(self, timeout=None, propagate=True, interval=0.5, diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index a1a641eb0ba..c35075989c5 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -142,6 +142,24 @@ def test_parent_ids(self, manager): assert parent_id == expected_parent_id assert value == i + 2 + @flaky + def test_nested_group(self, manager): + assert manager.inspect().ping() + + c = group( + add.si(1, 10), + group( + add.si(1, 100), + group( + add.si(1, 1000), + add.si(1, 2000), + ), + ), + ) + res = c() + + assert res.get(timeout=TIMEOUT) == [11, 101, 1001, 2001] + def assert_ids(r, expected_value, expected_root_id, expected_parent_id): root_id, parent_id, value = r.get(timeout=TIMEOUT) @@ -164,6 +182,32 @@ def test_group_chain(self, manager): res = c() assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15] + @flaky + def test_nested_group_chain(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + if not manager.app.backend.supports_native_join: + raise pytest.skip('Requires native join support.') + c = chain( + add.si(1, 0), + group( + add.si(1, 100), + chain( + add.si(1, 200), + group( + add.si(1, 1000), + add.si(1, 2000), + ), + ), + ), + add.si(1, 10), + ) + res = c() + assert res.get(timeout=TIMEOUT) == 11 + @flaky def test_parent_ids(self, manager): if not manager.app.conf.result_backend.startswith('redis'): diff --git a/t/unit/tasks/test_result.py b/t/unit/tasks/test_result.py index f21007ec988..98e9aa1cd98 100644 --- a/t/unit/tasks/test_result.py +++ b/t/unit/tasks/test_result.py @@ -519,12 +519,16 @@ def get(self, propagate=True, **kwargs): class MockAsyncResultSuccess(AsyncResult): forgotten = False + def __init__(self, *args, **kwargs): + self._result = kwargs.pop('result', 42) + super(MockAsyncResultSuccess, self).__init__(*args, **kwargs) + def forget(self): self.forgotten = True @property def result(self): - return 42 + return self._result @property def state(self): @@ -622,6 +626,37 @@ def test_forget(self): for sub in subs: assert sub.forgotten + def test_get_nested_without_native_join(self): + backend = SimpleBackend() + backend.supports_native_join = False + ts = self.app.GroupResult(uuid(), [ + MockAsyncResultSuccess(uuid(), result='1.1', + app=self.app, backend=backend), + self.app.GroupResult(uuid(), [ + MockAsyncResultSuccess(uuid(), result='2.1', + app=self.app, backend=backend), + self.app.GroupResult(uuid(), [ + MockAsyncResultSuccess(uuid(), result='3.1', + app=self.app, backend=backend), + MockAsyncResultSuccess(uuid(), result='3.2', + app=self.app, backend=backend), + ]), + ]), + ]) + ts.app.backend = backend + + vals = ts.get() + assert vals == [ + '1.1', + [ + '2.1', + [ + '3.1', + '3.2', + ] + ], + ] + def test_getitem(self): subs = [MockAsyncResultSuccess(uuid(), app=self.app), MockAsyncResultSuccess(uuid(), app=self.app)]