Skip to content

Commit

Permalink
upd: differenciated async and streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Jan 2, 2024
1 parent f48fdfc commit de0c08a
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 660 deletions.
33 changes: 21 additions & 12 deletions backend/config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
# LLMConfig: &LLMConfig
# source: ChatVertexAI
# source_config:
# model_name: gemini-pro
# temperature: 0.1

LLMConfig: &LLMConfig
source: "ChatVertexAI"
source: AzureChatOpenAI
source_config:
model_name: google/gemini-pro
temperature: 0.1
openai_api_type: azure
openai_api_key: {{ OPENAI_API_KEY }}
openai_api_base: https://genai-ds.openai.azure.com/
openai_api_version: 2023-07-01-preview
deployment_name: gpt4

VectorStoreConfig: &VectorStoreConfig
source: "Chroma"
source: Chroma
source_config:
persist_directory: "vector_database/"
persist_directory: vector_database/
collection_metadata:
hnsw:space: "cosine"
retreiver_search_type: "similarity"
hnsw:space: cosine
retreiver_search_type: similarity
retreiver_config:
top_k: 20
score_threshold: 0.5
insertion_mode: "full"
insertion_mode: full

EmbeddingModelConfig: &EmbeddingModelConfig
source: "OpenAIEmbeddings"
source: OpenAIEmbeddings
source_config:
openai_api_type: "azure"
openai_api_type: azure
openai_api_key: {{ EMBEDDING_API_KEY }}
openai_api_base: "https://poc-openai-artefact.openai.azure.com/"
deployment: "embeddings"
openai_api_base: https://poc-openai-artefact.openai.azure.com/
deployment: embeddings
chunk_size: 500

DatabaseConfig: &DatabaseConfig
Expand Down
18 changes: 12 additions & 6 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context)
response = rag.async_generate_response(message)

return StreamingResponse(streamed_llm_response(message.chat_id, response), media_type="text/event-stream")
return StreamingResponse(async_llm_response(message.chat_id, response), media_type="text/event-stream")


@app.post("/chat/regenerate")
Expand Down Expand Up @@ -130,14 +130,20 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) ->
return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]}


async def streamed_llm_response(chat_id, answer_chain):
async def async_llm_response(chat_id, answer_chain):
full_response = ""
async for data in answer_chain:
full_response += data
yield data.encode("utf-8")
response_id = str(uuid4())
try:
async for data in answer_chain:
full_response += data
yield data.encode("utf-8")
except Exception as e:
logger.error(f"Error generating response for chat {chat_id}: {e}")
full_response = f"Sorry, there was an error generating a response. Please contact an administrator and tell them the following error code: {response_id}, and message: {str(e)}"
yield full_response.encode("utf-8")

model_response = Message(
id=str(uuid4()),
id=response_id,
timestamp=datetime.now().isoformat(),
chat_id=chat_id,
sender="assistant",
Expand Down
27 changes: 20 additions & 7 deletions backend/rag_components/chain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from threading import Thread
from time import sleep

from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chat_models.base import SystemMessage
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
Expand All @@ -17,7 +18,7 @@



async def get_response_stream(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
async def async_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
run = asyncio.create_task(chain.arun({"question": query}))

async for token in streaming_callback_handler.aiter():
Expand All @@ -26,18 +27,30 @@ async def get_response_stream(chain: ConversationalRetrievalChain, query: str, s
await run


def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain:
def stream_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
thread = Thread(target=lambda chain, query: chain.run({"question": query}), args=(chain, query))
thread.start()

while thread.is_alive() or not streaming_callback_handler.queue.empty():
if not streaming_callback_handler.queue.empty():
yield streaming_callback_handler.queue.get()
else:
sleep(0.1)

thread.join()

def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler = None, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain:
callbacks = [logging_callback_handler] if logging_callback_handler is not None else []

streaming_callback = [streaming_callback_handler] if streaming_callback_handler is not None else []

condense_question_prompt = PromptTemplate.from_template(prompts.condense_history)
condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt, callbacks=callbacks)

messages = [
SystemMessage(content=prompts.rag_system_prompt),
HumanMessagePromptTemplate.from_template(prompts.respond_to_question),
]
question_answering_prompt = ChatPromptTemplate(messages=messages)
streaming_llm = get_llm_model(config, callbacks=[streaming_callback_handler] + callbacks)
streaming_llm = get_llm_model(config, callbacks=streaming_callback + callbacks)
question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt, callbacks=callbacks)

context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"])
Expand Down
4 changes: 4 additions & 0 deletions backend/rag_components/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
"""

respond_to_question = """
As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input.
It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.
Respond to the question taking into account the following context.
{context}
Expand Down
15 changes: 12 additions & 3 deletions backend/rag_components/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain.docstore.document import Document
from langchain.vectorstores.utils import filter_complex_metadata
from langchain.callbacks import AsyncIteratorCallbackHandler
from backend.rag_components.chain import get_answer_chain, get_response_stream
from backend.rag_components.chain import get_answer_chain, async_get_response, stream_get_response
from langchain.indexes import SQLRecordManager, index
from langchain.chat_models.base import BaseChatModel
from langchain.vectorstores import VectorStore
Expand All @@ -21,6 +21,7 @@
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.logging_callback_handler import LoggingCallbackHandler
from backend.rag_components.streaming_callback_handler import StreamingCallbackHandler
from backend.rag_components.vector_store import get_vector_store


Expand All @@ -42,14 +43,22 @@ def generate_response(self, message: Message) -> str:
loop = asyncio.get_event_loop()
response_stream = self.async_generate_response(message)
responses = loop.run_until_complete(self._collect_responses(response_stream))
return "".join(responses)
return "".join([str(response) for response in responses])

def stream_generate_response(self, message: Message) -> AsyncIterator[str]:
memory = get_conversation_buffer_memory(self.config, message.chat_id)
streaming_callback_handler = StreamingCallbackHandler()
logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context)
answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler)
response_stream = stream_get_response(answer_chain, message.content, streaming_callback_handler)
return response_stream

def async_generate_response(self, message: Message) -> AsyncIterator[str]:
memory = get_conversation_buffer_memory(self.config, message.chat_id)
streaming_callback_handler = AsyncIteratorCallbackHandler()
logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context)
answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler)
response_stream = get_response_stream(answer_chain, message.content, streaming_callback_handler)
response_stream = async_get_response(answer_chain, message.content, streaming_callback_handler)
return response_stream

def load_file(self, file_path: Path) -> List[Document]:
Expand Down
11 changes: 11 additions & 0 deletions backend/rag_components/streaming_callback_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from multiprocessing import Queue
from typing import AnyStr
from langchain_core.callbacks.base import BaseCallbackHandler

class StreamingCallbackHandler(BaseCallbackHandler):
queue = Queue()

def on_llm_new_token(self, token: str, **kwargs: AnyStr) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if token is not None and token != "":
self.queue.put_nowait(token)
16 changes: 0 additions & 16 deletions lel.py

This file was deleted.

Loading

0 comments on commit de0c08a

Please sign in to comment.