From de0c08a2b988fe7c66e2924e5f887fdc12049189 Mon Sep 17 00:00:00 2001 From: Alexis VIALARET Date: Tue, 2 Jan 2024 17:44:18 +0100 Subject: [PATCH] upd: differenciated async and streaming --- backend/config.yaml | 33 ++- backend/main.py | 18 +- backend/rag_components/chain.py | 27 +- backend/rag_components/prompts.py | 4 + backend/rag_components/rag.py | 15 +- .../streaming_callback_handler.py | 11 + lel.py | 16 -- notebooks/docs_loader.ipynb | 239 ----------------- notebooks/memory_tests.ipynb | 252 ------------------ notebooks/test_auth.ipynb | 125 --------- 10 files changed, 80 insertions(+), 660 deletions(-) create mode 100644 backend/rag_components/streaming_callback_handler.py delete mode 100644 lel.py delete mode 100644 notebooks/docs_loader.ipynb delete mode 100644 notebooks/memory_tests.ipynb delete mode 100644 notebooks/test_auth.ipynb diff --git a/backend/config.yaml b/backend/config.yaml index abf0fe7..0f2c74c 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -1,28 +1,37 @@ +# LLMConfig: &LLMConfig +# source: ChatVertexAI +# source_config: +# model_name: gemini-pro +# temperature: 0.1 + LLMConfig: &LLMConfig - source: "ChatVertexAI" + source: AzureChatOpenAI source_config: - model_name: google/gemini-pro - temperature: 0.1 + openai_api_type: azure + openai_api_key: {{ OPENAI_API_KEY }} + openai_api_base: https://genai-ds.openai.azure.com/ + openai_api_version: 2023-07-01-preview + deployment_name: gpt4 VectorStoreConfig: &VectorStoreConfig - source: "Chroma" + source: Chroma source_config: - persist_directory: "vector_database/" + persist_directory: vector_database/ collection_metadata: - hnsw:space: "cosine" - retreiver_search_type: "similarity" + hnsw:space: cosine + retreiver_search_type: similarity retreiver_config: top_k: 20 score_threshold: 0.5 - insertion_mode: "full" + insertion_mode: full EmbeddingModelConfig: &EmbeddingModelConfig - source: "OpenAIEmbeddings" + source: OpenAIEmbeddings source_config: - openai_api_type: "azure" + openai_api_type: azure openai_api_key: {{ EMBEDDING_API_KEY }} - openai_api_base: "https://poc-openai-artefact.openai.azure.com/" - deployment: "embeddings" + openai_api_base: https://poc-openai-artefact.openai.azure.com/ + deployment: embeddings chunk_size: 500 DatabaseConfig: &DatabaseConfig diff --git a/backend/main.py b/backend/main.py index 9d3e53b..d821136 100644 --- a/backend/main.py +++ b/backend/main.py @@ -89,7 +89,7 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context) response = rag.async_generate_response(message) - return StreamingResponse(streamed_llm_response(message.chat_id, response), media_type="text/event-stream") + return StreamingResponse(async_llm_response(message.chat_id, response), media_type="text/event-stream") @app.post("/chat/regenerate") @@ -130,14 +130,20 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]} -async def streamed_llm_response(chat_id, answer_chain): +async def async_llm_response(chat_id, answer_chain): full_response = "" - async for data in answer_chain: - full_response += data - yield data.encode("utf-8") + response_id = str(uuid4()) + try: + async for data in answer_chain: + full_response += data + yield data.encode("utf-8") + except Exception as e: + logger.error(f"Error generating response for chat {chat_id}: {e}") + full_response = f"Sorry, there was an error generating a response. Please contact an administrator and tell them the following error code: {response_id}, and message: {str(e)}" + yield full_response.encode("utf-8") model_response = Message( - id=str(uuid4()), + id=response_id, timestamp=datetime.now().isoformat(), chat_id=chat_id, sender="assistant", diff --git a/backend/rag_components/chain.py b/backend/rag_components/chain.py index a946672..f959f8d 100644 --- a/backend/rag_components/chain.py +++ b/backend/rag_components/chain.py @@ -1,8 +1,9 @@ import asyncio +from threading import Thread +from time import sleep + from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.chains.combine_documents.reduce import ReduceDocumentsChain -from langchain.chat_models.base import SystemMessage from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -17,7 +18,7 @@ -async def get_response_stream(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: +async def async_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: run = asyncio.create_task(chain.arun({"question": query})) async for token in streaming_callback_handler.aiter(): @@ -26,18 +27,30 @@ async def get_response_stream(chain: ConversationalRetrievalChain, query: str, s await run -def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain: +def stream_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: + thread = Thread(target=lambda chain, query: chain.run({"question": query}), args=(chain, query)) + thread.start() + + while thread.is_alive() or not streaming_callback_handler.queue.empty(): + if not streaming_callback_handler.queue.empty(): + yield streaming_callback_handler.queue.get() + else: + sleep(0.1) + + thread.join() + +def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler = None, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain: callbacks = [logging_callback_handler] if logging_callback_handler is not None else [] - + streaming_callback = [streaming_callback_handler] if streaming_callback_handler is not None else [] + condense_question_prompt = PromptTemplate.from_template(prompts.condense_history) condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt, callbacks=callbacks) messages = [ - SystemMessage(content=prompts.rag_system_prompt), HumanMessagePromptTemplate.from_template(prompts.respond_to_question), ] question_answering_prompt = ChatPromptTemplate(messages=messages) - streaming_llm = get_llm_model(config, callbacks=[streaming_callback_handler] + callbacks) + streaming_llm = get_llm_model(config, callbacks=streaming_callback + callbacks) question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt, callbacks=callbacks) context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"]) diff --git a/backend/rag_components/prompts.py b/backend/rag_components/prompts.py index 98af668..1d14cd4 100644 --- a/backend/rag_components/prompts.py +++ b/backend/rag_components/prompts.py @@ -14,6 +14,10 @@ """ respond_to_question = """ +As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. +It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. + + Respond to the question taking into account the following context. {context} diff --git a/backend/rag_components/rag.py b/backend/rag_components/rag.py index 420bd2f..26d1809 100644 --- a/backend/rag_components/rag.py +++ b/backend/rag_components/rag.py @@ -7,7 +7,7 @@ from langchain.docstore.document import Document from langchain.vectorstores.utils import filter_complex_metadata from langchain.callbacks import AsyncIteratorCallbackHandler -from backend.rag_components.chain import get_answer_chain, get_response_stream +from backend.rag_components.chain import get_answer_chain, async_get_response, stream_get_response from langchain.indexes import SQLRecordManager, index from langchain.chat_models.base import BaseChatModel from langchain.vectorstores import VectorStore @@ -21,6 +21,7 @@ from backend.rag_components.embedding import get_embedding_model from backend.rag_components.llm import get_llm_model from backend.rag_components.logging_callback_handler import LoggingCallbackHandler +from backend.rag_components.streaming_callback_handler import StreamingCallbackHandler from backend.rag_components.vector_store import get_vector_store @@ -42,14 +43,22 @@ def generate_response(self, message: Message) -> str: loop = asyncio.get_event_loop() response_stream = self.async_generate_response(message) responses = loop.run_until_complete(self._collect_responses(response_stream)) - return "".join(responses) + return "".join([str(response) for response in responses]) + + def stream_generate_response(self, message: Message) -> AsyncIterator[str]: + memory = get_conversation_buffer_memory(self.config, message.chat_id) + streaming_callback_handler = StreamingCallbackHandler() + logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context) + answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler) + response_stream = stream_get_response(answer_chain, message.content, streaming_callback_handler) + return response_stream def async_generate_response(self, message: Message) -> AsyncIterator[str]: memory = get_conversation_buffer_memory(self.config, message.chat_id) streaming_callback_handler = AsyncIteratorCallbackHandler() logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context) answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler) - response_stream = get_response_stream(answer_chain, message.content, streaming_callback_handler) + response_stream = async_get_response(answer_chain, message.content, streaming_callback_handler) return response_stream def load_file(self, file_path: Path) -> List[Document]: diff --git a/backend/rag_components/streaming_callback_handler.py b/backend/rag_components/streaming_callback_handler.py new file mode 100644 index 0000000..69b615a --- /dev/null +++ b/backend/rag_components/streaming_callback_handler.py @@ -0,0 +1,11 @@ +from multiprocessing import Queue +from typing import AnyStr +from langchain_core.callbacks.base import BaseCallbackHandler + +class StreamingCallbackHandler(BaseCallbackHandler): + queue = Queue() + + def on_llm_new_token(self, token: str, **kwargs: AnyStr) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if token is not None and token != "": + self.queue.put_nowait(token) diff --git a/lel.py b/lel.py deleted file mode 100644 index 5594aa9..0000000 --- a/lel.py +++ /dev/null @@ -1,16 +0,0 @@ -from pathlib import Path -from backend.rag_components.rag import RAG -from backend.model import Message - -config_directory = Path("backend/config.yaml") -rag = RAG(config_directory) - -message = Message( - id="123", - timestamp="2021-06-01T12:00:00", - chat_id="123", - sender="user", - content="Hello, how are you?", -) -response = rag.generate_response(message) -print(response) \ No newline at end of file diff --git a/notebooks/docs_loader.ipynb b/notebooks/docs_loader.ipynb deleted file mode 100644 index 7c08f4e..0000000 --- a/notebooks/docs_loader.ipynb +++ /dev/null @@ -1,239 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "current_directory = os.getcwd()\n", - "parent_directory = os.path.dirname(current_directory)\n", - "sys.path.append(parent_directory)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from backend.rag_components.main import RAG\n", - "from langchain.indexes import SQLRecordManager, index\n", - "from langchain.document_loaders.csv_loader import CSVLoader\n", - "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", - "from dotenv import load_dotenv\n", - "load_dotenv()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rag = RAG()\n", - "rag.vector_store" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "namespace = f\"chromadb/my_docs\"\n", - "record_manager = SQLRecordManager(\n", - " namespace, db_url=os.environ.get(\"DATABASE_CONNECTION_STRING\")\n", - ")\n", - "# pointer le record_manager vers une table dans db sql \n", - "record_manager.create_schema()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "loader = CSVLoader(f\"{parent_directory}/data/billionaires_csv.csv\")\n", - "documents = loader.load()\n", - "text_splitter = RecursiveCharacterTextSplitter(\n", - " separators=[\"\\n\\n\", \"\\n\"], chunk_size=1500, chunk_overlap=100\n", - ")\n", - "texts = text_splitter.split_documents(documents)\n", - "texts[:5]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "loader = CSVLoader(f\"{parent_directory}/data/billionaires_csv_bis.csv\")\n", - "documents = loader.load()\n", - "text_splitter = RecursiveCharacterTextSplitter(\n", - " separators=[\"\\n\\n\", \"\\n\"], chunk_size=1500, chunk_overlap=100\n", - ")\n", - "texts_bis = text_splitter.split_documents(documents)\n", - "texts_bis[:5]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "index(\n", - " [],\n", - " record_manager,\n", - " rag.vector_store,\n", - " cleanup=\"full\", #incremental\n", - " source_id_key=\"source\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "index(\n", - " texts[:100],\n", - " record_manager,\n", - " rag.vector_store,\n", - " cleanup=\"incremental\", #incremental\n", - " source_id_key=\"source\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "index(\n", - " texts_bis[50:100],\n", - " record_manager,\n", - " rag.vector_store,\n", - " cleanup=\"incremental\",\n", - " source_id_key=\"source\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "# print(os.environ.get(\"APIFY_API_TOKEN\"))\n", - "\n", - "from langchain.document_loaders.base import Document\n", - "from langchain.utilities import ApifyWrapper\n", - "from dotenv import load_dotenv\n", - "load_dotenv()\n", - "\n", - "apify = ApifyWrapper()\n", - "\n", - "loader = apify.call_actor(\n", - " actor_id=\"apify/website-content-crawler\",\n", - " run_input={\"startUrls\": [{\"url\": \"https://python.langchain.com/en/latest/modules/indexes/document_loaders.html\"}]},\n", - " dataset_mapping_function=lambda item: Document(\n", - " page_content=item[\"text\"] or \"\", metadata={\"source\": item[\"url\"]}\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "loader #.apify_client()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from apify_client import ApifyClient\n", - "\n", - "apify_client = loader.apify_client\n", - "\n", - "len(apify_client.dataset(loader.dataset_id).list_items().items)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "index(\n", - " [loader],\n", - " record_manager,\n", - " rag.vector_store,\n", - " cleanup=\"incremental\",\n", - " source_id_key=\"source\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "skaff-rag-accelerator", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/memory_tests.ipynb b/notebooks/memory_tests.ipynb deleted file mode 100644 index 520cf4d..0000000 --- a/notebooks/memory_tests.ipynb +++ /dev/null @@ -1,252 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "current_directory = os.getcwd()\n", - "parent_directory = os.path.dirname(current_directory)\n", - "sys.path.append(parent_directory)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.chains import ConversationalRetrievalChain, LLMChain\n", - "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n", - "from langchain.chat_models.base import SystemMessage, HumanMessage\n", - "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n", - "from langchain.vectorstores import VectorStore\n", - "from langchain.memory.chat_message_histories import SQLChatMessageHistory\n", - "\n", - "from backend.config_renderer import get_config\n", - "from backend.rag_components.embedding import get_embedding_model\n", - "from backend.rag_components.llm import get_llm_model\n", - "from backend.rag_components.vector_store import get_vector_store\n", - "import frontend.lib.auth as auth" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain:\n", - " \"\"\"Returns an instance of ConversationalRetrievalChain based on the provided parameters.\"\"\"\n", - " template = \"\"\"Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.\n", - "\n", - "Chat history :\n", - "{chat_history}\n", - "Question : {question}\n", - "\n", - "Rephrased question :\n", - "\"\"\"\n", - " condense_question_prompt = PromptTemplate.from_template(template)\n", - " condense_question_chain = LLMChain(\n", - " llm=llm,\n", - " prompt=condense_question_prompt,\n", - " )\n", - "\n", - " messages = [\n", - " SystemMessage(\n", - " content=(\n", - " \"\"\"As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.\"\"\"\n", - " )\n", - " ),\n", - " HumanMessage(content=\"Respond to the question taking into account the following context.\"),\n", - " HumanMessagePromptTemplate.from_template(\"{context}\"),\n", - " HumanMessagePromptTemplate.from_template(\"Question: {question}\"),\n", - " ]\n", - " system_prompt = ChatPromptTemplate(messages=messages)\n", - " qa_chain = LLMChain(\n", - " llm=llm,\n", - " prompt=system_prompt,\n", - " )\n", - "\n", - " doc_prompt = PromptTemplate(\n", - " template=\"Content: {page_content}\\nSource: {source}\",\n", - " input_variables=[\"page_content\", \"source\"],\n", - " )\n", - "\n", - " final_qa_chain = StuffDocumentsChain(\n", - " llm_chain=qa_chain,\n", - " document_variable_name=\"context\",\n", - " document_prompt=doc_prompt,\n", - " )\n", - "\n", - " return ConversationalRetrievalChain(\n", - " question_generator=condense_question_chain,\n", - " retriever=docsearch.as_retriever(search_kwargs={\"k\": 10}),\n", - " memory=memory,\n", - " combine_docs_chain=final_qa_chain,\n", - " verbose=True,\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = get_config()\n", - "llm = get_llm_model(config)\n", - "embeddings = get_embedding_model(config)\n", - "vector_store = get_vector_store(embeddings, config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "username = \"slauzeral\"\n", - "password = \"test\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "success = auth.sign_up(username, password)\n", - "if success:\n", - " token = auth.get_token(username, password)\n", - " session = auth.create_session()\n", - " auth_session = auth.authenticate_session(session, token)\n", - "\n", - "response = auth_session.post(\"/chat/new\")\n", - "chat_id = response.json()[\"chat_id\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from datetime import datetime\n", - "from uuid import uuid4\n", - "from typing import Any\n", - "\n", - "from langchain.memory.chat_message_histories.sql import DefaultMessageConverter\n", - "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", - "from sqlalchemy import Column, DateTime, Integer, Text\n", - "from sqlalchemy.orm import declarative_base\n", - "from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict\n", - "import json\n", - "\n", - "Base = declarative_base()\n", - "\n", - "class CustomMessage(Base):\n", - " __tablename__ = \"message_test\"\n", - "\n", - " id = Column(Text, primary_key=True, default=lambda: str(uuid4())) # default=lambda: str(uuid4())\n", - " timestamp = Column(DateTime)\n", - " chat_id = Column(Text)\n", - " sender = Column(Text)\n", - " content = Column(Text)\n", - " message = Column(Text)\n", - "\n", - "\n", - "class CustomMessageConverter(DefaultMessageConverter):\n", - "\n", - " def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n", - " sub_message = json.loads(message.content)\n", - " return CustomMessage(\n", - " id = sub_message[\"id\"],\n", - " timestamp = datetime.strptime(sub_message[\"timestamp\"], \"%Y-%m-%d %H:%M:%S.%f\"),\n", - " chat_id = session_id,\n", - " sender = message.type,\n", - " content = sub_message[\"content\"],\n", - " message = json.dumps(_message_to_dict(message)),\n", - " )\n", - "\n", - " def get_sql_model_class(self) -> Any:\n", - " return CustomMessage\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chat_message_history = SQLChatMessageHistory(\n", - " session_id=chat_id,\n", - " connection_string=\"sqlite:////Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@artefact.com/Mon Drive/internal_projects/skaff-rag-accelerator/database/database.sqlite\",\n", - " table_name=\"message_test\",\n", - " session_id_field_name=\"chat_id\",\n", - " custom_message_converter=CustomMessageConverter(table_name=\"message_test\"),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chat_message_history.add_ai_message(json.dumps({\"content\":\"Hi\", \"timestamp\":f\"{datetime.utcnow()}\", \"id\":\"764528762\"}))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chat_message_history.messages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "skaff-rag-accelerator", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/test_auth.ipynb b/notebooks/test_auth.ipynb deleted file mode 100644 index 4980182..0000000 --- a/notebooks/test_auth.ipynb +++ /dev/null @@ -1,125 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "current_directory = os.getcwd()\n", - "parent_directory = os.path.dirname(current_directory)\n", - "sys.path.append(parent_directory)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import NoReturn\n", - "from fastapi.testclient import TestClient\n", - "from lib.main import app\n", - "\n", - "import streamlit as st\n", - "import requests\n", - "\n", - "client = TestClient(app)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def log_in(username: str, password: str) -> Optional[str]:\n", - " response = client.post(\n", - " \"/user/login\", data={\"username\": username, \"password\": password}\n", - " )\n", - " if response.status_code == 200 and \"access_token\" in response.json():\n", - " return response.json()[\"access_token\"]\n", - " else:\n", - " return None\n", - "\n", - "def sign_up(username: str, password: str) -> str:\n", - " response = client.post(\n", - " \"/user/signup\", json={\"username\": username, \"password\": password}\n", - " )\n", - " if response.status_code == 200 and \"email\" in response.json():\n", - " return f\"User {username} registered successfully.\"\n", - " else:\n", - " return \"Registration failed.\"\n", - "\n", - "def reset_pwd(username: str) -> str:\n", - " # Assuming there's an endpoint to request a password reset\n", - " response = client.post(\n", - " \"/user/reset-password\", json={\"username\": username}\n", - " )\n", - " if response.status_code == 200:\n", - " return \"Password reset link sent.\"\n", - " else:\n", - " return \"Failed to send password reset link.\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sign_up(\"sarah.lauzeral@artefact.com\", \"test_pwd\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sign_up(\"test@example.com\", \"test_pwd\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "skaff-rag-accelerator", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}