From ec91da1d53d22e2b0e1780548b51ba92e01617df Mon Sep 17 00:00:00 2001 From: Reza Rahemtola Date: Tue, 27 Aug 2024 16:13:01 +0900 Subject: [PATCH] feat: Allow answer streaming and only final answer --- libertai_agents/agents.py | 81 ++++++++++++++++++++++++++++----------- main.py | 8 ++++ 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py index 5fa4436..bdc840f 100644 --- a/libertai_agents/agents.py +++ b/libertai_agents/agents.py @@ -1,10 +1,11 @@ import asyncio from http import HTTPStatus -from typing import Callable, Awaitable, Any +from typing import Callable, Awaitable, Any, AsyncIterable import aiohttp from aiohttp import ClientSession from fastapi import APIRouter, FastAPI +from starlette.responses import StreamingResponse from libertai_agents.interfaces.llamacpp import CustomizableLlamaCppParams, LlamaCppParams from libertai_agents.interfaces.messages import Message, MessageRoleEnum, MessageToolCall, ToolCallFunction, \ @@ -46,7 +47,8 @@ def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., A if expose_api: # Define API routes router = APIRouter() - router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"]) + router.add_api_route("/generate-answer", self.__api_generate_answer, methods=["POST"], + summary="Generate Answer") router.add_api_route("/model", self.get_model_information, methods=["GET"]) self.app = FastAPI(title="LibertAI ChatAgent") @@ -58,11 +60,33 @@ def get_model_information(self) -> ModelInformation: """ return ModelInformation(id=self.model.model_id, context_length=self.model.context_length) - async def generate_answer(self, messages: list[Message]) -> str: + async def __api_generate_answer(self, messages: list[Message], stream: bool = False, + only_final_answer: bool = True): + """ + Generate an answer based on an existing conversation. + The response messages can be streamed or sent in a single block + """ + if stream: + return StreamingResponse( + self.__dump_api_generate_streamed_answer(messages, only_final_answer=only_final_answer), + media_type='text/event-stream') + + response_messages: list[Message] = [] + async for message in self.generate_answer(messages, only_final_answer=only_final_answer): + response_messages.append(message) + return response_messages + + async def __dump_api_generate_streamed_answer(self, messages: list[Message], only_final_answer: bool) -> \ + AsyncIterable[str]: + async for message in self.generate_answer(messages, only_final_answer=only_final_answer): + yield message.model_dump_json() + + async def generate_answer(self, messages: list[Message], only_final_answer: bool = True) -> AsyncIterable[Message]: """ Generate an answer based on a conversation :param messages: List of messages previously sent in this conversation + :param only_final_answer: Only yields the final answer without include the thought process (tool calls and their response) :return: The string response of the agent """ if len(messages) == 0: @@ -70,26 +94,36 @@ async def generate_answer(self, messages: list[Message]) -> str: if messages[-1].role not in [MessageRoleEnum.user, MessageRoleEnum.tool]: raise ValueError("Last message is not from the user or a tool response") - prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools) - async with aiohttp.ClientSession() as session: - response = await self.__call_model(session, prompt) - - tool_calls = self.model.extract_tool_calls_from_response(response) - if len(tool_calls) == 0: - return response - - tool_calls_message = self.__create_tool_calls_message(tool_calls) - messages.append(tool_calls_message) - executed_calls = self.__execute_tool_calls(tool_calls_message.tool_calls) - results = await asyncio.gather(*executed_calls) - tool_results_messages: list[Message] = [ - ToolResponseMessage(role=MessageRoleEnum.tool, name=call.function.name, tool_call_id=call.id, - content=str(results[i])) for i, call in enumerate(tool_calls_message.tool_calls)] - - return await self.generate_answer(messages + tool_results_messages) - - async def __call_model(self, session: ClientSession, prompt: str): - # TODO: support streaming - detect tools calls to avoid sending them as response + while True: + prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools) + async with aiohttp.ClientSession() as session: + response = await self.__call_model(session, prompt) + + if response is None: + # TODO: handle error correctly + raise ValueError("Model didn't respond") + + tool_calls = self.model.extract_tool_calls_from_response(response) + if len(tool_calls) == 0: + yield Message(role=MessageRoleEnum.assistant, content=response) + return + + tool_calls_message = self.__create_tool_calls_message(tool_calls) + messages.append(tool_calls_message) + if not only_final_answer: + yield tool_calls_message + executed_calls = self.__execute_tool_calls(tool_calls_message.tool_calls) + results = await asyncio.gather(*executed_calls) + tool_results_messages: list[Message] = [ + ToolResponseMessage(role=MessageRoleEnum.tool, name=call.function.name, tool_call_id=call.id, + content=str(results[i])) for i, call in + enumerate(tool_calls_message.tool_calls)] + if not only_final_answer: + for tool_result_message in tool_results_messages: + yield tool_result_message + messages = messages + tool_results_messages + + async def __call_model(self, session: ClientSession, prompt: str) -> str | None: params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.model_dump()) async with session.post(self.model.vm_url, json=params.model_dump()) as response: @@ -97,6 +131,7 @@ async def __call_model(self, session: ClientSession, prompt: str): if response.status == HTTPStatus.OK: response_data = await response.json() return response_data["content"] + return None def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Awaitable[Any]]: executed_calls: list[Awaitable[Any]] = [] diff --git a/main.py b/main.py index e451bb1..869c9ac 100644 --- a/main.py +++ b/main.py @@ -20,3 +20,11 @@ async def get_current_temperature(location: str, unit: str) -> float: tools=[get_current_temperature]) app = agent.app + +# async def main(): +# async for message in agent.generate_answer( +# [Message(role=MessageRoleEnum.user, content="What is the temperature in Paris and in Lyon?")]): +# print(message) +# +# +# asyncio.run(main())