From 2c8a0af060e085323bbbda9f8b389731aef707d1 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 17 Oct 2024 14:07:10 +0300 Subject: [PATCH] feat: Choose the activity based on tool type --- .../agents_api/activities/execute_system.py | 2 +- .../activities/task_steps/prompt_step.py | 35 +--- .../agents_api/common/exceptions/tasks.py | 2 + .../workflows/task_execution/__init__.py | 174 +++++++++--------- 4 files changed, 92 insertions(+), 121 deletions(-) diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index 2d3d26687..769d3eec6 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -46,7 +46,7 @@ async def execute_system( context: StepContext, system: SystemDef, ) -> Any: - arguments = system.arguments + arguments = system.arguments or {} arguments["developer_id"] = context.execution_input.developer_id # Unbox all the arguments diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 901e86c1b..f4b405162 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -111,13 +111,6 @@ async def prompt_step(context: StepContext) -> StepOutcome: direction="desc", ) - ### [Function(...), ApiCall(...), Integration(...)] - ### -> [{"type": "function", "function": {...}}, {"type": "api_call", "api_call": {...}}, {"type": "integration", "integration": {...}}] - ### -> [{"type": "function", "function": {...}}] - ### -> openai - - # Format agent_tools for litellm - # COMMENT(oct-16): Format the tools for openai api here (api_call | integration | system) -> function formatted_agent_tools = [ func_call for tool in agent_tools if (func_call := make_function_call(tool)) ] @@ -157,30 +150,10 @@ async def prompt_step(context: StepContext) -> StepOutcome: response = choice.message.content if choice.finish_reason == "tool_calls": - choice.message.tool_calls = [ - call if isinstance(tc.function, dict) else tc.function.name - for tc in choice.message.tool_calls - if ( - call := ( - tools_mapping.get( - tc.function["name"] - if isinstance(tc.function, dict) - else tc.function.name - ) - ) - ) - ] - - ### response.choices[0].finish_reason == "tool_calls" - ### -> response.choices[0].message.tool_calls - ### -> [{"id": "call_abc", "name": "my_function", "arguments": "..."}, ...] - ### (cross-reference with original agent_tools list to get the original tool) - ### - ### -> FunctionCall(...) | ApiCall(...) | IntegrationCall(...) | SystemCall(...) - ### -> set this on response.choices[0].tool_calls - - # COMMENT(oct-16): Reference the original tool from tools passed to the activity - # if openai chooses to use a tool (finish_reason == "tool_calls") + tc = choice.message.tool_calls[0] + choice.message.tool_calls = tools_mapping.get( + tc.function["name"] if isinstance(tc.function, dict) else tc.function.name + ) return StepOutcome( output=response.model_dump() if hasattr(response, "model_dump") else response, diff --git a/agents-api/agents_api/common/exceptions/tasks.py b/agents-api/agents_api/common/exceptions/tasks.py index 19bb5b5ae..3c0d47458 100644 --- a/agents-api/agents_api/common/exceptions/tasks.py +++ b/agents-api/agents_api/common/exceptions/tasks.py @@ -5,6 +5,7 @@ """ import asyncio +from typing import cast import beartype import beartype.roar @@ -121,6 +122,7 @@ def is_non_retryable_error(error: BaseException) -> bool: # Check for specific HTTP errors (status code == 429) if isinstance(error, httpx.HTTPStatusError): + error = cast(httpx.HTTPStatusError, error) if error.response.status_code in (408, 429, 503, 504): return False diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 52c4ab377..92e4a3555 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -115,6 +115,71 @@ # Probably can be implemented much more efficiently +async def run_api_call(tool_call: dict, context: StepContext): + call = tool_call["api_call"] + tool_name = call["name"] + arguments = call["arguments"] + apicall_spec = next((t for t in context.tools if t.name == tool_name), None) + + if apicall_spec is None: + raise ApplicationError(f"Integration {tool_name} not found") + + api_call = ApiCallDef( + method=apicall_spec.spec["method"], + url=apicall_spec.spec["url"], + headers=apicall_spec.spec["headers"], + follow_redirects=apicall_spec.spec["follow_redirects"], + ) + + if "json_" in arguments: + arguments["json"] = arguments["json_"] + del arguments["json_"] + + return await workflow.execute_activity( + execute_api_call, + args=[ + api_call, + arguments, + ], + schedule_to_close_timeout=timedelta(seconds=30 if debug or testing else 600), + ) + + +async def run_integration_call(tool_call: dict, context: StepContext): + call = tool_call["integration"] + tool_name = call["name"] + arguments = call["arguments"] + integration_spec = next((t for t in context.tools if t.name == tool_name), None) + + if integration_spec is None: + raise ApplicationError(f"Integration {tool_name} not found") + + integration = IntegrationDef( + provider=integration_spec.spec["provider"], + setup=integration_spec.spec["setup"], + method=integration_spec.spec["method"], + arguments=arguments, + ) + + return await workflow.execute_activity( + execute_integration, + args=[context, tool_name, integration, arguments], + schedule_to_close_timeout=timedelta(seconds=30 if debug or testing else 600), + retry_policy=DEFAULT_RETRY_POLICY, + ) + + +async def run_system_call(tool_call: dict, context: StepContext): + call = tool_call.get("system") + + system_call = SystemDef(**call) + return await workflow.execute_activity( + execute_system, + args=[context, system_call], + schedule_to_close_timeout=timedelta(seconds=30 if debug or testing else 600), + ) + + # Main workflow definition @workflow.defn class TaskExecutionWorkflow: @@ -382,26 +447,23 @@ async def run( message = response["choices"][0]["message"] tool_calls_input = message["tool_calls"] - ### COMMENT(oct-16): do a match-case on tool_calls_input.type - ### -> FunctionCall(...), ApiCall(...), IntegrationCall(...), SystemCall(...) - ### -> if api_call: - ### => execute_api_call(api_call) - ### -> if integration_call: - ### => execute_integration(integration_call) - ### -> if system_call: - ### => execute_system(system_call) - ### -> else: - ### => wait for input - - # Enter a wait-for-input step to ask the developer to run the tool calls - tool_calls_results = await workflow.execute_activity( - task_steps.raise_complete_async, - args=[context, tool_calls_input], - schedule_to_close_timeout=timedelta(days=31), - retry_policy=DEFAULT_RETRY_POLICY, - ) - - ### COMMENT(oct-16): Continue as usual. Feed the tool call results back to the model + if tool_calls_input.get("api_call"): + tool_calls_results = await run_api_call(tool_calls_input, context) + elif tool_calls_input.get("integration"): + tool_calls_results = await run_integration_call( + tool_calls_input, context + ) + elif tool_calls_input.get("system"): + tool_calls_results = await run_system_call( + tool_calls_input, context + ) + else: + tool_calls_results = await workflow.execute_activity( + task_steps.raise_complete_async, + args=[context, tool_calls_input], + schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY, + ) # Feed the tool call results back to the model context.current_step.prompt.append(message) @@ -453,85 +515,19 @@ async def run( case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ "type" ] == "integration": - call = tool_call["integration"] - tool_name = call["name"] - arguments = call["arguments"] - integration_spec = next( - (t for t in context.tools if t.name == tool_name), None - ) - - if integration_spec is None: - raise ApplicationError(f"Integration {tool_name} not found") - - integration = IntegrationDef( - provider=integration_spec.spec["provider"], - setup=integration_spec.spec["setup"], - method=integration_spec.spec["method"], - arguments=arguments, - ) - - tool_call_response = await workflow.execute_activity( - execute_integration, - args=[context, tool_name, integration, arguments], - schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else 600 - ), - retry_policy=DEFAULT_RETRY_POLICY, - ) - + tool_call_response = await run_integration_call(tool_call, context) state = PartialTransition(output=tool_call_response) case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ "type" ] == "api_call": - call = tool_call["api_call"] - tool_name = call["name"] - arguments = call["arguments"] - apicall_spec = next( - (t for t in context.tools if t.name == tool_name), None - ) - - if apicall_spec is None: - raise ApplicationError(f"Integration {tool_name} not found") - - api_call = ApiCallDef( - method=apicall_spec.spec["method"], - url=apicall_spec.spec["url"], - headers=apicall_spec.spec["headers"], - follow_redirects=apicall_spec.spec["follow_redirects"], - ) - - if "json_" in arguments: - arguments["json"] = arguments["json_"] - del arguments["json_"] - - # Execute the API call using the `execute_api_call` function - tool_call_response = await workflow.execute_activity( - execute_api_call, - args=[ - api_call, - arguments, - ], - schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else 600 - ), - ) - + tool_call_response = await run_api_call(tool_call, context) state = PartialTransition(output=tool_call_response) case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ "type" ] == "system": - call = tool_call.get("system") - - system_call = SystemDef(**call) - tool_call_response = await workflow.execute_activity( - execute_system, - args=[context, system_call], - schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else 600 - ), - ) + tool_call_response = await run_system_call(tool_call, context) # FIXME: This is a hack to make the output of the system call match # the expected output format (convert uuid/datetime to strings)