-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
960d377
commit b63ffc2
Showing
7 changed files
with
125 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,88 +1,53 @@ | ||
import asyncio | ||
from langchain.chains import ConversationalRetrievalChain, LLMChain | ||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | ||
from langchain.chat_models.base import HumanMessage, SystemMessage | ||
from langchain.chat_models.base import SystemMessage | ||
from langchain.prompts import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
PromptTemplate, | ||
) | ||
from langchain.vectorstores import VectorStore | ||
|
||
from backend.config_renderer import get_config | ||
from backend.rag_components.chat_message_history import get_conversation_buffer_memory | ||
from backend.rag_components.embedding import get_embedding_model | ||
from backend.rag_components.llm import get_llm_model | ||
from backend.rag_components.vector_store import get_vector_store | ||
from backend.rag_components import prompts | ||
|
||
|
||
def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str: | ||
"""Processes the given query through the answer chain and returns the formatted response.""" | ||
{"content": answer_chain.run(query), } | ||
return answer_chain.run(query) | ||
async def get_response_stream(chain: ConversationalRetrievalChain, callback_handler, query: str) -> str: | ||
run = asyncio.create_task(chain.arun({"question": query})) | ||
|
||
async for token in callback_handler.aiter(): | ||
yield token | ||
|
||
def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain: | ||
"""Returns an instance of ConversationalRetrievalChain based on the provided parameters.""" | ||
template = """Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns. | ||
await run | ||
|
||
Chat history : | ||
{chat_history} | ||
Question : {question} | ||
|
||
Rephrased question : | ||
""" | ||
condense_question_prompt = PromptTemplate.from_template(template) | ||
condense_question_chain = LLMChain( | ||
llm=llm, | ||
prompt=condense_question_prompt, | ||
) | ||
def get_answer_chain(config, docsearch: VectorStore, memory) -> ConversationalRetrievalChain: | ||
condense_question_prompt = PromptTemplate.from_template(prompts.condense_history) | ||
condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt) | ||
|
||
messages = [ | ||
SystemMessage( | ||
content=( | ||
"""As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.""" | ||
) | ||
), | ||
HumanMessage(content="Respond to the question taking into account the following context."), | ||
HumanMessagePromptTemplate.from_template("{context}"), | ||
HumanMessagePromptTemplate.from_template("Question: {question}"), | ||
SystemMessage(content=prompts.rag_system_prompt), | ||
HumanMessagePromptTemplate.from_template(prompts.respond_to_question), | ||
] | ||
system_prompt = ChatPromptTemplate(messages=messages) | ||
qa_chain = LLMChain( | ||
llm=llm, | ||
prompt=system_prompt, | ||
) | ||
|
||
doc_prompt = PromptTemplate( | ||
template="Content: {page_content}\nSource: {source}", | ||
input_variables=["page_content", "source"], | ||
) | ||
question_answering_prompt = ChatPromptTemplate(messages=messages) | ||
streaming_llm, callback_handler = get_llm_model(config, streaming=True) | ||
question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt) | ||
|
||
context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"]) | ||
|
||
final_qa_chain = StuffDocumentsChain( | ||
llm_chain=qa_chain, | ||
llm_chain=question_answering_chain, | ||
document_variable_name="context", | ||
document_prompt=doc_prompt, | ||
document_prompt=context_with_docs_prompt, | ||
) | ||
|
||
return ConversationalRetrievalChain( | ||
chain = ConversationalRetrievalChain( | ||
question_generator=condense_question_chain, | ||
retriever=docsearch.as_retriever(search_kwargs={"k": 10}), | ||
retriever=docsearch.as_retriever(search_kwargs={"k": config["vector_store_provider"]["documents_to_retreive"]}), | ||
memory=memory, | ||
combine_docs_chain=final_qa_chain, | ||
verbose=True, | ||
) | ||
|
||
return chain, callback_handler | ||
|
||
if __name__ == "__main__": | ||
chat_id = "test" | ||
config = get_config() | ||
llm = get_llm_model(config) | ||
embeddings = get_embedding_model(config) | ||
vector_store = get_vector_store(embeddings, config) | ||
memory = get_conversation_buffer_memory(config, chat_id) | ||
answer_chain = get_answer_chain(llm, vector_store, memory) | ||
|
||
prompt = "Give me the top 5 bilionnaires in france based on their worth in order of decreasing net worth" | ||
response = get_response(answer_chain, prompt) | ||
print("Prompt: ", prompt) | ||
print("Response: ", response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,17 @@ | ||
from langchain import chat_models | ||
from langchain.callbacks import AsyncIteratorCallbackHandler | ||
|
||
|
||
def get_llm_model(config): | ||
def get_llm_model(config, streaming=False): | ||
llm_spec = getattr(chat_models, config["llm_model_config"]["model_source"]) | ||
all_config_field = {**config["llm_model_config"], **config["llm_provider_config"]} | ||
kwargs = { | ||
key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys() | ||
} | ||
return llm_spec(**kwargs) | ||
if streaming: | ||
kwargs["streaming"] = streaming | ||
callback_handler = AsyncIteratorCallbackHandler() | ||
kwargs["callbacks"] = [callback_handler] | ||
return llm_spec(**kwargs), callback_handler | ||
else: | ||
return llm_spec(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
condense_history = """ | ||
Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns. | ||
Chat history : | ||
{chat_history} | ||
Question : {question} | ||
Rephrased question : | ||
""" | ||
|
||
rag_system_prompt = """ | ||
As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. | ||
It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. | ||
""" | ||
|
||
respond_to_question = """ | ||
Respond to the question taking into account the following context. | ||
{context} | ||
Question: {question} | ||
""" | ||
|
||
document_context = """ | ||
Content: {page_content} | ||
Source: {source} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters