Skip to content

Commit

Permalink
feat: Add a flag to wait for delivery callback in taskworker (#83167)
Browse files Browse the repository at this point in the history
Sometimes it is helpful to have the taskworker actually wait for the
kafka producer to acknowledge
that the sent task has been successfully added to Kafka. In particular
this is useful for testing
scenarios. Add a flag to the task definition that can determine whether
the worker should wait for
the produce to be acknowledged before continuing.
  • Loading branch information
evanh authored and andrewshie-sentry committed Jan 22, 2025
1 parent 02b73e1 commit 9c7ceab
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 4 deletions.
16 changes: 13 additions & 3 deletions src/sentry/taskworker/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def register(
expires: int | datetime.timedelta | None = None,
processing_deadline_duration: int | datetime.timedelta | None = None,
at_most_once: bool = False,
wait_for_delivery: bool = False,
) -> Callable[[Callable[P, R]], Task[P, R]]:
"""
Register a task.
Expand All @@ -102,6 +103,9 @@ def register(
Enable at-most-once execution. Tasks with `at_most_once` cannot
define retry policies, and use a worker side idempotency key to
prevent processing deadline based retries.
wait_for_delivery: bool
If true, the task will wait for the delivery report to be received
before returning.
"""

def wrapped(func: Callable[P, R]) -> Task[P, R]:
Expand All @@ -118,6 +122,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]:
processing_deadline_duration or self.default_processing_deadline_duration
),
at_most_once=at_most_once,
wait_for_delivery=wait_for_delivery,
)
# TODO(taskworker) tasks should be registered into the registry
# so that we can ensure task names are globally unique
Expand All @@ -126,13 +131,18 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]:

return wrapped

def send_task(self, activation: TaskActivation) -> None:
def send_task(self, activation: TaskActivation, wait_for_delivery: bool = False) -> None:
metrics.incr("taskworker.registry.send_task", tags={"namespace": activation.namespace})
# TODO(taskworker) producer callback handling
self.producer.produce(

produce_future = self.producer.produce(
ArroyoTopic(name=self.topic.value),
KafkaPayload(key=None, value=activation.SerializeToString(), headers=[]),
)
if wait_for_delivery:
try:
produce_future.result(timeout=10)
except Exception:
logger.exception("Failed to wait for delivery")


class TaskRegistry:
Expand Down
6 changes: 5 additions & 1 deletion src/sentry/taskworker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
expires: int | datetime.timedelta | None = None,
processing_deadline_duration: int | datetime.timedelta | None = None,
at_most_once: bool = False,
wait_for_delivery: bool = False,
):
self.name = name
self._func = func
Expand All @@ -52,6 +53,7 @@ def __init__(
)
self._retry = retry
self.at_most_once = at_most_once
self.wait_for_delivery = wait_for_delivery
update_wrapper(self, func)

@property
Expand Down Expand Up @@ -84,7 +86,9 @@ def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> None:
self._func(*args, **kwargs)
else:
# TODO(taskworker) promote parameters to headers
self._namespace.send_task(self.create_activation(*args, **kwargs))
self._namespace.send_task(
self.create_activation(*args, **kwargs), wait_for_delivery=self.wait_for_delivery
)

def create_activation(self, *args: P.args, **kwargs: P.kwargs) -> TaskActivation:
received_at = Timestamp()
Expand Down
5 changes: 5 additions & 0 deletions src/sentry/taskworker/tasks/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def simple_task() -> None:
logger.info("simple_task complete")


@exampletasks.register(name="examples.simple_task_wait_delivery", wait_for_delivery=True)
def simple_task_wait_delivery() -> None:
logger.info("simple_task_wait_delivery complete")


@exampletasks.register(name="examples.retry_task", retry=Retry(times=2))
def retry_task() -> None:
raise RetryError
Expand Down
28 changes: 28 additions & 0 deletions tests/sentry/taskworker/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import Future
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -178,6 +179,33 @@ def simple_task() -> None:
assert proto_message == activation.SerializeToString()


def test_namespace_with_wait_for_delivery_send_task() -> None:
namespace = TaskNamespace(
name="tests",
topic=Topic.TASK_WORKER,
retry=Retry(times=3),
)

@namespace.register(name="test.simpletask", wait_for_delivery=True)
def simple_task() -> None:
raise NotImplementedError

activation = simple_task.create_activation()

with patch.object(namespace, "producer") as mock_producer:
ret_value: Future[None] = Future()
ret_value.set_result(None)
mock_producer.produce.return_value = ret_value
namespace.send_task(activation, wait_for_delivery=True)
assert mock_producer.produce.call_count == 1

mock_call = mock_producer.produce.call_args
assert mock_call[0][0].name == "task-worker"

proto_message = mock_call[0][1].value
assert proto_message == activation.SerializeToString()


def test_registry_get() -> None:
registry = TaskRegistry()
ns = registry.create_namespace(name="tests")
Expand Down

0 comments on commit 9c7ceab

Please sign in to comment.