Skip to content

Commit

Permalink
Finalized refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
gferey committed Nov 21, 2024
1 parent 9b6992e commit 377c80f
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 111 deletions.
19 changes: 7 additions & 12 deletions app-minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@

from src.chain_building.build_chain import build_chain
from src.chain_building.build_chain_validator import build_chain_validator
from src.config import load_config
from src.config import DefaultFullConfig, FullConfig, process_args
from src.db_building import load_retriever, load_vector_database
from src.model_building import build_llm_model
from src.results_logging.log_conversations import log_qa_to_s3
from src.utils.formatting_utilities import add_sources_to_messages, get_chatbot_template, str_to_bool

# Configuration and initial setup
config = load_config()["DEFAULT"]
process_args()
config: FullConfig = DefaultFullConfig()
logger = logging.getLogger(__name__)
fs = s3fs.S3FileSystem(endpoint_url=config["s3_endpoint_url"])
fs = s3fs.S3FileSystem(endpoint_url=config.s3_endpoint_url)

# APPLICATION -----------------------------------------

Expand Down Expand Up @@ -53,7 +54,7 @@ async def on_chat_start():

db = await cl.make_async(load_vector_database)(filesystem=fs, config=config)
llm, tokenizer = await cl.make_async(build_llm_model)(
model_name=config["llm_model"],
model_name=config.llm_model,
streaming=False,
config=config,
)
Expand All @@ -79,11 +80,7 @@ async def on_chat_start():

# Build chain
chain = build_chain(
retriever=retriever,
prompt=prompt,
llm=llm,
bool_log=IS_LOGGING_ON,
reranker=config.get("RERANKING_METHOD") or None,
retriever=retriever, prompt=prompt, llm=llm, bool_log=IS_LOGGING_ON, reranker=config.reranking_method or None
)
cl.user_session.set("chain", chain)
logger.info("------chain built")
Expand Down Expand Up @@ -138,9 +135,7 @@ async def on_message(message: cl.Message):
user_query=message.content,
generated_answer=None if cl.user_session.get("RETRIEVER_ONLY") else generated_answer,
retrieved_documents=docs,
embedding_model_name=config["emb_model"],
LLM_name=None if cl.user_session.get("RETRIEVER_ONLY") else config["LLM_MODEL_NAME"],
reranker=config.get("RERANKING_METHOD"),
LLM_name=None if cl.user_session.get("RETRIEVER_ONLY") else config.llm_model,
config=config,
)
else:
Expand Down
15 changes: 5 additions & 10 deletions app2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

# Logging, configuration and S3
args = process_args()
config = DefaultFullConfig()
logger = logging.getLogger(__name__)
fs = s3fs.S3FileSystem(endpoint_url=DefaultFullConfig().s3_endpoint_url)
fs = s3fs.S3FileSystem(endpoint_url=config.s3_endpoint_url)

# PARAMETERS --------------------------------------

CLI_MESSAGE_SEPARATOR = (DefaultFullConfig().cli_message_separator_length * "-") + " \n"
CLI_MESSAGE_SEPARATOR = (config.cli_message_separator_length * "-") + " \n"

# APPLICATION -----------------------------------------

Expand All @@ -35,7 +36,7 @@ def retrieve_model_tokenizer_and_db(filesystem=fs, with_db=True) -> tuple[Huggin

# Load LLM in session
llm, tokenizer = build_llm_model(
model_name=DefaultFullConfig().llm_model,
model_name=config.llm_model,
streaming=False,
)
return (
Expand Down Expand Up @@ -171,20 +172,14 @@ async def on_message(message: cl.Message):

# Log Q/A
if cl.user_session.get("IS_LOGGING_ON"):
embedding_model_name = os.getenv("EMB_MODEL_NAME")
LLM_name = os.getenv("LLM_MODEL_NAME")
reranker = os.getenv("RERANKING_METHOD")

log_qa_to_s3(
filesystem=fs,
thread_id=message.thread_id,
message_id=sources_msg.id,
user_query=message.content,
generated_answer=(None if cl.user_session.get("RETRIEVER_ONLY") else generated_answer),
retrieved_documents=docs,
embedding_model_name=embedding_model_name,
LLM_name=None if cl.user_session.get("RETRIEVER_ONLY") else LLM_name,
reranker=reranker,
llm_name=None if cl.user_session.get("RETRIEVER_ONLY") else config.llm_model,
)
else:
await cl.Message(
Expand Down
8 changes: 4 additions & 4 deletions notebooks/retrieval_eval_exp_data_size.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
" elif mode==\"top\":\n",
" sample_langchain_docs = get_top_n_documents_with_largest_content(langchain_docs, n=top_n) \n",
"\n",
" autokenizer, chunk_size, chunk_overlap = compute_autokonenizer_chunk_size(config.get(\"embedding_model_name\"))\n",
" autokenizer, chunk_size, chunk_overlap = compute_autokonenizer_chunk_size(config.emb_model)\n",
" \n",
" text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(\n",
" autokenizer,\n",
Expand All @@ -230,8 +230,8 @@
" docs_processed_unique.append(doc)\n",
"\n",
" embedding_model = HuggingFaceEmbeddings( # load from sentence transformers\n",
" model_name=config.get(\"embedding_model_name\"),\n",
" model_kwargs={\"device\": EMB_DEVICE, \"trust_remote_code\": True},\n",
" model_name=config.emb_model,\n",
" model_kwargs={\"device\": config.emb_device, \"trust_remote_code\": True},\n",
" encode_kwargs={\"normalize_embeddings\": True}, # set True for cosine similarity\n",
" show_progress=False,\n",
" )\n",
Expand All @@ -242,7 +242,7 @@
" for i in range(0, len(docs_processed_unique), max_batch_size):\n",
" batch_docs = docs_processed_unique[i:i + max_batch_size]\n",
" db = Chroma.from_documents(\n",
" collection_name=config.get(\"collection\"),\n",
" collection_name=config.collection_name,\n",
" documents=batch_docs,\n",
" persist_directory=DB_DIR_LOCAL,\n",
" embedding=embedding_model,\n",
Expand Down
2 changes: 1 addition & 1 deletion run_build_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run_build_database(config: FullConfig = DefaultFullConfig()) -> None:
df, all_splits = build_or_load_document_database(filesystem, config)

# Try to simply load the vector database
db = load_vector_database(filesystem, config)
db = None if config.force_rebuild else load_vector_database(filesystem, config)
if db is None:
# If no cached database found: rebuild from documents
db = build_vector_database(
Expand Down
18 changes: 8 additions & 10 deletions run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,9 @@ def run_evaluation(filesystem: s3fs.S3FileSystem, config: FullConfig = DefaultFu
# ------------------------
# IV - RERANKER

reranking_method = config.get("reranking_method")

if reranking_method is not None:
if config.reranking_method is not None:
logger.info(f"Applying reranking {80*'='}")
logger.info(f"Selected method: {reranking_method}")
logger.info(f"Selected method: {config.reranking_method}")

# Define a langchain prompt template
RAG_PROMPT_TEMPLATE_RERANKER = tokenizer.apply_chat_template(
Expand All @@ -116,7 +114,7 @@ def run_evaluation(filesystem: s3fs.S3FileSystem, config: FullConfig = DefaultFu
retriever=retriever,
prompt=prompt,
llm=llm,
reranker=reranking_method,
reranker=config.reranking_method,
)
else:
logger.info(f"Skipping reranking since value is None {80*'='}")
Expand All @@ -126,9 +124,9 @@ def run_evaluation(filesystem: s3fs.S3FileSystem, config: FullConfig = DefaultFu

logger.info(f"Evaluating model performance against expectations {80*'='}")

if reranking_method is None:
if config.reranking_method is None:
answers_bot = answer_faq_by_bot(retriever, faq)
eval_reponses_bot, answers_bot_topk = transform_answers_bot(answers_bot, k=int(config.get("topk_stats")))
eval_reponses_bot, answers_bot_topk = transform_answers_bot(answers_bot, k=config.topk_stats)
else:
answers_bot_before_reranker = answer_faq_by_bot(retriever, faq)
eval_reponses_bot_before_reranker, answers_bot_topk_before_reranker = transform_answers_bot(
Expand All @@ -147,12 +145,12 @@ def run_evaluation(filesystem: s3fs.S3FileSystem, config: FullConfig = DefaultFu
document_among_topk = answers_bot_topk["cumsum_url_expected"].max()
document_is_top = answers_bot_topk["cumsum_url_expected"].min()
# Also compute model performance before reranking when relevant
if reranking_method is not None:
if config.reranking_method is not None:
document_among_topk_before_reranker = answers_bot_topk_before_reranker["cumsum_url_expected"].max()
document_is_top_before_reranker = answers_bot_topk_before_reranker["cumsum_url_expected"].min()

# Store FAQ
mlflow_faq_raw = mlflow.data.from_pandas(faq, source=config["faq_s3_uri"], name="FAQ_data")
mlflow_faq_raw = mlflow.data.from_pandas(faq, source=config.faq_s3_uri, name="FAQ_data")
mlflow.log_input(mlflow_faq_raw, context="faq-raw")
mlflow.log_table(data=faq, artifact_file="faq_data.json")

Expand All @@ -168,7 +166,7 @@ def run_evaluation(filesystem: s3fs.S3FileSystem, config: FullConfig = DefaultFu
mlflow.log_table(data=eval_reponses_bot, artifact_file="output/eval_reponses_bot.json")

# If we used reranking, we also store performance before reranking
if reranking_method is not None:
if config.reranking_method is not None:
mlflow.log_metric("document_is_first_before_reranker", 100 * document_is_top_before_reranker)
mlflow.log_metric("document_among_topk_before_reranker", 100 * document_among_topk_before_reranker)
mlflow.log_metrics(
Expand Down
64 changes: 38 additions & 26 deletions src/config/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import os
from dataclasses import dataclass
from functools import wraps

import mlflow
import toml
Expand Down Expand Up @@ -111,45 +112,56 @@ def custom_config(defaults: dict | None = None, overrides: dict | None = None):

class Configurable:
"""
Decorator for function with a special "configuration" argument.
Decorator for functions with a special "configuration" argument.
The configuration argument must be a keyword (named) argument (after the / special argument)
"""

# The decorator is initialised with the configuration argument name
def __init__(self, config_param: str = "config"):
self.config_param = config_param
def __init__(self, config_arg_name: str = "config"):
"""Decorator initialised with the configuration argument name"""
self.config_arg_name = config_arg_name

def __call__(self, f):
sig = inspect.signature(f)
# The original arguments from the annotated function
declared_parameters = inspect.signature(f).parameters
declared_parameters = sig.parameters
# The configuration argument's annotation (required) and default value (optional)
config_default = declared_parameters[self.config_param].default
config_class = declared_parameters[self.config_param].annotation
config_default = declared_parameters[self.config_arg_name].default
config_class = declared_parameters[self.config_arg_name].annotation
# All fields from the config's class
config_parameters = config_class.model_fields.keys()
# All config fields without the ones already explicitely specified in the decorated function's arguments
overridable_params = set(config_parameters) - set(declared_parameters.keys())

# The returned function
def new_f(*args, **kwargs):
# Get original config argument from kwargs (and remove it) or use default
orig_config = kwargs.pop(self.config_param, config_default)
# TODO: do not hcange anything if no special overriding argument were provided!

# Dict with the original config overriden with keyword args (which are removed from the kwargs dict)
new_config_params = {
k: kwargs.pop(k) if k in kwargs and k in overridable_params else getattr(orig_config, k)
for k in config_parameters
}
# Updated config object buily by passing overriden parameters to the config_class constructor
# NOTE: this only works if the class annotation of config can be built this way (e.g. pydantic BaseModel)
new_config = config_class(**new_config_params)
# Set the config_param to the updated object
kwargs[self.config_param] = new_config
# Call the original function with the updated config object
return f(*args, **kwargs)

return new_f
@wraps(f)
def wrapped_f(*args, **kwargs):
# Map args and kwargs to a single dict of named variables
ba = sig.bind(*args, **kwargs)
override_args = [k for k in ba.arguments if k in overridable_params]
if not override_args:
# If no overriding argument is provided, simply call f
return f(*args, **kwargs)
else:
# Get original config argument from kwargs (and remove it) or use default
orig_config = ba.arguments.get(self.config_arg_name, config_default)

# Dict with the original config overriden with keyword args (which are removed from the kwargs dict)
new_config_params = {
k: ba.arguments[k] if k in override_args else getattr(orig_config, k) for k in config_parameters
}
for k in override_args:
del ba.arguments[k]
# Updated config object buily by passing overriden parameters to the config_class constructor
# NOTE: this only works if the class annotation of config can be built this way
# # (e.g. pydantic BaseModel)
new_config = config_class(**new_config_params)
# Set the config_param to the updated object
ba.arguments[self.config_param] = new_config
# Call the original function with the updated config object
return f(*ba.args, **ba.kwargs)

return wrapped_f


# Example:
Expand Down
1 change: 1 addition & 0 deletions src/config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ max_new_tokens = 2000

faq_s3_path = "data/FAQ_site/faq.parquet"
faq_s3_uri = "s3://{s3_bucket}/{faq_s3_path}" # (templated)
topk_stats = 3

# INSTRUCTION PROMPT ----------------------------------------------------------

Expand Down
4 changes: 3 additions & 1 deletion src/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class FullConfig(BaseConfig, metaclass=BaseConfigMetaclass):
# EVALUATION
faq_s3_path: str
faq_s3_uri: str # (Templated)
topk_stats: int

# INSTRUCTION PROMPT
BASIC_RAG_PROMPT_TEMPLATE: str
Expand All @@ -98,7 +99,8 @@ class FullConfig(BaseConfig, metaclass=BaseConfigMetaclass):
reranking_method: str | None = None
retriever_only: bool | None = None

# Allow the 'None' string to represent None values for all optional parameters (MLFlow import requires this)
# Allow the 'None' string to represent None values for all optional parameters
# (importing from MLFlow requires this)
@validator("chunk_size", "chunk_overlap", "max_pages", "reranking_method", "retriever_only", pre=True)
def allow_none(cls, data: Any) -> int | None:
return None if data == "None" else data
Expand Down
4 changes: 2 additions & 2 deletions src/db_building/build_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def build_vector_database(
# Loop through the chunks and build the Chroma database
try:
db = Chroma(
collection_name=config["collection_name"],
persist_directory=config["chroma_db_local_path"],
collection_name=config.collection_name,
persist_directory=config.chroma_db_local_path,
embedding_function=emb_model,
client_settings=Settings(anonymized_telemetry=False, is_persistent=True),
)
Expand Down
Loading

0 comments on commit 377c80f

Please sign in to comment.