diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 5bbc0a1ae..26b973f82 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -12,25 +12,7 @@ from temporalio import activity from temporalio.exceptions import ApplicationError -from ...autogen.Tools import ( - BaseIntegrationDef, - BraveIntegrationDef, - BrowserbaseCompleteSessionIntegrationDef, - BrowserbaseContextIntegrationDef, - BrowserbaseCreateSessionIntegrationDef, - BrowserbaseExtensionIntegrationDef, - BrowserbaseGetSessionConnectUrlIntegrationDef, - BrowserbaseGetSessionIntegrationDef, - BrowserbaseGetSessionLiveUrlsIntegrationDef, - BrowserbaseListSessionsIntegrationDef, - DummyIntegrationDef, - EmailIntegrationDef, - RemoteBrowserIntegrationDef, - SpiderIntegrationDef, - Tool, - WeatherIntegrationDef, - WikipediaIntegrationDef, -) +from ...autogen.Tools import Tool from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it ) @@ -38,64 +20,12 @@ from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import anthropic_api_key, debug -from ..utils import get_handler_with_filtered_params +from ..utils import get_handler_with_filtered_params, get_integration_arguments from .base_evaluate import base_evaluate COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" -def _get_integration_arguments(tool: Tool): - providers_map = { - "brave": BraveIntegrationDef, - "dummy": DummyIntegrationDef, - "email": EmailIntegrationDef, - "spider": SpiderIntegrationDef, - "wikipedia": WikipediaIntegrationDef, - "weather": WeatherIntegrationDef, - "browserbase": { - "create_context": BrowserbaseContextIntegrationDef, - "install_extension_from_github": BrowserbaseExtensionIntegrationDef, - "list_sessions": BrowserbaseListSessionsIntegrationDef, - "create_session": BrowserbaseCreateSessionIntegrationDef, - "get_session": BrowserbaseGetSessionIntegrationDef, - "complete_session": BrowserbaseCompleteSessionIntegrationDef, - "get_live_urls": BrowserbaseGetSessionLiveUrlsIntegrationDef, - "get_connect_url": BrowserbaseGetSessionConnectUrlIntegrationDef, - }, - "remote_browser": RemoteBrowserIntegrationDef, - } - - integration: BaseIntegrationDef | dict[str, BaseIntegrationDef] = providers_map.get( - tool.integration.provider - ) - if isinstance(integration, dict): - integration: BaseIntegrationDef = integration.get(tool.integration.method) - - properties = { - "type": "object", - "properties": {}, - "required": [], - } - - arguments: BaseModel | Any | None = integration.arguments - if not arguments: - return properties - - if isinstance(arguments, BaseModel): - for fld_name, fld_annotation in arguments.model_fields.items(): - properties["properties"][fld_name] = { - "type": fld_annotation.annotation, - "description": fld_name, - } - if fld_annotation.is_required: - properties["required"].append(fld_name) - - elif isinstance(arguments, dict): - properties["properties"] = arguments - - return properties - - def format_tool(tool: Tool) -> dict: if tool.type == "computer_20241022": return { @@ -142,7 +72,7 @@ def format_tool(tool: Tool) -> dict: formatted["function"]["parameters"] = json_schema elif tool.type == "integration" and tool.integration: - formatted["function"]["parameters"] = _get_integration_arguments(tool) + formatted["function"]["parameters"] = get_integration_arguments(tool) elif tool.type == "api_call" and tool.api_call: formatted["function"]["parameters"] = tool.api_call.schema_ diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index 3997104db..24f849822 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -9,14 +9,32 @@ import string import time import urllib.parse -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, Callable, Literal, ParamSpec, TypeVar, get_origin import re2 import zoneinfo from beartype import beartype +from pydantic import BaseModel from simpleeval import EvalWithCompoundTypes, SimpleEval from ..autogen.openapi_model import SystemDef +from ..autogen.Tools import ( + BraveSearchArguments, + BrowserbaseCompleteSessionArguments, + BrowserbaseContextArguments, + BrowserbaseCreateSessionArguments, + BrowserbaseExtensionArguments, + BrowserbaseGetSessionArguments, + BrowserbaseGetSessionConnectUrlArguments, + BrowserbaseGetSessionLiveUrlsArguments, + BrowserbaseListSessionsArguments, + EmailArguments, + RemoteBrowserArguments, + SpiderFetchArguments, + Tool, + WeatherGetArguments, + WikipediaSearchArguments, +) from ..common.utils import yaml T = TypeVar("T") @@ -378,3 +396,81 @@ def get_handler(system: SystemDef) -> Callable: raise NotImplementedError( f"System call not implemented for {system.resource}.{system.operation}" ) + + +def _annotation_to_type(annotation: type) -> dict[str, str]: + type_, enum = None, None + if get_origin(annotation) is Literal: + type_ = "string" + enum = ",".join(annotation.__args__) + elif annotation is str: + type_ = "string" + elif annotation in (int, float): + type_ = "number" + elif annotation is list: + type_ = "array" + elif annotation is bool: + type_ = "boolean" + elif annotation == type(None): + type_ = "null" + else: + type_ = "object" + + result = { + "type": type_, + } + if enum is not None: + result.update({"enum": enum}) + + return result + + +def get_integration_arguments(tool: Tool): + providers_map = { + "brave": BraveSearchArguments, + # "dummy": DummyIntegrationDef, + "email": EmailArguments, + "spider": SpiderFetchArguments, + "wikipedia": WikipediaSearchArguments, + "weather": WeatherGetArguments, + "browserbase": { + "create_context": BrowserbaseContextArguments, + "install_extension_from_github": BrowserbaseExtensionArguments, + "list_sessions": BrowserbaseListSessionsArguments, + "create_session": BrowserbaseCreateSessionArguments, + "get_session": BrowserbaseGetSessionArguments, + "complete_session": BrowserbaseCompleteSessionArguments, + "get_live_urls": BrowserbaseGetSessionLiveUrlsArguments, + "get_connect_url": BrowserbaseGetSessionConnectUrlArguments, + }, + "remote_browser": RemoteBrowserArguments, + } + properties = { + "type": "object", + "properties": {}, + "required": [], + } + + integration_args: type[BaseModel] | dict[str, type[BaseModel]] | None = ( + providers_map.get(tool.integration.provider) + ) + + if integration_args is None: + return properties + + if isinstance(integration_args, dict): + integration_args: type[BaseModel] | None = integration_args.get( + tool.integration.method + ) + + if integration_args is None: + return properties + + for fld_name, fld_annotation in integration_args.model_fields.items(): + tp = _annotation_to_type(fld_annotation.annotation) + tp["description"] = fld_name + properties["properties"][fld_name] = tp + if fld_annotation.is_required: + properties["required"].append(fld_name) + + return properties diff --git a/agents-api/tests/test_activities_utils.py b/agents-api/tests/test_activities_utils.py new file mode 100644 index 000000000..d7a83c34b --- /dev/null +++ b/agents-api/tests/test_activities_utils.py @@ -0,0 +1,26 @@ +from datetime import datetime, timezone +from uuid import uuid4 + +from ward import test + +from agents_api.activities.utils import get_integration_arguments +from agents_api.autogen.Tools import DummyIntegrationDef, Tool + + +@test("get_integration_arguments: dummy search") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=DummyIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": {}, + "required": [], + }