Skip to content

Commit

Permalink
Merge pull request #5 from artefactory/sla/feat-load-document
Browse files Browse the repository at this point in the history
Sla/feat load document
  • Loading branch information
AlexisVLRT authored Dec 21, 2023
2 parents 960d377 + 3343a6f commit fa9c948
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 329 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ python "/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@ar
- comment lancer l'API
- gestion de la config
- écrire des helpers de co, pour envoyer des messages...
- tester différents modèles
- écrire des snippets de code pour éxpliquer comment charger les docs dans le RAG

</div>
66 changes: 0 additions & 66 deletions backend/_logs.py

This file was deleted.

1 change: 0 additions & 1 deletion backend/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
"""Processes the given query through the answer chain and returns the formatted response."""
{"content": answer_chain.run(query), }
return answer_chain.run(query)


Expand Down
2 changes: 1 addition & 1 deletion backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ embedding_model_config:

vector_store_provider:
model_source: Chroma
persist_directory: database/
persist_directory: vector_database/

chat_message_history_config:
source: ChatMessageHistory
Expand Down
12 changes: 7 additions & 5 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from datetime import datetime, timedelta
from pathlib import Path
from typing import List
from uuid import uuid4

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt

import backend.document_store as document_store
from backend.document_store import StorageBackend
from backend.config_renderer import get_config
from backend.document_store import StorageBackend, store_documents
from backend.model import Doc, Message
from backend.rag_components.document_loader import generate_response
from backend.user_management import (
ALGORITHM,
SECRET_KEY,
Expand Down Expand Up @@ -129,14 +131,14 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)

#TODO : faire la réposne du llm
config = get_config()

model_response = Message(
id=str(uuid4()),
timestamp=datetime.now().isoformat(),
chat_id=message.chat_id,
sender="assistant",
content=f"Unique response: {uuid4()}",
content=response,
)

with Database() as connection:
Expand Down
46 changes: 1 addition & 45 deletions backend/rag_components/chat_message_history.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import json
import os
from datetime import datetime
from typing import Any
from uuid import uuid4

from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessage, _message_to_dict
from sqlalchemy import Column, DateTime, Text
from sqlalchemy.orm import declarative_base

TABLE_NAME = "message"
TABLE_NAME = "message_history"


def get_conversation_buffer_memory(config, chat_id):
Expand All @@ -29,39 +20,4 @@ def get_chat_message_history(chat_id):
session_id=chat_id,
connection_string=os.environ.get("DATABASE_CONNECTION_STRING"),
table_name=TABLE_NAME,
session_id_field_name="chat_id",
custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME),
)


Base = declarative_base()


class CustomMessage(Base):
__tablename__ = TABLE_NAME

id = Column(
Text, primary_key=True, default=lambda: str(uuid4())
) # default=lambda: str(uuid4())
timestamp = Column(DateTime)
chat_id = Column(Text)
sender = Column(Text)
content = Column(Text)
message = Column(Text)


class CustomMessageConverter(DefaultMessageConverter):
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
print(message.content)
sub_message = json.loads(message.content)
return CustomMessage(
id=sub_message["id"],
timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"),
chat_id=session_id,
sender=message.type,
content=sub_message["content"],
message=json.dumps(_message_to_dict(message)),
)

def get_sql_model_class(self) -> Any:
return CustomMessage
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import inspect
from pathlib import Path
from time import sleep
from typing import List

from langchain.chains import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.prompts import PromptTemplate
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import filter_complex_metadata


def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore):
documents = get_documents(file_path, llm)
filtered_documents = filter_complex_metadata(documents)
vector_store.add_documents(documents)


def get_documents(file_path: Path, llm: BaseChatModel):
file_extension = file_path.suffix
loader_class_name = get_best_loader(file_extension, llm)
print(f"loader selected {loader_class_name} for {file_path}")

if loader_class_name == "None":
raise Exception(f"No loader found for {file_extension} files.")
Expand Down Expand Up @@ -64,21 +56,12 @@ def get_loaders() -> List[str]:
from pathlib import Path

from backend.config_renderer import get_config
from backend.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store
from frontend.lib.chat import Message

config = get_config()
llm = get_llm_model(config)
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings)

document = load_document(
file_path=Path(
"/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv"
),
llm=llm,
vector_store=vector_store,
)
print(document)
data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv")
prompt = "Quelles sont les 5 plus grandes fortunes de France ?"
chat_id = "test"
input_query = Message("user", prompt, chat_id)
response = generate_response(data_to_store, config, input_query)
print(response)
48 changes: 48 additions & 0 deletions backend/rag_components/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path
from typing import List

from langchain.docstore.document import Document
from langchain.vectorstores.utils import filter_complex_metadata

from backend.config_renderer import get_config
from backend.rag_components.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store


class RAG:
def __init__(self):
self.config = get_config()
self.llm = get_llm_model(self.config)
self.embeddings = get_embedding_model(self.config)
self.vector_store = get_vector_store(self.embeddings, self.config)

def generate_response():
pass

def load_documents(self, documents: List[Document]):
# TODO améliorer la robustesse du load_document
# TODO agent langchain qui fait le get_best_loader
self.vector_store.add_documents(documents)

def load_file(self, file_path: Path):
documents = get_documents(file_path, self.llm)
filtered_documents = filter_complex_metadata(documents)
self.vector_store.add_documents(filtered_documents)

# TODO pour chaque fichier -> stocker un hash en base
# TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant
# TODO éviter de dupliquer les embeddings

def serve():
pass


if __name__ == "__main__":
file_path = Path(__file__).parent.parent.parent / "data"
rag = RAG()

for file in file_path.iterdir():
if file.is_file():
rag.load_file(file)
1 change: 0 additions & 1 deletion database/database_init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ CREATE TABLE IF NOT EXISTS "message" (
"chat_id" TEXT,
"sender" TEXT,
"content" TEXT,
"message" TEXT,
FOREIGN KEY ("chat_id") REFERENCES "chat" ("id")
);

Expand Down
Loading

0 comments on commit fa9c948

Please sign in to comment.