Skip to content

Commit

Permalink
Gros cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
linogaliana committed Nov 20, 2024
1 parent 1aa62aa commit a6030de
Show file tree
Hide file tree
Showing 27 changed files with 426 additions and 872 deletions.
47 changes: 11 additions & 36 deletions src/chain_building/build_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def format_docs(docs: list):


# Define the compression function
def compress_documents_lambda(
documents: Sequence[Document], query: str, k: int = 5, **kwargs: dict[str, Any]
) -> Sequence[Document]:
def compress_documents_lambda(documents: Sequence[Document], query: str, k: int = 5, **kwargs: dict[str, Any]) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""

# Initialize the retriever with the documents
Expand All @@ -58,17 +56,13 @@ def build_chain(

accepted_rerankers = [None, "BM25", "Cross-encoder", "Ensemble"]
if reranker not in accepted_rerankers:
raise ValueError(
f"Invalid reranker: {reranker}. Accepted values are: {', '.join(accepted_rerankers)}"
)
raise ValueError(f"Invalid reranker: {reranker}. Accepted values are: {', '.join(accepted_rerankers)}")

# Define the retrieval reranker strategy
if reranker is None:
retrieval_agent = retriever
elif reranker == "BM25":
retrieval_agent = RunnableParallel(
{"documents": retriever, "query": RunnablePassthrough()}
) | RunnableLambda(
retrieval_agent = RunnableParallel({"documents": retriever, "query": RunnablePassthrough()}) | RunnableLambda(
lambda r: compress_documents_lambda(
documents=r["documents"],
query=r["query"],
Expand All @@ -77,17 +71,11 @@ def build_chain(
)
elif reranker == "Cross-encoder":
model = HuggingFaceCrossEncoder(model_name=RERANKER_CROSS_ENCODER)
compressor = CrossEncoderReranker(
model=model, top_n=number_candidates_reranking
)
retrieval_agent = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
compressor = CrossEncoderReranker(model=model, top_n=number_candidates_reranking)
retrieval_agent = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
elif reranker == "Ensemble":
# BM25
reranker_1 = RunnableParallel(
{"documents": retriever, "query": RunnablePassthrough()}
) | RunnableLambda(
reranker_1 = RunnableParallel({"documents": retriever, "query": RunnablePassthrough()}) | RunnableLambda(
lambda r: compress_documents_lambda(
documents=r["documents"],
query=r["query"],
Expand All @@ -99,31 +87,18 @@ def build_chain(
model=HuggingFaceCrossEncoder(model_name=RERANKER_CROSS_ENCODER),
top_n=number_candidates_reranking,
)
reranker_2 = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
reranker_2 = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)

retrieval_agent = EnsembleRetriever(
retrievers=[reranker_1, reranker_2], weigths=[1 / 2, 1 / 2]
)
retrieval_agent = EnsembleRetriever(retrievers=[reranker_1, reranker_2], weigths=[1 / 2, 1 / 2])
else:
raise ValueError(f"Reranking method {reranker} is not implemented.")

if llm is not None:
# Create a Langchain LLM Chain
chain = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| prompt
| llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context": retrieval_agent, "question": RunnablePassthrough()}
).assign(answer=chain)
chain = RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | prompt | llm | StrOutputParser()
rag_chain_with_source = RunnableParallel({"context": retrieval_agent, "question": RunnablePassthrough()}).assign(answer=chain)
else:
# retriever mode
rag_chain_with_source = RunnableParallel(
{"context": retrieval_agent, "question": RunnablePassthrough()}
)
rag_chain_with_source = RunnableParallel({"context": retrieval_agent, "question": RunnablePassthrough()})

return rag_chain_with_source
10 changes: 2 additions & 8 deletions src/chain_building/build_chain_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,7 @@ def build_chain_validator(evaluator_llm=None, tokenizer=None):
defining a chain to check if a given query is related to INSEE expertise.
"""

prompt_template = tokenizer.apply_chat_template(
EVAL_TEMPLATE, tokenize=False, add_generation_prompt=True
)
prompt_template = tokenizer.apply_chat_template(EVAL_TEMPLATE, tokenize=False, add_generation_prompt=True)
prompt = PromptTemplate(template=prompt_template, input_variables=["query"])

return (
prompt
| evaluator_llm
| RunnableLambda(func=lambda generation: generation.lower().find("oui") != -1)
)
return prompt | evaluator_llm | RunnableLambda(func=lambda generation: generation.lower().find("oui") != -1)
5 changes: 2 additions & 3 deletions src/data/anonymize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re

import pandas as pd

from utils import fs


Expand Down Expand Up @@ -86,9 +87,7 @@ def anonymize_insee_contact_message(message: str, message_ner: list[dict]) -> st
if dictionary["entity_group"] == "PER":
message = message.replace(dictionary["word"], "[PER]")
elif dictionary["signature"]:
message = message.replace(
dictionary["word"], f"[{dictionary['entity_group']}]"
)
message = message.replace(dictionary["word"], f"[{dictionary['entity_group']}]")

# Identification of email addresses
email_regex = r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+"
Expand Down
14 changes: 4 additions & 10 deletions src/data/create_evaluation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pandas as pd
from constants import LS_ANNOTATIONS_PATH

from utils import fs


Expand All @@ -28,10 +29,7 @@ def create_insee_contact_eval_dataset():
entry_urls = []
# Parse annotations
for result_element in annotation_data["result"]:
if (
result_element["from_name"] == "keep_pair"
and "O" in result_element["value"]["choices"]
):
if result_element["from_name"] == "keep_pair" and "O" in result_element["value"]["choices"]:
keep_pair = True
elif result_element["from_name"] == "urls":
entry_urls += result_element["value"]["text"]
Expand All @@ -43,12 +41,8 @@ def create_insee_contact_eval_dataset():
answers.append(answer)
urls.append("|".join(entry_urls))

with fs.open(
"projet-llm-insee-open-data/data/eval_data/eval_dataset_insee_contact.csv", "w"
) as f:
pd.DataFrame({"questions": questions, "answers": answers, "urls": urls}).to_csv(
f, index=False
)
with fs.open("projet-llm-insee-open-data/data/eval_data/eval_dataset_insee_contact.csv", "w") as f:
pd.DataFrame({"questions": questions, "answers": answers, "urls": urls}).to_csv(f, index=False)


if __name__ == "__main__":
Expand Down
8 changes: 2 additions & 6 deletions src/data/insee_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def process_insee_contact_data(path: str):
"""
Process raw Insee contact data.
"""
fs = s3fs.S3FileSystem(
client_kwargs={"endpoint_url": "https://" + os.environ["AWS_S3_ENDPOINT"]}
)
fs = s3fs.S3FileSystem(client_kwargs={"endpoint_url": "https://" + os.environ["AWS_S3_ENDPOINT"]})

with fs.open(path) as f:
df = pd.read_csv(f)
Expand All @@ -26,9 +24,7 @@ def process_insee_contact_data(path: str):
df_eval = df.sample(200, random_state=42)

# Save to s3
with fs.open(
"projet-llm-insee-open-data/data/insee_contact/data_2019_eval.csv", "w"
) as f:
with fs.open("projet-llm-insee-open-data/data/insee_contact/data_2019_eval.csv", "w") as f:
df_eval.to_csv(f, index=False)


Expand Down
80 changes: 0 additions & 80 deletions src/data/ner.py

This file was deleted.

5 changes: 2 additions & 3 deletions src/data/to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anonymize import anonymize_insee_contact_message
from constants import LS_DATA_PATH, RAW_DATA
from ner import ner_series

from utils import create_ls_task, fs


Expand Down Expand Up @@ -38,9 +39,7 @@ def insee_contact_to_s3():
anonymized_answers.append(anonymize_insee_contact_message(message, ner))

# Json tasks creation
for idx, (question, answer) in enumerate(
zip(anonymized_questions, anonymized_answers, strict=False)
):
for idx, (question, answer) in enumerate(zip(anonymized_questions, anonymized_answers, strict=False)):
ls_task = create_ls_task(question, answer)
with fs.open(LS_DATA_PATH + f"{idx}.json", "w") as f:
json.dump(ls_task, f)
Expand Down
4 changes: 1 addition & 3 deletions src/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import s3fs

fs = s3fs.S3FileSystem(
client_kwargs={"endpoint_url": "https://" + os.environ["AWS_S3_ENDPOINT"]}
)
fs = s3fs.S3FileSystem(client_kwargs={"endpoint_url": "https://" + os.environ["AWS_S3_ENDPOINT"]})


def create_ls_task(
Expand Down
Loading

0 comments on commit a6030de

Please sign in to comment.