From 1e41ca57984bc01b5967911baaaeadbe148eb65b Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 14 Nov 2024 16:47:38 +0300 Subject: [PATCH] fix: Fix integrations to tools conversion --- agents-api/agents_api/activities/utils.py | 153 ++++++-- agents-api/tests/test_activities_utils.py | 427 +++++++++++++++++++++- 2 files changed, 553 insertions(+), 27 deletions(-) diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index 24f849822..fcdd091c5 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -8,8 +8,18 @@ import statistics import string import time +import types import urllib.parse -from typing import Any, Callable, Literal, ParamSpec, TypeVar, get_origin +from typing import ( + Annotated, + Any, + Callable, + Literal, + ParamSpec, + TypeVar, + get_args, + get_origin, +) import re2 import zoneinfo @@ -74,6 +84,101 @@ "match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)), } +_args_desc_map = { + BraveSearchArguments: { + "query": "The search query for searching with Brave", + }, + EmailArguments: { + "to": "The email address to send the email to", + "from_": "The email address to send the email from", + "subject": "The subject of the email", + "body": "The body of the email", + }, + SpiderFetchArguments: { + "url": "The URL to fetch data from", + "mode": "The type of crawler to use", + "params": "Additional parameters for the Spider API", + }, + WikipediaSearchArguments: { + "query": "The search query string", + "load_max_docs": "Maximum number of documents to load", + }, + WeatherGetArguments: { + "location": "The location for which to fetch weather data", + }, + BrowserbaseContextArguments: { + "project_id": "The Project ID. Can be found in Settings.", + }, + BrowserbaseExtensionArguments: { + "repository_name": "The GitHub repository name.", + "ref": "Ref to install from a branch or tag.", + }, + BrowserbaseListSessionsArguments: { + "status": "The status of the sessions to list (Available options: RUNNING, ERROR, TIMED_OUT, COMPLETED)", + }, + BrowserbaseCreateSessionArguments: { + "project_id": "The Project ID. Can be found in Settings.", + "extension_id": "The installed Extension ID. See Install Extension from GitHub.", + "browser_settings": "Browser settings", + "timeout": "Duration in seconds after which the session will automatically end. Defaults to the Project's defaultTimeout.", + "keep_alive": "Set to true to keep the session alive even after disconnections. This is available on the Startup plan only.", + "proxies": "Proxy configuration. Can be true for default proxy, or an array of proxy configurations.", + }, + BrowserbaseGetSessionArguments: { + "id": "Session ID", + }, + BrowserbaseCompleteSessionArguments: { + "id": "Session ID", + "status": "Session status", + }, + BrowserbaseGetSessionLiveUrlsArguments: { + "id": "Session ID", + }, + BrowserbaseGetSessionConnectUrlArguments: { + "id": "Session ID", + }, + RemoteBrowserArguments: { + "connect_url": "The connection URL for the remote browser", + "action": "The action to perform", + "text": "The text", + "coordinate": "The coordinate to move the mouse to", + }, +} + +_providers_map = { + "brave": BraveSearchArguments, + "email": EmailArguments, + "spider": SpiderFetchArguments, + "wikipedia": WikipediaSearchArguments, + "weather": WeatherGetArguments, + "browserbase": { + "create_context": BrowserbaseContextArguments, + "install_extension_from_github": BrowserbaseExtensionArguments, + "list_sessions": BrowserbaseListSessionsArguments, + "create_session": BrowserbaseCreateSessionArguments, + "get_session": BrowserbaseGetSessionArguments, + "complete_session": BrowserbaseCompleteSessionArguments, + "get_live_urls": BrowserbaseGetSessionLiveUrlsArguments, + "get_connect_url": BrowserbaseGetSessionConnectUrlArguments, + }, + "remote_browser": RemoteBrowserArguments, +} + + +_arg_types_map = { + BrowserbaseCreateSessionArguments: { + "proxies": { + "type": "boolean | array", + }, + }, + BrowserbaseListSessionsArguments: { + "status": { + "type": "string", + "enum": "RUNNING,ERROR,TIMED_OUT,COMPLETED", + }, + }, +} + class stdlib_re: fullmatch = re2.fullmatch @@ -398,7 +503,9 @@ def get_handler(system: SystemDef) -> Callable: ) -def _annotation_to_type(annotation: type) -> dict[str, str]: +def _annotation_to_type( + annotation: type, args_model: type[BaseModel], fld_name: str +) -> dict[str, str]: type_, enum = None, None if get_origin(annotation) is Literal: type_ = "string" @@ -413,8 +520,19 @@ def _annotation_to_type(annotation: type) -> dict[str, str]: type_ = "boolean" elif annotation == type(None): type_ = "null" - else: + elif get_origin(annotation) is types.UnionType: + args = [arg for arg in get_args(annotation) if arg is not types.NoneType] + if len(args): + return _annotation_to_type(args[0], args_model, fld_name) + else: + type_ = "null" + elif annotation is dict: type_ = "object" + else: + type_ = _arg_types_map.get(args_model, {fld_name: {"type": "object"}}).get( + fld_name, {"type": "object"} + )["type"] + enum = _arg_types_map.get(args_model, {}).get(fld_name, {}).get("enum") result = { "type": type_, @@ -426,25 +544,6 @@ def _annotation_to_type(annotation: type) -> dict[str, str]: def get_integration_arguments(tool: Tool): - providers_map = { - "brave": BraveSearchArguments, - # "dummy": DummyIntegrationDef, - "email": EmailArguments, - "spider": SpiderFetchArguments, - "wikipedia": WikipediaSearchArguments, - "weather": WeatherGetArguments, - "browserbase": { - "create_context": BrowserbaseContextArguments, - "install_extension_from_github": BrowserbaseExtensionArguments, - "list_sessions": BrowserbaseListSessionsArguments, - "create_session": BrowserbaseCreateSessionArguments, - "get_session": BrowserbaseGetSessionArguments, - "complete_session": BrowserbaseCompleteSessionArguments, - "get_live_urls": BrowserbaseGetSessionLiveUrlsArguments, - "get_connect_url": BrowserbaseGetSessionConnectUrlArguments, - }, - "remote_browser": RemoteBrowserArguments, - } properties = { "type": "object", "properties": {}, @@ -452,7 +551,7 @@ def get_integration_arguments(tool: Tool): } integration_args: type[BaseModel] | dict[str, type[BaseModel]] | None = ( - providers_map.get(tool.integration.provider) + _providers_map.get(tool.integration.provider) ) if integration_args is None: @@ -467,10 +566,12 @@ def get_integration_arguments(tool: Tool): return properties for fld_name, fld_annotation in integration_args.model_fields.items(): - tp = _annotation_to_type(fld_annotation.annotation) - tp["description"] = fld_name + tp = _annotation_to_type(fld_annotation.annotation, integration_args, fld_name) + tp["description"] = _args_desc_map.get(integration_args, fld_name).get( + fld_name, fld_name + ) properties["properties"][fld_name] = tp - if fld_annotation.is_required: + if fld_annotation.is_required(): properties["required"].append(fld_name) return properties diff --git a/agents-api/tests/test_activities_utils.py b/agents-api/tests/test_activities_utils.py index d7a83c34b..7e4c74de9 100644 --- a/agents-api/tests/test_activities_utils.py +++ b/agents-api/tests/test_activities_utils.py @@ -4,7 +4,25 @@ from ward import test from agents_api.activities.utils import get_integration_arguments -from agents_api.autogen.Tools import DummyIntegrationDef, Tool +from agents_api.autogen.Tools import ( + BraveIntegrationDef, + BrowserbaseCompleteSessionIntegrationDef, + BrowserbaseContextIntegrationDef, + BrowserbaseCreateSessionIntegrationDef, + BrowserbaseExtensionIntegrationDef, + BrowserbaseGetSessionConnectUrlIntegrationDef, + BrowserbaseGetSessionIntegrationDef, + BrowserbaseGetSessionLiveUrlsIntegrationDef, + BrowserbaseListSessionsIntegrationDef, + DummyIntegrationDef, + EmailIntegrationDef, + RemoteBrowserIntegrationDef, + RemoteBrowserSetup, + SpiderIntegrationDef, + Tool, + WeatherIntegrationDef, + WikipediaIntegrationDef, +) @test("get_integration_arguments: dummy search") @@ -24,3 +42,410 @@ async def _(): "properties": {}, "required": [], } + + +@test("get_integration_arguments: brave search") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BraveIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query for searching with Brave", + } + }, + "required": ["query"], + } + + +@test("get_integration_arguments: email search") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=EmailIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "to": { + "type": "string", + "description": "The email address to send the email to", + }, + "from_": { + "type": "string", + "description": "The email address to send the email from", + }, + "subject": { + "type": "string", + "description": "The subject of the email", + }, + "body": { + "type": "string", + "description": "The body of the email", + }, + }, + "required": ["to", "from_", "subject", "body"], + } + + +@test("get_integration_arguments: spider fetch") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=SpiderIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "url": { + "type": "object", + "description": "The URL to fetch data from", + }, + "mode": { + "type": "string", + "description": "The type of crawler to use", + "enum": "scrape", + }, + "params": { + "type": "object", + "description": "Additional parameters for the Spider API", + }, + }, + "required": ["url"], + } + + +@test("get_integration_arguments: wikipedia integration") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=WikipediaIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string", + }, + "load_max_docs": { + "type": "number", + "description": "Maximum number of documents to load", + }, + }, + "required": ["query"], + } + + +@test("get_integration_arguments: weather integration") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=WeatherIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location for which to fetch weather data", + } + }, + "required": ["location"], + } + + +@test("get_integration_arguments: browserbase context") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseContextIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "The Project ID. Can be found in Settings.", + }, + }, + "required": ["project_id"], + } + + +@test("get_integration_arguments: browserbase extension") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseExtensionIntegrationDef( + method="install_extension_from_github" + ), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "repository_name": { + "type": "string", + "description": "The GitHub repository name.", + }, + "ref": { + "type": "string", + "description": "Ref to install from a branch or tag.", + }, + }, + "required": ["repository_name"], + } + + +@test("get_integration_arguments: browserbase list sessions") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseListSessionsIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "The status of the sessions to list (Available options: RUNNING, ERROR, TIMED_OUT, COMPLETED)", + "enum": "RUNNING,ERROR,TIMED_OUT,COMPLETED", + } + }, + "required": [], + } + + +@test("get_integration_arguments: browserbase create session") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseCreateSessionIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "The Project ID. Can be found in Settings.", + }, + "extension_id": { + "type": "string", + "description": "The installed Extension ID. See Install Extension from GitHub.", + }, + "browser_settings": { + "type": "object", + "description": "Browser settings", + }, + "timeout": { + "type": "number", + "description": "Duration in seconds after which the session will automatically end. Defaults to the Project's defaultTimeout.", + }, + "keep_alive": { + "type": "boolean", + "description": "Set to true to keep the session alive even after disconnections. This is available on the Startup plan only.", + }, + "proxies": { + "type": "boolean | array", + "description": "Proxy configuration. Can be true for default proxy, or an array of proxy configurations.", + }, + }, + "required": [], + } + + +@test("get_integration_arguments: browserbase get session") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseGetSessionIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: browserbase complete session") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseCompleteSessionIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + "status": { + "type": "string", + "description": "Session status", + "enum": "REQUEST_RELEASE", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: browserbase get session live urls") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseGetSessionLiveUrlsIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: browserbase get session connect url") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=BrowserbaseGetSessionConnectUrlIntegrationDef(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Session ID", + }, + }, + "required": ["id"], + } + + +@test("get_integration_arguments: remote browser") +async def _(): + tool = Tool( + id=uuid4(), + name="tool1", + type="integration", + integration=RemoteBrowserIntegrationDef(setup=RemoteBrowserSetup()), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = get_integration_arguments(tool) + + assert result == { + "type": "object", + "properties": { + "connect_url": { + "type": "string", + "description": "The connection URL for the remote browser", + }, + "action": { + "type": "string", + "description": "The action to perform", + "enum": "key,type,mouse_move,left_click,left_click_drag,right_click,middle_click,double_click,screenshot,cursor_position,navigate,refresh", + }, + "text": { + "type": "string", + "description": "The text", + }, + "coordinate": { + "type": "array", + "description": "The coordinate to move the mouse to", + }, + }, + "required": ["action"], + }