Skip to content

Commit

Permalink
Added result parsing on return.
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius committed Mar 2, 2025
1 parent 826c31e commit 9235a09
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
8 changes: 8 additions & 0 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
Optional,
TypeVar,
Union,
get_type_hints,
overload,
)
from uuid import uuid4

from pydantic import TypeAdapter
from typing_extensions import ParamSpec, Self, TypeAlias

from taskiq.abc.middleware import TaskiqMiddleware
Expand Down Expand Up @@ -326,12 +328,18 @@ def inner(
inner_task_name = f"{fmodule}:{fname}"
wrapper = wraps(func)

sign = get_type_hints(func)
return_type = None
if "return" in sign:
return_type = TypeAdapter(sign["return"])

decorated_task = wrapper(
self.decorator_class(
broker=self,
original_func=func,
labels=inner_labels,
task_name=inner_task_name,
return_type=return_type,
),
)

Expand Down
1 change: 1 addition & 0 deletions taskiq/brokers/shared_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]:
task_name=self.task_name,
broker=broker,
labels=self.labels,
return_type=self.return_type,
)


Expand Down
5 changes: 5 additions & 0 deletions taskiq/decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
Callable,
Dict,
Generic,
Optional,
TypeVar,
Union,
overload,
)

from pydantic import TypeAdapter
from typing_extensions import ParamSpec

from taskiq.kicker import AsyncKicker
Expand Down Expand Up @@ -50,11 +52,13 @@ def __init__(
task_name: str,
original_func: Callable[_FuncParams, _ReturnType],
labels: Dict[str, Any],
return_type: Optional[TypeAdapter[_ReturnType]] = None,
) -> None:
self.broker = broker
self.task_name = task_name
self.original_func = original_func
self.labels = labels
self.return_type = return_type

# Docs for this method are omitted in order to help
# your IDE resolve correct docs for it.
Expand Down Expand Up @@ -172,6 +176,7 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]:
task_name=self.task_name,
broker=self.broker,
labels=self.labels,
return_type=self.return_type,
)

def __repr__(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
overload,
)

from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from typing_extensions import ParamSpec

from taskiq.abc.middleware import TaskiqMiddleware
Expand Down Expand Up @@ -46,12 +46,14 @@ def __init__(
task_name: str,
broker: "AsyncBroker",
labels: Dict[str, Any],
return_type: Optional[TypeAdapter[_ReturnType]] = None,
) -> None:
self.task_name = task_name
self.broker = broker
self.labels = labels
self.custom_task_id: Optional[str] = None
self.custom_schedule_id: Optional[str] = None
self.return_type = return_type

def with_labels(
self,
Expand Down Expand Up @@ -169,6 +171,7 @@ async def kiq(
return AsyncTaskiqTask(
task_id=message.task_id,
result_backend=self.broker.result_backend,
return_type=self.return_type, # type: ignore # (pyright issue)
)

async def schedule_by_cron(
Expand Down
16 changes: 15 additions & 1 deletion taskiq/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Any, Generic, Optional

from pydantic import TypeAdapter
from typing_extensions import TypeVar

from taskiq.exceptions import (
Expand All @@ -15,6 +17,8 @@
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.result import TaskiqResult

logger = getLogger("taskiq.task")

_ReturnType = TypeVar("_ReturnType")


Expand All @@ -25,9 +29,11 @@ def __init__(
self,
task_id: str,
result_backend: "AsyncResultBackend[_ReturnType]",
return_type: Optional[TypeAdapter[_ReturnType]] = None,
) -> None:
self.task_id = task_id
self.result_backend = result_backend
self.return_type = return_type

async def is_ready(self) -> bool:
"""
Expand All @@ -53,10 +59,18 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType
:return: task's return value.
"""
try:
return await self.result_backend.get_result(
res = await self.result_backend.get_result(
self.task_id,
with_logs=with_logs,
)
if self.return_type is not None:
try:
res.return_value = self.return_type.validate_python(
res.return_value,
)
except ValueError:
logger.warning("Cannot parse return type into %s", self.return_type)
return res
except Exception as exc:
raise ResultGetError from exc

Expand Down

0 comments on commit 9235a09

Please sign in to comment.