From 3343a6f5cbc32b6f1dc4a58cbef719b93401b9ec Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Thu, 21 Dec 2023 17:16:53 +0100 Subject: [PATCH] rag component --- backend/config.yaml | 2 +- backend/main.py | 17 -------- backend/rag_components/document_loader.py | 34 +--------------- backend/rag_components/main.py | 48 +++++++++++++++++++++++ 4 files changed, 51 insertions(+), 50 deletions(-) create mode 100644 backend/rag_components/main.py diff --git a/backend/config.yaml b/backend/config.yaml index 967332a..9ca450d 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -23,7 +23,7 @@ embedding_model_config: vector_store_provider: model_source: Chroma - persist_directory: database/ + persist_directory: vector_database/ chat_message_history_config: source: ChatMessageHistory diff --git a/backend/main.py b/backend/main.py index 4b67077..10d71e1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -133,23 +133,6 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current config = get_config() - rag_on_file = False - file_path_str = f"{Path(__file__).parent.parent}/data/billionaires_csv.csv" - if rag_on_file: - docs_to_store = Path(file_path_str) - else: - from langchain.document_loaders import CSVLoader - from langchain.text_splitter import RecursiveCharacterTextSplitter - - loader = CSVLoader(file_path_str) - documents = loader.load() - text_splitter = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n"], chunk_size=1500, chunk_overlap=100 - ) - docs_to_store = text_splitter.split_documents(documents) - - response = generate_response(docs_to_store, config, message) - model_response = Message( id=str(uuid4()), timestamp=datetime.now().isoformat(), diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py index 57bd191..da4ae23 100644 --- a/backend/rag_components/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -1,46 +1,16 @@ import inspect from pathlib import Path -from typing import List, Union +from typing import List from langchain.chains import LLMChain from langchain.chat_models.base import BaseChatModel -from langchain.docstore.document import Document from langchain.prompts import PromptTemplate -from langchain.vectorstores import VectorStore -from langchain.vectorstores.utils import filter_complex_metadata - -from backend.chatbot import get_answer_chain, get_response -from backend.rag_components.chat_message_history import get_conversation_buffer_memory -from backend.rag_components.embedding import get_embedding_model -from backend.rag_components.llm import get_llm_model -from backend.rag_components.vector_store import get_vector_store - - -def generate_response(file_path: Path, config, input_query): - llm = get_llm_model(config) - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings, config) - store_documents(file_path, llm, vector_store) - memory = get_conversation_buffer_memory(config, input_query.chat_id) - answer_chain = get_answer_chain(llm, vector_store, memory) - response = get_response(answer_chain, input_query.content) - return response - - -def store_documents( - data_to_store: Union[Path, Document], llm: BaseChatModel, vector_store: VectorStore -): - if isinstance(data_to_store, Path): - documents = get_documents(data_to_store, llm) - filtered_documents = filter_complex_metadata(documents) - vector_store.add_documents(filtered_documents) - else: - vector_store.add_documents(data_to_store) def get_documents(file_path: Path, llm: BaseChatModel): file_extension = file_path.suffix loader_class_name = get_best_loader(file_extension, llm) + print(f"loader selected {loader_class_name} for {file_path}") if loader_class_name == "None": raise Exception(f"No loader found for {file_extension} files.") diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py new file mode 100644 index 0000000..db93076 --- /dev/null +++ b/backend/rag_components/main.py @@ -0,0 +1,48 @@ +from pathlib import Path +from typing import List + +from langchain.docstore.document import Document +from langchain.vectorstores.utils import filter_complex_metadata + +from backend.config_renderer import get_config +from backend.rag_components.document_loader import get_documents +from backend.rag_components.embedding import get_embedding_model +from backend.rag_components.llm import get_llm_model +from backend.rag_components.vector_store import get_vector_store + + +class RAG: + def __init__(self): + self.config = get_config() + self.llm = get_llm_model(self.config) + self.embeddings = get_embedding_model(self.config) + self.vector_store = get_vector_store(self.embeddings, self.config) + + def generate_response(): + pass + + def load_documents(self, documents: List[Document]): + # TODO améliorer la robustesse du load_document + # TODO agent langchain qui fait le get_best_loader + self.vector_store.add_documents(documents) + + def load_file(self, file_path: Path): + documents = get_documents(file_path, self.llm) + filtered_documents = filter_complex_metadata(documents) + self.vector_store.add_documents(filtered_documents) + + # TODO pour chaque fichier -> stocker un hash en base + # TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant + # TODO éviter de dupliquer les embeddings + + def serve(): + pass + + +if __name__ == "__main__": + file_path = Path(__file__).parent.parent.parent / "data" + rag = RAG() + + for file in file_path.iterdir(): + if file.is_file(): + rag.load_file(file)