From 9c7ceabc2a8861eaaf178a86ab2aa87f9accd331 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Fri, 10 Jan 2025 15:27:40 -0500 Subject: [PATCH] feat: Add a flag to wait for delivery callback in taskworker (#83167) Sometimes it is helpful to have the taskworker actually wait for the kafka producer to acknowledge that the sent task has been successfully added to Kafka. In particular this is useful for testing scenarios. Add a flag to the task definition that can determine whether the worker should wait for the produce to be acknowledged before continuing. --- src/sentry/taskworker/registry.py | 16 +++++++++++--- src/sentry/taskworker/task.py | 6 ++++- src/sentry/taskworker/tasks/examples.py | 5 +++++ tests/sentry/taskworker/test_registry.py | 28 ++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/sentry/taskworker/registry.py b/src/sentry/taskworker/registry.py index 90ab9d4e8b55e4..1d1945c38f1620 100644 --- a/src/sentry/taskworker/registry.py +++ b/src/sentry/taskworker/registry.py @@ -80,6 +80,7 @@ def register( expires: int | datetime.timedelta | None = None, processing_deadline_duration: int | datetime.timedelta | None = None, at_most_once: bool = False, + wait_for_delivery: bool = False, ) -> Callable[[Callable[P, R]], Task[P, R]]: """ Register a task. @@ -102,6 +103,9 @@ def register( Enable at-most-once execution. Tasks with `at_most_once` cannot define retry policies, and use a worker side idempotency key to prevent processing deadline based retries. + wait_for_delivery: bool + If true, the task will wait for the delivery report to be received + before returning. """ def wrapped(func: Callable[P, R]) -> Task[P, R]: @@ -118,6 +122,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]: processing_deadline_duration or self.default_processing_deadline_duration ), at_most_once=at_most_once, + wait_for_delivery=wait_for_delivery, ) # TODO(taskworker) tasks should be registered into the registry # so that we can ensure task names are globally unique @@ -126,13 +131,18 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]: return wrapped - def send_task(self, activation: TaskActivation) -> None: + def send_task(self, activation: TaskActivation, wait_for_delivery: bool = False) -> None: metrics.incr("taskworker.registry.send_task", tags={"namespace": activation.namespace}) - # TODO(taskworker) producer callback handling - self.producer.produce( + + produce_future = self.producer.produce( ArroyoTopic(name=self.topic.value), KafkaPayload(key=None, value=activation.SerializeToString(), headers=[]), ) + if wait_for_delivery: + try: + produce_future.result(timeout=10) + except Exception: + logger.exception("Failed to wait for delivery") class TaskRegistry: diff --git a/src/sentry/taskworker/task.py b/src/sentry/taskworker/task.py index 920bcd727fc148..b72ed123df67a7 100644 --- a/src/sentry/taskworker/task.py +++ b/src/sentry/taskworker/task.py @@ -34,6 +34,7 @@ def __init__( expires: int | datetime.timedelta | None = None, processing_deadline_duration: int | datetime.timedelta | None = None, at_most_once: bool = False, + wait_for_delivery: bool = False, ): self.name = name self._func = func @@ -52,6 +53,7 @@ def __init__( ) self._retry = retry self.at_most_once = at_most_once + self.wait_for_delivery = wait_for_delivery update_wrapper(self, func) @property @@ -84,7 +86,9 @@ def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> None: self._func(*args, **kwargs) else: # TODO(taskworker) promote parameters to headers - self._namespace.send_task(self.create_activation(*args, **kwargs)) + self._namespace.send_task( + self.create_activation(*args, **kwargs), wait_for_delivery=self.wait_for_delivery + ) def create_activation(self, *args: P.args, **kwargs: P.kwargs) -> TaskActivation: received_at = Timestamp() diff --git a/src/sentry/taskworker/tasks/examples.py b/src/sentry/taskworker/tasks/examples.py index ad623fbe1b706a..cfba8a32936e3d 100644 --- a/src/sentry/taskworker/tasks/examples.py +++ b/src/sentry/taskworker/tasks/examples.py @@ -41,6 +41,11 @@ def simple_task() -> None: logger.info("simple_task complete") +@exampletasks.register(name="examples.simple_task_wait_delivery", wait_for_delivery=True) +def simple_task_wait_delivery() -> None: + logger.info("simple_task_wait_delivery complete") + + @exampletasks.register(name="examples.retry_task", retry=Retry(times=2)) def retry_task() -> None: raise RetryError diff --git a/tests/sentry/taskworker/test_registry.py b/tests/sentry/taskworker/test_registry.py index d100441492b87e..c9bfa9119d294f 100644 --- a/tests/sentry/taskworker/test_registry.py +++ b/tests/sentry/taskworker/test_registry.py @@ -1,3 +1,4 @@ +from concurrent.futures import Future from unittest.mock import patch import pytest @@ -178,6 +179,33 @@ def simple_task() -> None: assert proto_message == activation.SerializeToString() +def test_namespace_with_wait_for_delivery_send_task() -> None: + namespace = TaskNamespace( + name="tests", + topic=Topic.TASK_WORKER, + retry=Retry(times=3), + ) + + @namespace.register(name="test.simpletask", wait_for_delivery=True) + def simple_task() -> None: + raise NotImplementedError + + activation = simple_task.create_activation() + + with patch.object(namespace, "producer") as mock_producer: + ret_value: Future[None] = Future() + ret_value.set_result(None) + mock_producer.produce.return_value = ret_value + namespace.send_task(activation, wait_for_delivery=True) + assert mock_producer.produce.call_count == 1 + + mock_call = mock_producer.produce.call_args + assert mock_call[0][0].name == "task-worker" + + proto_message = mock_call[0][1].value + assert proto_message == activation.SerializeToString() + + def test_registry_get() -> None: registry = TaskRegistry() ns = registry.create_namespace(name="tests")