Skip to content

Commit

Permalink
fix: Fix integrations to tools conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Nov 14, 2024
1 parent 193c764 commit 4e6eb22
Show file tree
Hide file tree
Showing 2 changed files with 553 additions and 27 deletions.
153 changes: 127 additions & 26 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
import statistics
import string
import time
import types
import urllib.parse
from typing import Any, Callable, Literal, ParamSpec, TypeVar, get_origin
from typing import (
Annotated,
Any,
Callable,
Literal,
ParamSpec,
TypeVar,
get_args,
get_origin,
)

import re2
import zoneinfo
Expand Down Expand Up @@ -74,6 +84,101 @@
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
}

_args_desc_map = {
BraveSearchArguments: {
"query": "The search query for searching with Brave",
},
EmailArguments: {
"to": "The email address to send the email to",
"from_": "The email address to send the email from",
"subject": "The subject of the email",
"body": "The body of the email",
},
SpiderFetchArguments: {
"url": "The URL to fetch data from",
"mode": "The type of crawler to use",
"params": "Additional parameters for the Spider API",
},
WikipediaSearchArguments: {
"query": "The search query string",
"load_max_docs": "Maximum number of documents to load",
},
WeatherGetArguments: {
"location": "The location for which to fetch weather data",
},
BrowserbaseContextArguments: {
"project_id": "The Project ID. Can be found in Settings.",
},
BrowserbaseExtensionArguments: {
"repository_name": "The GitHub repository name.",
"ref": "Ref to install from a branch or tag.",
},
BrowserbaseListSessionsArguments: {
"status": "The status of the sessions to list (Available options: RUNNING, ERROR, TIMED_OUT, COMPLETED)",
},
BrowserbaseCreateSessionArguments: {
"project_id": "The Project ID. Can be found in Settings.",
"extension_id": "The installed Extension ID. See Install Extension from GitHub.",
"browser_settings": "Browser settings",
"timeout": "Duration in seconds after which the session will automatically end. Defaults to the Project's defaultTimeout.",
"keep_alive": "Set to true to keep the session alive even after disconnections. This is available on the Startup plan only.",
"proxies": "Proxy configuration. Can be true for default proxy, or an array of proxy configurations.",
},
BrowserbaseGetSessionArguments: {
"id": "Session ID",
},
BrowserbaseCompleteSessionArguments: {
"id": "Session ID",
"status": "Session status",
},
BrowserbaseGetSessionLiveUrlsArguments: {
"id": "Session ID",
},
BrowserbaseGetSessionConnectUrlArguments: {
"id": "Session ID",
},
RemoteBrowserArguments: {
"connect_url": "The connection URL for the remote browser",
"action": "The action to perform",
"text": "The text",
"coordinate": "The coordinate to move the mouse to",
},
}

_providers_map = {
"brave": BraveSearchArguments,
"email": EmailArguments,
"spider": SpiderFetchArguments,
"wikipedia": WikipediaSearchArguments,
"weather": WeatherGetArguments,
"browserbase": {
"create_context": BrowserbaseContextArguments,
"install_extension_from_github": BrowserbaseExtensionArguments,
"list_sessions": BrowserbaseListSessionsArguments,
"create_session": BrowserbaseCreateSessionArguments,
"get_session": BrowserbaseGetSessionArguments,
"complete_session": BrowserbaseCompleteSessionArguments,
"get_live_urls": BrowserbaseGetSessionLiveUrlsArguments,
"get_connect_url": BrowserbaseGetSessionConnectUrlArguments,
},
"remote_browser": RemoteBrowserArguments,
}


_arg_types_map = {
BrowserbaseCreateSessionArguments: {
"proxies": {
"type": "boolean | array",
},
},
BrowserbaseListSessionsArguments: {
"status": {
"type": "string",
"enum": "RUNNING,ERROR,TIMED_OUT,COMPLETED",
},
},
}


class stdlib_re:
fullmatch = re2.fullmatch
Expand Down Expand Up @@ -398,7 +503,9 @@ def get_handler(system: SystemDef) -> Callable:
)


def _annotation_to_type(annotation: type) -> dict[str, str]:
def _annotation_to_type(
annotation: type, args_model: type[BaseModel], fld_name: str
) -> dict[str, str]:
type_, enum = None, None
if get_origin(annotation) is Literal:
type_ = "string"
Expand All @@ -413,8 +520,19 @@ def _annotation_to_type(annotation: type) -> dict[str, str]:
type_ = "boolean"
elif annotation == type(None):
type_ = "null"
else:
elif get_origin(annotation) is types.UnionType:
args = [arg for arg in get_args(annotation) if arg is not types.NoneType]
if len(args):
return _annotation_to_type(args[0], args_model, fld_name)
else:
type_ = "null"
elif annotation is dict:
type_ = "object"
else:
type_ = _arg_types_map.get(args_model, {fld_name: {"type": "object"}}).get(
fld_name, {"type": "object"}
)["type"]
enum = _arg_types_map.get(args_model, {}).get(fld_name, {}).get("enum")

result = {
"type": type_,
Expand All @@ -426,33 +544,14 @@ def _annotation_to_type(annotation: type) -> dict[str, str]:


def get_integration_arguments(tool: Tool):
providers_map = {
"brave": BraveSearchArguments,
# "dummy": DummyIntegrationDef,
"email": EmailArguments,
"spider": SpiderFetchArguments,
"wikipedia": WikipediaSearchArguments,
"weather": WeatherGetArguments,
"browserbase": {
"create_context": BrowserbaseContextArguments,
"install_extension_from_github": BrowserbaseExtensionArguments,
"list_sessions": BrowserbaseListSessionsArguments,
"create_session": BrowserbaseCreateSessionArguments,
"get_session": BrowserbaseGetSessionArguments,
"complete_session": BrowserbaseCompleteSessionArguments,
"get_live_urls": BrowserbaseGetSessionLiveUrlsArguments,
"get_connect_url": BrowserbaseGetSessionConnectUrlArguments,
},
"remote_browser": RemoteBrowserArguments,
}
properties = {
"type": "object",
"properties": {},
"required": [],
}

integration_args: type[BaseModel] | dict[str, type[BaseModel]] | None = (
providers_map.get(tool.integration.provider)
_providers_map.get(tool.integration.provider)
)

if integration_args is None:
Expand All @@ -467,10 +566,12 @@ def get_integration_arguments(tool: Tool):
return properties

for fld_name, fld_annotation in integration_args.model_fields.items():
tp = _annotation_to_type(fld_annotation.annotation)
tp["description"] = fld_name
tp = _annotation_to_type(fld_annotation.annotation, integration_args, fld_name)
tp["description"] = _args_desc_map.get(integration_args, fld_name).get(
fld_name, fld_name
)
properties["properties"][fld_name] = tp
if fld_annotation.is_required:
if fld_annotation.is_required():
properties["required"].append(fld_name)

return properties
Loading

0 comments on commit 4e6eb22

Please sign in to comment.