diff --git a/backend/config.py b/backend/config.py index 7f6208e..8ecdb86 100644 --- a/backend/config.py +++ b/backend/config.py @@ -65,6 +65,7 @@ class RagConfig: database: DatabaseConfig = field(default_factory=DatabaseConfig) chat_history_window_size: int = 5 max_tokens_limit: int = 3000 + response_mode: str = "normal" @classmethod def from_yaml(cls, yaml_path: Path, env: dict = None): diff --git a/backend/config.yaml b/backend/config.yaml index 0f2c74c..bababf7 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -4,6 +4,14 @@ # model_name: gemini-pro # temperature: 0.1 +# LLMConfig: &LLMConfig +# source: GPT4All +# source_config: +# model: /Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/mistral-7b-openorca.Q4_0.gguf +# n_ctx: 1024 +# backend: gptj +# verbose: false + LLMConfig: &LLMConfig source: AzureChatOpenAI source_config: @@ -43,4 +51,5 @@ RagConfig: embedding_model: *EmbeddingModelConfig database: *DatabaseConfig chat_history_window_size: 5 - max_tokens_limit: 3000 \ No newline at end of file + max_tokens_limit: 3000 + response_mode: stream \ No newline at end of file diff --git a/backend/database.py b/backend/database.py index 04db5a1..6edd3a2 100644 --- a/backend/database.py +++ b/backend/database.py @@ -75,6 +75,7 @@ def initialize_schema(self): def _create_pool(self) -> PooledDB: if self.connection_string.startswith("sqlite:///"): import sqlite3 + Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True) return PooledDB(creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), maxconnections=5) elif self.connection_string.startswith("postgres://"): import psycopg2 diff --git a/backend/main.py b/backend/main.py index d821136..f073318 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,8 @@ +import asyncio from datetime import datetime, timedelta +import inspect from pathlib import Path +import traceback from typing import List from uuid import uuid4 @@ -8,9 +11,9 @@ from fastapi.responses import StreamingResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt -from backend.logger import get_logger - +from langchain_core.messages.ai import AIMessage +from backend.logger import get_logger from backend.model import Message from backend.rag_components.rag import RAG from backend.user_management import ( @@ -87,9 +90,8 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current "timestamp": message.timestamp, } rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context) - response = rag.async_generate_response(message) - - return StreamingResponse(async_llm_response(message.chat_id, response), media_type="text/event-stream") + response = rag.generate_response(message) + return StreamingResponse(stream_response(message.chat_id, response), media_type="text/event-stream") @app.post("/chat/regenerate") @@ -130,36 +132,35 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]} -async def async_llm_response(chat_id, answer_chain): +async def stream_response(chat_id: str, response): full_response = "" response_id = str(uuid4()) try: - async for data in answer_chain: - full_response += data - yield data.encode("utf-8") + if type(response) is AIMessage: + full_response = response.content + yield full_response.encode("utf-8") + elif inspect.isasyncgen(response): + async for data in response: + full_response += data.content + yield data.content.encode("utf-8") + else: + for part in response: + full_response += part.content + yield part.content.encode("utf-8") + await asyncio.sleep(0) except Exception as e: - logger.error(f"Error generating response for chat {chat_id}: {e}") - full_response = f"Sorry, there was an error generating a response. Please contact an administrator and tell them the following error code: {response_id}, and message: {str(e)}" + logger.error(f"Error generating response for chat {chat_id}: {e}", exc_info=True) + full_response = f"Sorry, there was an error generating a response. Please contact an administrator and provide them with the following error code: {response_id} \n\n {traceback.format_exc()}" yield full_response.encode("utf-8") + finally: + await log_response_to_db(chat_id, full_response) - model_response = Message( - id=response_id, - timestamp=datetime.now().isoformat(), - chat_id=chat_id, - sender="assistant", - content=full_response, - ) - +async def log_response_to_db(chat_id: str, full_response: str): + response_id = str(uuid4()) with Database() as connection: connection.execute( "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", - ( - model_response.id, - model_response.timestamp, - model_response.chat_id, - model_response.sender, - model_response.content, - ), + (response_id, datetime.now().isoformat(), chat_id, "assistant", full_response), ) diff --git a/backend/rag_components/chain.py b/backend/rag_components/chain.py index f959f8d..cb9b3b1 100644 --- a/backend/rag_components/chain.py +++ b/backend/rag_components/chain.py @@ -1,74 +1,45 @@ -import asyncio -from threading import Thread -from time import sleep +from operator import itemgetter + +from langchain.vectorstores.base import VectorStore +from langchain_core.runnables import RunnablePassthrough, RunnableParallel +from langchain.prompts import ChatPromptTemplate, PromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain.schema import format_document +from langchain.memory import ConversationBufferWindowMemory + -from langchain.chains import ConversationalRetrievalChain, LLMChain -from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - PromptTemplate, -) -from langchain.vectorstores import VectorStore from backend.config import RagConfig -from backend.rag_components.llm import get_llm_model from backend.rag_components import prompts +from backend.rag_components.llm import get_llm_model from backend.rag_components.logging_callback_handler import LoggingCallbackHandler +DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") +def get_answer_chain(config: RagConfig, vector_store: VectorStore, memory: ConversationBufferWindowMemory, logging_callback_handler: LoggingCallbackHandler = None): + llm_callbacks = [logging_callback_handler] if logging_callback_handler is not None else [] + llm = get_llm_model(config, callbacks=llm_callbacks) -async def async_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: - run = asyncio.create_task(chain.arun({"question": query})) - - async for token in streaming_callback_handler.aiter(): - yield token - - await run - - -def stream_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: - thread = Thread(target=lambda chain, query: chain.run({"question": query}), args=(chain, query)) - thread.start() - - while thread.is_alive() or not streaming_callback_handler.queue.empty(): - if not streaming_callback_handler.queue.empty(): - yield streaming_callback_handler.queue.get() - else: - sleep(0.1) - - thread.join() - -def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler = None, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain: - callbacks = [logging_callback_handler] if logging_callback_handler is not None else [] - streaming_callback = [streaming_callback_handler] if streaming_callback_handler is not None else [] + retriever = vector_store.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config) condense_question_prompt = PromptTemplate.from_template(prompts.condense_history) - condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt, callbacks=callbacks) + question_answering_prompt = ChatPromptTemplate.from_template(prompts.respond_to_question) - messages = [ - HumanMessagePromptTemplate.from_template(prompts.respond_to_question), - ] - question_answering_prompt = ChatPromptTemplate(messages=messages) - streaming_llm = get_llm_model(config, callbacks=streaming_callback + callbacks) - question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt, callbacks=callbacks) - context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"]) - - stuffed_qa_chain = StuffDocumentsChain( - llm_chain=question_answering_chain, - document_variable_name="context", - document_prompt=context_with_docs_prompt, - callbacks=callbacks + _inputs = RunnableParallel( + standalone_question=RunnablePassthrough.assign(chat_history=lambda _: memory.buffer_as_str) + | condense_question_prompt + | llm + | StrOutputParser(), ) + _context = { + "context": itemgetter("standalone_question") | retriever | _combine_documents, + "question": lambda x: x["standalone_question"], + } + conversational_qa_chain = _inputs | _context | question_answering_prompt | llm + return conversational_qa_chain - chain = ConversationalRetrievalChain( - question_generator=condense_question_chain, - retriever=docsearch.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config), - memory=memory, - max_tokens_limit=config.max_tokens_limit, - combine_docs_chain=stuffed_qa_chain, - callbacks=callbacks - ) - return chain +def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"): + doc_strings = [format_document(doc, document_prompt) for doc in docs] + return document_separator.join(doc_strings) \ No newline at end of file diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index 7d57135..e1035d4 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -8,7 +8,7 @@ TABLE_NAME = "message_history" -def get_conversation_buffer_memory(config: RagConfig, chat_id): +def get_conversation_buffer_memory(config: RagConfig, chat_id) -> ConversationBufferWindowMemory: return ConversationBufferWindowMemory( memory_key="chat_history", chat_memory=get_chat_message_history(config, chat_id), diff --git a/backend/rag_components/llm.py b/backend/rag_components/llm.py index f767b11..9b1dd4c 100644 --- a/backend/rag_components/llm.py +++ b/backend/rag_components/llm.py @@ -12,4 +12,5 @@ def get_llm_model(config: RagConfig, callbacks: List[BaseCallbackHandler] = []): } kwargs["streaming"] = True kwargs["callbacks"] = callbacks + return llm_spec(**kwargs) diff --git a/backend/rag_components/rag.py b/backend/rag_components/rag.py index 26d1809..788d545 100644 --- a/backend/rag_components/rag.py +++ b/backend/rag_components/rag.py @@ -1,13 +1,11 @@ import asyncio from logging import Logger from pathlib import Path -from typing import AsyncIterator, List, Union +from typing import List, Union from langchain.docstore.document import Document from langchain.vectorstores.utils import filter_complex_metadata -from langchain.callbacks import AsyncIteratorCallbackHandler -from backend.rag_components.chain import get_answer_chain, async_get_response, stream_get_response from langchain.indexes import SQLRecordManager, index from langchain.chat_models.base import BaseChatModel from langchain.vectorstores import VectorStore @@ -16,12 +14,12 @@ from backend.config import RagConfig from backend.model import Message +from backend.rag_components.chain import get_answer_chain from backend.rag_components.chat_message_history import get_conversation_buffer_memory from backend.rag_components.document_loader import get_documents from backend.rag_components.embedding import get_embedding_model from backend.rag_components.llm import get_llm_model from backend.rag_components.logging_callback_handler import LoggingCallbackHandler -from backend.rag_components.streaming_callback_handler import StreamingCallbackHandler from backend.rag_components.vector_store import get_vector_store @@ -40,26 +38,17 @@ def __init__(self, config: Union[Path, RagConfig], logger: Logger = None, contex self.vector_store: VectorStore = get_vector_store(self.embeddings, self.config) def generate_response(self, message: Message) -> str: - loop = asyncio.get_event_loop() - response_stream = self.async_generate_response(message) - responses = loop.run_until_complete(self._collect_responses(response_stream)) - return "".join([str(response) for response in responses]) - - def stream_generate_response(self, message: Message) -> AsyncIterator[str]: - memory = get_conversation_buffer_memory(self.config, message.chat_id) - streaming_callback_handler = StreamingCallbackHandler() - logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context) - answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler) - response_stream = stream_get_response(answer_chain, message.content, streaming_callback_handler) - return response_stream - - def async_generate_response(self, message: Message) -> AsyncIterator[str]: memory = get_conversation_buffer_memory(self.config, message.chat_id) - streaming_callback_handler = AsyncIteratorCallbackHandler() logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context) - answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler) - response_stream = async_get_response(answer_chain, message.content, streaming_callback_handler) - return response_stream + answer_chain = get_answer_chain(self.config, self.vector_store, memory, logging_callback_handler=logging_callback_handler) + + if self.config.response_mode == "async": + response = answer_chain.astream({"question": message.content}) + elif self.config.response_mode == "stream": + response = answer_chain.stream({"question": message.content}) + else: + response = answer_chain.invoke({"question": message.content}) + return response def load_file(self, file_path: Path) -> List[Document]: documents = get_documents(file_path, self.llm) @@ -81,9 +70,3 @@ def load_documents(self, documents: List[Document], insertion_mode: str = None): source_id_key="source", ) self.logger.info({"event": "load_documents", **indexing_output}) - - async def _collect_responses(self, response_stream): - responses = [] - async for response in response_stream: - responses.append(response) - return responses