Skip to content

Commit

Permalink
Updated configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 21, 2024
1 parent d705150 commit 004fac9
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 291 deletions.
75 changes: 66 additions & 9 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from sentence_transformers import SentenceTransformer

from bertrend import MODELS_DIR, CACHE_PATH
from bertrend.demos.weak_signals.messages import (
NO_GRANULARITY_WARNING,
)
from bertrend.demos.weak_signals.session_state_manager import SessionStateManager
from bertrend.topic_model import TopicModel
from bertrend.parameters import (
DEFAULT_MIN_SIMILARITY,
Expand Down Expand Up @@ -194,7 +198,6 @@ def train_topic_models(
grouped_data (Dict[pd.Timestamp, pd.DataFrame]): Dictionary of grouped data by timestamp.
embedding_model (SentenceTransformer): Sentence transformer model for embeddings.
embeddings (np.ndarray): Precomputed document embeddings.
granularity (int):
"""
# TODO from topic_modelling = train_topic_models (modulo data transformation)
# TODO rename to fit?
Expand Down Expand Up @@ -267,7 +270,6 @@ def merge_models(

# progress_bar = st.progress(0)
merge_df_size_over_time = []
# SessionStateManager.set("merge_df_size_over_time", [])

for i, (current_timestamp, next_timestamp) in enumerate(
zip(timestamps[:-1], timestamps[1:])
Expand Down Expand Up @@ -312,11 +314,7 @@ def merge_models(
merged_df_without_outliers["Topic"].max() + 1,
)
)
#
# SessionStateManager.update(
# "merge_df_size_over_time", merge_df_size_over_time
# )
#

# progress_bar.progress((i + 1) / len(timestamps))

all_merge_histories_df = pd.concat(all_merge_histories, ignore_index=True)
Expand All @@ -340,7 +338,7 @@ def calculate_signal_popularity(
Updates:
- topic_sizes (Dict[int, Dict[str, Any]]): Dictionary storing topic sizes and related information over time.
- topic_last_popularity (Dict[int, float]): Dictionary storing the last known popularity of each topic.
- topic_last_update (Dict[int, pd.Timestamp]]): Dictionary storing the last update timestamp of each topic.
- topic_last_update (Dict[int, pd.Timestamp]): Dictionary storing the last update timestamp of each topic.
Args:
all_merge_histories_df (pd.DataFrame): DataFrame containing all merge histories.
Expand Down Expand Up @@ -444,7 +442,7 @@ def save_models(self):
embedding_model = SessionStateManager.get("embedding_model")
topic_model.save(
model_dir,
serialization="safetensors",
serialization=BERTOPIC_SERIALIZATION,
save_ctfidf=False,
save_embedding_model=embedding_model,
)
Expand Down Expand Up @@ -482,6 +480,65 @@ def save_models(self):
pickle.dump(hyperparams, f)
"""

@classmethod
def restore_models(cls):
if not MODELS_DIR.exists():
raise FileNotFoundError(f"MODELS_DIR={MODELS_DIR} does not exist")

topic_models = {}
for period_dir in MODELS_DIR.iterdir():
if period_dir.is_dir():
topic_model = BERTopic.load(period_dir)

doc_info_df_file = period_dir / DOC_INFO_DF_FILE
topic_info_df_file = period_dir / TOPIC_INFO_DF_FILE
if doc_info_df_file.exists() and topic_info_df_file.exists():
topic_model.doc_info_df = pd.read_pickle(doc_info_df_file)
topic_model.topic_info_df = pd.read_pickle(topic_info_df_file)
else:
logger.warning(
f"doc_info_df or topic_info_df not found for period {period_dir.name}"
)

period = pd.Timestamp(period_dir.name.replace("_", ":"))
topic_models[period] = topic_model

SessionStateManager.set("topic_models", topic_models)

for file, key in [
(DOC_GROUPS_FILE, "doc_groups"),
(EMB_GROUPS_FILE, "emb_groups"),
]:
file_path = CACHE_PATH / file
if file_path.exists():
with open(file_path, "rb") as f:
SessionStateManager.set(key, pickle.load(f))
else:
logger.warning(f"{file} not found.")

granularity_file = CACHE_PATH / GRANULARITY_FILE
if granularity_file.exists():
with open(granularity_file, "rb") as f:
SessionStateManager.set("granularity_select", pickle.load(f))
else:
logger.warning(NO_GRANULARITY_WARNING)

# Restore the models_trained flag
models_trained_file = CACHE_PATH / MODELS_TRAINED_FILE
if models_trained_file.exists():
with open(models_trained_file, "rb") as f:
# FIXME! set bertrend first!
SessionStateManager.set("models_trained", pickle.load(f))
else:
logger.warning("Models trained flag not found.")

hyperparams_file = CACHE_PATH / HYPERPARAMS_FILE
if hyperparams_file.exists():
with open(hyperparams_file, "rb") as f:
SessionStateManager.set_multiple(**pickle.load(f))
else:
logger.warning("Hyperparameters file not found.")

#####################################################################################################
# FIXME: WIP
# def merge_models2(self):
Expand Down
2 changes: 2 additions & 0 deletions bertrend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# Read config
BERTREND_CONFIG = load_toml_config(BERTREND_CONFIG_PATH)
PARAMETERS_CONFIG = BERTREND_CONFIG["parameters"]
EMBEDDING_CONFIG = BERTREND_CONFIG["embedding_service"]
LLM_CONFIG = BERTREND_CONFIG["llm_service"]

Expand Down Expand Up @@ -48,4 +49,5 @@

# Create directories if they do not exist
DATA_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
CACHE_PATH.mkdir(parents=True, exist_ok=True)
24 changes: 24 additions & 0 deletions bertrend/bertrend.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@
# This file is part of BERTrend.

# Configuration for BERTrend
[parameters]
# BERTopic Hyperparameters
default_umap_n_components = 5
default_umap_n_neighbors = 5
default_hdbscan_min_cluster_size = 5
default_hdbscan_min_samples = 5
default_top_n_words = 10
default_min_df = 1
default_granularity = 2
default_min_similarity = 0.7
default_zeroshot_min_similarity = 0.5
bertopic_serialization = "safetensors" # or pickle
default_mmr_diversity = 0.3
default_umap_min_dist = 0.0
outlier_reduction_strategy = "c-tf-idf" # or "embeddings"
# signal classification settings
signal_classif_lower_bound = 10
signal_classif_upper_bound = 75
# other constants
default_zeroshot_topics = "" # empty string or a default list of topics


[embedding_service]
# Indicates if BERTrend shall use local embedding service (run by BERTrend) or if the embedding service
Expand All @@ -25,3 +46,6 @@ port = 6464
api_key = "$AZURE_SE_WATTELSE_OPENAI_API_KEY_DEV"
endpoint = "$AZURE_SE_WATTELSE_OPENAI_ENDPOINT_DEV"
model = "$AZURE_SE_WATTELSE_OPENAI_DEFAULT_MODEL_NAME_DEV"
temperature = 0.1
max_tokens = 2048
system_prompt = "You are a helpful assistant, skilled in detailing topic evolution over time for the detection of emerging trends and signals."
3 changes: 2 additions & 1 deletion bertrend/demos/topic_analysis/Main_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)

from bertrend.metrics.topic_metrics import get_coherence_value, get_diversity_value
from bertrend.parameters import BERTOPIC_SERIALIZATION
from bertrend.train import train_BERTopic
from bertrend.utils.data_loading import (
split_df_by_paragraphs,
Expand Down Expand Up @@ -120,7 +121,7 @@ def save_model_interface():
try:
st.session_state["topic_model"].save(
model_save_path,
serialization="safetensors",
serialization=BERTOPIC_SERIALIZATION,
save_ctfidf=True,
save_embedding_model=True,
)
Expand Down
69 changes: 6 additions & 63 deletions bertrend/demos/weak_signals/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from bertopic import BERTopic
from loguru import logger

from bertrend import (
DATA_PATH,
MODELS_DIR,
ZEROSHOT_TOPICS_DATA_DIR,
SIGNAL_EVOLUTION_DATA_DIR,
CACHE_PATH,
Expand All @@ -34,9 +32,8 @@
STATE_SAVED_MESSAGE,
STATE_RESTORED_MESSAGE,
MODELS_SAVED_MESSAGE,
NO_MODELS_WARNING,
NO_GRANULARITY_WARNING,
NO_DATASET_WARNING,
NO_MODELS_WARNING,
)
from bertrend.trend_analysis.weak_signals import detect_weak_signals_zeroshot

Expand Down Expand Up @@ -121,63 +118,6 @@ def restore_state():
st.warning("No saved state found.")


def restore_models():
if not MODELS_DIR.exists():
st.warning(NO_MODELS_WARNING)
return

topic_models = {}
for period_dir in MODELS_DIR.iterdir():
if period_dir.is_dir():
topic_model = BERTopic.load(period_dir)

doc_info_df_file = period_dir / DOC_INFO_DF_FILE
topic_info_df_file = period_dir / TOPIC_INFO_DF_FILE
if doc_info_df_file.exists() and topic_info_df_file.exists():
topic_model.doc_info_df = pd.read_pickle(doc_info_df_file)
topic_model.topic_info_df = pd.read_pickle(topic_info_df_file)
else:
logger.warning(
f"doc_info_df or topic_info_df not found for period {period_dir.name}"
)

period = pd.Timestamp(period_dir.name.replace("_", ":"))
topic_models[period] = topic_model

SessionStateManager.set("topic_models", topic_models)

for file, key in [(DOC_GROUPS_FILE, "doc_groups"), (EMB_GROUPS_FILE, "emb_groups")]:
file_path = CACHE_PATH / file
if file_path.exists():
with open(file_path, "rb") as f:
SessionStateManager.set(key, pickle.load(f))
else:
logger.warning(f"{file} not found.")

granularity_file = CACHE_PATH / GRANULARITY_FILE
if granularity_file.exists():
with open(granularity_file, "rb") as f:
SessionStateManager.set("granularity_select", pickle.load(f))
else:
logger.warning(NO_GRANULARITY_WARNING)

# Restore the models_trained flag
models_trained_file = CACHE_PATH / MODELS_TRAINED_FILE
if models_trained_file.exists():
with open(models_trained_file, "rb") as f:
# FIXME! set bertrend first!
SessionStateManager.set("models_trained", pickle.load(f))
else:
logger.warning("Models trained flag not found.")

hyperparams_file = CACHE_PATH / HYPERPARAMS_FILE
if hyperparams_file.exists():
with open(hyperparams_file, "rb") as f:
SessionStateManager.set_multiple(**pickle.load(f))
else:
logger.warning("Hyperparameters file not found.")


def purge_cache():
if CACHE_PATH.exists():
shutil.rmtree(CACHE_PATH)
Expand Down Expand Up @@ -206,8 +146,11 @@ def main():

if st.button("Restore Previous Run", use_container_width=True):
restore_state()
restore_models()
st.success(MODELS_RESTORED_MESSAGE)
try:
BERTrend.restore_models()
st.success(MODELS_RESTORED_MESSAGE)
except Exception as e:
st.warning(NO_MODELS_WARNING)

if st.button("Purge Cache", use_container_width=True):
purge_cache()
Expand Down
1 change: 1 addition & 0 deletions bertrend/demos/weak_signals/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
)
NO_GRANULARITY_WARNING = "Granularity value not found."
NO_DATASET_WARNING = "Please select at least one dataset to proceed."
HTML_GENERATION_FAILED_WARNING = "HTML generation failed. Displaying markdown instead."
10 changes: 6 additions & 4 deletions bertrend/demos/weak_signals/visualizations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pandas import Timestamp
from plotly import graph_objects as go

from bertrend import OUTPUT_PATH
from bertrend.demos.weak_signals.messages import HTML_GENERATION_FAILED_WARNING
from bertrend.demos.weak_signals.session_state_manager import SessionStateManager
from bertrend.parameters import MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE
from bertrend.trend_analysis.visualizations import (
Expand Down Expand Up @@ -144,6 +146,7 @@ def display_popularity_evolution():
(with the smallest possible value being the earliest timestamp in the provided data).
The latest selectable date corresponds to the most recent topic merges, which is at most equal
to the latest timestamp in the data minus the provided granularity.""",
key="current_date",
)

granularity = SessionStateManager.get("granularity_select")
Expand Down Expand Up @@ -272,7 +275,7 @@ def display_topics_per_timestamp(topic_models: Dict[pd.Timestamp, BERTopic]) ->
st.dataframe(selected_model.topic_info_df, use_container_width=True)


def display_signal_analysis(topic_number):
def display_signal_analysis(topic_number, output_file="signal_llm.html"):
language = SessionStateManager.get("language")
bertrend = SessionStateManager.get("bertrend")
all_merge_histories_df = bertrend.all_merge_histories_df
Expand All @@ -288,16 +291,15 @@ def display_signal_analysis(topic_number):
)

# Check if the HTML file was created successfully
# FIXME: output path
output_file_path = Path(__file__).parent / "signal_llm.html"
output_file_path = OUTPUT_PATH / output_file
if output_file_path.exists():
# Read the HTML file
with open(output_file_path, "r", encoding="utf-8") as file:
html_content = file.read()
# Display the HTML content
st.html(html_content)
else:
st.warning("HTML generation failed. Displaying markdown instead.")
st.warning(HTML_GENERATION_FAILED_WARNING)
# Fallback to displaying markdown if HTML generation fails
col1, col2 = st.columns(spec=[0.5, 0.5], gap="medium")
with col1:
Expand Down
4 changes: 2 additions & 2 deletions bertrend/metrics/temporal_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def __init__(
self.tts_scores_df = None
self.ttc_scores_df = None

def _topics_over_time(self) -> pd.DataFrame:
def _topics_over_time(self):
"""
Calculates and returns a DataFrame containing topics over time with their respective words and frequencies.
Calculates and sets as a property a DataFrame containing topics over time with their respective words and frequencies.
Returns:
- pd.DataFrame: Topics, their top words, frequencies, and timestamps.
Expand Down
Loading

0 comments on commit 004fac9

Please sign in to comment.