diff --git a/agents-api/agents_api/common/exceptions/tasks.py b/agents-api/agents_api/common/exceptions/tasks.py index 336cd3f90..d9e64382b 100644 --- a/agents-api/agents_api/common/exceptions/tasks.py +++ b/agents-api/agents_api/common/exceptions/tasks.py @@ -4,14 +4,15 @@ Temporal interceptors to prevent unnecessary retries of certain error types. """ -import temporalio.exceptions +import asyncio + import fastapi import httpx -import asyncio import jinja2 import jsonschema.exceptions import pydantic import requests +import temporalio.exceptions # List of error types that should not be retried NON_RETRYABLE_ERROR_TYPES = [ diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 35182f3a1..2fb077c45 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -9,19 +9,20 @@ from temporalio.exceptions import ApplicationError from temporalio.worker import ( ActivityInboundInterceptor, - WorkflowInboundInterceptor, ExecuteActivityInput, ExecuteWorkflowInput, Interceptor, - WorkflowInterceptorClassInput + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, ) + from .exceptions.tasks import is_non_retryable_error class CustomActivityInterceptor(ActivityInboundInterceptor): """ Custom interceptor for Temporal activities. - + This interceptor catches exceptions during activity execution and raises them as non-retryable ApplicationErrors if they are identified as non-retryable errors. @@ -43,7 +44,7 @@ async def execute_activity(self, input: ExecuteActivityInput): class CustomWorkflowInterceptor(WorkflowInboundInterceptor): """ Custom interceptor for Temporal workflows. - + This interceptor catches exceptions during workflow execution and raises them as non-retryable ApplicationErrors if they are identified as non-retryable errors. @@ -65,7 +66,7 @@ async def execute_workflow(self, input: ExecuteWorkflowInput): class CustomInterceptor(Interceptor): """ Custom Interceptor that combines both activity and workflow interceptors. - + This class is responsible for creating and returning the custom interceptors for both activities and workflows. """ @@ -75,7 +76,7 @@ def intercept_activity( ) -> ActivityInboundInterceptor: """ Creates and returns a CustomActivityInterceptor. - + This method is called by Temporal to intercept activity executions. """ return CustomActivityInterceptor(super().intercept_activity(next)) @@ -85,7 +86,7 @@ def workflow_interceptor_class( ) -> Optional[Type[WorkflowInboundInterceptor]]: """ Returns the CustomWorkflowInterceptor class. - + This method is called by Temporal to get the workflow interceptor class. """ return CustomWorkflowInterceptor diff --git a/agents-api/agents_api/common/retry_policies.py b/agents-api/agents_api/common/retry_policies.py index cb629fe7c..c6c1c362c 100644 --- a/agents-api/agents_api/common/retry_policies.py +++ b/agents-api/agents_api/common/retry_policies.py @@ -11,4 +11,4 @@ # FIXME: Adding both interceptors and retry policy (even with `non_retryable_errors` not set) # is causing the errors to be retried. We need to find a workaround for this. -DEFAULT_RETRY_POLICY = None \ No newline at end of file +DEFAULT_RETRY_POLICY = None diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 4d4eef08b..54f2bcdd5 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -22,6 +22,7 @@ def create_worker(client: Client) -> Any: from ..activities.mem_rating import mem_rating from ..activities.summarization import summarization from ..activities.truncation import truncation + from ..common.interceptors import CustomInterceptor from ..env import ( temporal_task_queue, ) @@ -32,7 +33,6 @@ def create_worker(client: Client) -> Any: from ..workflows.summarization import SummarizationWorkflow from ..workflows.task_execution import TaskExecutionWorkflow from ..workflows.truncation import TruncationWorkflow - from ..common.interceptors import CustomInterceptor task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction)) @@ -62,7 +62,7 @@ def create_worker(client: Client) -> Any: summarization, truncation, ], - interceptors=[CustomInterceptor()] + interceptors=[CustomInterceptor()], ) return worker