Skip to content

Commit

Permalink
feat(tools): Basic langchain integration
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Dec 4, 2024
1 parent 79d1690 commit 911d21f
Show file tree
Hide file tree
Showing 9 changed files with 1,163 additions and 447 deletions.
140 changes: 101 additions & 39 deletions libertai_agents/libertai_agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
import asyncio
import inspect
import json
from http import HTTPStatus
from typing import Callable, Awaitable, Any, AsyncIterable
from typing import Awaitable, Any, AsyncIterable

import aiohttp
from aiohttp import ClientSession
from fastapi import APIRouter, FastAPI
from starlette.responses import StreamingResponse

from libertai_agents.interfaces.llamacpp import CustomizableLlamaCppParams, LlamaCppParams
from libertai_agents.interfaces.messages import Message, MessageRoleEnum, MessageToolCall, ToolCallFunction, \
ToolCallMessage, ToolResponseMessage
from libertai_agents.interfaces.llamacpp import (
CustomizableLlamaCppParams,
LlamaCppParams,
)
from libertai_agents.interfaces.messages import (
Message,
MessageRoleEnum,
MessageToolCall,
ToolCallFunction,
ToolCallMessage,
ToolResponseMessage,
)
from libertai_agents.interfaces.models import ModelInformation
from libertai_agents.interfaces.tools import Tool
from libertai_agents.models import Model
from libertai_agents.utils import find

Expand All @@ -21,14 +32,18 @@
class ChatAgent:
model: Model
system_prompt: str | None
tools: list[Callable[..., Awaitable[Any]]]
tools: list[Tool]
llamacpp_params: CustomizableLlamaCppParams
app: FastAPI | None

def __init__(self, model: Model, system_prompt: str | None = None,
tools: list[Callable[..., Awaitable[Any]]] | None = None,
llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams(),
expose_api: bool = True):
def __init__(
self,
model: Model,
system_prompt: str | None = None,
tools: list[Tool] | None = None,
llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams(),
expose_api: bool = True,
):
"""
Create a LibertAI chatbot agent that can answer to messages from users
Expand All @@ -41,7 +56,7 @@ def __init__(self, model: Model, system_prompt: str | None = None,
if tools is None:
tools = []

if len(set(map(lambda x: x.__name__, tools))) != len(tools):
if len(set(map(lambda x: x.name, tools))) != len(tools):
raise ValueError("Tool functions must have different names")
self.model = model
self.system_prompt = system_prompt
Expand All @@ -51,8 +66,12 @@ def __init__(self, model: Model, system_prompt: str | None = None,
if expose_api:
# Define API routes
router = APIRouter()
router.add_api_route("/generate-answer", self.__api_generate_answer, methods=["POST"],
summary="Generate Answer")
router.add_api_route(
"/generate-answer",
self.__api_generate_answer,
methods=["POST"],
summary="Generate Answer",
)
router.add_api_route("/model", self.get_model_information, methods=["GET"])

self.app = FastAPI(title="LibertAI ChatAgent")
Expand All @@ -62,9 +81,13 @@ def get_model_information(self) -> ModelInformation:
"""
Get information about the model powering this agent
"""
return ModelInformation(id=self.model.model_id, context_length=self.model.context_length)
return ModelInformation(
id=self.model.model_id, context_length=self.model.context_length
)

async def generate_answer(self, messages: list[Message], only_final_answer: bool = True) -> AsyncIterable[Message]:
async def generate_answer(
self, messages: list[Message], only_final_answer: bool = True
) -> AsyncIterable[Message]:
"""
Generate an answer based on a conversation
Expand All @@ -78,7 +101,9 @@ async def generate_answer(self, messages: list[Message], only_final_answer: bool
raise ValueError("Last message is not from the user or a tool response")

for _ in range(MAX_TOOL_CALLS_DEPTH):
prompt = self.model.generate_prompt(messages, self.tools, system_prompt=self.system_prompt)
prompt = self.model.generate_prompt(
messages, self.tools, system_prompt=self.system_prompt
)
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)

Expand All @@ -97,44 +122,63 @@ async def generate_answer(self, messages: list[Message], only_final_answer: bool
if not only_final_answer:
yield tool_calls_message

executed_calls = self.__execute_tool_calls(tool_calls_message.tool_calls)
executed_calls = self.__execute_tool_calls(
tool_calls_message.tool_calls
)
results = await asyncio.gather(*executed_calls)
tool_results_messages: list[Message] = [
ToolResponseMessage(role=MessageRoleEnum.tool, name=call.function.name, tool_call_id=call.id,
content=str(results[i])) for i, call in
enumerate(tool_calls_message.tool_calls)]
ToolResponseMessage(
role=MessageRoleEnum.tool,
name=call.function.name,
tool_call_id=call.id,
content=str(results[i]),
)
for i, call in enumerate(tool_calls_message.tool_calls)
]
if not only_final_answer:
for tool_result_message in tool_results_messages:
yield tool_result_message
# Doing the next iteration of the loop with the results to make other tool calls or to answer
messages = messages + tool_results_messages

async def __api_generate_answer(self, messages: list[Message], stream: bool = False,
only_final_answer: bool = True):
async def __api_generate_answer(
self,
messages: list[Message],
stream: bool = False,
only_final_answer: bool = True,
):
"""
Generate an answer based on an existing conversation.
The response messages can be streamed or sent in a single block.
"""
if stream:
return StreamingResponse(
self.__dump_api_generate_streamed_answer(messages, only_final_answer=only_final_answer),
media_type='text/event-stream')
self.__dump_api_generate_streamed_answer(
messages, only_final_answer=only_final_answer
),
media_type="text/event-stream",
)

response_messages: list[Message] = []
async for message in self.generate_answer(messages, only_final_answer=only_final_answer):
async for message in self.generate_answer(
messages, only_final_answer=only_final_answer
):
response_messages.append(message)
return response_messages

async def __dump_api_generate_streamed_answer(self, messages: list[Message], only_final_answer: bool) -> \
AsyncIterable[str]:
async def __dump_api_generate_streamed_answer(
self, messages: list[Message], only_final_answer: bool
) -> AsyncIterable[str]:
"""
Dump to JSON the generate_answer iterable
:param messages: Messages to pass to generate_answer
:param only_final_answer: Param to pass to generate_answer
:return: Iterable of each messages from generate_answer dumped to JSON
"""
async for message in self.generate_answer(messages, only_final_answer=only_final_answer):
async for message in self.generate_answer(
messages, only_final_answer=only_final_answer
):
yield json.dumps(message.dict(), indent=4)

async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
Expand All @@ -154,7 +198,9 @@ async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
return response_data["content"]
return None

def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Awaitable[Any]]:
def __execute_tool_calls(
self, tool_calls: list[MessageToolCall]
) -> list[Awaitable[Any]]:
"""
Execute the given tool calls (without waiting for completion)
Expand All @@ -164,26 +210,42 @@ def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Awaita
executed_calls: list[Awaitable[Any]] = []
for call in tool_calls:
function_name = call.function.name
function_to_call = find(lambda x: x.__name__ == function_name, self.tools)
if function_to_call is None:
tool = find(lambda x: x.name == function_name, self.tools)
if tool is None:
# TODO: handle error
continue
function_response = function_to_call(*call.function.arguments.values())

function_to_call = tool.function
if inspect.iscoroutinefunction(function_to_call):
# Call async function directly
function_response = function_to_call(*call.function.arguments.values())
else:
# Wrap sync function in asyncio.to_thread to make it awaitable
function_response = asyncio.to_thread(
function_to_call, *call.function.arguments.values()
)

executed_calls.append(function_response)

return executed_calls

def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> ToolCallMessage:
def __create_tool_calls_message(
self, tool_calls: list[ToolCallFunction]
) -> ToolCallMessage:
"""
Craft a tool call message
:param tool_calls: Tool calls to include in the message
:return: Crafted tool call message
"""
return ToolCallMessage(role=MessageRoleEnum.assistant,
tool_calls=[MessageToolCall(type="function",
id=self.model.generate_tool_call_id(),
function=ToolCallFunction(name=call.name,
arguments=call.arguments)) for
call in
tool_calls])
return ToolCallMessage(
role=MessageRoleEnum.assistant,
tool_calls=[
MessageToolCall(
type="function",
id=self.model.generate_tool_call_id(),
function=ToolCallFunction(name=call.name, arguments=call.arguments),
)
for call in tool_calls
],
)
8 changes: 4 additions & 4 deletions libertai_agents/libertai_agents/interfaces/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@


class MessageRoleEnum(str, Enum):
user = 'user'
assistant = 'assistant'
system = 'system'
tool = 'tool'
user = "user"
assistant = "assistant"
system = "system"
tool = "tool"


class ToolCallFunction(BaseModel):
Expand Down
55 changes: 55 additions & 0 deletions libertai_agents/libertai_agents/interfaces/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Callable, Any, TYPE_CHECKING

from pydantic.v1 import BaseModel
from transformers.utils import get_json_schema
from transformers.utils.chat_template_utils import _convert_type_hints_to_json_schema

if TYPE_CHECKING:
# Importing only for type hinting purposes.
from langchain_core.tools import BaseTool

Check failure on line 9 in libertai_agents/libertai_agents/interfaces/tools.py

View workflow job for this annotation

GitHub Actions / Package: mypy

[mypy] reported by reviewdog 🐶 Cannot find implementation or library stub for module named "langchain_core.tools" [import-not-found] Raw Output: /home/runner/work/libertai-agents/libertai-agents/libertai_agents/libertai_agents/interfaces/tools.py:9:1: error: Cannot find implementation or library stub for module named "langchain_core.tools" [import-not-found]


class Tool(BaseModel):
name: str
function: Callable[..., Any]
args_schema: dict

@classmethod
def from_function(cls, function: Callable[..., Any]):
return cls(
name=function.__name__,
function=function,
args_schema=get_json_schema(function),
)

@classmethod
def from_langchain(cls, langchain_tool: "BaseTool"):
try:
from langchain_core.tools import StructuredTool
except ImportError:
raise RuntimeError(
"langchain_core is required for this functionality. Install with: libertai-agents[langchain]"
)

if isinstance(langchain_tool, StructuredTool):
# TODO: handle this case
raise NotImplementedError("Langchain StructuredTool aren't supported yet")

# Extracting function parameters to JSON schema
function_parameters = _convert_type_hints_to_json_schema(langchain_tool._run)
# Removing langchain-specific parameters
function_parameters["properties"].pop("run_manager", None)
function_parameters["properties"].pop("return", None)

return cls(
name=langchain_tool.name,
function=langchain_tool._run,
args_schema={
"type": "function",
"function": {
"name": langchain_tool.name,
"description": langchain_tool.description,
"parameters": function_parameters,
},
},
)
Loading

0 comments on commit 911d21f

Please sign in to comment.