diff --git a/src/sentry/taskworker/registry.py b/src/sentry/taskworker/registry.py index 90ab9d4e8b55e4..1d1945c38f1620 100644 --- a/src/sentry/taskworker/registry.py +++ b/src/sentry/taskworker/registry.py @@ -80,6 +80,7 @@ def register( expires: int | datetime.timedelta | None = None, processing_deadline_duration: int | datetime.timedelta | None = None, at_most_once: bool = False, + wait_for_delivery: bool = False, ) -> Callable[[Callable[P, R]], Task[P, R]]: """ Register a task. @@ -102,6 +103,9 @@ def register( Enable at-most-once execution. Tasks with `at_most_once` cannot define retry policies, and use a worker side idempotency key to prevent processing deadline based retries. + wait_for_delivery: bool + If true, the task will wait for the delivery report to be received + before returning. """ def wrapped(func: Callable[P, R]) -> Task[P, R]: @@ -118,6 +122,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]: processing_deadline_duration or self.default_processing_deadline_duration ), at_most_once=at_most_once, + wait_for_delivery=wait_for_delivery, ) # TODO(taskworker) tasks should be registered into the registry # so that we can ensure task names are globally unique @@ -126,13 +131,18 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]: return wrapped - def send_task(self, activation: TaskActivation) -> None: + def send_task(self, activation: TaskActivation, wait_for_delivery: bool = False) -> None: metrics.incr("taskworker.registry.send_task", tags={"namespace": activation.namespace}) - # TODO(taskworker) producer callback handling - self.producer.produce( + + produce_future = self.producer.produce( ArroyoTopic(name=self.topic.value), KafkaPayload(key=None, value=activation.SerializeToString(), headers=[]), ) + if wait_for_delivery: + try: + produce_future.result(timeout=10) + except Exception: + logger.exception("Failed to wait for delivery") class TaskRegistry: diff --git a/src/sentry/taskworker/task.py b/src/sentry/taskworker/task.py index 920bcd727fc148..b72ed123df67a7 100644 --- a/src/sentry/taskworker/task.py +++ b/src/sentry/taskworker/task.py @@ -34,6 +34,7 @@ def __init__( expires: int | datetime.timedelta | None = None, processing_deadline_duration: int | datetime.timedelta | None = None, at_most_once: bool = False, + wait_for_delivery: bool = False, ): self.name = name self._func = func @@ -52,6 +53,7 @@ def __init__( ) self._retry = retry self.at_most_once = at_most_once + self.wait_for_delivery = wait_for_delivery update_wrapper(self, func) @property @@ -84,7 +86,9 @@ def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> None: self._func(*args, **kwargs) else: # TODO(taskworker) promote parameters to headers - self._namespace.send_task(self.create_activation(*args, **kwargs)) + self._namespace.send_task( + self.create_activation(*args, **kwargs), wait_for_delivery=self.wait_for_delivery + ) def create_activation(self, *args: P.args, **kwargs: P.kwargs) -> TaskActivation: received_at = Timestamp() diff --git a/src/sentry/taskworker/tasks/examples.py b/src/sentry/taskworker/tasks/examples.py index ad623fbe1b706a..cfba8a32936e3d 100644 --- a/src/sentry/taskworker/tasks/examples.py +++ b/src/sentry/taskworker/tasks/examples.py @@ -41,6 +41,11 @@ def simple_task() -> None: logger.info("simple_task complete") +@exampletasks.register(name="examples.simple_task_wait_delivery", wait_for_delivery=True) +def simple_task_wait_delivery() -> None: + logger.info("simple_task_wait_delivery complete") + + @exampletasks.register(name="examples.retry_task", retry=Retry(times=2)) def retry_task() -> None: raise RetryError diff --git a/tests/sentry/taskworker/test_registry.py b/tests/sentry/taskworker/test_registry.py index d100441492b87e..c9bfa9119d294f 100644 --- a/tests/sentry/taskworker/test_registry.py +++ b/tests/sentry/taskworker/test_registry.py @@ -1,3 +1,4 @@ +from concurrent.futures import Future from unittest.mock import patch import pytest @@ -178,6 +179,33 @@ def simple_task() -> None: assert proto_message == activation.SerializeToString() +def test_namespace_with_wait_for_delivery_send_task() -> None: + namespace = TaskNamespace( + name="tests", + topic=Topic.TASK_WORKER, + retry=Retry(times=3), + ) + + @namespace.register(name="test.simpletask", wait_for_delivery=True) + def simple_task() -> None: + raise NotImplementedError + + activation = simple_task.create_activation() + + with patch.object(namespace, "producer") as mock_producer: + ret_value: Future[None] = Future() + ret_value.set_result(None) + mock_producer.produce.return_value = ret_value + namespace.send_task(activation, wait_for_delivery=True) + assert mock_producer.produce.call_count == 1 + + mock_call = mock_producer.produce.call_args + assert mock_call[0][0].name == "task-worker" + + proto_message = mock_call[0][1].value + assert proto_message == activation.SerializeToString() + + def test_registry_get() -> None: registry = TaskRegistry() ns = registry.create_namespace(name="tests")