Skip to content

Commit

Permalink
feat(tools): Support of langchain StructuredTool
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Dec 4, 2024
1 parent 911d21f commit 9b6d0f1
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions libertai_agents/libertai_agents/interfaces/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Callable, Any, TYPE_CHECKING

from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue
from pydantic.v1 import BaseModel
from pydantic_core import CoreSchema
from transformers.utils import get_json_schema
from transformers.utils.chat_template_utils import _convert_type_hints_to_json_schema

Expand All @@ -9,6 +11,17 @@
from langchain_core.tools import BaseTool

Check failure on line 11 in libertai_agents/libertai_agents/interfaces/tools.py

View workflow job for this annotation

GitHub Actions / Package: mypy

[mypy] reported by reviewdog 🐶 Cannot find implementation or library stub for module named "langchain_core.tools" [import-not-found] Raw Output: /home/runner/work/libertai-agents/libertai-agents/libertai_agents/libertai_agents/interfaces/tools.py:11:1: error: Cannot find implementation or library stub for module named "langchain_core.tools" [import-not-found]


class GenerateToolPropertiesJsonSchema(GenerateJsonSchema):
def generate(
self, schema: CoreSchema, mode: JsonSchemaMode = "validation"
) -> JsonSchemaValue:
json_schema = super().generate(schema, mode=mode)
for key in json_schema["properties"].keys():
json_schema["properties"][key].pop("title", None)
json_schema.pop("title", None)
return json_schema


class Tool(BaseModel):
name: str
function: Callable[..., Any]
Expand All @@ -32,8 +45,29 @@ def from_langchain(cls, langchain_tool: "BaseTool"):
)

if isinstance(langchain_tool, StructuredTool):
# TODO: handle this case
raise NotImplementedError("Langchain StructuredTool aren't supported yet")
# Particular case
structured_langchain_tool: StructuredTool = langchain_tool
function_parameters = (
structured_langchain_tool.args_schema.model_json_schema(
schema_generator=GenerateToolPropertiesJsonSchema
)
)

if structured_langchain_tool.func is None:
raise ValueError("Tool function is None, expected a Callable value")

return cls(
name=structured_langchain_tool.name,
function=structured_langchain_tool.func,
args_schema={
"type": "function",
"function": {
"name": structured_langchain_tool.name,
"description": structured_langchain_tool.description,
"parameters": function_parameters,
},
},
)

# Extracting function parameters to JSON schema
function_parameters = _convert_type_hints_to_json_schema(langchain_tool._run)
Expand Down

0 comments on commit 9b6d0f1

Please sign in to comment.