diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index 95d89d948..a93e5bb66 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -1,22 +1,23 @@ -from agents_api.autogen.Tasks import ToolCallStep +import base64 +import secrets + from beartype import beartype from temporalio import activity +from agents_api.activities.task_steps import base_evaluate +from agents_api.autogen.Tasks import ToolCallStep + from ...common.protocol.tasks import ( StepContext, StepOutcome, ) -import secrets -import base64 - def generate_call_id(): # Generate 18 random bytes (which will result in 24 base64 characters) random_bytes = secrets.token_bytes(18) # Encode to base64 and remove padding - base64_string = base64.urlsafe_b64encode( - random_bytes).decode('ascii').rstrip('=') + base64_string = base64.urlsafe_b64encode(random_bytes).decode("ascii").rstrip("=") # Add the "call_" prefix return f"call_{base64_string}" @@ -27,27 +28,23 @@ async def tool_call_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ToolCallStep) tool_type, tool_name = context.current_step.tool.split(".") - arguments = context.current_step.arguments + arguments = await base_evaluate( + context.current_step.arguments, context.model_dump() + ) tools = context.execution_input.tools - assert tool_name in [ - tool.name for tool in tools], f"Tool {tool_name} not found" + assert tool_name in [tool.name for tool in tools], f"Tool {tool_name} not found" call_id = generate_call_id() tool_call = { tool_type: { - 'arguments': arguments, - 'name': tool_name, + "arguments": arguments, + "name": tool_name, }, - 'id': call_id, - 'type': tool_type + "id": call_id, + "type": tool_type, } - print("tool_call") - print(tool_call) - - return StepOutcome( - output=tool_call - ) + return StepOutcome(output=tool_call) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 17aa1035a..d1f13585f 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -212,7 +212,7 @@ async def run( # 3. Then, based on the outcome and step type, decide what to do next workflow.logger.info(f"Processing outcome for step {context.cursor.step}") - + match context.current_step, outcome: # Handle errors (activity returns None) case step, StepOutcome(error=error) if error is not None: @@ -389,9 +389,7 @@ async def run( state = PartialTransition(type="resume", output=result) - case PromptStep(), StepOutcome( - output=response - ): + case PromptStep(), StepOutcome(output=response): workflow.logger.debug(f"Prompt step: Received response: {response}") if response["choices"][0]["finish_reason"] != "tool_calls": workflow.logger.debug("Prompt step: Received response") @@ -420,7 +418,6 @@ async def run( ) state = PartialTransition(output=new_response.output, type="resume") - case SetStep(), StepOutcome(output=evaluated_output): workflow.logger.info("Set step: Updating user state") self.update_user_state(evaluated_output) @@ -461,8 +458,7 @@ async def run( schedule_to_close_timeout=timedelta(days=31), ) - state = PartialTransition( - output=tool_call_response, type="resume") + state = PartialTransition(output=tool_call_response, type="resume") case _: workflow.logger.error(