diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 5f14b84f6..4bb25cbc9 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -5,6 +5,7 @@ from ..autogen.openapi_model import TransitionTarget from ..common.protocol.tasks import ExecutionInput +from ..common.retry_policies import DEFAULT_RETRY_POLICY from ..env import ( temporal_client_cert, temporal_namespace, @@ -54,6 +55,7 @@ async def run_task_execution_workflow( task_queue=temporal_task_queue, id=str(job_id), run_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY, # TODO: Should add search_attributes for queryability ) diff --git a/agents-api/agents_api/common/retry_policies.py b/agents-api/agents_api/common/retry_policies.py new file mode 100644 index 000000000..fc343553c --- /dev/null +++ b/agents-api/agents_api/common/retry_policies.py @@ -0,0 +1,63 @@ +from datetime import timedelta + +from temporalio.common import RetryPolicy + +DEFAULT_RETRY_POLICY = RetryPolicy( + initial_interval=timedelta(seconds=1), + backoff_coefficient=2, + maximum_attempts=25, + maximum_interval=timedelta(seconds=300), + non_retryable_error_types=[ + # Temporal-specific errors + "WorkflowExecutionAlreadyStarted", + "temporalio.exceptions.TerminalFailure", + "temporalio.exceptions.CanceledError", + # + # Built-in Python exceptions + "TypeError", + "AssertionError", + "SyntaxError", + "ValueError", + "ZeroDivisionError", + "IndexError", + "AttributeError", + "LookupError", + "BufferError", + "ArithmeticError", + "KeyError", + "NameError", + "NotImplementedError", + "RecursionError", + "RuntimeError", + "StopIteration", + "StopAsyncIteration", + "IndentationError", + "TabError", + # + # Unicode-related errors + "UnicodeError", + "UnicodeEncodeError", + "UnicodeDecodeError", + "UnicodeTranslateError", + # + # HTTP and API-related errors + "HTTPException", + "fastapi.exceptions.HTTPException", + "fastapi.exceptions.RequestValidationError", + "httpx.RequestError", + "httpx.HTTPStatusError", + # + # Asynchronous programming errors + "asyncio.CancelledError", + "asyncio.InvalidStateError", + "GeneratorExit", + # + # Third-party library exceptions + "jinja2.exceptions.TemplateSyntaxError", + "jinja2.exceptions.TemplateNotFound", + "jsonschema.exceptions.ValidationError", + "pydantic.ValidationError", + "requests.exceptions.InvalidURL", + "requests.exceptions.MissingSchema", + ], +) diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index 0ba22c8d5..16754dbe3 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -8,6 +8,7 @@ from ...activities.types import EmbedDocsPayload from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse from ...clients import temporal +from ...common.retry_policies import DEFAULT_RETRY_POLICY from ...dependencies.developer_id import get_developer_id from ...env import temporal_task_queue, testing from ...models.docs.create_doc import create_doc as create_doc_query @@ -41,6 +42,7 @@ async def run_embed_docs_task( embed_payload, task_queue=temporal_task_queue, id=str(job_id), + retry_policy=DEFAULT_RETRY_POLICY, ) # TODO: Remove this conditional once we have a way to run workflows in diff --git a/agents-api/agents_api/workflows/demo.py b/agents-api/agents_api/workflows/demo.py index 61ad9d4a8..0599a4392 100644 --- a/agents-api/agents_api/workflows/demo.py +++ b/agents-api/agents_api/workflows/demo.py @@ -2,6 +2,8 @@ from temporalio import workflow +from ..common.retry_policies import DEFAULT_RETRY_POLICY + with workflow.unsafe.imports_passed_through(): from ..activities.demo import demo_activity @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int: demo_activity, args=[a, b], start_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY, ) diff --git a/agents-api/agents_api/workflows/embed_docs.py b/agents-api/agents_api/workflows/embed_docs.py index 62e0e65ae..83eefe907 100644 --- a/agents-api/agents_api/workflows/embed_docs.py +++ b/agents-api/agents_api/workflows/embed_docs.py @@ -8,6 +8,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.embed_docs import embed_docs from ..activities.types import EmbedDocsPayload + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None: embed_docs, embed_payload, schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY, ) diff --git a/agents-api/agents_api/workflows/mem_rating.py b/agents-api/agents_api/workflows/mem_rating.py index 4b68a7198..0a87fd787 100644 --- a/agents-api/agents_api/workflows/mem_rating.py +++ b/agents-api/agents_api/workflows/mem_rating.py @@ -7,6 +7,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.mem_rating import mem_rating + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None: mem_rating, memory, schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY, ) diff --git a/agents-api/agents_api/workflows/summarization.py b/agents-api/agents_api/workflows/summarization.py index 7946e9109..96ce4c460 100644 --- a/agents-api/agents_api/workflows/summarization.py +++ b/agents-api/agents_api/workflows/summarization.py @@ -7,6 +7,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.summarization import summarization + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None: summarization, session_id, schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY, ) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 3c6106267..d26a3d999 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -47,6 +47,7 @@ StepContext, StepOutcome, ) + from ...common.retry_policies import DEFAULT_RETRY_POLICY from ...env import debug, testing from .helpers import ( continue_as_child, @@ -58,6 +59,7 @@ ) from .transition import transition + # Supported steps # --------------- @@ -204,6 +206,7 @@ async def run( schedule_to_close_timeout=timedelta( seconds=30 if debug or testing else 600 ), + retry_policy=DEFAULT_RETRY_POLICY, ) workflow.logger.debug( f"Step {context.cursor.step} completed successfully" @@ -389,6 +392,7 @@ async def run( task_steps.raise_complete_async, args=[context, output], schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY, ) state = PartialTransition(type="resume", output=result) @@ -421,6 +425,7 @@ async def run( task_steps.raise_complete_async, args=[context, tool_calls_input], schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY, ) # Feed the tool call results back to the model @@ -432,6 +437,7 @@ async def run( schedule_to_close_timeout=timedelta( seconds=30 if debug or testing else 600 ), + retry_policy=DEFAULT_RETRY_POLICY, ) state = PartialTransition(output=new_response.output, type="resume") @@ -475,6 +481,7 @@ async def run( task_steps.raise_complete_async, args=[context, tool_call], schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY, ) state = PartialTransition(output=tool_call_response, type="resume") @@ -505,6 +512,7 @@ async def run( schedule_to_close_timeout=timedelta( seconds=30 if debug or testing else 600 ), + retry_policy=DEFAULT_RETRY_POLICY, ) state = PartialTransition(output=tool_call_response) diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 88828b31b..04449db58 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -5,6 +5,8 @@ from temporalio import workflow from temporalio.exceptions import ApplicationError +from ...common.retry_policies import DEFAULT_RETRY_POLICY + with workflow.unsafe.imports_passed_through(): from ...activities import task_steps from ...autogen.openapi_model import ( @@ -33,6 +35,7 @@ async def continue_as_child( previous_inputs, user_state, ], + retry_policy=DEFAULT_RETRY_POLICY, ) @@ -169,6 +172,7 @@ async def execute_map_reduce_step( task_steps.base_evaluate, args=[reduce, {"results": result, "_": output}], schedule_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY, ) return result @@ -244,6 +248,7 @@ async def execute_map_reduce_step_parallel( extra_lambda_strs, ], schedule_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY, ) except BaseException as e: diff --git a/agents-api/agents_api/workflows/task_execution/transition.py b/agents-api/agents_api/workflows/task_execution/transition.py index dbcd776e4..035322dad 100644 --- a/agents-api/agents_api/workflows/task_execution/transition.py +++ b/agents-api/agents_api/workflows/task_execution/transition.py @@ -10,6 +10,7 @@ TransitionTarget, ) from ...common.protocol.tasks import PartialTransition, StepContext +from ...common.retry_policies import DEFAULT_RETRY_POLICY async def transition( @@ -44,6 +45,7 @@ async def transition( task_steps.transition_step, args=[context, transition_request], schedule_to_close_timeout=timedelta(seconds=30), + retry_policy=DEFAULT_RETRY_POLICY, ) except Exception as e: diff --git a/agents-api/agents_api/workflows/truncation.py b/agents-api/agents_api/workflows/truncation.py index d3646ccbe..d12a186b9 100644 --- a/agents-api/agents_api/workflows/truncation.py +++ b/agents-api/agents_api/workflows/truncation.py @@ -7,6 +7,7 @@ with workflow.unsafe.imports_passed_through(): from ..activities.truncation import truncation + from ..common.retry_policies import DEFAULT_RETRY_POLICY @workflow.defn @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None: truncation, args=[session_id, token_count_threshold], schedule_to_close_timeout=timedelta(seconds=600), + retry_policy=DEFAULT_RETRY_POLICY, ) diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index a2f15d179..6f65cd034 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -7,8 +7,14 @@ from agents_api.clients import temporal from agents_api.env import temporal_task_queue from agents_api.workflows.demo import DemoWorkflow -from tests.fixtures import cozo_client, test_developer_id, test_doc -from tests.utils import patch_testing_temporal +from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY + +from .fixtures import ( + cozo_client, + test_developer_id, + test_doc, +) +from .utils import patch_testing_temporal @test("activity: call direct embed_docs") @@ -44,6 +50,7 @@ async def _(): args=[1, 2], id=str(uuid4()), task_queue=temporal_task_queue, + retry_policy=DEFAULT_RETRY_POLICY, ) assert result == 3