diff --git a/README.md b/README.md index 49d1697..ac80bd0 100644 --- a/README.md +++ b/README.md @@ -14,5 +14,7 @@ python "/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@ar - comment lancer l'API - gestion de la config - écrire des helpers de co, pour envoyer des messages... +- tester différents modèles +- écrire des snippets de code pour éxpliquer comment charger les docs dans le RAG diff --git a/backend/_logs.py b/backend/_logs.py deleted file mode 100644 index 9451968..0000000 --- a/backend/_logs.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -from typing import Any, Dict, List, Sequence - -import streamlit as st -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema.document import Document - - -class StreamHandler(BaseCallbackHandler): - """StreamHandler is a class that handles the streaming of text. - - It is a callback handler for a language model. \ - It displays the generated text in a Streamlit container \ - and handles the start of the language model and the generation of new tokens. - """ - - def __init__( - self, container: st.delta_generator.DeltaGenerator, initial_text: str = "" - ) -> None: - """Initialize the StreamHandler.""" - self.container = container - self.text = initial_text - self.run_id_ignore_token = None - - def on_llm_start( - self, serialized: dict, prompts: List[str], **kwargs: Dict[str, Any] # noqa: ARG002 - ) -> None: - """Handle the start of the language model.""" - if "Question reformulée :" in prompts[0]: - self.run_id_ignore_token = kwargs.get("run_id") - - def on_llm_new_token(self, token: str, **kwargs: Dict[str, Any]) -> None: - """Handle the generation of a new token by the language model.""" - if self.run_id_ignore_token == kwargs.get("run_id", False): - return - self.text += token - self.container.markdown(self.text) - - -class PrintRetrievalHandler(BaseCallbackHandler): - """PrintRetrievalHandler is a class that handles the retrieval of documents. - - It is a callback handler for a document retriever. \ - It displays the status and content of the retrieved documents in a Streamlit container. - """ - - def __init__(self, container: st.delta_generator.DeltaGenerator) -> None: - """Initialize the PrintRetrievalHandler.""" - self.status = container.status("**Context Retrieval**") - - def on_retriever_start( - self, serialized: Dict[str, Any], query: str, **kwargs: Dict[str, Any] # noqa: ARG002 - ) -> None: - """Handle the start of the document retrieval.""" - self.status.write(f"**Question:** {query}") - self.status.update(label=f"**Context Retrieval:** {query}") - - def on_retriever_end( - self, documents: Sequence[Document], **kwargs: Dict[str, Any] # noqa: ARG002 - ) -> None: - """Handle the end of the document retrieval.""" - for idx, doc in enumerate(documents): - source = os.path.basename(doc.metadata["source"]) # noqa: PTH119 - self.status.write(f"**Document {idx} from {source}**") - self.status.markdown(doc.page_content) - self.status.update(state="complete") diff --git a/backend/chatbot.py b/backend/chatbot.py index 28c3988..937a68b 100644 --- a/backend/chatbot.py +++ b/backend/chatbot.py @@ -17,7 +17,6 @@ def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str: """Processes the given query through the answer chain and returns the formatted response.""" - {"content": answer_chain.run(query), } return answer_chain.run(query) diff --git a/backend/config.yaml b/backend/config.yaml index 967332a..9ca450d 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -23,7 +23,7 @@ embedding_model_config: vector_store_provider: model_source: Chroma - persist_directory: database/ + persist_directory: vector_database/ chat_message_history_config: source: ChatMessageHistory diff --git a/backend/main.py b/backend/main.py index cef545a..10d71e1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from pathlib import Path from typing import List from uuid import uuid4 @@ -6,9 +7,10 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt -import backend.document_store as document_store -from backend.document_store import StorageBackend +from backend.config_renderer import get_config +from backend.document_store import StorageBackend, store_documents from backend.model import Doc, Message +from backend.rag_components.document_loader import generate_response from backend.user_management import ( ALGORITHM, SECRET_KEY, @@ -129,14 +131,14 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current (message.id, message.timestamp, message.chat_id, message.sender, message.content), ) - #TODO : faire la réposne du llm - + config = get_config() + model_response = Message( id=str(uuid4()), timestamp=datetime.now().isoformat(), chat_id=message.chat_id, sender="assistant", - content=f"Unique response: {uuid4()}", + content=response, ) with Database() as connection: diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index e6e6774..20558b1 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -1,18 +1,9 @@ -import json import os -from datetime import datetime -from typing import Any -from uuid import uuid4 from langchain.memory import ConversationBufferWindowMemory from langchain.memory.chat_message_histories import SQLChatMessageHistory -from langchain.memory.chat_message_histories.sql import DefaultMessageConverter -from langchain.schema import BaseMessage -from langchain.schema.messages import BaseMessage, _message_to_dict -from sqlalchemy import Column, DateTime, Text -from sqlalchemy.orm import declarative_base -TABLE_NAME = "message" +TABLE_NAME = "message_history" def get_conversation_buffer_memory(config, chat_id): @@ -29,39 +20,4 @@ def get_chat_message_history(chat_id): session_id=chat_id, connection_string=os.environ.get("DATABASE_CONNECTION_STRING"), table_name=TABLE_NAME, - session_id_field_name="chat_id", - custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME), ) - - -Base = declarative_base() - - -class CustomMessage(Base): - __tablename__ = TABLE_NAME - - id = Column( - Text, primary_key=True, default=lambda: str(uuid4()) - ) # default=lambda: str(uuid4()) - timestamp = Column(DateTime) - chat_id = Column(Text) - sender = Column(Text) - content = Column(Text) - message = Column(Text) - - -class CustomMessageConverter(DefaultMessageConverter): - def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: - print(message.content) - sub_message = json.loads(message.content) - return CustomMessage( - id=sub_message["id"], - timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"), - chat_id=session_id, - sender=message.type, - content=sub_message["content"], - message=json.dumps(_message_to_dict(message)), - ) - - def get_sql_model_class(self) -> Any: - return CustomMessage diff --git a/backend/document_loader.py b/backend/rag_components/document_loader.py similarity index 64% rename from backend/document_loader.py rename to backend/rag_components/document_loader.py index a960a7f..da4ae23 100644 --- a/backend/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -1,24 +1,16 @@ import inspect from pathlib import Path -from time import sleep from typing import List from langchain.chains import LLMChain from langchain.chat_models.base import BaseChatModel from langchain.prompts import PromptTemplate -from langchain.vectorstores import VectorStore -from langchain.vectorstores.utils import filter_complex_metadata - - -def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore): - documents = get_documents(file_path, llm) - filtered_documents = filter_complex_metadata(documents) - vector_store.add_documents(documents) def get_documents(file_path: Path, llm: BaseChatModel): file_extension = file_path.suffix loader_class_name = get_best_loader(file_extension, llm) + print(f"loader selected {loader_class_name} for {file_path}") if loader_class_name == "None": raise Exception(f"No loader found for {file_extension} files.") @@ -64,21 +56,12 @@ def get_loaders() -> List[str]: from pathlib import Path from backend.config_renderer import get_config - from backend.document_loader import get_documents - from backend.rag_components.embedding import get_embedding_model - from backend.rag_components.llm import get_llm_model - from backend.rag_components.vector_store import get_vector_store + from frontend.lib.chat import Message config = get_config() - llm = get_llm_model(config) - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings) - - document = load_document( - file_path=Path( - "/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv" - ), - llm=llm, - vector_store=vector_store, - ) - print(document) + data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv") + prompt = "Quelles sont les 5 plus grandes fortunes de France ?" + chat_id = "test" + input_query = Message("user", prompt, chat_id) + response = generate_response(data_to_store, config, input_query) + print(response) diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py new file mode 100644 index 0000000..db93076 --- /dev/null +++ b/backend/rag_components/main.py @@ -0,0 +1,48 @@ +from pathlib import Path +from typing import List + +from langchain.docstore.document import Document +from langchain.vectorstores.utils import filter_complex_metadata + +from backend.config_renderer import get_config +from backend.rag_components.document_loader import get_documents +from backend.rag_components.embedding import get_embedding_model +from backend.rag_components.llm import get_llm_model +from backend.rag_components.vector_store import get_vector_store + + +class RAG: + def __init__(self): + self.config = get_config() + self.llm = get_llm_model(self.config) + self.embeddings = get_embedding_model(self.config) + self.vector_store = get_vector_store(self.embeddings, self.config) + + def generate_response(): + pass + + def load_documents(self, documents: List[Document]): + # TODO améliorer la robustesse du load_document + # TODO agent langchain qui fait le get_best_loader + self.vector_store.add_documents(documents) + + def load_file(self, file_path: Path): + documents = get_documents(file_path, self.llm) + filtered_documents = filter_complex_metadata(documents) + self.vector_store.add_documents(filtered_documents) + + # TODO pour chaque fichier -> stocker un hash en base + # TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant + # TODO éviter de dupliquer les embeddings + + def serve(): + pass + + +if __name__ == "__main__": + file_path = Path(__file__).parent.parent.parent / "data" + rag = RAG() + + for file in file_path.iterdir(): + if file.is_file(): + rag.load_file(file) diff --git a/database/database_init.sql b/database/database_init.sql index 8a96b21..eb9f971 100644 --- a/database/database_init.sql +++ b/database/database_init.sql @@ -23,7 +23,6 @@ CREATE TABLE IF NOT EXISTS "message" ( "chat_id" TEXT, "sender" TEXT, "content" TEXT, - "message" TEXT, FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") ); diff --git a/notebooks/docs_loader.ipynb b/notebooks/docs_loader.ipynb new file mode 100644 index 0000000..ffe377f --- /dev/null +++ b/notebooks/docs_loader.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "current_directory = os.getcwd()\n", + "parent_directory = os.path.dirname(current_directory)\n", + "sys.path.append(parent_directory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import PyPDFLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = PyPDFLoader(f\"{parent_directory}/data/Cheat sheet entretien.pdf\", extract_images=True)\n", + "documents = loader.load()\n", + "text_splitter = RecursiveCharacterTextSplitter(\n", + " separators=[\"\\n\\n\", \"\\n\", \".\"], chunk_size=1500, chunk_overlap=100\n", + ")\n", + "texts = text_splitter.split_documents(documents)\n", + "\n", + "with open('../data/local_documents.json', 'w') as f:\n", + " for doc in texts:\n", + " f.write(doc.json() + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from langchain.docstore.document import Document\n", + "\n", + "with open('../data/local_documents.json', 'r') as f:\n", + " json_data = [json.loads(line) for line in f]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents = []\n", + "\n", + "for r in json_data:\n", + " document = Document(page_content=r[\"page_content\"], metadata=r[\"metadata\"])\n", + " documents.append(document)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "skaff-rag-accelerator", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/memory_tests.ipynb b/notebooks/memory_tests.ipynb index a9e7443..520cf4d 100644 --- a/notebooks/memory_tests.ipynb +++ b/notebooks/memory_tests.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -104,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -213,20 +213,9 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[AIMessage(content='{\"content\": \"Hi\", \"timestamp\": \"2023-12-20 16:26:23.672506\", \"id\": \"764528762\"}')]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "chat_message_history.messages" ] diff --git a/notebooks/test_sla.ipynb b/notebooks/test_sla.ipynb deleted file mode 100644 index 523a617..0000000 --- a/notebooks/test_sla.ipynb +++ /dev/null @@ -1,161 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "current_directory = os.getcwd()\n", - "parent_directory = os.path.dirname(current_directory)\n", - "sys.path.append(parent_directory)\n", - "data_folder_path = f\"{current_directory}/data\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import lib.backend as utils\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()\n", - "embedding_api_base = os.getenv(\"EMBEDDING_OPENAI_API_BASE\")\n", - "embedding_api_key = os.getenv(\"EMBEDDING_API_KEY\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List\n", - "\n", - "def collect_files_with_extension(folder_path: str, extension: str) -> List[str]:\n", - " \"\"\"Collects and returns a list of file names with a given extension within the specified folder.\"\"\"\n", - " files_with_extension = []\n", - "\n", - " if not extension.startswith('.'):\n", - " extension = '.' + extension\n", - "\n", - " for file_name in os.listdir(folder_path):\n", - " if file_name.endswith(extension):\n", - " files_with_extension.append(file_name)\n", - "\n", - " return files_with_extension[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "file_extension = \".pdf\"\n", - "file_name = collect_files_with_extension(data_folder_path, file_extension)\n", - "file_path = f\"{data_folder_path}/{file_name}\"\n", - "file_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "documents = utils.load_documents(file_extension, file_path)\n", - "texts = utils.get_chunks(documents, chunk_size=1500, chunk_overlap=200, text_splitter_type=\"recursive\")\n", - "print(len(texts))\n", - "texts[1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "llm = utils.get_llm(temperature=0.1, model_version=\"4\", live_streaming=True)\n", - "embeddings = utils.get_embeddings_model(embedding_api_base, embedding_api_key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory\n", - "from langchain.memory.chat_message_histories import StreamlitChatMessageHistory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def choose_momery_type(memory_type, llm):\n", - " msgs = StreamlitChatMessageHistory(key=\"special_app_key\")\n", - " if memory_type == \"\":\n", - " memory = ConversationBufferMemory(memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " elif memory_type == \"\":\n", - " memory = ConversationBufferWindowMemory(k=2, memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " elif memory_type == \"\":\n", - " memory = ConversationSummaryMemory(llm=llm, memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " elif memory_type == \"\":\n", - " memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=100, memory_key=\"chat_history\", chat_memory=msgs, return_messages=True)\n", - " return memory\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "documents = utils.load_documents(source=\"site\")\n", - "texts = utils.get_chunks(documents, chunk_size=1500, chunk_overlap=200)\n", - "docsearch = utils.get_vector_store(texts, embeddings)\n", - "answer_chain = utils.get_answer_chain(llm, docsearch, memory)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "skaff-rag-accelerator", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index 0082cb2..ff39156 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ select = [ "PTH", "PD", ] # See: https://beta.ruff.rs/docs/rules/ -ignore = ["D100", "D203", "D213", "ANN101", "ANN102"] +ignore = ["D100", "D103", "D203", "D213", "ANN101", "ANN102"] line-length = 100 target-version = "py310" exclude = [