From 0ed297a82c179e131d27eab32164e7deed729024 Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Thu, 21 Dec 2023 10:11:20 +0100 Subject: [PATCH 1/5] update comments --- README.md | 2 ++ backend/chatbot.py | 4 +++- backend/document_loader.py | 5 ++++- backend/main.py | 4 ++-- notebooks/memory_tests.ipynb | 35 ++++++++++++----------------------- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 49d1697..ac80bd0 100644 --- a/README.md +++ b/README.md @@ -14,5 +14,7 @@ python "/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@ar - comment lancer l'API - gestion de la config - écrire des helpers de co, pour envoyer des messages... +- tester différents modèles +- écrire des snippets de code pour éxpliquer comment charger les docs dans le RAG diff --git a/backend/chatbot.py b/backend/chatbot.py index 28c3988..ca1d6c0 100644 --- a/backend/chatbot.py +++ b/backend/chatbot.py @@ -17,7 +17,9 @@ def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str: """Processes the given query through the answer chain and returns the formatted response.""" - {"content": answer_chain.run(query), } + { + "content": answer_chain.run(query), + } return answer_chain.run(query) diff --git a/backend/document_loader.py b/backend/document_loader.py index a960a7f..f3dc061 100644 --- a/backend/document_loader.py +++ b/backend/document_loader.py @@ -9,11 +9,14 @@ from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import filter_complex_metadata +# TODO rajhouter fonction level avec opssibilité de rajouter l'objet document de langchain +# TODO rename load_dopcument en load_embeddded document + def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore): documents = get_documents(file_path, llm) filtered_documents = filter_complex_metadata(documents) - vector_store.add_documents(documents) + vector_store.add_documents(filtered_documents) def get_documents(file_path: Path, llm: BaseChatModel): diff --git a/backend/main.py b/backend/main.py index cef545a..f15dde4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -129,8 +129,8 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current (message.id, message.timestamp, message.chat_id, message.sender, message.content), ) - #TODO : faire la réposne du llm - + # TODO : faire la réposne du llm + model_response = Message( id=str(uuid4()), timestamp=datetime.now().isoformat(), diff --git a/notebooks/memory_tests.ipynb b/notebooks/memory_tests.ipynb index a9e7443..520cf4d 100644 --- a/notebooks/memory_tests.ipynb +++ b/notebooks/memory_tests.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -104,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -213,20 +213,9 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[AIMessage(content='{\"content\": \"Hi\", \"timestamp\": \"2023-12-20 16:26:23.672506\", \"id\": \"764528762\"}')]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "chat_message_history.messages" ] From bc8d6e5dfa4f69f6e4fd7eca778a2f66f0fe0f04 Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Thu, 21 Dec 2023 11:05:40 +0100 Subject: [PATCH 2/5] vO bot fonctionnel --- backend/chatbot.py | 3 -- backend/document_loader.py | 8 ++-- backend/main.py | 23 +++++++++++- .../rag_components/chat_message_history.py | 37 +------------------ 4 files changed, 25 insertions(+), 46 deletions(-) diff --git a/backend/chatbot.py b/backend/chatbot.py index ca1d6c0..937a68b 100644 --- a/backend/chatbot.py +++ b/backend/chatbot.py @@ -17,9 +17,6 @@ def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str: """Processes the given query through the answer chain and returns the formatted response.""" - { - "content": answer_chain.run(query), - } return answer_chain.run(query) diff --git a/backend/document_loader.py b/backend/document_loader.py index f3dc061..218fe4a 100644 --- a/backend/document_loader.py +++ b/backend/document_loader.py @@ -9,8 +9,8 @@ from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import filter_complex_metadata -# TODO rajhouter fonction level avec opssibilité de rajouter l'objet document de langchain -# TODO rename load_dopcument en load_embeddded document +# TODO rajouter top level fonction avec possibilité de rajouter l'objet document de langchain +# TODO rename load_document en load_embedded_document def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore): @@ -78,9 +78,7 @@ def get_loaders() -> List[str]: vector_store = get_vector_store(embeddings) document = load_document( - file_path=Path( - "/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv" - ), + file_path=Path(f"{Path(__file__).parent}/data/billionaires_csv.csv"), llm=llm, vector_store=vector_store, ) diff --git a/backend/main.py b/backend/main.py index f15dde4..92b9deb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from pathlib import Path from typing import List from uuid import uuid4 @@ -7,8 +8,15 @@ from jose import JWTError, jwt import backend.document_store as document_store +from backend.chatbot import get_answer_chain, get_response +from backend.config_renderer import get_config +from backend.document_loader import load_document from backend.document_store import StorageBackend from backend.model import Doc, Message +from backend.rag_components.chat_message_history import get_conversation_buffer_memory +from backend.rag_components.embedding import get_embedding_model +from backend.rag_components.llm import get_llm_model +from backend.rag_components.vector_store import get_vector_store from backend.user_management import ( ALGORITHM, SECRET_KEY, @@ -129,14 +137,25 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current (message.id, message.timestamp, message.chat_id, message.sender, message.content), ) - # TODO : faire la réposne du llm + config = get_config() + llm = get_llm_model(config) + embeddings = get_embedding_model(config) + vector_store = get_vector_store(embeddings, config) + document = load_document( + file_path=Path(f"{Path(__file__).parent.parent}/data/billionaires_csv.csv"), + llm=llm, + vector_store=vector_store, + ) + memory = get_conversation_buffer_memory(config, message.chat_id) + answer_chain = get_answer_chain(llm, vector_store, memory) + response = get_response(answer_chain, message.content) model_response = Message( id=str(uuid4()), timestamp=datetime.now().isoformat(), chat_id=message.chat_id, sender="assistant", - content=f"Unique response: {uuid4()}", + content=response, ) with Database() as connection: diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index e6e6774..1f86f6d 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -12,7 +12,7 @@ from sqlalchemy import Column, DateTime, Text from sqlalchemy.orm import declarative_base -TABLE_NAME = "message" +TABLE_NAME = "message_history" def get_conversation_buffer_memory(config, chat_id): @@ -29,39 +29,4 @@ def get_chat_message_history(chat_id): session_id=chat_id, connection_string=os.environ.get("DATABASE_CONNECTION_STRING"), table_name=TABLE_NAME, - session_id_field_name="chat_id", - custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME), ) - - -Base = declarative_base() - - -class CustomMessage(Base): - __tablename__ = TABLE_NAME - - id = Column( - Text, primary_key=True, default=lambda: str(uuid4()) - ) # default=lambda: str(uuid4()) - timestamp = Column(DateTime) - chat_id = Column(Text) - sender = Column(Text) - content = Column(Text) - message = Column(Text) - - -class CustomMessageConverter(DefaultMessageConverter): - def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: - print(message.content) - sub_message = json.loads(message.content) - return CustomMessage( - id=sub_message["id"], - timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"), - chat_id=session_id, - sender=message.type, - content=sub_message["content"], - message=json.dumps(_message_to_dict(message)), - ) - - def get_sql_model_class(self) -> Any: - return CustomMessage From 6a3ad7ae58d7c822e8826f0d6b99d7098ffd0cac Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Thu, 21 Dec 2023 14:37:14 +0100 Subject: [PATCH 3/5] loading langchain document --- backend/_logs.py | 66 ------- backend/main.py | 38 ++--- .../rag_components/chat_message_history.py | 9 - .../{ => rag_components}/document_loader.py | 56 +++--- notebooks/docs_loader.ipynb | 119 +++++++++++++ notebooks/test_sla.ipynb | 161 ------------------ pyproject.toml | 2 +- 7 files changed, 173 insertions(+), 278 deletions(-) delete mode 100644 backend/_logs.py rename backend/{ => rag_components}/document_loader.py (55%) create mode 100644 notebooks/docs_loader.ipynb delete mode 100644 notebooks/test_sla.ipynb diff --git a/backend/_logs.py b/backend/_logs.py deleted file mode 100644 index 9451968..0000000 --- a/backend/_logs.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -from typing import Any, Dict, List, Sequence - -import streamlit as st -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema.document import Document - - -class StreamHandler(BaseCallbackHandler): - """StreamHandler is a class that handles the streaming of text. - - It is a callback handler for a language model. \ - It displays the generated text in a Streamlit container \ - and handles the start of the language model and the generation of new tokens. - """ - - def __init__( - self, container: st.delta_generator.DeltaGenerator, initial_text: str = "" - ) -> None: - """Initialize the StreamHandler.""" - self.container = container - self.text = initial_text - self.run_id_ignore_token = None - - def on_llm_start( - self, serialized: dict, prompts: List[str], **kwargs: Dict[str, Any] # noqa: ARG002 - ) -> None: - """Handle the start of the language model.""" - if "Question reformulée :" in prompts[0]: - self.run_id_ignore_token = kwargs.get("run_id") - - def on_llm_new_token(self, token: str, **kwargs: Dict[str, Any]) -> None: - """Handle the generation of a new token by the language model.""" - if self.run_id_ignore_token == kwargs.get("run_id", False): - return - self.text += token - self.container.markdown(self.text) - - -class PrintRetrievalHandler(BaseCallbackHandler): - """PrintRetrievalHandler is a class that handles the retrieval of documents. - - It is a callback handler for a document retriever. \ - It displays the status and content of the retrieved documents in a Streamlit container. - """ - - def __init__(self, container: st.delta_generator.DeltaGenerator) -> None: - """Initialize the PrintRetrievalHandler.""" - self.status = container.status("**Context Retrieval**") - - def on_retriever_start( - self, serialized: Dict[str, Any], query: str, **kwargs: Dict[str, Any] # noqa: ARG002 - ) -> None: - """Handle the start of the document retrieval.""" - self.status.write(f"**Question:** {query}") - self.status.update(label=f"**Context Retrieval:** {query}") - - def on_retriever_end( - self, documents: Sequence[Document], **kwargs: Dict[str, Any] # noqa: ARG002 - ) -> None: - """Handle the end of the document retrieval.""" - for idx, doc in enumerate(documents): - source = os.path.basename(doc.metadata["source"]) # noqa: PTH119 - self.status.write(f"**Document {idx} from {source}**") - self.status.markdown(doc.page_content) - self.status.update(state="complete") diff --git a/backend/main.py b/backend/main.py index 92b9deb..4b67077 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,16 +7,10 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt -import backend.document_store as document_store -from backend.chatbot import get_answer_chain, get_response from backend.config_renderer import get_config -from backend.document_loader import load_document -from backend.document_store import StorageBackend +from backend.document_store import StorageBackend, store_documents from backend.model import Doc, Message -from backend.rag_components.chat_message_history import get_conversation_buffer_memory -from backend.rag_components.embedding import get_embedding_model -from backend.rag_components.llm import get_llm_model -from backend.rag_components.vector_store import get_vector_store +from backend.rag_components.document_loader import generate_response from backend.user_management import ( ALGORITHM, SECRET_KEY, @@ -138,17 +132,23 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current ) config = get_config() - llm = get_llm_model(config) - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings, config) - document = load_document( - file_path=Path(f"{Path(__file__).parent.parent}/data/billionaires_csv.csv"), - llm=llm, - vector_store=vector_store, - ) - memory = get_conversation_buffer_memory(config, message.chat_id) - answer_chain = get_answer_chain(llm, vector_store, memory) - response = get_response(answer_chain, message.content) + + rag_on_file = False + file_path_str = f"{Path(__file__).parent.parent}/data/billionaires_csv.csv" + if rag_on_file: + docs_to_store = Path(file_path_str) + else: + from langchain.document_loaders import CSVLoader + from langchain.text_splitter import RecursiveCharacterTextSplitter + + loader = CSVLoader(file_path_str) + documents = loader.load() + text_splitter = RecursiveCharacterTextSplitter( + separators=["\n\n", "\n"], chunk_size=1500, chunk_overlap=100 + ) + docs_to_store = text_splitter.split_documents(documents) + + response = generate_response(docs_to_store, config, message) model_response = Message( id=str(uuid4()), diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index 1f86f6d..20558b1 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -1,16 +1,7 @@ -import json import os -from datetime import datetime -from typing import Any -from uuid import uuid4 from langchain.memory import ConversationBufferWindowMemory from langchain.memory.chat_message_histories import SQLChatMessageHistory -from langchain.memory.chat_message_histories.sql import DefaultMessageConverter -from langchain.schema import BaseMessage -from langchain.schema.messages import BaseMessage, _message_to_dict -from sqlalchemy import Column, DateTime, Text -from sqlalchemy.orm import declarative_base TABLE_NAME = "message_history" diff --git a/backend/document_loader.py b/backend/rag_components/document_loader.py similarity index 55% rename from backend/document_loader.py rename to backend/rag_components/document_loader.py index 218fe4a..57bd191 100644 --- a/backend/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -1,22 +1,41 @@ import inspect from pathlib import Path -from time import sleep -from typing import List +from typing import List, Union from langchain.chains import LLMChain from langchain.chat_models.base import BaseChatModel +from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import filter_complex_metadata -# TODO rajouter top level fonction avec possibilité de rajouter l'objet document de langchain -# TODO rename load_document en load_embedded_document +from backend.chatbot import get_answer_chain, get_response +from backend.rag_components.chat_message_history import get_conversation_buffer_memory +from backend.rag_components.embedding import get_embedding_model +from backend.rag_components.llm import get_llm_model +from backend.rag_components.vector_store import get_vector_store -def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore): - documents = get_documents(file_path, llm) - filtered_documents = filter_complex_metadata(documents) - vector_store.add_documents(filtered_documents) +def generate_response(file_path: Path, config, input_query): + llm = get_llm_model(config) + embeddings = get_embedding_model(config) + vector_store = get_vector_store(embeddings, config) + store_documents(file_path, llm, vector_store) + memory = get_conversation_buffer_memory(config, input_query.chat_id) + answer_chain = get_answer_chain(llm, vector_store, memory) + response = get_response(answer_chain, input_query.content) + return response + + +def store_documents( + data_to_store: Union[Path, Document], llm: BaseChatModel, vector_store: VectorStore +): + if isinstance(data_to_store, Path): + documents = get_documents(data_to_store, llm) + filtered_documents = filter_complex_metadata(documents) + vector_store.add_documents(filtered_documents) + else: + vector_store.add_documents(data_to_store) def get_documents(file_path: Path, llm: BaseChatModel): @@ -67,19 +86,12 @@ def get_loaders() -> List[str]: from pathlib import Path from backend.config_renderer import get_config - from backend.document_loader import get_documents - from backend.rag_components.embedding import get_embedding_model - from backend.rag_components.llm import get_llm_model - from backend.rag_components.vector_store import get_vector_store + from frontend.lib.chat import Message config = get_config() - llm = get_llm_model(config) - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings) - - document = load_document( - file_path=Path(f"{Path(__file__).parent}/data/billionaires_csv.csv"), - llm=llm, - vector_store=vector_store, - ) - print(document) + data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv") + prompt = "Quelles sont les 5 plus grandes fortunes de France ?" + chat_id = "test" + input_query = Message("user", prompt, chat_id) + response = generate_response(data_to_store, config, input_query) + print(response) diff --git a/notebooks/docs_loader.ipynb b/notebooks/docs_loader.ipynb new file mode 100644 index 0000000..ffe377f --- /dev/null +++ b/notebooks/docs_loader.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "current_directory = os.getcwd()\n", + "parent_directory = os.path.dirname(current_directory)\n", + "sys.path.append(parent_directory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import PyPDFLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = PyPDFLoader(f\"{parent_directory}/data/Cheat sheet entretien.pdf\", extract_images=True)\n", + "documents = loader.load()\n", + "text_splitter = RecursiveCharacterTextSplitter(\n", + " separators=[\"\\n\\n\", \"\\n\", \".\"], chunk_size=1500, chunk_overlap=100\n", + ")\n", + "texts = text_splitter.split_documents(documents)\n", + "\n", + "with open('../data/local_documents.json', 'w') as f:\n", + " for doc in texts:\n", + " f.write(doc.json() + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from langchain.docstore.document import Document\n", + "\n", + "with open('../data/local_documents.json', 'r') as f:\n", + " json_data = [json.loads(line) for line in f]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents = []\n", + "\n", + "for r in json_data:\n", + " document = Document(page_content=r[\"page_content\"], metadata=r[\"metadata\"])\n", + " documents.append(document)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "skaff-rag-accelerator", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/test_sla.ipynb b/notebooks/test_sla.ipynb deleted file mode 100644 index 523a617..0000000 --- a/notebooks/test_sla.ipynb +++ /dev/null @@ -1,161 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "current_directory = os.getcwd()\n", - "parent_directory = os.path.dirname(current_directory)\n", - "sys.path.append(parent_directory)\n", - "data_folder_path = f\"{current_directory}/data\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import lib.backend as utils\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()\n", - "embedding_api_base = os.getenv(\"EMBEDDING_OPENAI_API_BASE\")\n", - "embedding_api_key = os.getenv(\"EMBEDDING_API_KEY\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List\n", - "\n", - "def collect_files_with_extension(folder_path: str, extension: str) -> List[str]:\n", - " \"\"\"Collects and returns a list of file names with a given extension within the specified folder.\"\"\"\n", - " files_with_extension = []\n", - "\n", - " if not extension.startswith('.'):\n", - " extension = '.' + extension\n", - "\n", - " for file_name in os.listdir(folder_path):\n", - " if file_name.endswith(extension):\n", - " files_with_extension.append(file_name)\n", - "\n", - " return files_with_extension[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "file_extension = \".pdf\"\n", - "file_name = collect_files_with_extension(data_folder_path, file_extension)\n", - "file_path = f\"{data_folder_path}/{file_name}\"\n", - "file_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "documents = utils.load_documents(file_extension, file_path)\n", - "texts = utils.get_chunks(documents, chunk_size=1500, chunk_overlap=200, text_splitter_type=\"recursive\")\n", - "print(len(texts))\n", - "texts[1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "llm = utils.get_llm(temperature=0.1, model_version=\"4\", live_streaming=True)\n", - "embeddings = utils.get_embeddings_model(embedding_api_base, embedding_api_key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory\n", - "from langchain.memory.chat_message_histories import StreamlitChatMessageHistory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def choose_momery_type(memory_type, llm):\n", - " msgs = StreamlitChatMessageHistory(key=\"special_app_key\")\n", - " if memory_type == \"\":\n", - " memory = ConversationBufferMemory(memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " elif memory_type == \"\":\n", - " memory = ConversationBufferWindowMemory(k=2, memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " elif memory_type == \"\":\n", - " memory = ConversationSummaryMemory(llm=llm, memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " elif memory_type == \"\":\n", - " memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=100, memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " return memory\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "documents = utils.load_documents(source=\"site\")\n", - "texts = utils.get_chunks(documents, chunk_size=1500, chunk_overlap=200)\n", - "docsearch = utils.get_vector_store(texts, embeddings)\n", - "answer_chain = utils.get_answer_chain(llm, docsearch, memory)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "skaff-rag-accelerator", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index 0082cb2..ff39156 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ select = [ "PTH", "PD", ] # See: https://beta.ruff.rs/docs/rules/ -ignore = ["D100", "D203", "D213", "ANN101", "ANN102"] +ignore = ["D100", "D103", "D203", "D213", "ANN101", "ANN102"] line-length = 100 target-version = "py310" exclude = [ From 909aba2620d1d5fc29f7aad93a94c215cd69bc96 Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Thu, 21 Dec 2023 14:41:46 +0100 Subject: [PATCH 4/5] missing file to commit --- database/database_init.sql | 1 - 1 file changed, 1 deletion(-) diff --git a/database/database_init.sql b/database/database_init.sql index 8a96b21..eb9f971 100644 --- a/database/database_init.sql +++ b/database/database_init.sql @@ -23,7 +23,6 @@ CREATE TABLE IF NOT EXISTS "message" ( "chat_id" TEXT, "sender" TEXT, "content" TEXT, - "message" TEXT, FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") ); From 3343a6f5cbc32b6f1dc4a58cbef719b93401b9ec Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Thu, 21 Dec 2023 17:16:53 +0100 Subject: [PATCH 5/5] rag component --- backend/config.yaml | 2 +- backend/main.py | 17 -------- backend/rag_components/document_loader.py | 34 +--------------- backend/rag_components/main.py | 48 +++++++++++++++++++++++ 4 files changed, 51 insertions(+), 50 deletions(-) create mode 100644 backend/rag_components/main.py diff --git a/backend/config.yaml b/backend/config.yaml index 967332a..9ca450d 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -23,7 +23,7 @@ embedding_model_config: vector_store_provider: model_source: Chroma - persist_directory: database/ + persist_directory: vector_database/ chat_message_history_config: source: ChatMessageHistory diff --git a/backend/main.py b/backend/main.py index 4b67077..10d71e1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -133,23 +133,6 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current config = get_config() - rag_on_file = False - file_path_str = f"{Path(__file__).parent.parent}/data/billionaires_csv.csv" - if rag_on_file: - docs_to_store = Path(file_path_str) - else: - from langchain.document_loaders import CSVLoader - from langchain.text_splitter import RecursiveCharacterTextSplitter - - loader = CSVLoader(file_path_str) - documents = loader.load() - text_splitter = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n"], chunk_size=1500, chunk_overlap=100 - ) - docs_to_store = text_splitter.split_documents(documents) - - response = generate_response(docs_to_store, config, message) - model_response = Message( id=str(uuid4()), timestamp=datetime.now().isoformat(), diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py index 57bd191..da4ae23 100644 --- a/backend/rag_components/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -1,46 +1,16 @@ import inspect from pathlib import Path -from typing import List, Union +from typing import List from langchain.chains import LLMChain from langchain.chat_models.base import BaseChatModel -from langchain.docstore.document import Document from langchain.prompts import PromptTemplate -from langchain.vectorstores import VectorStore -from langchain.vectorstores.utils import filter_complex_metadata - -from backend.chatbot import get_answer_chain, get_response -from backend.rag_components.chat_message_history import get_conversation_buffer_memory -from backend.rag_components.embedding import get_embedding_model -from backend.rag_components.llm import get_llm_model -from backend.rag_components.vector_store import get_vector_store - - -def generate_response(file_path: Path, config, input_query): - llm = get_llm_model(config) - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings, config) - store_documents(file_path, llm, vector_store) - memory = get_conversation_buffer_memory(config, input_query.chat_id) - answer_chain = get_answer_chain(llm, vector_store, memory) - response = get_response(answer_chain, input_query.content) - return response - - -def store_documents( - data_to_store: Union[Path, Document], llm: BaseChatModel, vector_store: VectorStore -): - if isinstance(data_to_store, Path): - documents = get_documents(data_to_store, llm) - filtered_documents = filter_complex_metadata(documents) - vector_store.add_documents(filtered_documents) - else: - vector_store.add_documents(data_to_store) def get_documents(file_path: Path, llm: BaseChatModel): file_extension = file_path.suffix loader_class_name = get_best_loader(file_extension, llm) + print(f"loader selected {loader_class_name} for {file_path}") if loader_class_name == "None": raise Exception(f"No loader found for {file_extension} files.") diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py new file mode 100644 index 0000000..db93076 --- /dev/null +++ b/backend/rag_components/main.py @@ -0,0 +1,48 @@ +from pathlib import Path +from typing import List + +from langchain.docstore.document import Document +from langchain.vectorstores.utils import filter_complex_metadata + +from backend.config_renderer import get_config +from backend.rag_components.document_loader import get_documents +from backend.rag_components.embedding import get_embedding_model +from backend.rag_components.llm import get_llm_model +from backend.rag_components.vector_store import get_vector_store + + +class RAG: + def __init__(self): + self.config = get_config() + self.llm = get_llm_model(self.config) + self.embeddings = get_embedding_model(self.config) + self.vector_store = get_vector_store(self.embeddings, self.config) + + def generate_response(): + pass + + def load_documents(self, documents: List[Document]): + # TODO améliorer la robustesse du load_document + # TODO agent langchain qui fait le get_best_loader + self.vector_store.add_documents(documents) + + def load_file(self, file_path: Path): + documents = get_documents(file_path, self.llm) + filtered_documents = filter_complex_metadata(documents) + self.vector_store.add_documents(filtered_documents) + + # TODO pour chaque fichier -> stocker un hash en base + # TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant + # TODO éviter de dupliquer les embeddings + + def serve(): + pass + + +if __name__ == "__main__": + file_path = Path(__file__).parent.parent.parent / "data" + rag = RAG() + + for file in file_path.iterdir(): + if file.is_file(): + rag.load_file(file)