Skip to content

Commit

Permalink
modifs memory
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-lauzeral committed Dec 20, 2023
1 parent fe4944d commit 960d377
Show file tree
Hide file tree
Showing 15 changed files with 421 additions and 59 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ export PYTHONPATH="/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.
python "/Users/sarah.lauzeral/Library/CloudStorage/[email protected]/Mon Drive/internal_projects/skaff-rag-accelerator/backend/main.py"
```

- comment mettre des docs dans le chatbot
- comment lancer l'API
- gestion de la config
- écrire des helpers de co, pour envoyer des messages...

</div>
18 changes: 12 additions & 6 deletions backend/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chat_models.base import SystemMessage, HumanMessage
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.chat_models.base import HumanMessage, SystemMessage
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain.vectorstores import VectorStore

from backend.config_renderer import get_config
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store
from backend.rag_components.chat_message_history import get_conversation_buffer_memory


def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
"""Processes the given query through the answer chain and returns the formatted response."""
{"content": answer_chain.run(query), }
return answer_chain.run(query)


Expand Down Expand Up @@ -69,14 +74,15 @@ def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetri


if __name__ == "__main__":
chat_id = "test"
config = get_config()
llm = get_llm_model(config)
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings)
memory = get_conversation_buffer_memory(config)
vector_store = get_vector_store(embeddings, config)
memory = get_conversation_buffer_memory(config, chat_id)
answer_chain = get_answer_chain(llm, vector_store, memory)

prompt = "Give me the top 5 bilionnaires in france based on their worth in order of decreasing net worth"
response = get_response(answer_chain, prompt)
print("Prompt :", prompt)
print("Prompt: ", prompt)
print("Response: ", response)
31 changes: 21 additions & 10 deletions backend/document_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
from typing import List

from langchain.chains import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.prompts import PromptTemplate
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import filter_complex_metadata
from langchain.chat_models.base import BaseChatModel


def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore):
documents = get_documents(file_path, llm)
filtered_documents = filter_complex_metadata(documents)
vector_store.add_documents(documents)


def get_documents(file_path: Path, llm: BaseChatModel):
file_extension = file_path.suffix
loader_class_name = get_best_loader(file_extension, llm)
Expand All @@ -25,19 +27,24 @@ def get_documents(file_path: Path, llm: BaseChatModel):
loader = loader_class(str(file_path))
return loader.load()


def get_loader_class(loader_class_name: str):
import langchain.document_loaders

loader_class = getattr(langchain.document_loaders, loader_class_name)
return loader_class


def get_best_loader(file_extension: str, llm: BaseChatModel):
loaders = get_loaders()
prompt = PromptTemplate(input_variables=["file_extension", "loaders"], template="""
prompt = PromptTemplate(
input_variables=["file_extension", "loaders"],
template="""
Among the following loaders, which is the best to load a "{file_extension}" file? Only give me one the class name without any other special characters. If no relevant loader is found, respond "None".
Loaders: {loaders}
""")
""",
)
chain = LLMChain(llm=llm, prompt=prompt, output_key="loader_class_name")

return chain({"file_extension": file_extension, "loaders": loaders})["loader_class_name"]
Expand All @@ -52,22 +59,26 @@ def get_loaders() -> List[str]:
loaders.append(obj.__name__)
return loaders


if __name__ == "__main__":
from pathlib import Path

from backend.config_renderer import get_config
from backend.rag_components.llm import get_llm_model
from backend.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store
from backend.document_loader import get_documents

config = get_config()
llm = get_llm_model(config)
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings)

document = load_document(
file_path=Path("/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv"),
llm=llm,
vector_store=vector_store
file_path=Path(
"/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv"
),
llm=llm,
vector_store=vector_store,
)
print(document)
print(document)
14 changes: 12 additions & 2 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from jose import JWTError, jwt

import backend.document_store as document_store
from database.database import Database
from backend.document_store import StorageBackend
from backend.model import Doc, Message
from backend.user_management import (
Expand All @@ -20,6 +19,7 @@
get_user,
user_exists,
)
from database.database import Database

app = FastAPI()

Expand Down Expand Up @@ -107,6 +107,7 @@ async def user_me(current_user: User = Depends(get_current_user)) -> User:
### Chat ###
############################################


@app.post("/chat/new")
async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
chat_id = str(uuid4())
Expand All @@ -119,13 +120,16 @@ async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
)
return {"chat_id": chat_id}


@app.post("/chat/{chat_id}/user_message")
async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict:
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)

#TODO : faire la réposne du llm

model_response = Message(
id=str(uuid4()),
Expand All @@ -138,7 +142,13 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(model_response.id, model_response.timestamp, model_response.chat_id, model_response.sender, model_response.content),
(
model_response.id,
model_response.timestamp,
model_response.chat_id,
model_response.sender,
model_response.content,
),
)
return {"message": model_response}

Expand Down
3 changes: 3 additions & 0 deletions backend/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from datetime import datetime
from uuid import uuid4

from langchain.docstore.document import Document
from pydantic import BaseModel


class Message(BaseModel):
id: str
timestamp: str
chat_id: str
sender: str
content: str


class Doc(BaseModel):
"""Represents a document with content and associated metadata."""

Expand Down
74 changes: 62 additions & 12 deletions backend/rag_components/chat_message_history.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,67 @@
from langchain.memory import chat_message_histories
import json
import os
from datetime import datetime
from typing import Any
from uuid import uuid4

from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessage, _message_to_dict
from sqlalchemy import Column, DateTime, Text
from sqlalchemy.orm import declarative_base

TABLE_NAME = "message"

def get_chat_message_history(config):
spec = getattr(chat_message_histories, config["chat_message_history_config"]["source"])
kwargs = {
key: value for key, value in config["chat_message_history_config"].items() if key in spec.__fields__.keys()
}
return spec(**kwargs)

def get_conversation_buffer_memory(config):
def get_conversation_buffer_memory(config, chat_id):
return ConversationBufferWindowMemory(
memory_key="chat_history",
chat_memory=get_chat_message_history(config),
memory_key="chat_history",
chat_memory=get_chat_message_history(chat_id),
return_messages=True,
k=config["chat_message_history_config"]["window_size"]
)
k=config["chat_message_history_config"]["window_size"],
)


def get_chat_message_history(chat_id):
return SQLChatMessageHistory(
session_id=chat_id,
connection_string=os.environ.get("DATABASE_CONNECTION_STRING"),
table_name=TABLE_NAME,
session_id_field_name="chat_id",
custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME),
)


Base = declarative_base()


class CustomMessage(Base):
__tablename__ = TABLE_NAME

id = Column(
Text, primary_key=True, default=lambda: str(uuid4())
) # default=lambda: str(uuid4())
timestamp = Column(DateTime)
chat_id = Column(Text)
sender = Column(Text)
content = Column(Text)
message = Column(Text)


class CustomMessageConverter(DefaultMessageConverter):
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
print(message.content)
sub_message = json.loads(message.content)
return CustomMessage(
id=sub_message["id"],
timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"),
chat_id=session_id,
sender=message.type,
content=sub_message["content"],
message=json.dumps(_message_to_dict(message)),
)

def get_sql_model_class(self) -> Any:
return CustomMessage
4 changes: 1 addition & 3 deletions backend/rag_components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ def get_embedding_model(config):
spec = getattr(embeddings, config["embedding_model_config"]["model_source"])
all_config_field = {**config["embedding_model_config"], **config["embedding_provider_config"]}
kwargs = {
key: value
for key, value in all_config_field.items()
if key in spec.__fields__.keys()
key: value for key, value in all_config_field.items() if key in spec.__fields__.keys()
}
return spec(**kwargs)
4 changes: 1 addition & 3 deletions backend/rag_components/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import inspect

from config_renderer import get_config
from langchain import vectorstores


def get_vector_store(embedding_model):
config = get_config()
def get_vector_store(embedding_model, config):
vector_store_spec = getattr(vectorstores, config["vector_store_provider"]["model_source"])
all_config_field = config["vector_store_provider"]

Expand Down
3 changes: 2 additions & 1 deletion database/database_init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ CREATE TABLE IF NOT EXISTS "chat" (

CREATE TABLE IF NOT EXISTS "message" (
"id" TEXT PRIMARY KEY,
"timestamp" TEXT,
"timestamp" DATETIME,
"chat_id" TEXT,
"sender" TEXT,
"content" TEXT,
"message" TEXT,
FOREIGN KEY ("chat_id") REFERENCES "chat" ("id")
);

Expand Down
1 change: 0 additions & 1 deletion frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@
)

chat()

Loading

0 comments on commit 960d377

Please sign in to comment.