-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c8e3bc4
commit aeb3c59
Showing
8 changed files
with
82 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,45 @@ | ||
import asyncio | ||
from threading import Thread | ||
from time import sleep | ||
from operator import itemgetter | ||
|
||
from langchain.vectorstores.base import VectorStore | ||
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | ||
from langchain.prompts import ChatPromptTemplate, PromptTemplate | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain.schema import format_document | ||
from langchain.memory import ConversationBufferWindowMemory | ||
|
||
|
||
from langchain.chains import ConversationalRetrievalChain, LLMChain | ||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | ||
from langchain.prompts import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
PromptTemplate, | ||
) | ||
from langchain.vectorstores import VectorStore | ||
|
||
from backend.config import RagConfig | ||
from backend.rag_components.llm import get_llm_model | ||
from backend.rag_components import prompts | ||
from backend.rag_components.llm import get_llm_model | ||
from backend.rag_components.logging_callback_handler import LoggingCallbackHandler | ||
|
||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") | ||
|
||
def get_answer_chain(config: RagConfig, vector_store: VectorStore, memory: ConversationBufferWindowMemory, logging_callback_handler: LoggingCallbackHandler = None): | ||
llm_callbacks = [logging_callback_handler] if logging_callback_handler is not None else [] | ||
llm = get_llm_model(config, callbacks=llm_callbacks) | ||
|
||
async def async_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: | ||
run = asyncio.create_task(chain.arun({"question": query})) | ||
|
||
async for token in streaming_callback_handler.aiter(): | ||
yield token | ||
|
||
await run | ||
|
||
|
||
def stream_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str: | ||
thread = Thread(target=lambda chain, query: chain.run({"question": query}), args=(chain, query)) | ||
thread.start() | ||
|
||
while thread.is_alive() or not streaming_callback_handler.queue.empty(): | ||
if not streaming_callback_handler.queue.empty(): | ||
yield streaming_callback_handler.queue.get() | ||
else: | ||
sleep(0.1) | ||
|
||
thread.join() | ||
|
||
def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler = None, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain: | ||
callbacks = [logging_callback_handler] if logging_callback_handler is not None else [] | ||
streaming_callback = [streaming_callback_handler] if streaming_callback_handler is not None else [] | ||
retriever = vector_store.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config) | ||
|
||
condense_question_prompt = PromptTemplate.from_template(prompts.condense_history) | ||
condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt, callbacks=callbacks) | ||
question_answering_prompt = ChatPromptTemplate.from_template(prompts.respond_to_question) | ||
|
||
messages = [ | ||
HumanMessagePromptTemplate.from_template(prompts.respond_to_question), | ||
] | ||
question_answering_prompt = ChatPromptTemplate(messages=messages) | ||
streaming_llm = get_llm_model(config, callbacks=streaming_callback + callbacks) | ||
question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt, callbacks=callbacks) | ||
|
||
context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"]) | ||
|
||
stuffed_qa_chain = StuffDocumentsChain( | ||
llm_chain=question_answering_chain, | ||
document_variable_name="context", | ||
document_prompt=context_with_docs_prompt, | ||
callbacks=callbacks | ||
_inputs = RunnableParallel( | ||
standalone_question=RunnablePassthrough.assign(chat_history=lambda _: memory.buffer_as_str) | ||
| condense_question_prompt | ||
| llm | ||
| StrOutputParser(), | ||
) | ||
_context = { | ||
"context": itemgetter("standalone_question") | retriever | _combine_documents, | ||
"question": lambda x: x["standalone_question"], | ||
} | ||
conversational_qa_chain = _inputs | _context | question_answering_prompt | llm | ||
return conversational_qa_chain | ||
|
||
chain = ConversationalRetrievalChain( | ||
question_generator=condense_question_chain, | ||
retriever=docsearch.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config), | ||
memory=memory, | ||
max_tokens_limit=config.max_tokens_limit, | ||
combine_docs_chain=stuffed_qa_chain, | ||
callbacks=callbacks | ||
) | ||
|
||
return chain | ||
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"): | ||
doc_strings = [format_document(doc, document_prompt) for doc in docs] | ||
return document_separator.join(doc_strings) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters