From 2431db75f002977768c1440f18360404c80ea7c3 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Wed, 8 Jan 2025 15:39:03 +0300 Subject: [PATCH] feat(agents-api): Add transitioning to ``cancelled`` state --- .../workflows/task_execution/__init__.py | 34 ++++++++++++++----- .../workflows/task_execution/transition.py | 8 +++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 2d2bfddba..96ddd0e7f 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -10,7 +10,7 @@ # Import necessary modules and types with workflow.unsafe.imports_passed_through(): from pydantic import RootModel - + from temporalio.exceptions import ActivityError, CancelledError from ...activities import task_steps from ...activities.excecute_api_call import execute_api_call from ...activities.execute_integration import execute_integration @@ -202,10 +202,18 @@ async def run( workflow.logger.debug(f"Step {context.cursor.step} completed successfully") except Exception as e: - workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}") - await transition(context, type="error", output=str(e)) - msg = f"Activity {activity} threw error: {e}" - raise ApplicationError(msg) from e + if isinstance(e, CancelledError) or ( + isinstance(e, ActivityError) + and isinstance(e.__cause__, CancelledError) + ): + workflow.logger.info("Workflow cancelled") + await transition(context, type="cancelled", output=None) + raise + else: + workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}") + await transition(context, type="error", output=str(e)) + msg = f"Activity {activity} threw error: {e}" + raise ApplicationError(msg) from e # --- @@ -647,6 +655,7 @@ async def run( # The returned value is the transition finally created state = state or PartialTransition(type="error", output="Not implemented") + if state.output and isinstance(state.output, StepOutcome) and state.output.error: state = PartialTransition(type="error", output=state.output.error) final_state = await transition( @@ -685,10 +694,17 @@ async def run( ) except Exception as e: - workflow.logger.error(f"Unhandled error: {e!s}") - await transition(context, type="error", output=str(e), last_error=self.last_error) - msg = "Workflow encountered an error" - raise ApplicationError(msg) from e + if isinstance(e, CancelledError) or ( + isinstance(e, ActivityError) + and isinstance(e.__cause__, CancelledError) + ): + workflow.logger.info("Workflow cancelled") + await transition(context, type="cancelled", output=None) + raise + else: + workflow.logger.error(f"Unhandled error: {e!s}") + await transition(context, type="error", output=str(e), last_error=self.last_error) + raise ApplicationError("Workflow encountered an error") from e previous_inputs.append(final_output) diff --git a/agents-api/agents_api/workflows/task_execution/transition.py b/agents-api/agents_api/workflows/task_execution/transition.py index afd4819d4..e7b1e86f5 100644 --- a/agents-api/agents_api/workflows/task_execution/transition.py +++ b/agents-api/agents_api/workflows/task_execution/transition.py @@ -30,8 +30,12 @@ async def transition( if state.type is not None and state.type == "error": error_type = "error" + elif state.type is not None and state.type == "cancelled": + error_type = "cancelled" - if error_type and error_type == "error": + if error_type and error_type == "cancelled": + state.type = "cancelled" + elif error_type and error_type == "error": state.type = "error" else: match context.is_last_step, context.cursor: @@ -71,4 +75,4 @@ async def transition( except Exception as e: workflow.logger.error(f"Error in transition: {e!s}") msg = f"Error in transition: {e}" - raise ApplicationError(msg) from e + raise