Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
linogaliana committed Nov 20, 2024
1 parent a6030de commit b815a5c
Show file tree
Hide file tree
Showing 6 changed files with 43,409 additions and 221 deletions.
63 changes: 19 additions & 44 deletions app-minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
max_new_tokens = 10
model_temperature = 1
embedding = os.getenv("EMB_MODEL_NAME", EMB_MODEL_NAME)
CHROMA_DB_LOCAL_DIRECTORY = 'data/chroma_db'
CHROMA_DB_LOCAL_DIRECTORY = "data/chroma_db"

os.environ['MLFLOW_TRACKING_URI'] = "https://projet-llm-insee-open-data-mlflow.user.lab.sspcloud.fr/"
os.environ["MLFLOW_TRACKING_URI"] = "https://projet-llm-insee-open-data-mlflow.user.lab.sspcloud.fr/"
fs = s3fs.S3FileSystem(client_kwargs={"endpoint_url": f"""https://{os.environ["AWS_S3_ENDPOINT"]}"""})

# APPLICATION -----------------------------------------
Expand All @@ -42,9 +42,7 @@
@cl.on_chat_start
async def on_chat_start():
# Initial message
init_msg = cl.Message(
content="Bienvenue sur le ChatBot de l'INSEE!"
)
init_msg = cl.Message(content="Bienvenue sur le ChatBot de l'INSEE!")
await init_msg.send()

# Logging configuration
Expand All @@ -59,27 +57,18 @@ async def on_chat_start():
],
).send()
if res and res.get("value") == "log":
await cl.Message(
content="Vous avez choisi de partager vos interactions."
).send()
await cl.Message(content="Vous avez choisi de partager vos interactions.").send()
if res and res.get("value") == "no_log":
IS_LOGGING_ON = False
await cl.Message(
content="Vous avez choisi de garder vos interactions avec le ChatBot privées."
).send()
await cl.Message(content="Vous avez choisi de garder vos interactions avec le ChatBot privées.").send()
cl.user_session.set("IS_LOGGING_ON", IS_LOGGING_ON)


# Build chat model
RETRIEVER_ONLY = str_to_bool(os.getenv("RETRIEVER_ONLY", "false"))
cl.user_session.set("RETRIEVER_ONLY", RETRIEVER_ONLY)
logging.info("------ chatbot mode : RAG")


db = await cl.make_async(load_vector_database)(
filesystem=fs,
database_run_id="32d4150a14fa40d49b9512e1f3ff9e8c"
)
db = await cl.make_async(load_vector_database)(filesystem=fs, database_run_id="32d4150a14fa40d49b9512e1f3ff9e8c")

llm, tokenizer = await cl.make_async(build_llm_model)(
model_name=model,
Expand All @@ -101,19 +90,15 @@ async def on_chat_start():
cl.user_session.set("validator", validator)
logging.info("------validator loaded")

RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
CHATBOT_TEMPLATE, tokenize=False, add_generation_prompt=True
)
prompt = PromptTemplate(
input_variables=["context", "question"], template=RAG_PROMPT_TEMPLATE
)
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(CHATBOT_TEMPLATE, tokenize=False, add_generation_prompt=True)
prompt = PromptTemplate(input_variables=["context", "question"], template=RAG_PROMPT_TEMPLATE)
logging.info("------prompt loaded")
retriever, vectorstore = await cl.make_async(load_retriever)(
emb_model_name="OrdalieTech/Solon-embeddings-large-0.1",
vectorstore=db,
persist_directory=CHROMA_DB_LOCAL_DIRECTORY,
retriever_params={"search_type": "similarity", "search_kwargs": {"k": 30}},
)
emb_model_name="OrdalieTech/Solon-embeddings-large-0.1",
vectorstore=db,
persist_directory=CHROMA_DB_LOCAL_DIRECTORY,
retriever_params={"search_type": "similarity", "search_kwargs": {"k": 30}},
)
logging.info("------retriever loaded")
logging.info(f"----- {len(vectorstore.get()['documents'])} documents")

Expand All @@ -140,9 +125,7 @@ async def on_message(message: cl.Message):
Handle incoming messages and process the response using the RAG chain.
"""
validator = cl.user_session.get("validator")
test_relevancy = await check_query_relevance(
validator=validator, query=message.content
)
test_relevancy = await check_query_relevance(validator=validator, query=message.content)
if test_relevancy:
# Retrieve the chain from the user session
chain = cl.user_session.get("chain")
Expand All @@ -155,9 +138,7 @@ async def on_message(message: cl.Message):
# Generate ChatBot's answer
async for chunk in chain.astream(
message.content,
config=RunnableConfig(
callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)]
),
config=RunnableConfig(callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)]),
):
if "answer" in chunk:
await answer_msg.stream_token(chunk["answer"])
Expand Down Expand Up @@ -188,24 +169,18 @@ async def on_message(message: cl.Message):
thread_id=message.thread_id,
message_id=sources_msg.id,
user_query=message.content,
generated_answer=(
None if cl.user_session.get("RETRIEVER_ONLY") else generated_answer
),
generated_answer=(None if cl.user_session.get("RETRIEVER_ONLY") else generated_answer),
retrieved_documents=docs,
embedding_model_name=embedding_model_name,
LLM_name=None if cl.user_session.get("RETRIEVER_ONLY") else LLM_name,
reranker=reranker,
)
else:
await cl.Message(
content=f"Votre requête '{message.content}' ne concerne pas les domaines d'expertise de l'INSEE."
).send()
await cl.Message(content=f"Votre requête '{message.content}' ne concerne pas les domaines d'expertise de l'INSEE.").send()


async def check_query_relevance(validator, query):
result = await validator.ainvoke(
query, config=RunnableConfig(callbacks=[cl.AsyncLangchainCallbackHandler()])
)
result = await validator.ainvoke(query, config=RunnableConfig(callbacks=[cl.AsyncLangchainCallbackHandler()]))
return result


Expand All @@ -220,4 +195,4 @@ async def upsert_feedback(self, feedback: cl_data.Feedback) -> str:


# Enable data persistence for human feedbacks
cl_data._data_layer = CustomDataLayer()
cl_data._data_layer = CustomDataLayer()
49 changes: 10 additions & 39 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

from src.config import CHATBOT_TEMPLATE, EMB_MODEL_NAME
from utils import (
format_docs, create_prompt_from_instructions,
retrieve_llm_from_cache,
retrieve_db_from_cache
format_docs,
create_prompt_from_instructions,
# retrieve_llm_from_cache,
retrieve_db_from_cache,
)

# Logging configuration
Expand All @@ -32,12 +33,8 @@
)

# Remote file configuration
os.environ["MLFLOW_TRACKING_URI"] = (
"https://projet-llm-insee-open-data-mlflow.user.lab.sspcloud.fr/"
)
fs = s3fs.S3FileSystem(
client_kwargs={"endpoint_url": f"""https://{os.environ["AWS_S3_ENDPOINT"]}"""}
)
os.environ["MLFLOW_TRACKING_URI"] = "https://projet-llm-insee-open-data-mlflow.user.lab.sspcloud.fr/"
fs = s3fs.S3FileSystem(client_kwargs={"endpoint_url": f"""https://{os.environ["AWS_S3_ENDPOINT"]}"""})

# PARAMETERS --------------------------------------

Expand Down Expand Up @@ -92,28 +89,14 @@
"""


prompt = create_prompt_from_instructions(
system_instructions, question_instructions
)
prompt = create_prompt_from_instructions(system_instructions, question_instructions)


# CHAT START -------------------------------

def format_docs2(docs: list):
return "\n\n".join(
[
f"""
Doc {i + 1}:\nTitle: {doc.metadata.get("Header 1")}\n
Source: {doc.metadata.get("url")}\n
Content:\n{doc.page_content}
"""
for i, doc in enumerate(docs)
]
)

@cl.on_chat_start
async def on_chat_start():

# Initial message
init_msg = cl.Message(content="Bienvenue sur le ChatBot de l'INSEE!")
await init_msg.send()
Expand All @@ -124,10 +107,7 @@ async def on_chat_start():

logging.info("------ database loaded")

db = retrieve_db_from_cache(
filesystem=fs,
run_id=DATABASE_RUN_ID
)
db = retrieve_db_from_cache(filesystem=fs, run_id=DATABASE_RUN_ID)

logging.info("------ database loaded")

Expand All @@ -153,12 +133,7 @@ async def on_chat_start():

logging.info("------ VLLM object ready")

rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser()

cl.user_session.set("rag_chain", rag_chain)

Expand All @@ -168,16 +143,12 @@ async def on_chat_start():

@cl.on_message
async def on_message(message: cl.Message):

rag_chain = cl.user_session.get("rag_chain")

answer_msg = cl.Message(content="")

async for chunk in rag_chain.astream(
message.content,
config=RunnableConfig(
callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)]
)
message.content, config=RunnableConfig(callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)])
):
await answer_msg.send()
await cl.sleep(1)
Expand Down
Loading

0 comments on commit b815a5c

Please sign in to comment.