diff --git a/bertrend/BERTrend.py b/bertrend/BERTrend.py index 665a674..79e0c10 100644 --- a/bertrend/BERTrend.py +++ b/bertrend/BERTrend.py @@ -5,6 +5,7 @@ import pickle import shutil from collections import defaultdict +from pathlib import Path from typing import Dict, Tuple, List, Any import numpy as np @@ -14,10 +15,7 @@ from sentence_transformers import SentenceTransformer from bertrend import MODELS_DIR, CACHE_PATH -from bertrend.demos.weak_signals.messages import ( - NO_GRANULARITY_WARNING, -) -from bertrend.demos.weak_signals.session_state_manager import SessionStateManager + from bertrend.topic_model import TopicModel from bertrend.parameters import ( DEFAULT_MIN_SIMILARITY, @@ -29,6 +27,7 @@ EMB_GROUPS_FILE, GRANULARITY_FILE, HYPERPARAMS_FILE, + BERTOPIC_SERIALIZATION, ) from bertrend.trend_analysis.topic_modeling import preprocess_model, merge_models from bertrend.trend_analysis.weak_signals import ( @@ -62,6 +61,7 @@ def __init__( ) self.zeroshot_topic_list = zeroshot_topic_list self.zeroshot_min_similarity = zeroshot_min_similarity + self.granularity = DEFAULT_GRANULARITY # State variables of BERTrend self._is_fitted = False @@ -429,17 +429,16 @@ def calculate_signal_popularity( self.topic_last_popularity = topic_last_popularity self.topic_last_update = topic_last_update - def save_models(self): - if MODELS_DIR.exists(): - shutil.rmtree(MODELS_DIR) - MODELS_DIR.mkdir(parents=True, exist_ok=True) + def save_models(self, models_path: Path = MODELS_DIR): + if models_path.exists(): + shutil.rmtree(models_path) + models_path.mkdir(parents=True, exist_ok=True) - # TODO - """ + # Save topic models using the selected serialization type for period, topic_model in self.topic_models.items(): - model_dir = MODELS_DIR / period.strftime("%Y-%m-%d") + model_dir = models_path / period.strftime("%Y-%m-%d") model_dir.mkdir(exist_ok=True) - embedding_model = SessionStateManager.get("embedding_model") + embedding_model = topic_model.embedding_model topic_model.save( model_dir, serialization=BERTOPIC_SERIALIZATION, @@ -449,44 +448,54 @@ def save_models(self): topic_model.doc_info_df.to_pickle(model_dir / DOC_INFO_DF_FILE) topic_model.topic_info_df.to_pickle(model_dir / TOPIC_INFO_DF_FILE) - """ + # Save topic model parameters + with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f: + pickle.dump(self.topic_model_parameters, f) + # Save granularity file + with open(CACHE_PATH / GRANULARITY_FILE, "wb") as f: + pickle.dump(self.granularity, f) + # Save doc_groups file with open(CACHE_PATH / DOC_GROUPS_FILE, "wb") as f: pickle.dump(self.doc_groups, f) + # Save emb_groups file with open(CACHE_PATH / EMB_GROUPS_FILE, "wb") as f: pickle.dump(self.emb_groups, f) - - # FIXME: granularity currently not set at this stage - # with open(CACHE_PATH / GRANULARITY_FILE, "wb") as f: - # pickle.dump(self.granularity) - # Save the models_trained flag with open(CACHE_PATH / MODELS_TRAINED_FILE, "wb") as f: pickle.dump(self._is_fitted, f) - # TODO! - """ - hyperparams = SessionStateManager.get_multiple( - "umap_n_components", - "umap_n_neighbors", - "hdbscan_min_cluster_size", - "hdbscan_min_samples", - "hdbscan_cluster_selection_method", - "top_n_words", - "vectorizer_ngram_range", - "min_df", - ) - with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f: - pickle.dump(hyperparams, f) - """ + logger.info(f"Models saved to: {models_path}") @classmethod - def restore_models(cls): - if not MODELS_DIR.exists(): - raise FileNotFoundError(f"MODELS_DIR={MODELS_DIR} does not exist") - + def restore_models(cls, models_path: Path = MODELS_DIR) -> "BERTrend": + if not models_path.exists(): + raise FileNotFoundError(f"models_path={models_path} does not exist") + + logger.info(f"Loading models from: {models_path}") + + # Create BERTrend object + bertrend = cls() + + # load topic model parameters + with open(CACHE_PATH / HYPERPARAMS_FILE, "rb") as f: + bertrend.topic_model_parameters = pickle.load(f) + # load granularity file + with open(CACHE_PATH / GRANULARITY_FILE, "rb") as f: + bertrend.granularity = pickle.load(f) + # load doc_groups file + with open(CACHE_PATH / DOC_GROUPS_FILE, "rb") as f: + bertrend.doc_groups = pickle.load(f) + # load emb_groups file + with open(CACHE_PATH / EMB_GROUPS_FILE, "rb") as f: + bertrend.emb_groups = pickle.load(f) + # load the models_trained flag + with open(CACHE_PATH / MODELS_TRAINED_FILE, "rb") as f: + bertrend._is_fitted = pickle.load(f) + + # Restore topic models using the selected serialization type topic_models = {} - for period_dir in MODELS_DIR.iterdir(): + for period_dir in models_path.iterdir(): if period_dir.is_dir(): topic_model = BERTopic.load(period_dir) @@ -502,72 +511,9 @@ def restore_models(cls): period = pd.Timestamp(period_dir.name.replace("_", ":")) topic_models[period] = topic_model + bertrend.topic_models = topic_models - SessionStateManager.set("topic_models", topic_models) + return bertrend - for file, key in [ - (DOC_GROUPS_FILE, "doc_groups"), - (EMB_GROUPS_FILE, "emb_groups"), - ]: - file_path = CACHE_PATH / file - if file_path.exists(): - with open(file_path, "rb") as f: - SessionStateManager.set(key, pickle.load(f)) - else: - logger.warning(f"{file} not found.") - - granularity_file = CACHE_PATH / GRANULARITY_FILE - if granularity_file.exists(): - with open(granularity_file, "rb") as f: - SessionStateManager.set("granularity_select", pickle.load(f)) - else: - logger.warning(NO_GRANULARITY_WARNING) - - # Restore the models_trained flag - models_trained_file = CACHE_PATH / MODELS_TRAINED_FILE - if models_trained_file.exists(): - with open(models_trained_file, "rb") as f: - # FIXME! set bertrend first! - SessionStateManager.set("models_trained", pickle.load(f)) - else: - logger.warning("Models trained flag not found.") - - hyperparams_file = CACHE_PATH / HYPERPARAMS_FILE - if hyperparams_file.exists(): - with open(hyperparams_file, "rb") as f: - SessionStateManager.set_multiple(**pickle.load(f)) - else: - logger.warning("Hyperparameters file not found.") - - ##################################################################################################### - # FIXME: WIP - # def merge_models2(self): - # if not self._is_fitted: - # raise RuntimeError("You must fit the BERTrend model before merging models.") - # - # merged_data = self._initialize_merge_data() - # - # logger.info("Merging models...") - # for timestamp, model in self.topic_models.items(): - # if not merged_data: - # merged_data = self._process_first_model(model) - # else: - # merged_data = self._merge_with_existing_data( - # merged_data, model, timestamp - # ) - # - # self.merged_topics = merged_data - # - # def _merge_with_existing_data( - # self, merged_data: Dict, model: BERTopic, timestamp: pd.Timestamp - # ) -> Dict: - # # Extract topics and embeddings - # - # # Compute similarity between current model's topics and the merged ones - # - # # Update merged_data with this model's data based on computed similarities - # # Implement business logic to handle merging decisions - # # This can involve thresholding, updating topic IDs, and merging document and metadata entries - # - # # return merged_data # Return the updated merged data - # pass + +# TODO: methods for prospective analysis (handle topic generation step by step) diff --git a/bertrend/demos/demos_utils/__init__.py b/bertrend/demos/demos_utils/__init__.py new file mode 100644 index 0000000..ae6e745 --- /dev/null +++ b/bertrend/demos/demos_utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. diff --git a/bertrend/demos/demos_utils/data_loading_component.py b/bertrend/demos/demos_utils/data_loading_component.py new file mode 100644 index 0000000..00b3280 --- /dev/null +++ b/bertrend/demos/demos_utils/data_loading_component.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import pandas as pd +import streamlit as st + +from bertrend import DATA_PATH +from bertrend.demos.demos_utils.session_state_manager import SessionStateManager +from bertrend.parameters import MIN_CHARS_DEFAULT, SAMPLE_SIZE_DEFAULT +from bertrend.utils.data_loading import ( + find_compatible_files, + load_and_preprocess_data, + TEXT_COLUMN, +) + +NO_DATASET_WARNING = "Please select at least one dataset to proceed." + + +def display_data_loading_component(): + # Find files in the current directory and subdirectories + compatible_extensions = ["csv", "parquet", "json", "jsonl"] + selected_files = st.multiselect( + "Select one or more datasets", + find_compatible_files(DATA_PATH, compatible_extensions), + default=SessionStateManager.get("selected_files", []), + key="selected_files", + ) + + if not selected_files: + st.warning(NO_DATASET_WARNING) + return + + # Display number input and checkbox for preprocessing options + col1, col2 = st.columns(2) + with col1: + min_chars = st.number_input( + "Minimum Characters", + value=MIN_CHARS_DEFAULT, + min_value=0, + max_value=1000, + key="min_chars", + ) + with col2: + split_by_paragraph = st.checkbox( + "Split text by paragraphs", value=False, key="split_by_paragraph" + ) + + # Load and preprocess each selected file, then concatenate them + dfs = [] + for selected_file, ext in selected_files: + file_path = DATA_PATH / selected_file + df = load_and_preprocess_data( + (file_path, ext), + st.session_state["language"], + min_chars, + split_by_paragraph, + ) + dfs.append(df) + + if not dfs: + st.warning( + "No data available after preprocessing. Please check the selected files and preprocessing options." + ) + else: + df = pd.concat(dfs, ignore_index=True) + + # Deduplicate using all columns + df = df.drop_duplicates() + + # Select timeframe + min_date, max_date = df["timestamp"].dt.date.agg(["min", "max"]) + start_date, end_date = st.slider( + "Select Timeframe", + min_value=min_date, + max_value=max_date, + value=(min_date, max_date), + key="timeframe_slider", + ) + + # Filter and sample the DataFrame + df_filtered = df[ + (df["timestamp"].dt.date >= start_date) + & (df["timestamp"].dt.date <= end_date) + ] + df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True) + + sample_size = st.number_input( + "Sample Size", + value=SAMPLE_SIZE_DEFAULT or len(df_filtered), + min_value=1, + max_value=len(df_filtered), + key="sample_size", + ) + if sample_size < len(df_filtered): + df_filtered = df_filtered.sample(n=sample_size, random_state=42) + + df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True) + + SessionStateManager.set("timefiltered_df", df_filtered) + st.write( + f"Number of documents in selected timeframe: {len(SessionStateManager.get_dataframe('timefiltered_df'))}" + ) + st.dataframe( + SessionStateManager.get_dataframe("timefiltered_df")[ + [TEXT_COLUMN, "timestamp"] + ], + use_container_width=True, + ) diff --git a/bertrend/demos/demos_utils/parameters_component.py b/bertrend/demos/demos_utils/parameters_component.py new file mode 100644 index 0000000..33f4411 --- /dev/null +++ b/bertrend/demos/demos_utils/parameters_component.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import streamlit as st +from bertrend.parameters import ( + DEFAULT_UMAP_N_COMPONENTS, + DEFAULT_UMAP_N_NEIGHBORS, + DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE, + DEFAULT_HDBSCAN_MIN_SAMPLES, + DEFAULT_TOP_N_WORDS, + DEFAULT_MIN_DF, + DEFAULT_MIN_SIMILARITY, + VECTORIZER_NGRAM_RANGES, + DEFAULT_ZEROSHOT_MIN_SIMILARITY, + HDBSCAN_CLUSTER_SELECTION_METHODS, + EMBEDDING_DTYPES, + LANGUAGES, + ENGLISH_EMBEDDING_MODELS, + FRENCH_EMBEDDING_MODELS, +) + + +def display_bertopic_hyperparameters(): + with st.expander("Embedding Model Settings", expanded=False): + language = st.selectbox("Select Language", LANGUAGES, key="language") + embedding_dtype = st.selectbox( + "Embedding Dtype", EMBEDDING_DTYPES, key="embedding_dtype" + ) + + embedding_models = ( + ENGLISH_EMBEDDING_MODELS + if language == "English" + else FRENCH_EMBEDDING_MODELS + ) + embedding_model_name = st.selectbox( + "Embedding Model", embedding_models, key="embedding_model_name" + ) + + for expander, params in [ + ( + "UMAP Hyperparameters", + [ + ( + "umap_n_components", + "UMAP n_components", + DEFAULT_UMAP_N_COMPONENTS, + 2, + 100, + ), + ( + "umap_n_neighbors", + "UMAP n_neighbors", + DEFAULT_UMAP_N_NEIGHBORS, + 2, + 100, + ), + ], + ), + ( + "HDBSCAN Hyperparameters", + [ + ( + "hdbscan_min_cluster_size", + "HDBSCAN min_cluster_size", + DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE, + 2, + 100, + ), + ( + "hdbscan_min_samples", + "HDBSCAN min_sample", + DEFAULT_HDBSCAN_MIN_SAMPLES, + 1, + 100, + ), + ], + ), + ( + "Vectorizer Hyperparameters", + [ + ("top_n_words", "Top N Words", DEFAULT_TOP_N_WORDS, 1, 50), + ("min_df", "min_df", DEFAULT_MIN_DF, 1, 50), + ], + ), + ]: + with st.expander(expander, expanded=False): + for key, label, default, min_val, max_val in params: + st.number_input( + label, + value=default, + min_value=min_val, + max_value=max_val, + key=key, + ) + + if expander == "HDBSCAN Hyperparameters": + st.selectbox( + "Cluster Selection Method", + HDBSCAN_CLUSTER_SELECTION_METHODS, + key="hdbscan_cluster_selection_method", + ) + elif expander == "Vectorizer Hyperparameters": + st.selectbox( + "N-Gram range", + VECTORIZER_NGRAM_RANGES, + key="vectorizer_ngram_range", + ) + + with st.expander("Merging Hyperparameters", expanded=False): + st.slider( + "Minimum Similarity for Merging", + 0.0, + 1.0, + DEFAULT_MIN_SIMILARITY, + 0.01, + key="min_similarity", + ) + + with st.expander("Zero-shot Parameters", expanded=False): + st.slider( + "Zeroshot Minimum Similarity", + 0.0, + 1.0, + DEFAULT_ZEROSHOT_MIN_SIMILARITY, + 0.01, + key="zeroshot_min_similarity", + ) diff --git a/bertrend/demos/weak_signals/session_state_manager.py b/bertrend/demos/demos_utils/session_state_manager.py similarity index 100% rename from bertrend/demos/weak_signals/session_state_manager.py rename to bertrend/demos/demos_utils/session_state_manager.py diff --git a/bertrend/demos/topic_analysis/state_utils.py b/bertrend/demos/demos_utils/state_utils.py similarity index 100% rename from bertrend/demos/topic_analysis/state_utils.py rename to bertrend/demos/demos_utils/state_utils.py diff --git a/bertrend/demos/topic_analysis/Main_page.py b/bertrend/demos/topic_analysis/Main_page.py index bcbc926..18aad68 100644 --- a/bertrend/demos/topic_analysis/Main_page.py +++ b/bertrend/demos/topic_analysis/Main_page.py @@ -6,7 +6,6 @@ import ast import datetime import re -from pathlib import Path import pandas as pd import streamlit as st @@ -14,7 +13,7 @@ from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer -from bertrend import DATA_PATH +from bertrend import DATA_PATH, OUTPUT_PATH from bertrend.demos.topic_analysis.app_utils import ( embedding_model_options, @@ -27,7 +26,7 @@ load_data_wrapper, ) from bertrend.demos.topic_analysis.data_utils import data_overview, choose_data -from bertrend.demos.topic_analysis.state_utils import ( +from bertrend.demos.demos_utils.state_utils import ( register_widget, save_widget_state, restore_widget_state, @@ -112,9 +111,7 @@ def save_model_interface(): dynamic_model_name = generate_model_name( base_model_name if base_model_name else "topic_model" ) - model_save_path = ( - Path(__file__).parent / "saved_models" / dynamic_model_name - ) + model_save_path = OUTPUT_PATH / "saved_models" / dynamic_model_name logger.debug( f"Saving the model in the following directory: {model_save_path}" ) diff --git a/bertrend/demos/topic_analysis/app_utils.py b/bertrend/demos/topic_analysis/app_utils.py index 4c8a53a..33f803f 100644 --- a/bertrend/demos/topic_analysis/app_utils.py +++ b/bertrend/demos/topic_analysis/app_utils.py @@ -10,7 +10,7 @@ import plotly.express as px import streamlit as st -from bertrend.demos.topic_analysis.state_utils import register_widget +from bertrend.demos.demos_utils.state_utils import register_widget from bertrend.demos.weak_signals.visualizations_utils import PLOTLY_BUTTON_SAVE_CONFIG from bertrend.utils.data_loading import ( load_data, diff --git a/bertrend/demos/topic_analysis/data_utils.py b/bertrend/demos/topic_analysis/data_utils.py index a4a5cda..15ab772 100644 --- a/bertrend/demos/topic_analysis/data_utils.py +++ b/bertrend/demos/topic_analysis/data_utils.py @@ -11,7 +11,7 @@ from pathlib import Path from bertrend.demos.topic_analysis.app_utils import plot_docs_reparition_over_time -from bertrend.demos.topic_analysis.state_utils import save_widget_state +from bertrend.demos.demos_utils.state_utils import save_widget_state from bertrend.utils.data_loading import TEXT_COLUMN, TIMESTAMP_COLUMN diff --git a/bertrend/demos/topic_analysis/pages/1_Explore_Topics.py b/bertrend/demos/topic_analysis/pages/1_Explore_Topics.py index b10e66d..8436dc9 100644 --- a/bertrend/demos/topic_analysis/pages/1_Explore_Topics.py +++ b/bertrend/demos/topic_analysis/pages/1_Explore_Topics.py @@ -19,7 +19,7 @@ from loguru import logger from bertrend import LLM_CONFIG -from bertrend.demos.topic_analysis.state_utils import restore_widget_state +from bertrend.demos.demos_utils.state_utils import restore_widget_state from bertrend.llm_utils.openai_client import OpenAI_Client from bertrend.demos.weak_signals.visualizations_utils import PLOTLY_BUTTON_SAVE_CONFIG from bertrend.utils.data_loading import TIMESTAMP_COLUMN, TEXT_COLUMN diff --git a/bertrend/demos/topic_analysis/pages/4_Generate_Newsletters.py b/bertrend/demos/topic_analysis/pages/4_Generate_Newsletters.py index f3b2931..2dca423 100644 --- a/bertrend/demos/topic_analysis/pages/4_Generate_Newsletters.py +++ b/bertrend/demos/topic_analysis/pages/4_Generate_Newsletters.py @@ -7,7 +7,7 @@ import streamlit as st from pathlib import Path -from bertrend.demos.topic_analysis.state_utils import ( +from bertrend.demos.demos_utils.state_utils import ( restore_widget_state, register_widget, save_widget_state, diff --git a/bertrend/demos/topic_analysis/pages/6_Visualizations.py b/bertrend/demos/topic_analysis/pages/6_Visualizations.py index 8d8d8aa..941b6f6 100644 --- a/bertrend/demos/topic_analysis/pages/6_Visualizations.py +++ b/bertrend/demos/topic_analysis/pages/6_Visualizations.py @@ -16,10 +16,11 @@ from loguru import logger from umap import UMAP +from bertrend import OUTPUT_PATH from bertrend.demos.topic_analysis.app_utils import ( plot_2d_topics, ) -from bertrend.demos.topic_analysis.state_utils import restore_widget_state +from bertrend.demos.demos_utils.state_utils import restore_widget_state from bertrend.demos.weak_signals.visualizations_utils import PLOTLY_BUTTON_SAVE_CONFIG from bertrend.utils.data_loading import TEXT_COLUMN @@ -207,7 +208,7 @@ def create_datamap(include_outliers): logo_width=100, ) - save_path = Path(__file__).parent.parent / "datamapplot.html" + save_path = OUTPUT_PATH / "datamapplot.html" with open(save_path, "wb") as f: f.write(plot._html_str.encode(encoding="UTF-8", errors="replace")) diff --git a/bertrend/demos/topic_analysis/pages/7_Temporal_Visualizations.py b/bertrend/demos/topic_analysis/pages/7_Temporal_Visualizations.py index 1fc03bf..ed480fd 100644 --- a/bertrend/demos/topic_analysis/pages/7_Temporal_Visualizations.py +++ b/bertrend/demos/topic_analysis/pages/7_Temporal_Visualizations.py @@ -18,7 +18,7 @@ plot_topics_over_time, compute_topics_over_time, ) -from bertrend.demos.topic_analysis.state_utils import ( +from bertrend.demos.demos_utils.state_utils import ( register_widget, save_widget_state, restore_widget_state, diff --git a/bertrend/demos/topic_analysis/pages/hidden/2_Topics_Emergence_Map.py b/bertrend/demos/topic_analysis/pages/hidden/2_Topics_Emergence_Map.py index d846165..357fdca 100644 --- a/bertrend/demos/topic_analysis/pages/hidden/2_Topics_Emergence_Map.py +++ b/bertrend/demos/topic_analysis/pages/hidden/2_Topics_Emergence_Map.py @@ -7,7 +7,7 @@ from statistics import StatisticsError from bertrend.demos.topic_analysis.app_utils import compute_topics_over_time -from bertrend.demos.topic_analysis.state_utils import ( +from bertrend.demos.demos_utils.state_utils import ( restore_widget_state, register_widget, save_widget_state, diff --git a/bertrend/demos/topic_analysis/pages/hidden/3_Simulation_of_new_data.py b/bertrend/demos/topic_analysis/pages/hidden/3_Simulation_of_new_data.py index c78d2ca..eb70dc2 100644 --- a/bertrend/demos/topic_analysis/pages/hidden/3_Simulation_of_new_data.py +++ b/bertrend/demos/topic_analysis/pages/hidden/3_Simulation_of_new_data.py @@ -14,7 +14,7 @@ compute_topics_over_time, plot_topics_over_time, ) -from bertrend.demos.topic_analysis.state_utils import ( +from bertrend.demos.demos_utils.state_utils import ( restore_widget_state, register_widget, save_widget_state, diff --git a/bertrend/demos/topic_analysis/pages/hidden/5_Merge_models.py b/bertrend/demos/topic_analysis/pages/hidden/5_Merge_models.py index 5f23150..8bc152e 100644 --- a/bertrend/demos/topic_analysis/pages/hidden/5_Merge_models.py +++ b/bertrend/demos/topic_analysis/pages/hidden/5_Merge_models.py @@ -8,7 +8,7 @@ from bertopic import BERTopic from typing import List, Optional, Union -from bertrend.demos.topic_analysis.state_utils import restore_widget_state +from bertrend.demos.demos_utils.state_utils import restore_widget_state def list_saved_models(saved_models_dir: Union[str, Path]) -> List[Path]: diff --git a/bertrend/demos/topic_analysis/pages/hidden/7_Temporal_Visualizations.py b/bertrend/demos/topic_analysis/pages/hidden/7_Temporal_Visualizations.py index c92eeb6..f94a202 100644 --- a/bertrend/demos/topic_analysis/pages/hidden/7_Temporal_Visualizations.py +++ b/bertrend/demos/topic_analysis/pages/hidden/7_Temporal_Visualizations.py @@ -16,7 +16,7 @@ import pandas as pd -from bertrend.demos.topic_analysis.state_utils import ( +from bertrend.demos.demos_utils.state_utils import ( restore_widget_state, register_widget, save_widget_state, diff --git a/bertrend/demos/weak_signals/app.py b/bertrend/demos/weak_signals/app.py index ea5bb21..92b9268 100644 --- a/bertrend/demos/weak_signals/app.py +++ b/bertrend/demos/weak_signals/app.py @@ -13,12 +13,16 @@ from loguru import logger from bertrend import ( - DATA_PATH, ZEROSHOT_TOPICS_DATA_DIR, - SIGNAL_EVOLUTION_DATA_DIR, CACHE_PATH, ) from bertrend.BERTrend import BERTrend +from bertrend.demos.demos_utils.data_loading_component import ( + display_data_loading_component, +) +from bertrend.demos.demos_utils.parameters_component import ( + display_bertopic_hyperparameters, +) from bertrend.services.embedding_service import EmbeddingService from bertrend.topic_model import TopicModel from bertrend.demos.weak_signals.messages import ( @@ -32,19 +36,16 @@ STATE_SAVED_MESSAGE, STATE_RESTORED_MESSAGE, MODELS_SAVED_MESSAGE, - NO_DATASET_WARNING, NO_MODELS_WARNING, ) from bertrend.trend_analysis.weak_signals import detect_weak_signals_zeroshot from bertrend.utils.data_loading import ( - load_and_preprocess_data, group_by_days, - find_compatible_files, TEXT_COLUMN, ) from bertrend.parameters import * -from session_state_manager import SessionStateManager +from bertrend.demos.demos_utils.session_state_manager import SessionStateManager from bertrend.trend_analysis.visualizations import ( plot_size_outliers, plot_num_topics, @@ -57,6 +58,7 @@ display_popularity_evolution, save_signal_evolution, display_signal_analysis, + retrieve_topic_counts, ) # UI Settings @@ -147,7 +149,7 @@ def main(): if st.button("Restore Previous Run", use_container_width=True): restore_state() try: - BERTrend.restore_models() + SessionStateManager.set("bertrend", BERTrend.restore_models()) st.success(MODELS_RESTORED_MESSAGE) except Exception as e: st.warning(NO_MODELS_WARNING) @@ -160,110 +162,7 @@ def main(): # BERTopic Hyperparameters st.subheader("BERTopic Hyperparameters") - with st.expander("Embedding Model Settings", expanded=False): - language = st.selectbox("Select Language", LANGUAGES, key="language") - embedding_dtype = st.selectbox( - "Embedding Dtype", EMBEDDING_DTYPES, key="embedding_dtype" - ) - - embedding_models = ( - ENGLISH_EMBEDDING_MODELS - if language == "English" - else FRENCH_EMBEDDING_MODELS - ) - embedding_model_name = st.selectbox( - "Embedding Model", embedding_models, key="embedding_model_name" - ) - - for expander, params in [ - ( - "UMAP Hyperparameters", - [ - ( - "umap_n_components", - "UMAP n_components", - DEFAULT_UMAP_N_COMPONENTS, - 2, - 100, - ), - ( - "umap_n_neighbors", - "UMAP n_neighbors", - DEFAULT_UMAP_N_NEIGHBORS, - 2, - 100, - ), - ], - ), - ( - "HDBSCAN Hyperparameters", - [ - ( - "hdbscan_min_cluster_size", - "HDBSCAN min_cluster_size", - DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE, - 2, - 100, - ), - ( - "hdbscan_min_samples", - "HDBSCAN min_sample", - DEFAULT_HDBSCAN_MIN_SAMPLES, - 1, - 100, - ), - ], - ), - ( - "Vectorizer Hyperparameters", - [ - ("top_n_words", "Top N Words", DEFAULT_TOP_N_WORDS, 1, 50), - ("min_df", "min_df", DEFAULT_MIN_DF, 1, 50), - ], - ), - ]: - with st.expander(expander, expanded=False): - for key, label, default, min_val, max_val in params: - st.number_input( - label, - value=default, - min_value=min_val, - max_value=max_val, - key=key, - ) - - if expander == "HDBSCAN Hyperparameters": - st.selectbox( - "Cluster Selection Method", - HDBSCAN_CLUSTER_SELECTION_METHODS, - key="hdbscan_cluster_selection_method", - ) - elif expander == "Vectorizer Hyperparameters": - st.selectbox( - "N-Gram range", - VECTORIZER_NGRAM_RANGES, - key="vectorizer_ngram_range", - ) - - with st.expander("Merging Hyperparameters", expanded=False): - st.slider( - "Minimum Similarity for Merging", - 0.0, - 1.0, - DEFAULT_MIN_SIMILARITY, - 0.01, - key="min_similarity", - ) - - with st.expander("Zero-shot Parameters", expanded=False): - st.slider( - "Zeroshot Minimum Similarity", - 0.0, - 1.0, - DEFAULT_ZEROSHOT_MIN_SIMILARITY, - 0.01, - key="zeroshot_min_similarity", - ) + display_bertopic_hyperparameters() # Main content tab1, tab2, tab3 = st.tabs(["Data Loading", "Model Training", "Results Analysis"]) @@ -271,102 +170,18 @@ def main(): with tab1: st.header("Data Loading and Preprocessing") - # Find files in the current directory and subdirectories - compatible_extensions = ["csv", "parquet", "json", "jsonl"] - selected_files = st.multiselect( - "Select one or more datasets", - find_compatible_files(DATA_PATH, compatible_extensions), - default=SessionStateManager.get("selected_files", []), - key="selected_files", - ) - - if not selected_files: - st.warning(NO_DATASET_WARNING) - return - - # Display number input and checkbox for preprocessing options - col1, col2 = st.columns(2) - with col1: - min_chars = st.number_input( - "Minimum Characters", - value=MIN_CHARS_DEFAULT, - min_value=0, - max_value=1000, - key="min_chars", - ) - with col2: - split_by_paragraph = st.checkbox( - "Split text by paragraphs", value=False, key="split_by_paragraph" - ) - - # Load and preprocess each selected file, then concatenate them - dfs = [] - for selected_file, ext in selected_files: - file_path = DATA_PATH / selected_file - df = load_and_preprocess_data( - (file_path, ext), language, min_chars, split_by_paragraph - ) - dfs.append(df) - - if not dfs: - st.warning( - "No data available after preprocessing. Please check the selected files and preprocessing options." - ) - else: - df = pd.concat(dfs, ignore_index=True) - - # Deduplicate using all columns - df = df.drop_duplicates() - - # Select timeframe - min_date, max_date = df["timestamp"].dt.date.agg(["min", "max"]) - start_date, end_date = st.slider( - "Select Timeframe", - min_value=min_date, - max_value=max_date, - value=(min_date, max_date), - key="timeframe_slider", - ) - - # Filter and sample the DataFrame - df_filtered = df[ - (df["timestamp"].dt.date >= start_date) - & (df["timestamp"].dt.date <= end_date) - ] - df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True) - - sample_size = st.number_input( - "Sample Size", - value=SAMPLE_SIZE_DEFAULT or len(df_filtered), - min_value=1, - max_value=len(df_filtered), - key="sample_size", - ) - if sample_size < len(df_filtered): - df_filtered = df_filtered.sample(n=sample_size, random_state=42) + display_data_loading_component() - df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True) - - SessionStateManager.set("timefiltered_df", df_filtered) - st.write( - f"Number of documents in selected timeframe: {len(SessionStateManager.get_dataframe('timefiltered_df'))}" - ) - st.dataframe( - SessionStateManager.get_dataframe("timefiltered_df")[ - [TEXT_COLUMN, "timestamp"] - ], - use_container_width=True, - ) + if "timefiltered_df" in st.session_state: # Embed documents if st.button("Embed Documents"): - embedding_service = EmbeddingService() - with st.spinner("Embedding documents..."): embedding_dtype = SessionStateManager.get("embedding_dtype") embedding_model_name = SessionStateManager.get( "embedding_model_name" ) + embedding_service = EmbeddingService() texts = SessionStateManager.get_dataframe("timefiltered_df")[ TEXT_COLUMN @@ -454,6 +269,7 @@ def main(): logger.debug(SessionStateManager.get("language")) + # Initialize topic model topic_model = TopicModel( umap_n_components=SessionStateManager.get("umap_n_components"), umap_n_neighbors=SessionStateManager.get("umap_n_neighbors"), @@ -474,6 +290,7 @@ def main(): language=SessionStateManager.get("language"), ) + # Created BERTrend object bertrend = BERTrend( topic_model=topic_model, zeroshot_topic_list=zeroshot_topic_list, @@ -481,20 +298,15 @@ def main(): "zeroshot_min_similarity" ), ) + # Train topic models on data bertrend.train_topic_models( grouped_data=grouped_data, embedding_model=SessionStateManager.get("embedding_model"), embeddings=SessionStateManager.get_embeddings(), ) - - # TODO: A supprimer / adapter - cf save/restore - SessionStateManager.set_multiple( - doc_groups=bertrend.doc_groups, - emb_groups=bertrend.emb_groups, - ) - st.success(MODEL_TRAINING_COMPLETE_MESSAGE) + # Save trained models bertrend.save_models() st.success(MODELS_SAVED_MESSAGE) @@ -509,12 +321,12 @@ def main(): else: if st.button("Merge Models"): with st.spinner("Merging models..."): - # TODO: encapsulate into a merging function - SessionStateManager.get("bertrend").merge_models( + bertrend = SessionStateManager.get("bertrend") + bertrend.merge_models( min_similarity=SessionStateManager.get("min_similarity"), ) - SessionStateManager.get("bertrend").calculate_signal_popularity( + bertrend.calculate_signal_popularity( granularity=SessionStateManager.get("granularity_select"), ) @@ -676,48 +488,7 @@ def main(): if st.button("Retrieve Topic Counts"): with st.spinner("Retrieving topic counts..."): # Number of topics per individual topic model - individual_model_topic_counts = [ - (timestamp, model.topic_info_df["Topic"].max() + 1) - for timestamp, model in topic_models.items() - ] - df_individual_models = pd.DataFrame( - individual_model_topic_counts, - columns=["timestamp", "num_topics"], - ) - - # Number of topics per cumulative merged model - cumulative_merged_topic_counts = SessionStateManager.get( - "merge_df_size_over_time", [] - ) - df_cumulative_merged = pd.DataFrame( - cumulative_merged_topic_counts, - columns=["timestamp", "num_topics"], - ) - - # Convert to JSON - json_individual_models = df_individual_models.to_json( - orient="records", date_format="iso", indent=4 - ) - json_cumulative_merged = df_cumulative_merged.to_json( - orient="records", date_format="iso", indent=4 - ) - - # Save individual model topic counts - json_file_path = ( - SIGNAL_EVOLUTION_DATA_DIR - / f"retrospective_{SessionStateManager.get('window_size')}_days" - ) - json_file_path.mkdir(parents=True, exist_ok=True) - - ( - json_file_path / INDIVIDUAL_MODEL_TOPIC_COUNTS_FILE - ).write_text(json_individual_models) - - # Save cumulative merged model topic counts - ( - json_file_path / CUMULATIVE_MERGED_TOPIC_COUNTS_FILE - ).write_text(json_cumulative_merged) - + retrieve_topic_counts(topic_models) st.success( f"Topic counts for individual and cumulative merged models saved to {json_file_path}" ) diff --git a/bertrend/demos/weak_signals/messages.py b/bertrend/demos/weak_signals/messages.py index 6cc9b25..2dfd1e5 100644 --- a/bertrend/demos/weak_signals/messages.py +++ b/bertrend/demos/weak_signals/messages.py @@ -24,5 +24,4 @@ "Topic {topic_number} not found in the merge histories within the specified window." ) NO_GRANULARITY_WARNING = "Granularity value not found." -NO_DATASET_WARNING = "Please select at least one dataset to proceed." HTML_GENERATION_FAILED_WARNING = "HTML generation failed. Displaying markdown instead." diff --git a/bertrend/demos/weak_signals/visualizations_utils.py b/bertrend/demos/weak_signals/visualizations_utils.py index 5385158..551e19e 100644 --- a/bertrend/demos/weak_signals/visualizations_utils.py +++ b/bertrend/demos/weak_signals/visualizations_utils.py @@ -2,8 +2,7 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from pathlib import Path -from typing import Dict, Tuple +from typing import Dict import pandas as pd import streamlit as st @@ -11,10 +10,15 @@ from pandas import Timestamp from plotly import graph_objects as go -from bertrend import OUTPUT_PATH +from bertrend import OUTPUT_PATH, SIGNAL_EVOLUTION_DATA_DIR from bertrend.demos.weak_signals.messages import HTML_GENERATION_FAILED_WARNING -from bertrend.demos.weak_signals.session_state_manager import SessionStateManager -from bertrend.parameters import MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE +from bertrend.demos.demos_utils.session_state_manager import SessionStateManager +from bertrend.parameters import ( + MAX_WINDOW_SIZE, + DEFAULT_WINDOW_SIZE, + INDIVIDUAL_MODEL_TOPIC_COUNTS_FILE, + CUMULATIVE_MERGED_TOPIC_COUNTS_FILE, +) from bertrend.trend_analysis.visualizations import ( create_sankey_diagram_plotly, plot_newly_emerged_topics, @@ -253,7 +257,8 @@ def display_topics_per_timestamp(topic_models: Dict[pd.Timestamp, BERTopic]) -> Plot the topics discussed per source for each timestamp. Args: - topic_models (Dict[pd.Timestamp, BERTopic]): A dictionary of BERTopic models, where the key is the timestamp and the value is the corresponding model. + topic_models (Dict[pd.Timestamp, BERTopic]): A dictionary of BERTopic models, where the key is the timestamp + and the value is the corresponding model. """ with st.expander("Explore topic models"): model_periods = sorted(topic_models.keys()) @@ -306,3 +311,47 @@ def display_signal_analysis(topic_number, output_file="signal_llm.html"): st.markdown(summary) with col2: st.markdown(analysis) + + +def retrieve_topic_counts(topic_models: Dict[pd.Timestamp, BERTopic]) -> None: + individual_model_topic_counts = [ + (timestamp, model.topic_info_df["Topic"].max() + 1) + for timestamp, model in topic_models.items() + ] + df_individual_models = pd.DataFrame( + individual_model_topic_counts, + columns=["timestamp", "num_topics"], + ) + + # Number of topics per cumulative merged model + cumulative_merged_topic_counts = SessionStateManager.get( + "merge_df_size_over_time", [] + ) + df_cumulative_merged = pd.DataFrame( + cumulative_merged_topic_counts, + columns=["timestamp", "num_topics"], + ) + + # Convert to JSON + json_individual_models = df_individual_models.to_json( + orient="records", date_format="iso", indent=4 + ) + json_cumulative_merged = df_cumulative_merged.to_json( + orient="records", date_format="iso", indent=4 + ) + + # Save individual model topic counts + json_file_path = ( + SIGNAL_EVOLUTION_DATA_DIR + / f"retrospective_{SessionStateManager.get('window_size')}_days" + ) + json_file_path.mkdir(parents=True, exist_ok=True) + + (json_file_path / INDIVIDUAL_MODEL_TOPIC_COUNTS_FILE).write_text( + json_individual_models + ) + + # Save cumulative merged model topic counts + (json_file_path / CUMULATIVE_MERGED_TOPIC_COUNTS_FILE).write_text( + json_cumulative_merged + ) diff --git a/bertrend/metrics/temporal_metrics_embedding.py b/bertrend/metrics/temporal_metrics_embedding.py index 73d1460..0651583 100644 --- a/bertrend/metrics/temporal_metrics_embedding.py +++ b/bertrend/metrics/temporal_metrics_embedding.py @@ -60,6 +60,8 @@ from typing import List, Union, Tuple import re +from bertrend import OUTPUT_PATH + class TempTopic: def __init__( @@ -124,7 +126,7 @@ def __init__( self.representation_embeddings_df = None self.stemmer = PorterStemmer() - self.debug_file = Path(__file__).parent.parent / "match_debugging.txt" + self.debug_file = OUTPUT_PATH / "match_debugging.txt" open(self.debug_file, "w").close() def fit( diff --git a/bertrend/parameters.py b/bertrend/parameters.py index cbf3e32..29db1ac 100644 --- a/bertrend/parameters.py +++ b/bertrend/parameters.py @@ -10,9 +10,12 @@ from bertrend import PARAMETERS_CONFIG +stopwords_en_file = Path(__file__).parent / "resources" / "stopwords-en.json" stopwords_fr_file = Path(__file__).parent / "resources" / "stopwords-fr.json" stopwords_rte_file = Path(__file__).parent / "resources" / "stopwords-rte.json" common_ngrams_file = Path(__file__).parent / "resources" / "common_ngrams.json" +with open(stopwords_en_file, "r", encoding="utf-8") as file: + ENGLISH_STOPWORDS = json.load(file) with open(stopwords_fr_file, "r", encoding="utf-8") as file: FRENCH_STOPWORDS = json.load(file) with open(stopwords_rte_file, "r", encoding="utf-8") as file: diff --git a/bertrend/resources/stopwords-en.json b/bertrend/resources/stopwords-en.json new file mode 100644 index 0000000..c44dc44 --- /dev/null +++ b/bertrend/resources/stopwords-en.json @@ -0,0 +1,3 @@ +[ + +] \ No newline at end of file diff --git a/bertrend/topic_model.py b/bertrend/topic_model.py index 6778ec4..a2aeed2 100644 --- a/bertrend/topic_model.py +++ b/bertrend/topic_model.py @@ -29,6 +29,7 @@ STOPWORDS, DEFAULT_MMR_DIVERSITY, OUTLIER_REDUCTION_STRATEGY, + ENGLISH_STOPWORDS, ) @@ -94,7 +95,9 @@ def _initialize_models(self): prediction_data=True, ) - self.stopword_set = STOPWORDS if self.language == "French" else "english" + self.stopword_set = ( + STOPWORDS if self.language == "French" else ENGLISH_STOPWORDS + ) self.vectorizer_model = CountVectorizer( stop_words=self.stopword_set, min_df=self.min_df,