Skip to content

Commit

Permalink
Added tool factory to api reference
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Feb 29, 2024
1 parent 03a77fc commit 8c32043
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 20 deletions.
72 changes: 53 additions & 19 deletions agency_swarm/tools/ToolFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
class ToolFactory:

@staticmethod
def from_langchain_tools(tools: List):
def from_langchain_tools(tools: List) -> List[Type[BaseTool]]:
"""
Converts a list of langchain tools into a list of BaseTools.
:param tools: A list of langchain tools.
:return: A list of BaseTools.
Parameters:
tools: The langchain tools to convert.
Returns:
A list of BaseTools.
"""
converted_tools = []
for tool in tools:
Expand All @@ -29,11 +33,15 @@ def from_langchain_tools(tools: List):
return converted_tools

@staticmethod
def from_langchain_tool(tool):
def from_langchain_tool(tool) -> Type[BaseTool]:
"""
Converts a langchain tool into a BaseTool.
:param tool: A langchain tool.
:return: A BaseTool.
Parameters:
tool: The langchain tool to convert.
Returns:
A BaseTool.
"""
try:
from langchain.tools import format_tool_to_openai_function
Expand Down Expand Up @@ -61,12 +69,16 @@ def callback(self):


@staticmethod
def from_openai_schema(schema: Dict[str, Any], callback: Any):
def from_openai_schema(schema: Dict[str, Any], callback: Any) -> Type[BaseTool]:
"""
Converts an OpenAI schema into a BaseTool. Nested propoerties without refs are not supported yet.
:param schema:
:param callback:
:return:
Parameters:
schema: The OpenAI schema to convert.
callback: The function to run when the tool is called.
Returns:
A BaseTool.
"""
def resolve_ref(ref: str, defs: Dict[str, Any]) -> Any:
# Extract the key from the reference
Expand Down Expand Up @@ -167,7 +179,19 @@ def create_fields(schema: Dict[str, Any], type_mapping: Dict[str, Type[Any]], re
return tool

@staticmethod
def from_openapi_schema(schema: Union[str, dict], headers: Dict[str, str] = None, params: Dict[str, Any] = None):
def from_openapi_schema(schema: Union[str, dict], headers: Dict[str, str] = None, params: Dict[str, Any] = None) \
-> List[Type[BaseTool]]:
"""
Converts an OpenAPI schema into a list of BaseTools.
Parameters:
schema: The OpenAPI schema to convert.
headers: The headers to use for requests.
params: The parameters to use for requests.
Returns:
A list of BaseTools.
"""
if isinstance(schema, dict):
openapi_spec = schema
openapi_spec = jsonref.JsonRef.replace_refs(openapi_spec)
Expand Down Expand Up @@ -260,8 +284,16 @@ def callback(self):
return tools

@staticmethod
def from_file(file_path: str):
"""Dynamically imports a class from a Python file, ensuring BaseTool itself is not imported."""
def from_file(file_path: str) -> Type[BaseTool]:
"""Dynamically imports a BaseTool from a Python file. The file must be named the same as the class.
Parameters:
file_path: The file path to the Python file containing the BaseTool class.
Returns:
The BaseTool class from the given file path.
"""
# Extract class name from file path (assuming class name matches file name without .py extension)
class_name = os.path.basename(file_path)
if class_name.endswith('.py'):
Expand All @@ -283,16 +315,18 @@ def from_file(file_path: str):

@staticmethod
def get_openapi_schema(tools: List[Type[BaseTool]], url: str, title="Agent Tools",
description="A collection of tools."):
description="A collection of tools.") -> str:
"""
Generates an OpenAPI schema from a list of BaseTools.
:param tools: BaseTools to generate the schema from.
:param url: The base URL for the schema.
:param title: The title of the schema.
:param description: The description of the schema.
Parameters:
tools: BaseTools to generate the schema from.
url: The base URL for the schema.
title: The title of the schema.
description: The description of the schema.
:return: A JSON string representing the OpenAPI schema with all the tools combined as separate endpoints.
Returns:
A JSON string representing the OpenAPI schema with all the tools combined as separate endpoints.
"""
schema = {
"openapi": "3.1.0",
Expand Down
4 changes: 3 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@

::: agency_swarm.agents.agent

::: agency_swarm.agency.agency
::: agency_swarm.agency

::: agency_swarm.tools.ToolFactory

0 comments on commit 8c32043

Please sign in to comment.