Skip to content

Commit

Permalink
rag component
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-lauzeral committed Dec 21, 2023
1 parent 909aba2 commit 3343a6f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 50 deletions.
2 changes: 1 addition & 1 deletion backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ embedding_model_config:

vector_store_provider:
model_source: Chroma
persist_directory: database/
persist_directory: vector_database/

chat_message_history_config:
source: ChatMessageHistory
Expand Down
17 changes: 0 additions & 17 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,6 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current

config = get_config()

rag_on_file = False
file_path_str = f"{Path(__file__).parent.parent}/data/billionaires_csv.csv"
if rag_on_file:
docs_to_store = Path(file_path_str)
else:
from langchain.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

loader = CSVLoader(file_path_str)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"], chunk_size=1500, chunk_overlap=100
)
docs_to_store = text_splitter.split_documents(documents)

response = generate_response(docs_to_store, config, message)

model_response = Message(
id=str(uuid4()),
timestamp=datetime.now().isoformat(),
Expand Down
34 changes: 2 additions & 32 deletions backend/rag_components/document_loader.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,16 @@
import inspect
from pathlib import Path
from typing import List, Union
from typing import List

from langchain.chains import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import filter_complex_metadata

from backend.chatbot import get_answer_chain, get_response
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store


def generate_response(file_path: Path, config, input_query):
llm = get_llm_model(config)
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings, config)
store_documents(file_path, llm, vector_store)
memory = get_conversation_buffer_memory(config, input_query.chat_id)
answer_chain = get_answer_chain(llm, vector_store, memory)
response = get_response(answer_chain, input_query.content)
return response


def store_documents(
data_to_store: Union[Path, Document], llm: BaseChatModel, vector_store: VectorStore
):
if isinstance(data_to_store, Path):
documents = get_documents(data_to_store, llm)
filtered_documents = filter_complex_metadata(documents)
vector_store.add_documents(filtered_documents)
else:
vector_store.add_documents(data_to_store)


def get_documents(file_path: Path, llm: BaseChatModel):
file_extension = file_path.suffix
loader_class_name = get_best_loader(file_extension, llm)
print(f"loader selected {loader_class_name} for {file_path}")

if loader_class_name == "None":
raise Exception(f"No loader found for {file_extension} files.")
Expand Down
48 changes: 48 additions & 0 deletions backend/rag_components/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path
from typing import List

from langchain.docstore.document import Document
from langchain.vectorstores.utils import filter_complex_metadata

from backend.config_renderer import get_config
from backend.rag_components.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store


class RAG:
def __init__(self):
self.config = get_config()
self.llm = get_llm_model(self.config)
self.embeddings = get_embedding_model(self.config)
self.vector_store = get_vector_store(self.embeddings, self.config)

def generate_response():
pass

def load_documents(self, documents: List[Document]):
# TODO améliorer la robustesse du load_document
# TODO agent langchain qui fait le get_best_loader
self.vector_store.add_documents(documents)

def load_file(self, file_path: Path):
documents = get_documents(file_path, self.llm)
filtered_documents = filter_complex_metadata(documents)
self.vector_store.add_documents(filtered_documents)

# TODO pour chaque fichier -> stocker un hash en base
# TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant
# TODO éviter de dupliquer les embeddings

def serve():
pass


if __name__ == "__main__":
file_path = Path(__file__).parent.parent.parent / "data"
rag = RAG()

for file in file_path.iterdir():
if file.is_file():
rag.load_file(file)

0 comments on commit 3343a6f

Please sign in to comment.