Skip to content

Commit

Permalink
feat: Allow answer streaming and only final answer
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 27, 2024
1 parent 5f9d2c8 commit ec91da1
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
81 changes: 58 additions & 23 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
from http import HTTPStatus
from typing import Callable, Awaitable, Any
from typing import Callable, Awaitable, Any, AsyncIterable

import aiohttp
from aiohttp import ClientSession
from fastapi import APIRouter, FastAPI
from starlette.responses import StreamingResponse

from libertai_agents.interfaces.llamacpp import CustomizableLlamaCppParams, LlamaCppParams
from libertai_agents.interfaces.messages import Message, MessageRoleEnum, MessageToolCall, ToolCallFunction, \
Expand Down Expand Up @@ -46,7 +47,8 @@ def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., A
if expose_api:
# Define API routes
router = APIRouter()
router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"])
router.add_api_route("/generate-answer", self.__api_generate_answer, methods=["POST"],
summary="Generate Answer")
router.add_api_route("/model", self.get_model_information, methods=["GET"])

self.app = FastAPI(title="LibertAI ChatAgent")
Expand All @@ -58,45 +60,78 @@ def get_model_information(self) -> ModelInformation:
"""
return ModelInformation(id=self.model.model_id, context_length=self.model.context_length)

async def generate_answer(self, messages: list[Message]) -> str:
async def __api_generate_answer(self, messages: list[Message], stream: bool = False,
only_final_answer: bool = True):
"""
Generate an answer based on an existing conversation.
The response messages can be streamed or sent in a single block
"""
if stream:
return StreamingResponse(
self.__dump_api_generate_streamed_answer(messages, only_final_answer=only_final_answer),
media_type='text/event-stream')

response_messages: list[Message] = []
async for message in self.generate_answer(messages, only_final_answer=only_final_answer):
response_messages.append(message)
return response_messages

async def __dump_api_generate_streamed_answer(self, messages: list[Message], only_final_answer: bool) -> \
AsyncIterable[str]:
async for message in self.generate_answer(messages, only_final_answer=only_final_answer):
yield message.model_dump_json()

async def generate_answer(self, messages: list[Message], only_final_answer: bool = True) -> AsyncIterable[Message]:
"""
Generate an answer based on a conversation
:param messages: List of messages previously sent in this conversation
:param only_final_answer: Only yields the final answer without include the thought process (tool calls and their response)
:return: The string response of the agent
"""
if len(messages) == 0:
raise ValueError("No previous message to respond to")
if messages[-1].role not in [MessageRoleEnum.user, MessageRoleEnum.tool]:
raise ValueError("Last message is not from the user or a tool response")

prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools)
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)

tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
return response

tool_calls_message = self.__create_tool_calls_message(tool_calls)
messages.append(tool_calls_message)
executed_calls = self.__execute_tool_calls(tool_calls_message.tool_calls)
results = await asyncio.gather(*executed_calls)
tool_results_messages: list[Message] = [
ToolResponseMessage(role=MessageRoleEnum.tool, name=call.function.name, tool_call_id=call.id,
content=str(results[i])) for i, call in enumerate(tool_calls_message.tool_calls)]

return await self.generate_answer(messages + tool_results_messages)

async def __call_model(self, session: ClientSession, prompt: str):
# TODO: support streaming - detect tools calls to avoid sending them as response
while True:
prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools)
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)

if response is None:
# TODO: handle error correctly
raise ValueError("Model didn't respond")

tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
yield Message(role=MessageRoleEnum.assistant, content=response)
return

tool_calls_message = self.__create_tool_calls_message(tool_calls)
messages.append(tool_calls_message)
if not only_final_answer:
yield tool_calls_message
executed_calls = self.__execute_tool_calls(tool_calls_message.tool_calls)
results = await asyncio.gather(*executed_calls)
tool_results_messages: list[Message] = [
ToolResponseMessage(role=MessageRoleEnum.tool, name=call.function.name, tool_call_id=call.id,
content=str(results[i])) for i, call in
enumerate(tool_calls_message.tool_calls)]
if not only_final_answer:
for tool_result_message in tool_results_messages:
yield tool_result_message
messages = messages + tool_results_messages

async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.model_dump())

async with session.post(self.model.vm_url, json=params.model_dump()) as response:
# TODO: handle errors and retries
if response.status == HTTPStatus.OK:
response_data = await response.json()
return response_data["content"]
return None

def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Awaitable[Any]]:
executed_calls: list[Awaitable[Any]] = []
Expand Down
8 changes: 8 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ async def get_current_temperature(location: str, unit: str) -> float:
tools=[get_current_temperature])

app = agent.app

# async def main():
# async for message in agent.generate_answer(
# [Message(role=MessageRoleEnum.user, content="What is the temperature in Paris and in Lyon?")]):
# print(message)
#
#
# asyncio.run(main())

0 comments on commit ec91da1

Please sign in to comment.