diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index 0ba22c8d5..aa2ba719a 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -1,6 +1,7 @@ from typing import Annotated from uuid import UUID, uuid4 +from ...common.retry_policies import DEFAULT_RETRY_POLICY from fastapi import BackgroundTasks, Depends from starlette.status import HTTP_201_CREATED from temporalio.client import Client as TemporalClient @@ -41,6 +42,7 @@ async def run_embed_docs_task( embed_payload, task_queue=temporal_task_queue, id=str(job_id), + retry_policy=DEFAULT_RETRY_POLICY ) # TODO: Remove this conditional once we have a way to run workflows in diff --git a/agents-api/agents_api/workflows/demo.py b/agents-api/agents_api/workflows/demo.py index 61ad9d4a8..e5725065e 100644 --- a/agents-api/agents_api/workflows/demo.py +++ b/agents-api/agents_api/workflows/demo.py @@ -2,6 +2,8 @@ from temporalio import workflow +from ..common.retry_policies import DEFAULT_RETRY_POLICY + with workflow.unsafe.imports_passed_through(): from ..activities.demo import demo_activity @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int: demo_activity, args=[a, b], start_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY ) diff --git a/agents-api/agents_api/workflows/embed_docs.py b/agents-api/agents_api/workflows/embed_docs.py index 62e0e65ae..e04ebd0f5 100644 --- a/agents-api/agents_api/workflows/embed_docs.py +++ b/agents-api/agents_api/workflows/embed_docs.py @@ -8,6 +8,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.embed_docs import embed_docs from ..activities.types import EmbedDocsPayload + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None: embed_docs, embed_payload, schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY ) diff --git a/agents-api/agents_api/workflows/mem_rating.py b/agents-api/agents_api/workflows/mem_rating.py index 4b68a7198..ffcc4bb93 100644 --- a/agents-api/agents_api/workflows/mem_rating.py +++ b/agents-api/agents_api/workflows/mem_rating.py @@ -7,6 +7,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.mem_rating import mem_rating + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None: mem_rating, memory, schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY ) diff --git a/agents-api/agents_api/workflows/summarization.py b/agents-api/agents_api/workflows/summarization.py index 7946e9109..e117930a9 100644 --- a/agents-api/agents_api/workflows/summarization.py +++ b/agents-api/agents_api/workflows/summarization.py @@ -7,6 +7,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.summarization import summarization + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None: summarization, session_id, schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY ) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 80863e0e0..ea0797b11 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -387,6 +387,7 @@ async def run( task_steps.raise_complete_async, args=[context, output], schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY ) state = PartialTransition(type="resume", output=result) @@ -419,6 +420,7 @@ async def run( task_steps.raise_complete_async, args=[context, tool_calls_input], schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY ) # Feed the tool call results back to the model @@ -430,6 +432,7 @@ async def run( schedule_to_close_timeout=timedelta( seconds=30 if debug or testing else 600 ), + retry_policy=DEFAULT_RETRY_POLICY ) state = PartialTransition(output=new_response.output, type="resume") @@ -473,6 +476,7 @@ async def run( task_steps.raise_complete_async, args=[context, tool_call], schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY ) state = PartialTransition(output=tool_call_response, type="resume") @@ -503,6 +507,7 @@ async def run( schedule_to_close_timeout=timedelta( seconds=30 if debug or testing else 600 ), + retry_policy=DEFAULT_RETRY_POLICY ) state = PartialTransition(output=tool_call_response) diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index cdf499783..fb3e104d1 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -170,6 +170,7 @@ async def execute_map_reduce_step( task_steps.base_evaluate, args=[reduce, {"results": result, "_": output}], schedule_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY ) return result @@ -245,6 +246,7 @@ async def execute_map_reduce_step_parallel( extra_lambda_strs, ], schedule_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY ) except BaseException as e: diff --git a/agents-api/agents_api/workflows/truncation.py b/agents-api/agents_api/workflows/truncation.py index d3646ccbe..2a84b2c5f 100644 --- a/agents-api/agents_api/workflows/truncation.py +++ b/agents-api/agents_api/workflows/truncation.py @@ -7,6 +7,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.truncation import truncation + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None: truncation, args=[session_id, token_count_threshold], schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY ) diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index 98dfc97b5..987ba5b01 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -7,6 +7,7 @@ from agents_api.clients import temporal from agents_api.env import temporal_task_queue from agents_api.workflows.demo import DemoWorkflow +from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY from .fixtures import ( cozo_client, @@ -49,6 +50,7 @@ async def _(): args=[1, 2], id=str(uuid4()), task_queue=temporal_task_queue, + retry_policy=DEFAULT_RETRY_POLICY ) assert result == 3