From b63ffc2ce0402369662297a43df1e3bef2898f06 Mon Sep 17 00:00:00 2001 From: Alexis VIALARET <alexis.vialaret@artefact.com> Date: Thu, 21 Dec 2023 17:23:17 +0100 Subject: [PATCH] feat: streaming through API --- backend/chatbot.py | 81 ++++++------------- backend/main.py | 45 ++++++++--- .../rag_components/chat_message_history.py | 51 +----------- backend/rag_components/llm.py | 11 ++- backend/rag_components/prompts.py | 28 +++++++ database/database_init.sql | 1 - frontend/lib/chat.py | 47 +++++++---- 7 files changed, 125 insertions(+), 139 deletions(-) create mode 100644 backend/rag_components/prompts.py diff --git a/backend/chatbot.py b/backend/chatbot.py index 28c3988..f5e7f91 100644 --- a/backend/chatbot.py +++ b/backend/chatbot.py @@ -1,6 +1,7 @@ +import asyncio from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.chat_models.base import HumanMessage, SystemMessage +from langchain.chat_models.base import SystemMessage from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -8,81 +9,45 @@ ) from langchain.vectorstores import VectorStore -from backend.config_renderer import get_config -from backend.rag_components.chat_message_history import get_conversation_buffer_memory -from backend.rag_components.embedding import get_embedding_model from backend.rag_components.llm import get_llm_model -from backend.rag_components.vector_store import get_vector_store +from backend.rag_components import prompts -def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str: - """Processes the given query through the answer chain and returns the formatted response.""" - {"content": answer_chain.run(query), } - return answer_chain.run(query) +async def get_response_stream(chain: ConversationalRetrievalChain, callback_handler, query: str) -> str: + run = asyncio.create_task(chain.arun({"question": query})) + async for token in callback_handler.aiter(): + yield token -def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain: - """Returns an instance of ConversationalRetrievalChain based on the provided parameters.""" - template = """Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns. + await run -Chat history : -{chat_history} -Question : {question} -Rephrased question : -""" - condense_question_prompt = PromptTemplate.from_template(template) - condense_question_chain = LLMChain( - llm=llm, - prompt=condense_question_prompt, - ) +def get_answer_chain(config, docsearch: VectorStore, memory) -> ConversationalRetrievalChain: + condense_question_prompt = PromptTemplate.from_template(prompts.condense_history) + condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt) messages = [ - SystemMessage( - content=( - """As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.""" - ) - ), - HumanMessage(content="Respond to the question taking into account the following context."), - HumanMessagePromptTemplate.from_template("{context}"), - HumanMessagePromptTemplate.from_template("Question: {question}"), + SystemMessage(content=prompts.rag_system_prompt), + HumanMessagePromptTemplate.from_template(prompts.respond_to_question), ] - system_prompt = ChatPromptTemplate(messages=messages) - qa_chain = LLMChain( - llm=llm, - prompt=system_prompt, - ) - - doc_prompt = PromptTemplate( - template="Content: {page_content}\nSource: {source}", - input_variables=["page_content", "source"], - ) + question_answering_prompt = ChatPromptTemplate(messages=messages) + streaming_llm, callback_handler = get_llm_model(config, streaming=True) + question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt) + context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"]) + final_qa_chain = StuffDocumentsChain( - llm_chain=qa_chain, + llm_chain=question_answering_chain, document_variable_name="context", - document_prompt=doc_prompt, + document_prompt=context_with_docs_prompt, ) - return ConversationalRetrievalChain( + chain = ConversationalRetrievalChain( question_generator=condense_question_chain, - retriever=docsearch.as_retriever(search_kwargs={"k": 10}), + retriever=docsearch.as_retriever(search_kwargs={"k": config["vector_store_provider"]["documents_to_retreive"]}), memory=memory, combine_docs_chain=final_qa_chain, - verbose=True, ) + return chain, callback_handler -if __name__ == "__main__": - chat_id = "test" - config = get_config() - llm = get_llm_model(config) - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings, config) - memory = get_conversation_buffer_memory(config, chat_id) - answer_chain = get_answer_chain(llm, vector_store, memory) - - prompt = "Give me the top 5 bilionnaires in france based on their worth in order of decreasing net worth" - response = get_response(answer_chain, prompt) - print("Prompt: ", prompt) - print("Response: ", response) diff --git a/backend/main.py b/backend/main.py index cef545a..209c7e0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,9 +3,11 @@ from uuid import uuid4 from fastapi import Depends, FastAPI, HTTPException, status +from fastapi.responses import StreamingResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt + import backend.document_store as document_store from backend.document_store import StorageBackend from backend.model import Doc, Message @@ -20,6 +22,11 @@ user_exists, ) from database.database import Database +from backend.config_renderer import get_config +from backend.rag_components.chat_message_history import get_conversation_buffer_memory +from backend.rag_components.embedding import get_embedding_model +from backend.rag_components.vector_store import get_vector_store +from backend.chatbot import get_answer_chain, get_response_stream app = FastAPI() @@ -121,22 +128,18 @@ async def chat_new(current_user: User = Depends(get_current_user)) -> dict: return {"chat_id": chat_id} -@app.post("/chat/{chat_id}/user_message") -async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict: - with Database() as connection: - connection.query( - "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", - (message.id, message.timestamp, message.chat_id, message.sender, message.content), - ) +async def streamed_llm_response(chat_id, answer_chain): + full_response = "" + async for data in answer_chain: + full_response += data + yield data.encode("utf-8") - #TODO : faire la réposne du llm - model_response = Message( id=str(uuid4()), timestamp=datetime.now().isoformat(), - chat_id=message.chat_id, + chat_id=chat_id, sender="assistant", - content=f"Unique response: {uuid4()}", + content=full_response, ) with Database() as connection: @@ -150,7 +153,25 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current model_response.content, ), ) - return {"message": model_response} + + +@app.post("/chat/{chat_id}/user_message") +async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict: + with Database() as connection: + connection.query( + "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", + (message.id, message.timestamp, message.chat_id, message.sender, message.content), + ) + + config = get_config() + embeddings = get_embedding_model(config) + vector_store = get_vector_store(embeddings, config) + memory = get_conversation_buffer_memory(config, message.chat_id) + answer_chain, callback_handler = get_answer_chain(config, vector_store, memory) + + response_stream = get_response_stream(answer_chain, callback_handler, message.content) + + return StreamingResponse(streamed_llm_response(message.chat_id, response_stream), media_type="text/event-stream") @app.post("/chat/regenerate") diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index e6e6774..8db09a2 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -1,18 +1,7 @@ -import json import os -from datetime import datetime -from typing import Any -from uuid import uuid4 from langchain.memory import ConversationBufferWindowMemory from langchain.memory.chat_message_histories import SQLChatMessageHistory -from langchain.memory.chat_message_histories.sql import DefaultMessageConverter -from langchain.schema import BaseMessage -from langchain.schema.messages import BaseMessage, _message_to_dict -from sqlalchemy import Column, DateTime, Text -from sqlalchemy.orm import declarative_base - -TABLE_NAME = "message" def get_conversation_buffer_memory(config, chat_id): @@ -23,45 +12,9 @@ def get_conversation_buffer_memory(config, chat_id): k=config["chat_message_history_config"]["window_size"], ) - def get_chat_message_history(chat_id): return SQLChatMessageHistory( session_id=chat_id, connection_string=os.environ.get("DATABASE_CONNECTION_STRING"), - table_name=TABLE_NAME, - session_id_field_name="chat_id", - custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME), - ) - - -Base = declarative_base() - - -class CustomMessage(Base): - __tablename__ = TABLE_NAME - - id = Column( - Text, primary_key=True, default=lambda: str(uuid4()) - ) # default=lambda: str(uuid4()) - timestamp = Column(DateTime) - chat_id = Column(Text) - sender = Column(Text) - content = Column(Text) - message = Column(Text) - - -class CustomMessageConverter(DefaultMessageConverter): - def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: - print(message.content) - sub_message = json.loads(message.content) - return CustomMessage( - id=sub_message["id"], - timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"), - chat_id=session_id, - sender=message.type, - content=sub_message["content"], - message=json.dumps(_message_to_dict(message)), - ) - - def get_sql_model_class(self) -> Any: - return CustomMessage + table_name="message_history", + ) \ No newline at end of file diff --git a/backend/rag_components/llm.py b/backend/rag_components/llm.py index a4a5690..370ac59 100644 --- a/backend/rag_components/llm.py +++ b/backend/rag_components/llm.py @@ -1,10 +1,17 @@ from langchain import chat_models +from langchain.callbacks import AsyncIteratorCallbackHandler -def get_llm_model(config): +def get_llm_model(config, streaming=False): llm_spec = getattr(chat_models, config["llm_model_config"]["model_source"]) all_config_field = {**config["llm_model_config"], **config["llm_provider_config"]} kwargs = { key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys() } - return llm_spec(**kwargs) + if streaming: + kwargs["streaming"] = streaming + callback_handler = AsyncIteratorCallbackHandler() + kwargs["callbacks"] = [callback_handler] + return llm_spec(**kwargs), callback_handler + else: + return llm_spec(**kwargs) \ No newline at end of file diff --git a/backend/rag_components/prompts.py b/backend/rag_components/prompts.py new file mode 100644 index 0000000..98af668 --- /dev/null +++ b/backend/rag_components/prompts.py @@ -0,0 +1,28 @@ +condense_history = """ +Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns. + +Chat history : +{chat_history} +Question : {question} + +Rephrased question : +""" + +rag_system_prompt = """ +As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. +It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. +""" + +respond_to_question = """ +Respond to the question taking into account the following context. + +{context} + +Question: {question} +""" + +document_context = """ +Content: {page_content} + +Source: {source} +""" \ No newline at end of file diff --git a/database/database_init.sql b/database/database_init.sql index 8a96b21..eb9f971 100644 --- a/database/database_init.sql +++ b/database/database_init.sql @@ -23,7 +23,6 @@ CREATE TABLE IF NOT EXISTS "message" ( "chat_id" TEXT, "sender" TEXT, "content" TEXT, - "message" TEXT, FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") ); diff --git a/frontend/lib/chat.py b/frontend/lib/chat.py index 3d3d1d6..486151c 100644 --- a/frontend/lib/chat.py +++ b/frontend/lib/chat.py @@ -22,20 +22,35 @@ def __post_init__(self): def chat(): prompt = st.chat_input("Say something") - if prompt: - if len(st.session_state.get("messages", [])) == 0: - chat_id = new_chat() - else: - chat_id = st.session_state.get("chat_id") - - st.session_state.get("messages").append(Message("user", prompt, chat_id)) - response = send_prompt(st.session_state.get("messages")[-1]) - st.session_state.get("messages").append(Message(**response)) - with st.container(border=True): for message in st.session_state.get("messages", []): with st.chat_message(message.sender): st.write(message.content) + + if prompt: + if len(st.session_state.get("messages", [])) == 0: + chat_id = new_chat() + else: + chat_id = st.session_state.get("chat_id") + + with st.chat_message("user"): + st.write(prompt) + + user_message = Message("user", prompt, chat_id) + st.session_state["messages"].append(user_message) + + response = send_prompt(user_message) + with st.chat_message("assistant"): + placeholder = st.empty() + full_response = '' + for item in response: + full_response += item + placeholder.write(full_response) + placeholder.write(full_response) + + bot_message = Message("assistant", full_response, chat_id) + st.session_state["messages"].append(bot_message) + if ( len(st.session_state.get("messages", [])) > 0 and len(st.session_state.get("messages")) % 2 == 0 @@ -43,9 +58,7 @@ def chat(): streamlit_feedback( key=str(len(st.session_state.get("messages"))), feedback_type="thumbs", - on_submit=lambda feedback: send_feedback( - st.session_state.get("messages")[-1].id, feedback - ), + on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback), ) @@ -59,10 +72,10 @@ def new_chat(): def send_prompt(message: Message): session = st.session_state.get("session") - response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message)) - print(response.headers) - print(response.text) - return response.json()["message"] + response = session.post(f"/chat/{message.chat_id}/user_message", stream=True, json=asdict(message)) + + for line in response.iter_content(chunk_size=16, decode_unicode=True): + yield line def send_feedback(message_id: str, feedback: str):