Skip to content

Commit

Permalink
Merge branch 'main' into av/streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 21, 2023
2 parents b63ffc2 + fa9c948 commit 4451414
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 280 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ python "/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@ar
- comment lancer l'API
- gestion de la config
- écrire des helpers de co, pour envoyer des messages...
- tester différents modèles
- écrire des snippets de code pour éxpliquer comment charger les docs dans le RAG

</div>
66 changes: 0 additions & 66 deletions backend/_logs.py

This file was deleted.

2 changes: 1 addition & 1 deletion backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ embedding_model_config:

vector_store_provider:
model_source: Chroma
persist_directory: database/
persist_directory: vector_database/

chat_message_history_config:
source: ChatMessageHistory
Expand Down
3 changes: 2 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from jose import JWTError, jwt


import backend.document_store as document_store

from backend.config_renderer import get_config
from backend.document_store import StorageBackend
from backend.model import Doc, Message
from backend.user_management import (
Expand Down
6 changes: 4 additions & 2 deletions backend/rag_components/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import SQLChatMessageHistory

TABLE_NAME = "message_history"


def get_conversation_buffer_memory(config, chat_id):
return ConversationBufferWindowMemory(
Expand All @@ -16,5 +18,5 @@ def get_chat_message_history(chat_id):
return SQLChatMessageHistory(
session_id=chat_id,
connection_string=os.environ.get("DATABASE_CONNECTION_STRING"),
table_name="message_history",
)
table_name=TABLE_NAME,
)
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import inspect
from pathlib import Path
from time import sleep
from typing import List

from langchain.chains import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.prompts import PromptTemplate
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import filter_complex_metadata


def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore):
documents = get_documents(file_path, llm)
filtered_documents = filter_complex_metadata(documents)
vector_store.add_documents(documents)


def get_documents(file_path: Path, llm: BaseChatModel):
file_extension = file_path.suffix
loader_class_name = get_best_loader(file_extension, llm)
print(f"loader selected {loader_class_name} for {file_path}")

if loader_class_name == "None":
raise Exception(f"No loader found for {file_extension} files.")
Expand Down Expand Up @@ -64,21 +56,12 @@ def get_loaders() -> List[str]:
from pathlib import Path

from backend.config_renderer import get_config
from backend.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store
from frontend.lib.chat import Message

config = get_config()
llm = get_llm_model(config)
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings)

document = load_document(
file_path=Path(
"/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv"
),
llm=llm,
vector_store=vector_store,
)
print(document)
data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv")
prompt = "Quelles sont les 5 plus grandes fortunes de France ?"
chat_id = "test"
input_query = Message("user", prompt, chat_id)
response = generate_response(data_to_store, config, input_query)
print(response)
48 changes: 48 additions & 0 deletions backend/rag_components/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path
from typing import List

from langchain.docstore.document import Document
from langchain.vectorstores.utils import filter_complex_metadata

from backend.config_renderer import get_config
from backend.rag_components.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store


class RAG:
def __init__(self):
self.config = get_config()
self.llm = get_llm_model(self.config)
self.embeddings = get_embedding_model(self.config)
self.vector_store = get_vector_store(self.embeddings, self.config)

def generate_response():
pass

def load_documents(self, documents: List[Document]):
# TODO améliorer la robustesse du load_document
# TODO agent langchain qui fait le get_best_loader
self.vector_store.add_documents(documents)

def load_file(self, file_path: Path):
documents = get_documents(file_path, self.llm)
filtered_documents = filter_complex_metadata(documents)
self.vector_store.add_documents(filtered_documents)

# TODO pour chaque fichier -> stocker un hash en base
# TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant
# TODO éviter de dupliquer les embeddings

def serve():
pass


if __name__ == "__main__":
file_path = Path(__file__).parent.parent.parent / "data"
rag = RAG()

for file in file_path.iterdir():
if file.is_file():
rag.load_file(file)
119 changes: 119 additions & 0 deletions notebooks/docs_loader.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"current_directory = os.getcwd()\n",
"parent_directory = os.path.dirname(current_directory)\n",
"sys.path.append(parent_directory)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loader = PyPDFLoader(f\"{parent_directory}/data/Cheat sheet entretien.pdf\", extract_images=True)\n",
"documents = loader.load()\n",
"text_splitter = RecursiveCharacterTextSplitter(\n",
" separators=[\"\\n\\n\", \"\\n\", \".\"], chunk_size=1500, chunk_overlap=100\n",
")\n",
"texts = text_splitter.split_documents(documents)\n",
"\n",
"with open('../data/local_documents.json', 'w') as f:\n",
" for doc in texts:\n",
" f.write(doc.json() + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from langchain.docstore.document import Document\n",
"\n",
"with open('../data/local_documents.json', 'r') as f:\n",
" json_data = [json.loads(line) for line in f]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"documents = []\n",
"\n",
"for r in json_data:\n",
" document = Document(page_content=r[\"page_content\"], metadata=r[\"metadata\"])\n",
" documents.append(document)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"documents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "skaff-rag-accelerator",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 4451414

Please sign in to comment.