Skip to content

Commit

Permalink
upd: basic model picking
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 19, 2023
1 parent df1ca03 commit d042167
Show file tree
Hide file tree
Showing 27 changed files with 2,959 additions and 347 deletions.
File renamed without changes.
File renamed without changes.
5 changes: 2 additions & 3 deletions lib/backend.py → backend/backend_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import streamlit as st
from langchain.chat_models import AzureChatOpenAI
from langchain.document_loaders import (
BaseLoader,
CSVLoader,
Docx2txtLoader,
PyPDFLoader,
Expand Down Expand Up @@ -31,7 +30,7 @@ def get_llm(
"""Returns an instance of AzureChatOpenAI based on the provided parameters."""
if model_version == "4":
llm = AzureChatOpenAI(
deployment_name="gpt-4",
deployment_name="gpt4v",
temperature=temperature,
openai_api_version="2023-07-01-preview",
streaming=live_streaming,
Expand Down Expand Up @@ -59,7 +58,7 @@ def get_embeddings_model(embedding_api_base: str, embedding_api_key: str) -> Ope
)


def load_documents(file_extension: str, file_path: str) -> BaseLoader:
def load_documents(file_extension: str, file_path: str):
"""Loads documents based on the file extension and path provided."""
if file_extension == ".pdf":
loader = PyPDFLoader(file_path)
Expand Down
2,641 changes: 2,641 additions & 0 deletions backend/billionaires_csv.csv

Large diffs are not rendered by default.

180 changes: 82 additions & 98 deletions interface/lib/chatbot.py → backend/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import pandas as pd
import streamlit as st
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chat_models import AzureChatOpenAI
Expand All @@ -19,7 +19,7 @@
ConversationSummaryBufferMemory,
ConversationSummaryMemory,
)
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.memory.chat_message_histories import ChatMessageHistory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
Expand All @@ -33,34 +33,73 @@
)
from langchain.vectorstores import Chroma

from interface.lib.logs import StreamHandler


@st.cache_resource
def get_llm(
temperature: float, model_version: str, live_streaming: bool = False
) -> AzureChatOpenAI:
"""Returns an instance of AzureChatOpenAI based on the provided parameters."""
if model_version == "4":
llm = AzureChatOpenAI(
deployment_name="gpt-4",
temperature=temperature,
openai_api_version="2023-07-01-preview",
streaming=live_streaming,
verbose=live_streaming,
)
elif model_version == "3.5":
llm = AzureChatOpenAI(
deployment_name="gpt-35-turbo",
temperature=temperature,
openai_api_version="2023-03-15-preview",
streaming=live_streaming,
verbose=live_streaming,
)
return llm
from backend.llm import get_model_instance

def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
"""Processes the given query through the answer chain and returns the formatted response."""
return answer_chain.run(query)

def get_answer_chain(
llm, docsearch: Chroma, memory: ConversationBufferMemory
) -> ConversationalRetrievalChain:
"""Returns an instance of ConversationalRetrievalChain based on the provided parameters."""
template = """Étant donné l'historique de conversation et la question suivante, \
pouvez-vous reformuler dans sa langue d'origine la question de l'utilisateur \
pour qu'elle soit auto porteuse. Assurez-vous d'éviter l'utilisation de pronoms peu clairs.
Historique de chat :
{chat_history}
Question complémentaire : {question}
Question reformulée :
"""
condense_question_prompt = PromptTemplate.from_template(template)
condense_question_chain = LLMChain(
llm=llm,
prompt=condense_question_prompt,
)

messages = [
SystemMessage(
content=(
"""En tant qu'assistant chatbot, votre mission est de répondre de manière \
précise et concise aux interrogations des utilisateurs à partir des documents donnés en input.
Il est essentiel de répondre dans la même langue que celle utilisée pour poser la question.
Les réponses doivent être rédigées dans un style professionnel et doivent faire preuve \
d'une grande attention aux détails.
"""
)
),
HumanMessage(content="Répondez à la question en prenant en compte le contexte suivant"),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("Question: {question}"),
]
system_prompt = ChatPromptTemplate(messages=messages)
qa_chain = LLMChain(
llm=llm,
prompt=system_prompt,
)

doc_prompt = PromptTemplate(
template="Content: {page_content}\nSource: {source}",
input_variables=["page_content", "source"],
)

final_qa_chain = StuffDocumentsChain(
llm_chain=qa_chain,
document_variable_name="context",
document_prompt=doc_prompt,
)

return ConversationalRetrievalChain(
question_generator=condense_question_chain,
retriever=docsearch.as_retriever(),
memory=memory,
combine_docs_chain=final_qa_chain,
verbose=False,
)


@st.cache_resource
def get_embeddings_model(embedding_api_base: str, embedding_api_key: str) -> OpenAIEmbeddings:
"""Returns an instance of OpenAIEmbeddings based on the provided parameters."""
return OpenAIEmbeddings(
Expand All @@ -84,7 +123,6 @@ def get_documents(data: pd.DataFrame) -> List[Document]:
return documents


@st.cache_data
def load_documents(file_extension: str, file_path: str):
"""Loads documents based on the file extension and path provided."""
if file_extension == ".pdf":
Expand All @@ -98,12 +136,11 @@ def load_documents(file_extension: str, file_path: str):
elif file_extension in [".docx"]:
loader = Docx2txtLoader(file_path)
else:
st.error("Unsupported file type!")
raise Exception("Unsupported file type!")

return loader.load()


@st.cache_data
def get_chunks(
_documents: List[str], chunk_size: int, chunk_overlap: int, text_splitter_type: int
) -> List[str]:
Expand All @@ -121,17 +158,15 @@ def get_chunks(
return text_splitter.split_documents(_documents)


@st.cache_resource
def get_vector_store(_texts: List[str], _embeddings: OpenAIEmbeddings) -> Chroma:
"""Returns an instance of Chroma based on the provided parameters."""
return Chroma.from_documents(_texts, _embeddings)


@st.cache_data
def choose_memory_type(
memory_type: str, _llm: Optional[AzureChatOpenAI] = None
) -> Tuple[
StreamlitChatMessageHistory,
ChatMessageHistory,
Union[
ConversationBufferMemory,
ConversationBufferWindowMemory,
Expand All @@ -140,7 +175,7 @@ def choose_memory_type(
],
]:
"""Chooses the memory type for the conversation based on the provided memory_type string."""
msgs = StreamlitChatMessageHistory(key="special_app_key")
msgs = ChatMessageHistory(key="special_app_key")
if memory_type == "buffer":
memory = ConversationBufferMemory(
memory_key="chat_history", chat_memory=msgs, return_messages=True
Expand All @@ -163,69 +198,18 @@ def choose_memory_type(
)
return msgs, memory

if __name__ == "__main__":
llm = get_model_instance()

def get_answer_chain(
llm: AzureChatOpenAI, docsearch: Chroma, memory: ConversationBufferMemory
) -> ConversationalRetrievalChain:
"""Returns an instance of ConversationalRetrievalChain based on the provided parameters."""
template = """Étant donné l'historique de conversation et la question suivante, \
pouvez-vous reformuler dans sa langue d'origine la question de l'utilisateur \
pour qu'elle soit auto porteuse. Assurez-vous d'éviter l'utilisation de pronoms peu clairs.
Historique de chat :
{chat_history}
Question complémentaire : {question}
Question reformulée :
"""
condense_question_prompt = PromptTemplate.from_template(template)
condense_question_chain = LLMChain(
llm=llm,
prompt=condense_question_prompt,
)

messages = [
SystemMessage(
content=(
"""En tant qu'assistant chatbot, votre mission est de répondre de manière \
précise et concise aux interrogations des utilisateurs à partir des documents donnés en input.
Il est essentiel de répondre dans la même langue que celle utilisée pour poser la question.
Les réponses doivent être rédigées dans un style professionnel et doivent faire preuve \
d'une grande attention aux détails.
"""
)
),
HumanMessage(content="Répondez à la question en prenant en compte le contexte suivant"),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("Question: {question}"),
]
system_prompt = ChatPromptTemplate(messages=messages)
qa_chain = LLMChain(
llm=llm,
prompt=system_prompt,
)

doc_prompt = PromptTemplate(
template="Content: {page_content}\nSource: {source}",
input_variables=["page_content", "source"],
)

final_qa_chain = StuffDocumentsChain(
llm_chain=qa_chain,
document_variable_name="context",
document_prompt=doc_prompt,
)
embeddings = get_embeddings_model("https://poc-openai-artefact.openai.azure.com/", "")

return ConversationalRetrievalChain(
question_generator=condense_question_chain,
retriever=docsearch.as_retriever(),
memory=memory,
combine_docs_chain=final_qa_chain,
verbose=False,
)
documents = load_documents(".csv", str(Path(__file__).parent / "billionaires_csv.csv"))
texts = get_chunks(documents, chunk_size=1500, chunk_overlap=200, text_splitter_type="recursive")
docsearch = get_vector_store(texts, embeddings)
msgs, memory = choose_memory_type(memory_type="buffer")
answer_chain = get_answer_chain(llm, docsearch, memory)


def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
"""Processes the given query through the answer chain and returns the formatted response."""
stream_handler = StreamHandler(st.empty())
return answer_chain.run(query, callbacks=[stream_handler])
prompt = "Qui sont les 3 personnes les plus riches en france ?"
response = get_response(answer_chain, prompt)
print("Prompt :", prompt)
print("Response: ", response)
File renamed without changes.
17 changes: 17 additions & 0 deletions backend/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from langchain import chat_models
from pathlib import Path
import yaml

def get_model_instance():
config = load_models_config()
llm_spec = getattr(chat_models, config["llm_model_config"]["model_source"])
all_config_field = {**config["llm_model_config"], **config["llm_provider_config"]}
kwargs = {key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys()}
return llm_spec(**kwargs)


def load_models_config():
with open(Path(__file__).parent / "models_config.yaml", "r") as file:
return yaml.safe_load(file)


17 changes: 9 additions & 8 deletions lib/main.py → backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt

import lib.document_store as document_store
import backend.document_store as document_store
from database.database import Database
from lib.document_store import StorageBackend
from lib.model import Doc
from lib.user_management import (
from backend.document_store import StorageBackend
from backend.model import Doc
from backend.user_management import (
ALGORITHM,
SECRET_KEY,
User,
Expand Down Expand Up @@ -116,8 +116,11 @@ async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
# P1
@app.post("/chat/user_message")
async def chat_prompt(current_user: User = Depends(get_current_user)) -> dict:
"""Send a message in a chat session."""
pass
# TODO: Log message to db
# TODO: Get response from model
# TODO: Log response to db
# TODO: Return response
return {"message": f"Unique response: {uuid4()}"}


@app.post("/chat/regenerate")
Expand Down Expand Up @@ -147,7 +150,6 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) ->
async def feedback_thumbs_up(
message_id: str, current_user: User = Depends(get_current_user)
) -> None:
"""Record a 'thumbs up' feedback for a message."""
with Database() as connection:
connection.query(
"INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)",
Expand All @@ -159,7 +161,6 @@ async def feedback_thumbs_up(
async def feedback_thumbs_down(
message_id: str, current_user: User = Depends(get_current_user)
) -> None:
"""Record a 'thumbs down' feedback for a message."""
with Database() as connection:
connection.query(
"INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)",
Expand Down
File renamed without changes.
21 changes: 21 additions & 0 deletions backend/models_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
llm_provider_config:
openai_api_type: azure
openai_api_base: https://poc-genai-gpt4.openai.azure.com/
openai_api_version: 2023-07-01-preview
openai_api_key:

llm_model_config:
model_source: AzureChatOpenAI
deployment_name: gpt4v
temperature: 0.1
streaming: true
verbose: true

embedding_provider_config:
openai_api_base: "https://poc-openai-artefact.openai.azure.com/"
openai_api_key:

embedding_model_config:
model_source: OpenAIEmbeddings
deployment: text-embedding-ada-002
chunk_size: 16
File renamed without changes.
Empty file removed client/main.py
Empty file.
Binary file modified database/database.sqlite
Binary file not shown.
Loading

0 comments on commit d042167

Please sign in to comment.