Skip to content

Commit

Permalink
feat: streaming through API
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 21, 2023
1 parent 960d377 commit b63ffc2
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 139 deletions.
81 changes: 23 additions & 58 deletions backend/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,53 @@
import asyncio
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chat_models.base import HumanMessage, SystemMessage
from langchain.chat_models.base import SystemMessage
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain.vectorstores import VectorStore

from backend.config_renderer import get_config
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store
from backend.rag_components import prompts


def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
"""Processes the given query through the answer chain and returns the formatted response."""
{"content": answer_chain.run(query), }
return answer_chain.run(query)
async def get_response_stream(chain: ConversationalRetrievalChain, callback_handler, query: str) -> str:
run = asyncio.create_task(chain.arun({"question": query}))

async for token in callback_handler.aiter():
yield token

def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain:
"""Returns an instance of ConversationalRetrievalChain based on the provided parameters."""
template = """Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.
await run

Chat history :
{chat_history}
Question : {question}

Rephrased question :
"""
condense_question_prompt = PromptTemplate.from_template(template)
condense_question_chain = LLMChain(
llm=llm,
prompt=condense_question_prompt,
)
def get_answer_chain(config, docsearch: VectorStore, memory) -> ConversationalRetrievalChain:
condense_question_prompt = PromptTemplate.from_template(prompts.condense_history)
condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt)

messages = [
SystemMessage(
content=(
"""As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail."""
)
),
HumanMessage(content="Respond to the question taking into account the following context."),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("Question: {question}"),
SystemMessage(content=prompts.rag_system_prompt),
HumanMessagePromptTemplate.from_template(prompts.respond_to_question),
]
system_prompt = ChatPromptTemplate(messages=messages)
qa_chain = LLMChain(
llm=llm,
prompt=system_prompt,
)

doc_prompt = PromptTemplate(
template="Content: {page_content}\nSource: {source}",
input_variables=["page_content", "source"],
)
question_answering_prompt = ChatPromptTemplate(messages=messages)
streaming_llm, callback_handler = get_llm_model(config, streaming=True)
question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt)

context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"])

final_qa_chain = StuffDocumentsChain(
llm_chain=qa_chain,
llm_chain=question_answering_chain,
document_variable_name="context",
document_prompt=doc_prompt,
document_prompt=context_with_docs_prompt,
)

return ConversationalRetrievalChain(
chain = ConversationalRetrievalChain(
question_generator=condense_question_chain,
retriever=docsearch.as_retriever(search_kwargs={"k": 10}),
retriever=docsearch.as_retriever(search_kwargs={"k": config["vector_store_provider"]["documents_to_retreive"]}),
memory=memory,
combine_docs_chain=final_qa_chain,
verbose=True,
)

return chain, callback_handler

if __name__ == "__main__":
chat_id = "test"
config = get_config()
llm = get_llm_model(config)
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings, config)
memory = get_conversation_buffer_memory(config, chat_id)
answer_chain = get_answer_chain(llm, vector_store, memory)

prompt = "Give me the top 5 bilionnaires in france based on their worth in order of decreasing net worth"
response = get_response(answer_chain, prompt)
print("Prompt: ", prompt)
print("Response: ", response)
45 changes: 33 additions & 12 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from uuid import uuid4

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt


import backend.document_store as document_store
from backend.document_store import StorageBackend
from backend.model import Doc, Message
Expand All @@ -20,6 +22,11 @@
user_exists,
)
from database.database import Database
from backend.config_renderer import get_config
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.vector_store import get_vector_store
from backend.chatbot import get_answer_chain, get_response_stream

app = FastAPI()

Expand Down Expand Up @@ -121,22 +128,18 @@ async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
return {"chat_id": chat_id}


@app.post("/chat/{chat_id}/user_message")
async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict:
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)
async def streamed_llm_response(chat_id, answer_chain):
full_response = ""
async for data in answer_chain:
full_response += data
yield data.encode("utf-8")

#TODO : faire la réposne du llm

model_response = Message(
id=str(uuid4()),
timestamp=datetime.now().isoformat(),
chat_id=message.chat_id,
chat_id=chat_id,
sender="assistant",
content=f"Unique response: {uuid4()}",
content=full_response,
)

with Database() as connection:
Expand All @@ -150,7 +153,25 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
model_response.content,
),
)
return {"message": model_response}


@app.post("/chat/{chat_id}/user_message")
async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict:
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)

config = get_config()
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings, config)
memory = get_conversation_buffer_memory(config, message.chat_id)
answer_chain, callback_handler = get_answer_chain(config, vector_store, memory)

response_stream = get_response_stream(answer_chain, callback_handler, message.content)

return StreamingResponse(streamed_llm_response(message.chat_id, response_stream), media_type="text/event-stream")


@app.post("/chat/regenerate")
Expand Down
51 changes: 2 additions & 49 deletions backend/rag_components/chat_message_history.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
import json
import os
from datetime import datetime
from typing import Any
from uuid import uuid4

from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessage, _message_to_dict
from sqlalchemy import Column, DateTime, Text
from sqlalchemy.orm import declarative_base

TABLE_NAME = "message"


def get_conversation_buffer_memory(config, chat_id):
Expand All @@ -23,45 +12,9 @@ def get_conversation_buffer_memory(config, chat_id):
k=config["chat_message_history_config"]["window_size"],
)


def get_chat_message_history(chat_id):
return SQLChatMessageHistory(
session_id=chat_id,
connection_string=os.environ.get("DATABASE_CONNECTION_STRING"),
table_name=TABLE_NAME,
session_id_field_name="chat_id",
custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME),
)


Base = declarative_base()


class CustomMessage(Base):
__tablename__ = TABLE_NAME

id = Column(
Text, primary_key=True, default=lambda: str(uuid4())
) # default=lambda: str(uuid4())
timestamp = Column(DateTime)
chat_id = Column(Text)
sender = Column(Text)
content = Column(Text)
message = Column(Text)


class CustomMessageConverter(DefaultMessageConverter):
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
print(message.content)
sub_message = json.loads(message.content)
return CustomMessage(
id=sub_message["id"],
timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"),
chat_id=session_id,
sender=message.type,
content=sub_message["content"],
message=json.dumps(_message_to_dict(message)),
)

def get_sql_model_class(self) -> Any:
return CustomMessage
table_name="message_history",
)
11 changes: 9 additions & 2 deletions backend/rag_components/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from langchain import chat_models
from langchain.callbacks import AsyncIteratorCallbackHandler


def get_llm_model(config):
def get_llm_model(config, streaming=False):
llm_spec = getattr(chat_models, config["llm_model_config"]["model_source"])
all_config_field = {**config["llm_model_config"], **config["llm_provider_config"]}
kwargs = {
key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys()
}
return llm_spec(**kwargs)
if streaming:
kwargs["streaming"] = streaming
callback_handler = AsyncIteratorCallbackHandler()
kwargs["callbacks"] = [callback_handler]
return llm_spec(**kwargs), callback_handler
else:
return llm_spec(**kwargs)
28 changes: 28 additions & 0 deletions backend/rag_components/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
condense_history = """
Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.
Chat history :
{chat_history}
Question : {question}
Rephrased question :
"""

rag_system_prompt = """
As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input.
It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.
"""

respond_to_question = """
Respond to the question taking into account the following context.
{context}
Question: {question}
"""

document_context = """
Content: {page_content}
Source: {source}
"""
1 change: 0 additions & 1 deletion database/database_init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ CREATE TABLE IF NOT EXISTS "message" (
"chat_id" TEXT,
"sender" TEXT,
"content" TEXT,
"message" TEXT,
FOREIGN KEY ("chat_id") REFERENCES "chat" ("id")
);

Expand Down
47 changes: 30 additions & 17 deletions frontend/lib/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,43 @@ def __post_init__(self):
def chat():
prompt = st.chat_input("Say something")

if prompt:
if len(st.session_state.get("messages", [])) == 0:
chat_id = new_chat()
else:
chat_id = st.session_state.get("chat_id")

st.session_state.get("messages").append(Message("user", prompt, chat_id))
response = send_prompt(st.session_state.get("messages")[-1])
st.session_state.get("messages").append(Message(**response))

with st.container(border=True):
for message in st.session_state.get("messages", []):
with st.chat_message(message.sender):
st.write(message.content)

if prompt:
if len(st.session_state.get("messages", [])) == 0:
chat_id = new_chat()
else:
chat_id = st.session_state.get("chat_id")

with st.chat_message("user"):
st.write(prompt)

user_message = Message("user", prompt, chat_id)
st.session_state["messages"].append(user_message)

response = send_prompt(user_message)
with st.chat_message("assistant"):
placeholder = st.empty()
full_response = ''
for item in response:
full_response += item
placeholder.write(full_response)
placeholder.write(full_response)

bot_message = Message("assistant", full_response, chat_id)
st.session_state["messages"].append(bot_message)

if (
len(st.session_state.get("messages", [])) > 0
and len(st.session_state.get("messages")) % 2 == 0
):
streamlit_feedback(
key=str(len(st.session_state.get("messages"))),
feedback_type="thumbs",
on_submit=lambda feedback: send_feedback(
st.session_state.get("messages")[-1].id, feedback
),
on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback),
)


Expand All @@ -59,10 +72,10 @@ def new_chat():

def send_prompt(message: Message):
session = st.session_state.get("session")
response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message))
print(response.headers)
print(response.text)
return response.json()["message"]
response = session.post(f"/chat/{message.chat_id}/user_message", stream=True, json=asdict(message))

for line in response.iter_content(chunk_size=16, decode_unicode=True):
yield line


def send_feedback(message_id: str, feedback: str):
Expand Down

0 comments on commit b63ffc2

Please sign in to comment.