From d2b81f4d0a6ee1a29eaefd6d25164b9215dcdbf9 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Wed, 8 Jan 2025 12:18:08 +0300 Subject: [PATCH] fix(agents-api): Misc fixes for transitioning to error state --- .../workflows/task_execution/__init__.py | 37 +++++++------------ .../workflows/task_execution/helpers.py | 8 ---- .../workflows/task_execution/transition.py | 18 ++++----- 3 files changed, 22 insertions(+), 41 deletions(-) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index d066cdb59..6e25aad49 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -220,7 +220,6 @@ async def run( # Handle errors (activity returns None) case step, StepOutcome(error=error) if error is not None: workflow.logger.error(f"Error in step {context.cursor.step}: {error}") - await transition(context, type="error", output=error) msg = f"Step {type(step).__name__} threw error: {error}" raise ApplicationError(msg) @@ -345,11 +344,6 @@ async def run( workflow.logger.error(f"Error step: {error}") state = PartialTransition(type="error", output=error) - await transition( - context, - state, - last_error=self.last_error, - ) msg = f"Error raised by ErrorWorkflowStep: {error}" raise ApplicationError(msg) @@ -644,11 +638,6 @@ async def run( f"Unhandled step type: {type(context.current_step).__name__}" ) state = PartialTransition(type="error", output="Not implemented") - await transition( - context, - state, - last_error=self.last_error, - ) msg = "Not implemented" raise ApplicationError(msg) @@ -679,7 +668,7 @@ async def run( if not final_state.next: msg = "No next step" raise ApplicationError(msg) - + workflow.logger.info( f"Continuing to next step: {final_state.next.workflow}.{final_state.next.step}" ) @@ -694,19 +683,19 @@ async def run( retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) + + except Exception as e: + workflow.logger.error(f"Unhandled error: {e!s}") + await transition(context, type="error", output=str(e), last_error=self.last_error) + raise ApplicationError("Workflow encountered an error") from e - previous_inputs.append(final_output) + previous_inputs.append(final_output) - # Continue as a child workflow - return await continue_as_child( - context.execution_input, - start=final_state.next, - previous_inputs=previous_inputs, - user_state=state.user_state, + # Continue as a child workflow + return await continue_as_child( + context.execution_input, + start=final_state.next, + previous_inputs=previous_inputs, + user_state=state.user_state, ) - except Exception as e: - workflow.logger.error(f"Unhandled error: {e!s}") - await transition(context, type="error", output=str(e)) - msg = "Workflow encountered an error" - raise ApplicationError(msg) from e diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index b9e8e0a4a..6b86c184a 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -26,14 +26,6 @@ T = TypeVar("T") -async def handle_error(context: StepContext, error: BaseException): - workflow.logger.error(f"Error in workflow: {error!s}") - workflow.logger.error(f"Error in step {context.cursor.step}: {error}") - await transition(context, type="error", output=error) - msg = f"Step {type(context.current_step).__name__} threw error: {error}" - raise ApplicationError(msg) from error - - def validate_execution_input(execution_input: ExecutionInput) -> TaskSpecDef: """Validates and returns the task from execution input. diff --git a/agents-api/agents_api/workflows/task_execution/transition.py b/agents-api/agents_api/workflows/task_execution/transition.py index b1f5f3d43..afd4819d4 100644 --- a/agents-api/agents_api/workflows/task_execution/transition.py +++ b/agents-api/agents_api/workflows/task_execution/transition.py @@ -31,16 +31,16 @@ async def transition( if state.type is not None and state.type == "error": error_type = "error" - match context.is_last_step, context.cursor and not error_type: - case (True, TransitionTarget(workflow="main")): - state.type = "finish" - case (True, _): - state.type = "finish_branch" - case _, _: - state.type = "step" - - if error_type: + if error_type and error_type == "error": state.type = "error" + else: + match context.is_last_step, context.cursor: + case (True, TransitionTarget(workflow="main")): + state.type = "finish" + case (True, _): + state.type = "finish_branch" + case _, _: + state.type = "step" transition_request = CreateTransitionRequest( current=context.cursor,