diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index ad16ed6bd..bbae66dfb 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Callable +from typing import Any, Callable from anthropic import AsyncAnthropic # Import AsyncAnthropic client from anthropic.types.beta.beta_message import BetaMessage @@ -8,6 +8,7 @@ from langchain_core.tools.convert import tool as tool_decorator from litellm import ChatCompletionMessageToolCall, Function, Message from litellm.types.utils import Choices, ModelResponse +from pydantic import BaseModel from temporalio import activity from temporalio.exceptions import ApplicationError @@ -19,7 +20,7 @@ from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import anthropic_api_key, debug -from ..utils import get_handler_with_filtered_params +from ..utils import get_handler_with_filtered_params, get_integration_arguments from .base_evaluate import base_evaluate COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" @@ -70,13 +71,11 @@ def format_tool(tool: Tool) -> dict: formatted["function"]["parameters"] = json_schema - # # FIXME: Implement integration tools - # elif tool.type == "integration": - # raise NotImplementedError("Integration tools are not supported") + elif tool.type == "integration" and tool.integration: + formatted["function"]["parameters"] = get_integration_arguments(tool) - # # FIXME: Implement API call tools - # elif tool.type == "api_call": - # raise NotImplementedError("API call tools are not supported") + elif tool.type == "api_call" and tool.api_call: + formatted["function"]["parameters"] = tool.api_call.schema_ return formatted @@ -146,7 +145,9 @@ async def prompt_step(context: StepContext) -> StepOutcome: # Get passed settings passed_settings: dict = context.current_step.model_dump( - exclude=excluded_keys, exclude_unset=True + # TODO: Should we exclude unset? + exclude=excluded_keys, + exclude_unset=True, ) passed_settings.update(passed_settings.pop("settings", {})) @@ -251,8 +252,6 @@ async def prompt_step(context: StepContext) -> StepOutcome: ) else: - # FIXME: hardcoded tool to a None value as the tool calls are not implemented yet - formatted_tools = None # Use litellm for other models completion_data: dict = { "model": agent_model, diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index 3997104db..fcdd091c5 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -8,15 +8,43 @@ import statistics import string import time +import types import urllib.parse -from typing import Any, Callable, ParamSpec, TypeVar +from typing import ( + Annotated, + Any, + Callable, + Literal, + ParamSpec, + TypeVar, + get_args, + get_origin, +) import re2 import zoneinfo from beartype import beartype +from pydantic import BaseModel from simpleeval import EvalWithCompoundTypes, SimpleEval from ..autogen.openapi_model import SystemDef +from ..autogen.Tools import ( + BraveSearchArguments, + BrowserbaseCompleteSessionArguments, + BrowserbaseContextArguments, + BrowserbaseCreateSessionArguments, + BrowserbaseExtensionArguments, + BrowserbaseGetSessionArguments, + BrowserbaseGetSessionConnectUrlArguments, + BrowserbaseGetSessionLiveUrlsArguments, + BrowserbaseListSessionsArguments, + EmailArguments, + RemoteBrowserArguments, + SpiderFetchArguments, + Tool, + WeatherGetArguments, + WikipediaSearchArguments, +) from ..common.utils import yaml T = TypeVar("T") @@ -56,6 +84,101 @@ "match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)), } +_args_desc_map = { + BraveSearchArguments: { + "query": "The search query for searching with Brave", + }, + EmailArguments: { + "to": "The email address to send the email to", + "from_": "The email address to send the email from", + "subject": "The subject of the email", + "body": "The body of the email", + }, + SpiderFetchArguments: { + "url": "The URL to fetch data from", + "mode": "The type of crawler to use", + "params": "Additional parameters for the Spider API", + }, + WikipediaSearchArguments: { + "query": "The search query string", + "load_max_docs": "Maximum number of documents to load", + }, + WeatherGetArguments: { + "location": "The location for which to fetch weather data", + }, + BrowserbaseContextArguments: { + "project_id": "The Project ID. Can be found in Settings.", + }, + BrowserbaseExtensionArguments: { + "repository_name": "The GitHub repository name.", + "ref": "Ref to install from a branch or tag.", + }, + BrowserbaseListSessionsArguments: { + "status": "The status of the sessions to list (Available options: RUNNING, ERROR, TIMED_OUT, COMPLETED)", + }, + BrowserbaseCreateSessionArguments: { + "project_id": "The Project ID. Can be found in Settings.", + "extension_id": "The installed Extension ID. See Install Extension from GitHub.", + "browser_settings": "Browser settings", + "timeout": "Duration in seconds after which the session will automatically end. Defaults to the Project's defaultTimeout.", + "keep_alive": "Set to true to keep the session alive even after disconnections. This is available on the Startup plan only.", + "proxies": "Proxy configuration. Can be true for default proxy, or an array of proxy configurations.", + }, + BrowserbaseGetSessionArguments: { + "id": "Session ID", + }, + BrowserbaseCompleteSessionArguments: { + "id": "Session ID", + "status": "Session status", + }, + BrowserbaseGetSessionLiveUrlsArguments: { + "id": "Session ID", + }, + BrowserbaseGetSessionConnectUrlArguments: { + "id": "Session ID", + }, + RemoteBrowserArguments: { + "connect_url": "The connection URL for the remote browser", + "action": "The action to perform", + "text": "The text", + "coordinate": "The coordinate to move the mouse to", + }, +} + +_providers_map = { + "brave": BraveSearchArguments, + "email": EmailArguments, + "spider": SpiderFetchArguments, + "wikipedia": WikipediaSearchArguments, + "weather": WeatherGetArguments, + "browserbase": { + "create_context": BrowserbaseContextArguments, + "install_extension_from_github": BrowserbaseExtensionArguments, + "list_sessions": BrowserbaseListSessionsArguments, + "create_session": BrowserbaseCreateSessionArguments, + "get_session": BrowserbaseGetSessionArguments, + "complete_session": BrowserbaseCompleteSessionArguments, + "get_live_urls": BrowserbaseGetSessionLiveUrlsArguments, + "get_connect_url": BrowserbaseGetSessionConnectUrlArguments, + }, + "remote_browser": RemoteBrowserArguments, +} + + +_arg_types_map = { + BrowserbaseCreateSessionArguments: { + "proxies": { + "type": "boolean | array", + }, + }, + BrowserbaseListSessionsArguments: { + "status": { + "type": "string", + "enum": "RUNNING,ERROR,TIMED_OUT,COMPLETED", + }, + }, +} + class stdlib_re: fullmatch = re2.fullmatch @@ -378,3 +501,77 @@ def get_handler(system: SystemDef) -> Callable: raise NotImplementedError( f"System call not implemented for {system.resource}.{system.operation}" ) + + +def _annotation_to_type( + annotation: type, args_model: type[BaseModel], fld_name: str +) -> dict[str, str]: + type_, enum = None, None + if get_origin(annotation) is Literal: + type_ = "string" + enum = ",".join(annotation.__args__) + elif annotation is str: + type_ = "string" + elif annotation in (int, float): + type_ = "number" + elif annotation is list: + type_ = "array" + elif annotation is bool: + type_ = "boolean" + elif annotation == type(None): + type_ = "null" + elif get_origin(annotation) is types.UnionType: + args = [arg for arg in get_args(annotation) if arg is not types.NoneType] + if len(args): + return _annotation_to_type(args[0], args_model, fld_name) + else: + type_ = "null" + elif annotation is dict: + type_ = "object" + else: + type_ = _arg_types_map.get(args_model, {fld_name: {"type": "object"}}).get( + fld_name, {"type": "object"} + )["type"] + enum = _arg_types_map.get(args_model, {}).get(fld_name, {}).get("enum") + + result = { + "type": type_, + } + if enum is not None: + result.update({"enum": enum}) + + return result + + +def get_integration_arguments(tool: Tool): + properties = { + "type": "object", + "properties": {}, + "required": [], + } + + integration_args: type[BaseModel] | dict[str, type[BaseModel]] | None = ( + _providers_map.get(tool.integration.provider) + ) + + if integration_args is None: + return properties + + if isinstance(integration_args, dict): + integration_args: type[BaseModel] | None = integration_args.get( + tool.integration.method + ) + + if integration_args is None: + return properties + + for fld_name, fld_annotation in integration_args.model_fields.items(): + tp = _annotation_to_type(fld_annotation.annotation, integration_args, fld_name) + tp["description"] = _args_desc_map.get(integration_args, fld_name).get( + fld_name, fld_name + ) + properties["properties"][fld_name] = tp + if fld_annotation.is_required(): + properties["required"].append(fld_name) + + return properties diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 05f4ce795..fa224fdc9 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -437,7 +437,7 @@ async def run( ]["type"] == "integration": workflow.logger.debug("Prompt step: Received INTEGRATION tool call") - # FIXME: Implement integration tool calls + # TODO: Implement integration tool calls # See: MANUAL TOOL CALL INTEGRATION (below) raise NotImplementedError("Integration tool calls not yet supported") @@ -452,7 +452,7 @@ async def run( ]["type"] == "api_call": workflow.logger.debug("Prompt step: Received API_CALL tool call") - # FIXME: Implement API_CALL tool calls + # TODO: Implement API_CALL tool calls # See: MANUAL TOOL CALL API_CALL (below) raise NotImplementedError("API_CALL tool calls not yet supported") @@ -467,7 +467,7 @@ async def run( ]["type"] == "system": workflow.logger.debug("Prompt step: Received SYSTEM tool call") - # FIXME: Implement SYSTEM tool calls + # TODO: Implement SYSTEM tool calls # See: MANUAL TOOL CALL SYSTEM (below) raise NotImplementedError("SYSTEM tool calls not yet supported") diff --git a/agents-api/tests/test_activities_utils.py b/agents-api/tests/test_activities_utils.py new file mode 100644 index 000000000..7e4c74de9 --- /dev/null +++ b/agents-api/tests/test_activities_utils.py @@ -0,0 +1,451 @@ +from datetime import datetime, timezone +from uuid import uuid4 + +from ward import test + +from agents_api.activities.utils import get_integration_arguments +from agents_api.autogen.Tools import ( + BraveIntegrationDef, + BrowserbaseCompleteSessionIntegrationDef, + BrowserbaseContextIntegrationDef, + BrowserbaseCreateSessionIntegrationDef, + BrowserbaseExtensionIntegrationDef, + BrowserbaseGetSessionConnectUrlIntegrationDef, + BrowserbaseGetSessionIntegrationDef, + BrowserbaseGetSessionLiveUrlsIntegrationDef, + BrowserbaseListSessionsIntegrationDef, + DummyIntegrationDef, + EmailIntegrationDef, + RemoteBrowserIntegrationDef, + RemoteBrowserSetup, + SpiderIntegrationDef, + Tool, + WeatherIntegrationDef, + WikipediaIntegrationDef, +) + + +@test("get_integration_arguments: dummy search") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=DummyIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": {}, + "required": [], + } + + +@test("get_integration_arguments: brave search") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BraveIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query for searching with Brave", + } + }, + "required": ["query"], + } + + +@test("get_integration_arguments: email search") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=EmailIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "to": { + "type": "string", + "description": "The email address to send the email to", + }, + "from_": { + "type": "string", + "description": "The email address to send the email from", + }, + "subject": { + "type": "string", + "description": "The subject of the email", + }, + "body": { + "type": "string", + "description": "The body of the email", + }, + }, + "required": ["to", "from_", "subject", "body"], + } + + +@test("get_integration_arguments: spider fetch") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=SpiderIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "url": { + "type": "object", + "description": "The URL to fetch data from", + }, + "mode": { + "type": "string", + "description": "The type of crawler to use", + "enum": "scrape", + }, + "params": { + "type": "object", + "description": "Additional parameters for the Spider API", + }, + }, + "required": ["url"], + } + + +@test("get_integration_arguments: wikipedia integration") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=WikipediaIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string", + }, + "load_max_docs": { + "type": "number", + "description": "Maximum number of documents to load", + }, + }, + "required": ["query"], + } + + +@test("get_integration_arguments: weather integration") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=WeatherIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location for which to fetch weather data", + } + }, + "required": ["location"], + } + + +@test("get_integration_arguments: browserbase context") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseContextIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "The Project ID. Can be found in Settings.", + }, + }, + "required": ["project_id"], + } + + +@test("get_integration_arguments: browserbase extension") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseExtensionIntegrationDef( + method="install_extension_from_github" + ), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "repository_name": { + "type": "string", + "description": "The GitHub repository name.", + }, + "ref": { + "type": "string", + "description": "Ref to install from a branch or tag.", + }, + }, + "required": ["repository_name"], + } + + +@test("get_integration_arguments: browserbase list sessions") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseListSessionsIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "The status of the sessions to list (Available options: RUNNING, ERROR, TIMED_OUT, COMPLETED)", + "enum": "RUNNING,ERROR,TIMED_OUT,COMPLETED", + } + }, + "required": [], + } + + +@test("get_integration_arguments: browserbase create session") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseCreateSessionIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "The Project ID. Can be found in Settings.", + }, + "extension_id": { + "type": "string", + "description": "The installed Extension ID. See Install Extension from GitHub.", + }, + "browser_settings": { + "type": "object", + "description": "Browser settings", + }, + "timeout": { + "type": "number", + "description": "Duration in seconds after which the session will automatically end. Defaults to the Project's defaultTimeout.", + }, + "keep_alive": { + "type": "boolean", + "description": "Set to true to keep the session alive even after disconnections. This is available on the Startup plan only.", + }, + "proxies": { + "type": "boolean | array", + "description": "Proxy configuration. Can be true for default proxy, or an array of proxy configurations.", + }, + }, + "required": [], + } + + +@test("get_integration_arguments: browserbase get session") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseGetSessionIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: browserbase complete session") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseCompleteSessionIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + "status": { + "type": "string", + "description": "Session status", + "enum": "REQUEST_RELEASE", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: browserbase get session live urls") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseGetSessionLiveUrlsIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: browserbase get session connect url") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseGetSessionConnectUrlIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: remote browser") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=RemoteBrowserIntegrationDef(setup=RemoteBrowserSetup()), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "connect_url": { + "type": "string", + "description": "The connection URL for the remote browser", + }, + "action": { + "type": "string", + "description": "The action to perform", + "enum": "key,type,mouse_move,left_click,left_click_drag,right_click,middle_click,double_click,screenshot,cursor_position,navigate,refresh", + }, + "text": { + "type": "string", + "description": "The text", + }, + "coordinate": { + "type": "array", + "description": "The coordinate to move the mouse to", + }, + }, + "required": ["action"], + } diff --git a/typespec/tools/models.tsp b/typespec/tools/models.tsp index e7f5deb5e..193a43359 100644 --- a/typespec/tools/models.tsp +++ b/typespec/tools/models.tsp @@ -255,6 +255,7 @@ model FunctionCallOption { model NamedToolChoice { function?: FunctionCallOption; + // TODO: Add integration, system, api_call integration?: never; system?: never; api_call?: never; @@ -288,6 +289,7 @@ model BaseChosenToolCall { type: ToolType; function?: FunctionCallOption; + // TODO: Add integration, system, api_call integration?: unknown; // ChosenIntegrationCall system?: unknown; // ChosenSystemCall api_call?: unknown; // ChosenApiCall