Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: AbortPipeline error propagated to the last step task result of pipeline #20

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions taskiq_pipelines/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Any, List
from typing import Any, List, Optional

import pydantic
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
Expand Down Expand Up @@ -108,23 +108,28 @@ async def on_error(
return
if current_step_num == len(steps) - 1:
return
await self.fail_pipeline(steps[-1].task_id)
await self.fail_pipeline(steps[-1].task_id, result.error)

async def fail_pipeline(self, last_task_id: str) -> None:
async def fail_pipeline(
self,
last_task_id: str,
abort: Optional[BaseException] = None,
) -> None:
"""
This function aborts pipeline.

This is done by setting error result for
the last task in the pipeline.

:param last_task_id: id of the last task.
:param abort: caught earlier exception or default
"""
await self.broker.result_backend.set_result(
last_task_id,
TaskiqResult(
is_err=True,
return_value=None, # type: ignore
error=AbortPipeline("Execution aborted."),
error=abort or AbortPipeline("Execution aborted."),
execution_time=0,
log="Error found while executing pipeline.",
),
Expand Down
31 changes: 30 additions & 1 deletion tests/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from taskiq import InMemoryBroker

from taskiq_pipelines import Pipeline, PipelineMiddleware
from taskiq_pipelines import AbortPipeline, Pipeline, PipelineMiddleware


@pytest.mark.anyio
Expand Down Expand Up @@ -42,3 +42,32 @@ def double(i: int) -> int:
sent = await pipe.kiq(4)
res = await sent.wait_result()
assert res.return_value == list(map(double, ranger(4)))


@pytest.mark.anyio
async def test_abort_pipeline() -> None:
"""Test AbortPipeline."""
broker = InMemoryBroker().with_middlewares(PipelineMiddleware())
text = "task was aborted"

@broker.task
def normal_task(i: bool) -> bool:
return i

@broker.task
def aborting_task(i: int) -> bool:
if i:
raise AbortPipeline(text)
return True

pipe = Pipeline(broker, aborting_task).call_next(normal_task)
sent = await pipe.kiq(0)
res = await sent.wait_result()
assert res.is_err is False
assert res.return_value is True
assert res.error is None
sent = await pipe.kiq(1)
res = await sent.wait_result()
assert res.is_err is True
assert res.return_value is None
assert res.error.args[0] == text
Loading