From 9b6d0f1710aac144080f836654717cf39e52ad7d Mon Sep 17 00:00:00 2001 From: Reza Rahemtola Date: Thu, 5 Dec 2024 04:00:52 +0900 Subject: [PATCH] feat(tools): Support of langchain StructuredTool --- .../libertai_agents/interfaces/tools.py | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/libertai_agents/libertai_agents/interfaces/tools.py b/libertai_agents/libertai_agents/interfaces/tools.py index 9ed803e..58f69d3 100644 --- a/libertai_agents/libertai_agents/interfaces/tools.py +++ b/libertai_agents/libertai_agents/interfaces/tools.py @@ -1,6 +1,8 @@ from typing import Callable, Any, TYPE_CHECKING +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue from pydantic.v1 import BaseModel +from pydantic_core import CoreSchema from transformers.utils import get_json_schema from transformers.utils.chat_template_utils import _convert_type_hints_to_json_schema @@ -9,6 +11,17 @@ from langchain_core.tools import BaseTool +class GenerateToolPropertiesJsonSchema(GenerateJsonSchema): + def generate( + self, schema: CoreSchema, mode: JsonSchemaMode = "validation" + ) -> JsonSchemaValue: + json_schema = super().generate(schema, mode=mode) + for key in json_schema["properties"].keys(): + json_schema["properties"][key].pop("title", None) + json_schema.pop("title", None) + return json_schema + + class Tool(BaseModel): name: str function: Callable[..., Any] @@ -32,8 +45,29 @@ def from_langchain(cls, langchain_tool: "BaseTool"): ) if isinstance(langchain_tool, StructuredTool): - # TODO: handle this case - raise NotImplementedError("Langchain StructuredTool aren't supported yet") + # Particular case + structured_langchain_tool: StructuredTool = langchain_tool + function_parameters = ( + structured_langchain_tool.args_schema.model_json_schema( + schema_generator=GenerateToolPropertiesJsonSchema + ) + ) + + if structured_langchain_tool.func is None: + raise ValueError("Tool function is None, expected a Callable value") + + return cls( + name=structured_langchain_tool.name, + function=structured_langchain_tool.func, + args_schema={ + "type": "function", + "function": { + "name": structured_langchain_tool.name, + "description": structured_langchain_tool.description, + "parameters": function_parameters, + }, + }, + ) # Extracting function parameters to JSON schema function_parameters = _convert_type_hints_to_json_schema(langchain_tool._run)