From 960d3772e60a42c3b2b4b80b8486a267041c5393 Mon Sep 17 00:00:00 2001 From: Sarah LAUZERAL Date: Wed, 20 Dec 2023 18:14:21 +0100 Subject: [PATCH] modifs memory --- README.md | 5 + backend/chatbot.py | 18 +- backend/document_loader.py | 31 ++- backend/main.py | 14 +- backend/model.py | 3 + .../rag_components/chat_message_history.py | 74 ++++- backend/rag_components/embedding.py | 4 +- backend/rag_components/vector_store.py | 4 +- database/database_init.sql | 3 +- frontend/app.py | 1 - frontend/lib/auth.py | 32 ++- frontend/lib/chat.py | 24 +- notebooks/memory_tests.ipynb | 263 ++++++++++++++++++ tests/test_feedback.py | 2 +- tests/test_users.py | 2 +- 15 files changed, 421 insertions(+), 59 deletions(-) create mode 100644 notebooks/memory_tests.ipynb diff --git a/README.md b/README.md index aed26bc..49d1697 100644 --- a/README.md +++ b/README.md @@ -10,4 +10,9 @@ export PYTHONPATH="/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah. python "/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@artefact.com/Mon Drive/internal_projects/skaff-rag-accelerator/backend/main.py" ``` +- comment mettre des docs dans le chatbot +- comment lancer l'API +- gestion de la config +- écrire des helpers de co, pour envoyer des messages... + diff --git a/backend/chatbot.py b/backend/chatbot.py index da03c70..28c3988 100644 --- a/backend/chatbot.py +++ b/backend/chatbot.py @@ -1,18 +1,23 @@ from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.chat_models.base import SystemMessage, HumanMessage -from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate +from langchain.chat_models.base import HumanMessage, SystemMessage +from langchain.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, +) from langchain.vectorstores import VectorStore from backend.config_renderer import get_config +from backend.rag_components.chat_message_history import get_conversation_buffer_memory from backend.rag_components.embedding import get_embedding_model from backend.rag_components.llm import get_llm_model from backend.rag_components.vector_store import get_vector_store -from backend.rag_components.chat_message_history import get_conversation_buffer_memory def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str: """Processes the given query through the answer chain and returns the formatted response.""" + {"content": answer_chain.run(query), } return answer_chain.run(query) @@ -69,14 +74,15 @@ def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetri if __name__ == "__main__": + chat_id = "test" config = get_config() llm = get_llm_model(config) embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings) - memory = get_conversation_buffer_memory(config) + vector_store = get_vector_store(embeddings, config) + memory = get_conversation_buffer_memory(config, chat_id) answer_chain = get_answer_chain(llm, vector_store, memory) prompt = "Give me the top 5 bilionnaires in france based on their worth in order of decreasing net worth" response = get_response(answer_chain, prompt) - print("Prompt :", prompt) + print("Prompt: ", prompt) print("Response: ", response) diff --git a/backend/document_loader.py b/backend/document_loader.py index 68964c0..a960a7f 100644 --- a/backend/document_loader.py +++ b/backend/document_loader.py @@ -4,16 +4,18 @@ from typing import List from langchain.chains import LLMChain +from langchain.chat_models.base import BaseChatModel from langchain.prompts import PromptTemplate from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import filter_complex_metadata -from langchain.chat_models.base import BaseChatModel + def load_document(file_path: Path, llm: BaseChatModel, vector_store: VectorStore): documents = get_documents(file_path, llm) filtered_documents = filter_complex_metadata(documents) vector_store.add_documents(documents) + def get_documents(file_path: Path, llm: BaseChatModel): file_extension = file_path.suffix loader_class_name = get_best_loader(file_extension, llm) @@ -25,19 +27,24 @@ def get_documents(file_path: Path, llm: BaseChatModel): loader = loader_class(str(file_path)) return loader.load() + def get_loader_class(loader_class_name: str): import langchain.document_loaders loader_class = getattr(langchain.document_loaders, loader_class_name) return loader_class + def get_best_loader(file_extension: str, llm: BaseChatModel): loaders = get_loaders() - prompt = PromptTemplate(input_variables=["file_extension", "loaders"], template=""" + prompt = PromptTemplate( + input_variables=["file_extension", "loaders"], + template=""" Among the following loaders, which is the best to load a "{file_extension}" file? Only give me one the class name without any other special characters. If no relevant loader is found, respond "None". Loaders: {loaders} - """) + """, + ) chain = LLMChain(llm=llm, prompt=prompt, output_key="loader_class_name") return chain({"file_extension": file_extension, "loaders": loaders})["loader_class_name"] @@ -52,22 +59,26 @@ def get_loaders() -> List[str]: loaders.append(obj.__name__) return loaders + if __name__ == "__main__": from pathlib import Path + from backend.config_renderer import get_config - from backend.rag_components.llm import get_llm_model + from backend.document_loader import get_documents from backend.rag_components.embedding import get_embedding_model + from backend.rag_components.llm import get_llm_model from backend.rag_components.vector_store import get_vector_store - from backend.document_loader import get_documents config = get_config() llm = get_llm_model(config) embeddings = get_embedding_model(config) vector_store = get_vector_store(embeddings) - + document = load_document( - file_path=Path("/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv"), - llm=llm, - vector_store=vector_store + file_path=Path( + "/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/billionaires_csv.csv" + ), + llm=llm, + vector_store=vector_store, ) - print(document) \ No newline at end of file + print(document) diff --git a/backend/main.py b/backend/main.py index e6187ed..cef545a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,7 +7,6 @@ from jose import JWTError, jwt import backend.document_store as document_store -from database.database import Database from backend.document_store import StorageBackend from backend.model import Doc, Message from backend.user_management import ( @@ -20,6 +19,7 @@ get_user, user_exists, ) +from database.database import Database app = FastAPI() @@ -107,6 +107,7 @@ async def user_me(current_user: User = Depends(get_current_user)) -> User: ### Chat ### ############################################ + @app.post("/chat/new") async def chat_new(current_user: User = Depends(get_current_user)) -> dict: chat_id = str(uuid4()) @@ -119,6 +120,7 @@ async def chat_new(current_user: User = Depends(get_current_user)) -> dict: ) return {"chat_id": chat_id} + @app.post("/chat/{chat_id}/user_message") async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict: with Database() as connection: @@ -126,6 +128,8 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", (message.id, message.timestamp, message.chat_id, message.sender, message.content), ) + + #TODO : faire la réposne du llm model_response = Message( id=str(uuid4()), @@ -138,7 +142,13 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current with Database() as connection: connection.query( "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", - (model_response.id, model_response.timestamp, model_response.chat_id, model_response.sender, model_response.content), + ( + model_response.id, + model_response.timestamp, + model_response.chat_id, + model_response.sender, + model_response.content, + ), ) return {"message": model_response} diff --git a/backend/model.py b/backend/model.py index 6c49ec1..5e24a4f 100644 --- a/backend/model.py +++ b/backend/model.py @@ -1,8 +1,10 @@ from datetime import datetime from uuid import uuid4 + from langchain.docstore.document import Document from pydantic import BaseModel + class Message(BaseModel): id: str timestamp: str @@ -10,6 +12,7 @@ class Message(BaseModel): sender: str content: str + class Doc(BaseModel): """Represents a document with content and associated metadata.""" diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index a04fb78..e6e6774 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -1,17 +1,67 @@ -from langchain.memory import chat_message_histories +import json +import os +from datetime import datetime +from typing import Any +from uuid import uuid4 + from langchain.memory import ConversationBufferWindowMemory +from langchain.memory.chat_message_histories import SQLChatMessageHistory +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter +from langchain.schema import BaseMessage +from langchain.schema.messages import BaseMessage, _message_to_dict +from sqlalchemy import Column, DateTime, Text +from sqlalchemy.orm import declarative_base + +TABLE_NAME = "message" -def get_chat_message_history(config): - spec = getattr(chat_message_histories, config["chat_message_history_config"]["source"]) - kwargs = { - key: value for key, value in config["chat_message_history_config"].items() if key in spec.__fields__.keys() - } - return spec(**kwargs) -def get_conversation_buffer_memory(config): +def get_conversation_buffer_memory(config, chat_id): return ConversationBufferWindowMemory( - memory_key="chat_history", - chat_memory=get_chat_message_history(config), + memory_key="chat_history", + chat_memory=get_chat_message_history(chat_id), return_messages=True, - k=config["chat_message_history_config"]["window_size"] - ) \ No newline at end of file + k=config["chat_message_history_config"]["window_size"], + ) + + +def get_chat_message_history(chat_id): + return SQLChatMessageHistory( + session_id=chat_id, + connection_string=os.environ.get("DATABASE_CONNECTION_STRING"), + table_name=TABLE_NAME, + session_id_field_name="chat_id", + custom_message_converter=CustomMessageConverter(table_name=TABLE_NAME), + ) + + +Base = declarative_base() + + +class CustomMessage(Base): + __tablename__ = TABLE_NAME + + id = Column( + Text, primary_key=True, default=lambda: str(uuid4()) + ) # default=lambda: str(uuid4()) + timestamp = Column(DateTime) + chat_id = Column(Text) + sender = Column(Text) + content = Column(Text) + message = Column(Text) + + +class CustomMessageConverter(DefaultMessageConverter): + def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: + print(message.content) + sub_message = json.loads(message.content) + return CustomMessage( + id=sub_message["id"], + timestamp=datetime.strptime(sub_message["timestamp"], "%Y-%m-%d %H:%M:%S.%f"), + chat_id=session_id, + sender=message.type, + content=sub_message["content"], + message=json.dumps(_message_to_dict(message)), + ) + + def get_sql_model_class(self) -> Any: + return CustomMessage diff --git a/backend/rag_components/embedding.py b/backend/rag_components/embedding.py index 346d2b7..dab7715 100644 --- a/backend/rag_components/embedding.py +++ b/backend/rag_components/embedding.py @@ -5,8 +5,6 @@ def get_embedding_model(config): spec = getattr(embeddings, config["embedding_model_config"]["model_source"]) all_config_field = {**config["embedding_model_config"], **config["embedding_provider_config"]} kwargs = { - key: value - for key, value in all_config_field.items() - if key in spec.__fields__.keys() + key: value for key, value in all_config_field.items() if key in spec.__fields__.keys() } return spec(**kwargs) diff --git a/backend/rag_components/vector_store.py b/backend/rag_components/vector_store.py index 823bbf0..63db85a 100644 --- a/backend/rag_components/vector_store.py +++ b/backend/rag_components/vector_store.py @@ -1,11 +1,9 @@ import inspect -from config_renderer import get_config from langchain import vectorstores -def get_vector_store(embedding_model): - config = get_config() +def get_vector_store(embedding_model, config): vector_store_spec = getattr(vectorstores, config["vector_store_provider"]["model_source"]) all_config_field = config["vector_store_provider"] diff --git a/database/database_init.sql b/database/database_init.sql index d3f17c9..8a96b21 100644 --- a/database/database_init.sql +++ b/database/database_init.sql @@ -19,10 +19,11 @@ CREATE TABLE IF NOT EXISTS "chat" ( CREATE TABLE IF NOT EXISTS "message" ( "id" TEXT PRIMARY KEY, - "timestamp" TEXT, + "timestamp" DATETIME, "chat_id" TEXT, "sender" TEXT, "content" TEXT, + "message" TEXT, FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") ); diff --git a/frontend/app.py b/frontend/app.py index 6381684..bad3bda 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -37,4 +37,3 @@ ) chat() - diff --git a/frontend/lib/auth.py b/frontend/lib/auth.py index bb23909..8880bcc 100644 --- a/frontend/lib/auth.py +++ b/frontend/lib/auth.py @@ -1,20 +1,22 @@ import os from typing import Optional -import requests -from requests.sessions import Session from urllib.parse import urljoin - -import streamlit as st import extra_streamlit_components as stx +import requests +import streamlit as st +from requests.sessions import Session FASTAPI_URL = os.getenv("FASTAPI_URL", "http://localhost:8000/") + def auth() -> Optional[str]: - tab = stx.tab_bar(data=[ + tab = stx.tab_bar( + data=[ stx.TabBarItemData(id="Login", title="Login", description=""), - stx.TabBarItemData(id="Signup", title="Signup", description="") - ], default="Login" + stx.TabBarItemData(id="Signup", title="Signup", description=""), + ], + default="Login", ) if tab == "Login": return login_form() @@ -24,6 +26,7 @@ def auth() -> Optional[str]: st.error("Invalid auth mode") return None + def login_form() -> tuple[bool, Optional[str]]: with st.form("Login"): username = st.text_input("Username", key="username") @@ -68,7 +71,7 @@ def get_token(username: str, password: str) -> Optional[str]: return response.json()["access_token"] else: return None - + def sign_up(username: str, password: str) -> bool: session = create_session() @@ -77,24 +80,27 @@ def sign_up(username: str, password: str) -> bool: return True else: return False - + + def create_session() -> requests.Session: session = BaseUrlSession(FASTAPI_URL) return session + def authenticate_session(session, bearer_token: str) -> requests.Session: session.headers.update({"Authorization": f"Bearer {bearer_token}"}) return session + class BaseUrlSession(Session): def __init__(self, base_url): super().__init__() self.base_url = base_url def request(self, method, url, *args, **kwargs): - if not self.base_url.endswith('/'): - self.base_url += '/' - if url.startswith('/'): + if not self.base_url.endswith("/"): + self.base_url += "/" + if url.startswith("/"): url = url[1:] url = urljoin(self.base_url, url) - return super().request(method, url, *args, **kwargs) \ No newline at end of file + return super().request(method, url, *args, **kwargs) diff --git a/frontend/lib/chat.py b/frontend/lib/chat.py index 47d1c73..3d3d1d6 100644 --- a/frontend/lib/chat.py +++ b/frontend/lib/chat.py @@ -1,11 +1,11 @@ -from uuid import uuid4 +from dataclasses import asdict, dataclass from datetime import datetime +from uuid import uuid4 import streamlit as st - -from dataclasses import dataclass, asdict from streamlit_feedback import streamlit_feedback + @dataclass class Message: sender: str @@ -18,6 +18,7 @@ def __post_init__(self): self.id = str(uuid4()) if self.id is None else self.id self.timestamp = datetime.now().isoformat() if self.timestamp is None else self.timestamp + def chat(): prompt = st.chat_input("Say something") @@ -35,8 +36,17 @@ def chat(): for message in st.session_state.get("messages", []): with st.chat_message(message.sender): st.write(message.content) - if len(st.session_state.get("messages", [])) > 0 and len(st.session_state.get("messages")) % 2 == 0: - streamlit_feedback(key=str(len(st.session_state.get("messages"))), feedback_type="thumbs", on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback)) + if ( + len(st.session_state.get("messages", [])) > 0 + and len(st.session_state.get("messages")) % 2 == 0 + ): + streamlit_feedback( + key=str(len(st.session_state.get("messages"))), + feedback_type="thumbs", + on_submit=lambda feedback: send_feedback( + st.session_state.get("messages")[-1].id, feedback + ), + ) def new_chat(): @@ -46,6 +56,7 @@ def new_chat(): st.session_state["messages"] = [] return response.json()["chat_id"] + def send_prompt(message: Message): session = st.session_state.get("session") response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message)) @@ -53,8 +64,9 @@ def send_prompt(message: Message): print(response.text) return response.json()["message"] + def send_feedback(message_id: str, feedback: str): feedback = "thumbs_up" if feedback["score"] == "👍" else "thumbs_down" session = st.session_state.get("session") response = session.post(f"/feedback/{message_id}/{feedback}") - return response.text \ No newline at end of file + return response.text diff --git a/notebooks/memory_tests.ipynb b/notebooks/memory_tests.ipynb new file mode 100644 index 0000000..a9e7443 --- /dev/null +++ b/notebooks/memory_tests.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "current_directory = os.getcwd()\n", + "parent_directory = os.path.dirname(current_directory)\n", + "sys.path.append(parent_directory)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import ConversationalRetrievalChain, LLMChain\n", + "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n", + "from langchain.chat_models.base import SystemMessage, HumanMessage\n", + "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n", + "from langchain.vectorstores import VectorStore\n", + "from langchain.memory.chat_message_histories import SQLChatMessageHistory\n", + "\n", + "from backend.config_renderer import get_config\n", + "from backend.rag_components.embedding import get_embedding_model\n", + "from backend.rag_components.llm import get_llm_model\n", + "from backend.rag_components.vector_store import get_vector_store\n", + "import frontend.lib.auth as auth" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain:\n", + " \"\"\"Returns an instance of ConversationalRetrievalChain based on the provided parameters.\"\"\"\n", + " template = \"\"\"Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.\n", + "\n", + "Chat history :\n", + "{chat_history}\n", + "Question : {question}\n", + "\n", + "Rephrased question :\n", + "\"\"\"\n", + " condense_question_prompt = PromptTemplate.from_template(template)\n", + " condense_question_chain = LLMChain(\n", + " llm=llm,\n", + " prompt=condense_question_prompt,\n", + " )\n", + "\n", + " messages = [\n", + " SystemMessage(\n", + " content=(\n", + " \"\"\"As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.\"\"\"\n", + " )\n", + " ),\n", + " HumanMessage(content=\"Respond to the question taking into account the following context.\"),\n", + " HumanMessagePromptTemplate.from_template(\"{context}\"),\n", + " HumanMessagePromptTemplate.from_template(\"Question: {question}\"),\n", + " ]\n", + " system_prompt = ChatPromptTemplate(messages=messages)\n", + " qa_chain = LLMChain(\n", + " llm=llm,\n", + " prompt=system_prompt,\n", + " )\n", + "\n", + " doc_prompt = PromptTemplate(\n", + " template=\"Content: {page_content}\\nSource: {source}\",\n", + " input_variables=[\"page_content\", \"source\"],\n", + " )\n", + "\n", + " final_qa_chain = StuffDocumentsChain(\n", + " llm_chain=qa_chain,\n", + " document_variable_name=\"context\",\n", + " document_prompt=doc_prompt,\n", + " )\n", + "\n", + " return ConversationalRetrievalChain(\n", + " question_generator=condense_question_chain,\n", + " retriever=docsearch.as_retriever(search_kwargs={\"k\": 10}),\n", + " memory=memory,\n", + " combine_docs_chain=final_qa_chain,\n", + " verbose=True,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "config = get_config()\n", + "llm = get_llm_model(config)\n", + "embeddings = get_embedding_model(config)\n", + "vector_store = get_vector_store(embeddings, config)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "username = \"slauzeral\"\n", + "password = \"test\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "success = auth.sign_up(username, password)\n", + "if success:\n", + " token = auth.get_token(username, password)\n", + " session = auth.create_session()\n", + " auth_session = auth.authenticate_session(session, token)\n", + "\n", + "response = auth_session.post(\"/chat/new\")\n", + "chat_id = response.json()[\"chat_id\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from uuid import uuid4\n", + "from typing import Any\n", + "\n", + "from langchain.memory.chat_message_histories.sql import DefaultMessageConverter\n", + "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", + "from sqlalchemy import Column, DateTime, Integer, Text\n", + "from sqlalchemy.orm import declarative_base\n", + "from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict\n", + "import json\n", + "\n", + "Base = declarative_base()\n", + "\n", + "class CustomMessage(Base):\n", + " __tablename__ = \"message_test\"\n", + "\n", + " id = Column(Text, primary_key=True, default=lambda: str(uuid4())) # default=lambda: str(uuid4())\n", + " timestamp = Column(DateTime)\n", + " chat_id = Column(Text)\n", + " sender = Column(Text)\n", + " content = Column(Text)\n", + " message = Column(Text)\n", + "\n", + "\n", + "class CustomMessageConverter(DefaultMessageConverter):\n", + "\n", + " def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n", + " sub_message = json.loads(message.content)\n", + " return CustomMessage(\n", + " id = sub_message[\"id\"],\n", + " timestamp = datetime.strptime(sub_message[\"timestamp\"], \"%Y-%m-%d %H:%M:%S.%f\"),\n", + " chat_id = session_id,\n", + " sender = message.type,\n", + " content = sub_message[\"content\"],\n", + " message = json.dumps(_message_to_dict(message)),\n", + " )\n", + "\n", + " def get_sql_model_class(self) -> Any:\n", + " return CustomMessage\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "chat_message_history = SQLChatMessageHistory(\n", + " session_id=chat_id,\n", + " connection_string=\"sqlite:////Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@artefact.com/Mon Drive/internal_projects/skaff-rag-accelerator/database/database.sqlite\",\n", + " table_name=\"message_test\",\n", + " session_id_field_name=\"chat_id\",\n", + " custom_message_converter=CustomMessageConverter(table_name=\"message_test\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "chat_message_history.add_ai_message(json.dumps({\"content\":\"Hi\", \"timestamp\":f\"{datetime.utcnow()}\", \"id\":\"764528762\"}))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[AIMessage(content='{\"content\": \"Hi\", \"timestamp\": \"2023-12-20 16:26:23.672506\", \"id\": \"764528762\"}')]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "skaff-rag-accelerator", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_feedback.py b/tests/test_feedback.py index 46695ac..9993026 100644 --- a/tests/test_feedback.py +++ b/tests/test_feedback.py @@ -4,9 +4,9 @@ import pytest from fastapi.testclient import TestClient +from lib.main import app from database.database import Database -from lib.main import app os.environ["TESTING"] = "True" client = TestClient(app) diff --git a/tests/test_users.py b/tests/test_users.py index 76920a7..f39e802 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -3,9 +3,9 @@ import pytest from fastapi.testclient import TestClient +from lib.main import app from database.database import Database -from lib.main import app os.environ["TESTING"] = "True" client = TestClient(app)