Skip to content

Commit

Permalink
feat(tools): Support of langchain StructuredTool
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Dec 4, 2024
1 parent 911d21f commit f207e64
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions libertai_agents/libertai_agents/interfaces/tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from typing import Callable, Any, TYPE_CHECKING

from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue
from pydantic.v1 import BaseModel
from pydantic_core import CoreSchema
from transformers.utils import get_json_schema
from transformers.utils.chat_template_utils import _convert_type_hints_to_json_schema

if TYPE_CHECKING:
# Importing only for type hinting purposes.
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool # type: ignore


class GenerateToolPropertiesJsonSchema(GenerateJsonSchema):
def generate(
self, schema: CoreSchema, mode: JsonSchemaMode = "validation"
) -> JsonSchemaValue:
json_schema = super().generate(schema, mode=mode)
for key in json_schema["properties"].keys():
json_schema["properties"][key].pop("title", None)
json_schema.pop("title", None)
return json_schema


class Tool(BaseModel):
Expand All @@ -32,8 +45,29 @@ def from_langchain(cls, langchain_tool: "BaseTool"):
)

if isinstance(langchain_tool, StructuredTool):
# TODO: handle this case
raise NotImplementedError("Langchain StructuredTool aren't supported yet")
# Particular case
structured_langchain_tool: StructuredTool = langchain_tool
function_parameters = (
structured_langchain_tool.args_schema.model_json_schema(
schema_generator=GenerateToolPropertiesJsonSchema
)
)

if structured_langchain_tool.func is None:
raise ValueError("Tool function is None, expected a Callable value")

return cls(
name=structured_langchain_tool.name,
function=structured_langchain_tool.func,
args_schema={
"type": "function",
"function": {
"name": structured_langchain_tool.name,
"description": structured_langchain_tool.description,
"parameters": function_parameters,
},
},
)

# Extracting function parameters to JSON schema
function_parameters = _convert_type_hints_to_json_schema(langchain_tool._run)
Expand Down

0 comments on commit f207e64

Please sign in to comment.