diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index d56305959..901e86c1b 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -1,8 +1,9 @@ from beartype import beartype +from litellm.types.utils import Function from temporalio import activity from temporalio.exceptions import ApplicationError -from ...autogen.Tools import Tool +from ...autogen.Tools import ApiCallDef, Tool from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it ) @@ -11,20 +12,69 @@ from ...models.tools.list_tools import list_tools -# FIXME: This shouldn't be here. -def format_agent_tool(tool: Tool) -> dict: +def make_function_call(tool: Tool) -> dict | None: + result = {"type": "function"} + if tool.function: - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.function.parameters, - }, + result.update( + { + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.function.parameters, + }, + } + ) + elif tool.api_call: + result.update( + { + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + k.rstrip("_"): getattr(tool.api_call, k) + for k in ApiCallDef.model_fields.keys() + }, + }, + } + ) + elif tool.system: + parameters = { + "resource_id": tool.system.resource_id, + "subresource": tool.system.subresource, } - # TODO: Add integration | system | api_call tool types - else: - return {} + if tool.system.arguments: + parameters.update({"arguments": tool.system.arguments}) + + result.update( + { + "function": { + "name": f"{tool.system.resource}.{tool.system.operation}", + "description": f"{tool.system.operation} a {tool.system.resource}", + "parameters": parameters, + } + } + ) + elif tool.integration: + parameters = {} + if tool.integration.method: + parameters.update({"method": tool.integration.method}) + if tool.integration.setup: + parameters.update({"setup": tool.integration.setup}) + if tool.integration.arguments: + parameters.update({"arguments": tool.integration.arguments}) + + result.update( + { + "function": { + "name": tool.name, + "description": f"{tool.integration.provider} integration", + "parameters": parameters, + } + } + ) + + return result if result.get("function") else None @activity.defn @@ -69,8 +119,12 @@ async def prompt_step(context: StepContext) -> StepOutcome: # Format agent_tools for litellm # COMMENT(oct-16): Format the tools for openai api here (api_call | integration | system) -> function formatted_agent_tools = [ - format_agent_tool(tool) for tool in agent_tools if format_agent_tool(tool) + func_call for tool in agent_tools if (func_call := make_function_call(tool)) ] + tools_mapping = { + fmt_tool["function"]["name"]: orig_tool + for fmt_tool, orig_tool in zip(formatted_agent_tools, agent_tools) + } if context.current_step.settings: passed_settings: dict = context.current_step.settings.model_dump( @@ -95,11 +149,27 @@ async def prompt_step(context: StepContext) -> StepOutcome: **completion_data, ) + choice = response.choices[0] if context.current_step.unwrap: - if response.choices[0].finish_reason == "tool_calls": + if choice.finish_reason == "tool_calls": raise ApplicationError("Tool calls cannot be unwrapped") - response = response.choices[0].message.content + response = choice.message.content + + if choice.finish_reason == "tool_calls": + choice.message.tool_calls = [ + call if isinstance(tc.function, dict) else tc.function.name + for tc in choice.message.tool_calls + if ( + call := ( + tools_mapping.get( + tc.function["name"] + if isinstance(tc.function, dict) + else tc.function.name + ) + ) + ) + ] ### response.choices[0].finish_reason == "tool_calls" ### -> response.choices[0].message.tool_calls