Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a flag to wait for delivery callback in taskworker #83167

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/sentry/taskworker/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def register(
expires: int | datetime.timedelta | None = None,
processing_deadline_duration: int | datetime.timedelta | None = None,
at_most_once: bool = False,
wait_for_delivery: bool = False,
) -> Callable[[Callable[P, R]], Task[P, R]]:
"""
Register a task.
Expand All @@ -102,6 +103,9 @@ def register(
Enable at-most-once execution. Tasks with `at_most_once` cannot
define retry policies, and use a worker side idempotency key to
prevent processing deadline based retries.
wait_for_delivery: bool
If true, the task will wait for the delivery report to be received
before returning.
"""

def wrapped(func: Callable[P, R]) -> Task[P, R]:
Expand All @@ -118,6 +122,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]:
processing_deadline_duration or self.default_processing_deadline_duration
),
at_most_once=at_most_once,
wait_for_delivery=wait_for_delivery,
)
# TODO(taskworker) tasks should be registered into the registry
# so that we can ensure task names are globally unique
Expand All @@ -126,13 +131,18 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]:

return wrapped

def send_task(self, activation: TaskActivation) -> None:
def send_task(self, activation: TaskActivation, wait_for_delivery: bool = False) -> None:
metrics.incr("taskworker.registry.send_task", tags={"namespace": activation.namespace})
# TODO(taskworker) producer callback handling
self.producer.produce(

produce_future = self.producer.produce(
ArroyoTopic(name=self.topic.value),
KafkaPayload(key=None, value=activation.SerializeToString(), headers=[]),
)
if wait_for_delivery:
try:
produce_future.result(timeout=10)
except Exception:
logger.exception("Failed to wait for delivery")


class TaskRegistry:
Expand Down
6 changes: 5 additions & 1 deletion src/sentry/taskworker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
expires: int | datetime.timedelta | None = None,
processing_deadline_duration: int | datetime.timedelta | None = None,
at_most_once: bool = False,
wait_for_delivery: bool = False,
):
self.name = name
self._func = func
Expand All @@ -52,6 +53,7 @@ def __init__(
)
self._retry = retry
self.at_most_once = at_most_once
self.wait_for_delivery = wait_for_delivery
update_wrapper(self, func)

@property
Expand Down Expand Up @@ -84,7 +86,9 @@ def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> None:
self._func(*args, **kwargs)
else:
# TODO(taskworker) promote parameters to headers
self._namespace.send_task(self.create_activation(*args, **kwargs))
self._namespace.send_task(
self.create_activation(*args, **kwargs), wait_for_delivery=self.wait_for_delivery
)

def create_activation(self, *args: P.args, **kwargs: P.kwargs) -> TaskActivation:
received_at = Timestamp()
Expand Down
5 changes: 5 additions & 0 deletions src/sentry/taskworker/tasks/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def simple_task() -> None:
logger.info("simple_task complete")


@exampletasks.register(name="examples.simple_task_wait_delivery", wait_for_delivery=True)
def simple_task_wait_delivery() -> None:
logger.info("simple_task_wait_delivery complete")


@exampletasks.register(name="examples.retry_task", retry=Retry(times=2))
def retry_task() -> None:
raise RetryError
Expand Down
28 changes: 28 additions & 0 deletions tests/sentry/taskworker/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import Future
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -178,6 +179,33 @@ def simple_task() -> None:
assert proto_message == activation.SerializeToString()


def test_namespace_with_wait_for_delivery_send_task() -> None:
namespace = TaskNamespace(
name="tests",
topic=Topic.TASK_WORKER,
retry=Retry(times=3),
)

@namespace.register(name="test.simpletask", wait_for_delivery=True)
def simple_task() -> None:
raise NotImplementedError

activation = simple_task.create_activation()

with patch.object(namespace, "producer") as mock_producer:
evanh marked this conversation as resolved.
Show resolved Hide resolved
ret_value: Future[None] = Future()
ret_value.set_result(None)
mock_producer.produce.return_value = ret_value
namespace.send_task(activation, wait_for_delivery=True)
assert mock_producer.produce.call_count == 1

mock_call = mock_producer.produce.call_args
assert mock_call[0][0].name == "task-worker"

proto_message = mock_call[0][1].value
assert proto_message == activation.SerializeToString()


def test_registry_get() -> None:
registry = TaskRegistry()
ns = registry.create_namespace(name="tests")
Expand Down
Loading