Skip to content

Commit

Permalink
upd: migrated to lcel
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Jan 3, 2024
1 parent c8e3bc4 commit aeb3c59
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 115 deletions.
1 change: 1 addition & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class RagConfig:
database: DatabaseConfig = field(default_factory=DatabaseConfig)
chat_history_window_size: int = 5
max_tokens_limit: int = 3000
response_mode: str = "normal"

@classmethod
def from_yaml(cls, yaml_path: Path, env: dict = None):
Expand Down
11 changes: 10 additions & 1 deletion backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
# model_name: gemini-pro
# temperature: 0.1

# LLMConfig: &LLMConfig
# source: GPT4All
# source_config:
# model: /Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/data/mistral-7b-openorca.Q4_0.gguf
# n_ctx: 1024
# backend: gptj
# verbose: false

LLMConfig: &LLMConfig
source: AzureChatOpenAI
source_config:
Expand Down Expand Up @@ -43,4 +51,5 @@ RagConfig:
embedding_model: *EmbeddingModelConfig
database: *DatabaseConfig
chat_history_window_size: 5
max_tokens_limit: 3000
max_tokens_limit: 3000
response_mode: stream
1 change: 1 addition & 0 deletions backend/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def initialize_schema(self):
def _create_pool(self) -> PooledDB:
if self.connection_string.startswith("sqlite:///"):
import sqlite3
Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True)
return PooledDB(creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), maxconnections=5)
elif self.connection_string.startswith("postgres://"):
import psycopg2
Expand Down
53 changes: 27 additions & 26 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
from datetime import datetime, timedelta
import inspect
from pathlib import Path
import traceback
from typing import List
from uuid import uuid4

Expand All @@ -8,9 +11,9 @@
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from backend.logger import get_logger

from langchain_core.messages.ai import AIMessage

from backend.logger import get_logger
from backend.model import Message
from backend.rag_components.rag import RAG
from backend.user_management import (
Expand Down Expand Up @@ -87,9 +90,8 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
"timestamp": message.timestamp,
}
rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context)
response = rag.async_generate_response(message)

return StreamingResponse(async_llm_response(message.chat_id, response), media_type="text/event-stream")
response = rag.generate_response(message)
return StreamingResponse(stream_response(message.chat_id, response), media_type="text/event-stream")


@app.post("/chat/regenerate")
Expand Down Expand Up @@ -130,36 +132,35 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) ->
return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]}


async def async_llm_response(chat_id, answer_chain):
async def stream_response(chat_id: str, response):
full_response = ""
response_id = str(uuid4())
try:
async for data in answer_chain:
full_response += data
yield data.encode("utf-8")
if type(response) is AIMessage:
full_response = response.content
yield full_response.encode("utf-8")
elif inspect.isasyncgen(response):
async for data in response:
full_response += data.content
yield data.content.encode("utf-8")
else:
for part in response:
full_response += part.content
yield part.content.encode("utf-8")
await asyncio.sleep(0)
except Exception as e:
logger.error(f"Error generating response for chat {chat_id}: {e}")
full_response = f"Sorry, there was an error generating a response. Please contact an administrator and tell them the following error code: {response_id}, and message: {str(e)}"
logger.error(f"Error generating response for chat {chat_id}: {e}", exc_info=True)
full_response = f"Sorry, there was an error generating a response. Please contact an administrator and provide them with the following error code: {response_id} \n\n {traceback.format_exc()}"
yield full_response.encode("utf-8")
finally:
await log_response_to_db(chat_id, full_response)

model_response = Message(
id=response_id,
timestamp=datetime.now().isoformat(),
chat_id=chat_id,
sender="assistant",
content=full_response,
)

async def log_response_to_db(chat_id: str, full_response: str):
response_id = str(uuid4())
with Database() as connection:
connection.execute(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(
model_response.id,
model_response.timestamp,
model_response.chat_id,
model_response.sender,
model_response.content,
),
(response_id, datetime.now().isoformat(), chat_id, "assistant", full_response),
)


Expand Down
89 changes: 30 additions & 59 deletions backend/rag_components/chain.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,45 @@
import asyncio
from threading import Thread
from time import sleep
from operator import itemgetter

from langchain.vectorstores.base import VectorStore
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import format_document
from langchain.memory import ConversationBufferWindowMemory


from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain.vectorstores import VectorStore

from backend.config import RagConfig
from backend.rag_components.llm import get_llm_model
from backend.rag_components import prompts
from backend.rag_components.llm import get_llm_model
from backend.rag_components.logging_callback_handler import LoggingCallbackHandler

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def get_answer_chain(config: RagConfig, vector_store: VectorStore, memory: ConversationBufferWindowMemory, logging_callback_handler: LoggingCallbackHandler = None):
llm_callbacks = [logging_callback_handler] if logging_callback_handler is not None else []
llm = get_llm_model(config, callbacks=llm_callbacks)

async def async_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
run = asyncio.create_task(chain.arun({"question": query}))

async for token in streaming_callback_handler.aiter():
yield token

await run


def stream_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
thread = Thread(target=lambda chain, query: chain.run({"question": query}), args=(chain, query))
thread.start()

while thread.is_alive() or not streaming_callback_handler.queue.empty():
if not streaming_callback_handler.queue.empty():
yield streaming_callback_handler.queue.get()
else:
sleep(0.1)

thread.join()

def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler = None, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain:
callbacks = [logging_callback_handler] if logging_callback_handler is not None else []
streaming_callback = [streaming_callback_handler] if streaming_callback_handler is not None else []
retriever = vector_store.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config)

condense_question_prompt = PromptTemplate.from_template(prompts.condense_history)
condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt, callbacks=callbacks)
question_answering_prompt = ChatPromptTemplate.from_template(prompts.respond_to_question)

messages = [
HumanMessagePromptTemplate.from_template(prompts.respond_to_question),
]
question_answering_prompt = ChatPromptTemplate(messages=messages)
streaming_llm = get_llm_model(config, callbacks=streaming_callback + callbacks)
question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt, callbacks=callbacks)

context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"])

stuffed_qa_chain = StuffDocumentsChain(
llm_chain=question_answering_chain,
document_variable_name="context",
document_prompt=context_with_docs_prompt,
callbacks=callbacks
_inputs = RunnableParallel(
standalone_question=RunnablePassthrough.assign(chat_history=lambda _: memory.buffer_as_str)
| condense_question_prompt
| llm
| StrOutputParser(),
)
_context = {
"context": itemgetter("standalone_question") | retriever | _combine_documents,
"question": lambda x: x["standalone_question"],
}
conversational_qa_chain = _inputs | _context | question_answering_prompt | llm
return conversational_qa_chain

chain = ConversationalRetrievalChain(
question_generator=condense_question_chain,
retriever=docsearch.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config),
memory=memory,
max_tokens_limit=config.max_tokens_limit,
combine_docs_chain=stuffed_qa_chain,
callbacks=callbacks
)

return chain
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
2 changes: 1 addition & 1 deletion backend/rag_components/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
TABLE_NAME = "message_history"


def get_conversation_buffer_memory(config: RagConfig, chat_id):
def get_conversation_buffer_memory(config: RagConfig, chat_id) -> ConversationBufferWindowMemory:
return ConversationBufferWindowMemory(
memory_key="chat_history",
chat_memory=get_chat_message_history(config, chat_id),
Expand Down
1 change: 1 addition & 0 deletions backend/rag_components/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ def get_llm_model(config: RagConfig, callbacks: List[BaseCallbackHandler] = []):
}
kwargs["streaming"] = True
kwargs["callbacks"] = callbacks

return llm_spec(**kwargs)
39 changes: 11 additions & 28 deletions backend/rag_components/rag.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio
from logging import Logger
from pathlib import Path
from typing import AsyncIterator, List, Union
from typing import List, Union


from langchain.docstore.document import Document
from langchain.vectorstores.utils import filter_complex_metadata
from langchain.callbacks import AsyncIteratorCallbackHandler
from backend.rag_components.chain import get_answer_chain, async_get_response, stream_get_response
from langchain.indexes import SQLRecordManager, index
from langchain.chat_models.base import BaseChatModel
from langchain.vectorstores import VectorStore
Expand All @@ -16,12 +14,12 @@

from backend.config import RagConfig
from backend.model import Message
from backend.rag_components.chain import get_answer_chain
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.logging_callback_handler import LoggingCallbackHandler
from backend.rag_components.streaming_callback_handler import StreamingCallbackHandler
from backend.rag_components.vector_store import get_vector_store


Expand All @@ -40,26 +38,17 @@ def __init__(self, config: Union[Path, RagConfig], logger: Logger = None, contex
self.vector_store: VectorStore = get_vector_store(self.embeddings, self.config)

def generate_response(self, message: Message) -> str:
loop = asyncio.get_event_loop()
response_stream = self.async_generate_response(message)
responses = loop.run_until_complete(self._collect_responses(response_stream))
return "".join([str(response) for response in responses])

def stream_generate_response(self, message: Message) -> AsyncIterator[str]:
memory = get_conversation_buffer_memory(self.config, message.chat_id)
streaming_callback_handler = StreamingCallbackHandler()
logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context)
answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler)
response_stream = stream_get_response(answer_chain, message.content, streaming_callback_handler)
return response_stream

def async_generate_response(self, message: Message) -> AsyncIterator[str]:
memory = get_conversation_buffer_memory(self.config, message.chat_id)
streaming_callback_handler = AsyncIteratorCallbackHandler()
logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context)
answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler)
response_stream = async_get_response(answer_chain, message.content, streaming_callback_handler)
return response_stream
answer_chain = get_answer_chain(self.config, self.vector_store, memory, logging_callback_handler=logging_callback_handler)

if self.config.response_mode == "async":
response = answer_chain.astream({"question": message.content})
elif self.config.response_mode == "stream":
response = answer_chain.stream({"question": message.content})
else:
response = answer_chain.invoke({"question": message.content})
return response

def load_file(self, file_path: Path) -> List[Document]:
documents = get_documents(file_path, self.llm)
Expand All @@ -81,9 +70,3 @@ def load_documents(self, documents: List[Document], insertion_mode: str = None):
source_id_key="source",
)
self.logger.info({"event": "load_documents", **indexing_output})

async def _collect_responses(self, response_stream):
responses = []
async for response in response_stream:
responses.append(response)
return responses

0 comments on commit aeb3c59

Please sign in to comment.