diff --git a/agents-api/agents_api/common/exceptions/tasks.py b/agents-api/agents_api/common/exceptions/tasks.py new file mode 100644 index 000000000..336cd3f90 --- /dev/null +++ b/agents-api/agents_api/common/exceptions/tasks.py @@ -0,0 +1,84 @@ +""" +This module defines non-retryable error types and provides a function to check +if a given error is non-retryable. These are used in conjunction with custom +Temporal interceptors to prevent unnecessary retries of certain error types. +""" + +import temporalio.exceptions +import fastapi +import httpx +import asyncio +import jinja2 +import jsonschema.exceptions +import pydantic +import requests + +# List of error types that should not be retried +NON_RETRYABLE_ERROR_TYPES = [ + # Temporal-specific errors + temporalio.exceptions.WorkflowAlreadyStartedError, + temporalio.exceptions.TerminatedError, + temporalio.exceptions.CancelledError, + # + # Built-in Python exceptions + TypeError, + AssertionError, + SyntaxError, + ValueError, + ZeroDivisionError, + IndexError, + AttributeError, + LookupError, + BufferError, + ArithmeticError, + KeyError, + NameError, + NotImplementedError, + RecursionError, + RuntimeError, + StopIteration, + StopAsyncIteration, + IndentationError, + TabError, + # + # Unicode-related errors + UnicodeError, + UnicodeEncodeError, + UnicodeDecodeError, + UnicodeTranslateError, + # + # HTTP and API-related errors + fastapi.exceptions.HTTPException, + fastapi.exceptions.RequestValidationError, + httpx.RequestError, + httpx.HTTPStatusError, + # + # Asynchronous programming errors + asyncio.CancelledError, + asyncio.InvalidStateError, + GeneratorExit, + # + # Third-party library exceptions + jinja2.exceptions.TemplateSyntaxError, + jinja2.exceptions.TemplateNotFound, + jsonschema.exceptions.ValidationError, + pydantic.ValidationError, + requests.exceptions.InvalidURL, + requests.exceptions.MissingSchema, +] + + +def is_non_retryable_error(error: Exception) -> bool: + """ + Determines if the given error is non-retryable. + + This function checks if the error is an instance of any of the error types + defined in NON_RETRYABLE_ERROR_TYPES. + + Args: + error (Exception): The error to check. + + Returns: + bool: True if the error is non-retryable, False otherwise. + """ + return isinstance(error, tuple(NON_RETRYABLE_ERROR_TYPES)) diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py new file mode 100644 index 000000000..35182f3a1 --- /dev/null +++ b/agents-api/agents_api/common/interceptors.py @@ -0,0 +1,91 @@ +""" +This module defines custom interceptors for Temporal activities and workflows. +The main purpose of these interceptors is to handle errors and prevent retrying +certain types of errors that are known to be non-retryable. +""" + +from typing import Optional, Type + +from temporalio.exceptions import ApplicationError +from temporalio.worker import ( + ActivityInboundInterceptor, + WorkflowInboundInterceptor, + ExecuteActivityInput, + ExecuteWorkflowInput, + Interceptor, + WorkflowInterceptorClassInput +) +from .exceptions.tasks import is_non_retryable_error + + +class CustomActivityInterceptor(ActivityInboundInterceptor): + """ + Custom interceptor for Temporal activities. + + This interceptor catches exceptions during activity execution and + raises them as non-retryable ApplicationErrors if they are identified + as non-retryable errors. + """ + + async def execute_activity(self, input: ExecuteActivityInput): + try: + return await super().execute_activity(input) + except Exception as e: + if is_non_retryable_error(e): + raise ApplicationError( + str(e), + type=type(e).__name__, + non_retryable=True, + ) + raise + + +class CustomWorkflowInterceptor(WorkflowInboundInterceptor): + """ + Custom interceptor for Temporal workflows. + + This interceptor catches exceptions during workflow execution and + raises them as non-retryable ApplicationErrors if they are identified + as non-retryable errors. + """ + + async def execute_workflow(self, input: ExecuteWorkflowInput): + try: + return await super().execute_workflow(input) + except Exception as e: + if is_non_retryable_error(e): + raise ApplicationError( + str(e), + type=type(e).__name__, + non_retryable=True, + ) + raise + + +class CustomInterceptor(Interceptor): + """ + Custom Interceptor that combines both activity and workflow interceptors. + + This class is responsible for creating and returning the custom + interceptors for both activities and workflows. + """ + + def intercept_activity( + self, next: ActivityInboundInterceptor + ) -> ActivityInboundInterceptor: + """ + Creates and returns a CustomActivityInterceptor. + + This method is called by Temporal to intercept activity executions. + """ + return CustomActivityInterceptor(super().intercept_activity(next)) + + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> Optional[Type[WorkflowInboundInterceptor]]: + """ + Returns the CustomWorkflowInterceptor class. + + This method is called by Temporal to get the workflow interceptor class. + """ + return CustomWorkflowInterceptor diff --git a/agents-api/agents_api/common/retry_policies.py b/agents-api/agents_api/common/retry_policies.py index fc343553c..cb629fe7c 100644 --- a/agents-api/agents_api/common/retry_policies.py +++ b/agents-api/agents_api/common/retry_policies.py @@ -1,63 +1,14 @@ -from datetime import timedelta +# from datetime import timedelta -from temporalio.common import RetryPolicy +# from temporalio.common import RetryPolicy -DEFAULT_RETRY_POLICY = RetryPolicy( - initial_interval=timedelta(seconds=1), - backoff_coefficient=2, - maximum_attempts=25, - maximum_interval=timedelta(seconds=300), - non_retryable_error_types=[ - # Temporal-specific errors - "WorkflowExecutionAlreadyStarted", - "temporalio.exceptions.TerminalFailure", - "temporalio.exceptions.CanceledError", - # - # Built-in Python exceptions - "TypeError", - "AssertionError", - "SyntaxError", - "ValueError", - "ZeroDivisionError", - "IndexError", - "AttributeError", - "LookupError", - "BufferError", - "ArithmeticError", - "KeyError", - "NameError", - "NotImplementedError", - "RecursionError", - "RuntimeError", - "StopIteration", - "StopAsyncIteration", - "IndentationError", - "TabError", - # - # Unicode-related errors - "UnicodeError", - "UnicodeEncodeError", - "UnicodeDecodeError", - "UnicodeTranslateError", - # - # HTTP and API-related errors - "HTTPException", - "fastapi.exceptions.HTTPException", - "fastapi.exceptions.RequestValidationError", - "httpx.RequestError", - "httpx.HTTPStatusError", - # - # Asynchronous programming errors - "asyncio.CancelledError", - "asyncio.InvalidStateError", - "GeneratorExit", - # - # Third-party library exceptions - "jinja2.exceptions.TemplateSyntaxError", - "jinja2.exceptions.TemplateNotFound", - "jsonschema.exceptions.ValidationError", - "pydantic.ValidationError", - "requests.exceptions.InvalidURL", - "requests.exceptions.MissingSchema", - ], -) +# DEFAULT_RETRY_POLICY = RetryPolicy( +# initial_interval=timedelta(seconds=1), +# backoff_coefficient=2, +# maximum_attempts=25, +# maximum_interval=timedelta(seconds=300), +# ) + +# FIXME: Adding both interceptors and retry policy (even with `non_retryable_errors` not set) +# is causing the errors to be retried. We need to find a workaround for this. +DEFAULT_RETRY_POLICY = None \ No newline at end of file diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index dc02cb4a7..4d4eef08b 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -32,6 +32,7 @@ def create_worker(client: Client) -> Any: from ..workflows.summarization import SummarizationWorkflow from ..workflows.task_execution import TaskExecutionWorkflow from ..workflows.truncation import TruncationWorkflow + from ..common.interceptors import CustomInterceptor task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction)) @@ -61,6 +62,7 @@ def create_worker(client: Client) -> Any: summarization, truncation, ], + interceptors=[CustomInterceptor()] ) return worker