From 12eed54c9a1fcf66797a9376155cddf027732a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Picault?= Date: Fri, 20 Dec 2024 16:38:12 +0100 Subject: [PATCH] Updated configurations --- bertrend/BERTrend.py | 75 +++++++- bertrend/__init__.py | 2 + bertrend/bertrend.toml | 24 +++ bertrend/demos/topic_analysis/Main_page.py | 3 +- bertrend/demos/weak_signals/app.py | 69 +------ bertrend/demos/weak_signals/messages.py | 1 + .../weak_signals/visualizations_utils.py | 10 +- bertrend/metrics/temporal_metrics.py | 4 +- .../metrics/temporal_metrics_embedding.py | 176 +----------------- bertrend/parameters.py | 47 +++-- bertrend/services/embedding_service.py | 11 +- bertrend/trend_analysis/prompts.py | 7 +- bertrend/trend_analysis/weak_signals.py | 21 +-- bertrend_apps/newsletters/__main__.py | 3 +- 14 files changed, 162 insertions(+), 291 deletions(-) diff --git a/bertrend/BERTrend.py b/bertrend/BERTrend.py index 8977432..665a674 100644 --- a/bertrend/BERTrend.py +++ b/bertrend/BERTrend.py @@ -14,6 +14,10 @@ from sentence_transformers import SentenceTransformer from bertrend import MODELS_DIR, CACHE_PATH +from bertrend.demos.weak_signals.messages import ( + NO_GRANULARITY_WARNING, +) +from bertrend.demos.weak_signals.session_state_manager import SessionStateManager from bertrend.topic_model import TopicModel from bertrend.parameters import ( DEFAULT_MIN_SIMILARITY, @@ -194,7 +198,6 @@ def train_topic_models( grouped_data (Dict[pd.Timestamp, pd.DataFrame]): Dictionary of grouped data by timestamp. embedding_model (SentenceTransformer): Sentence transformer model for embeddings. embeddings (np.ndarray): Precomputed document embeddings. - granularity (int): """ # TODO from topic_modelling = train_topic_models (modulo data transformation) # TODO rename to fit? @@ -267,7 +270,6 @@ def merge_models( # progress_bar = st.progress(0) merge_df_size_over_time = [] - # SessionStateManager.set("merge_df_size_over_time", []) for i, (current_timestamp, next_timestamp) in enumerate( zip(timestamps[:-1], timestamps[1:]) @@ -312,11 +314,7 @@ def merge_models( merged_df_without_outliers["Topic"].max() + 1, ) ) - # - # SessionStateManager.update( - # "merge_df_size_over_time", merge_df_size_over_time - # ) - # + # progress_bar.progress((i + 1) / len(timestamps)) all_merge_histories_df = pd.concat(all_merge_histories, ignore_index=True) @@ -340,7 +338,7 @@ def calculate_signal_popularity( Updates: - topic_sizes (Dict[int, Dict[str, Any]]): Dictionary storing topic sizes and related information over time. - topic_last_popularity (Dict[int, float]): Dictionary storing the last known popularity of each topic. - - topic_last_update (Dict[int, pd.Timestamp]]): Dictionary storing the last update timestamp of each topic. + - topic_last_update (Dict[int, pd.Timestamp]): Dictionary storing the last update timestamp of each topic. Args: all_merge_histories_df (pd.DataFrame): DataFrame containing all merge histories. @@ -444,7 +442,7 @@ def save_models(self): embedding_model = SessionStateManager.get("embedding_model") topic_model.save( model_dir, - serialization="safetensors", + serialization=BERTOPIC_SERIALIZATION, save_ctfidf=False, save_embedding_model=embedding_model, ) @@ -482,6 +480,65 @@ def save_models(self): pickle.dump(hyperparams, f) """ + @classmethod + def restore_models(cls): + if not MODELS_DIR.exists(): + raise FileNotFoundError(f"MODELS_DIR={MODELS_DIR} does not exist") + + topic_models = {} + for period_dir in MODELS_DIR.iterdir(): + if period_dir.is_dir(): + topic_model = BERTopic.load(period_dir) + + doc_info_df_file = period_dir / DOC_INFO_DF_FILE + topic_info_df_file = period_dir / TOPIC_INFO_DF_FILE + if doc_info_df_file.exists() and topic_info_df_file.exists(): + topic_model.doc_info_df = pd.read_pickle(doc_info_df_file) + topic_model.topic_info_df = pd.read_pickle(topic_info_df_file) + else: + logger.warning( + f"doc_info_df or topic_info_df not found for period {period_dir.name}" + ) + + period = pd.Timestamp(period_dir.name.replace("_", ":")) + topic_models[period] = topic_model + + SessionStateManager.set("topic_models", topic_models) + + for file, key in [ + (DOC_GROUPS_FILE, "doc_groups"), + (EMB_GROUPS_FILE, "emb_groups"), + ]: + file_path = CACHE_PATH / file + if file_path.exists(): + with open(file_path, "rb") as f: + SessionStateManager.set(key, pickle.load(f)) + else: + logger.warning(f"{file} not found.") + + granularity_file = CACHE_PATH / GRANULARITY_FILE + if granularity_file.exists(): + with open(granularity_file, "rb") as f: + SessionStateManager.set("granularity_select", pickle.load(f)) + else: + logger.warning(NO_GRANULARITY_WARNING) + + # Restore the models_trained flag + models_trained_file = CACHE_PATH / MODELS_TRAINED_FILE + if models_trained_file.exists(): + with open(models_trained_file, "rb") as f: + # FIXME! set bertrend first! + SessionStateManager.set("models_trained", pickle.load(f)) + else: + logger.warning("Models trained flag not found.") + + hyperparams_file = CACHE_PATH / HYPERPARAMS_FILE + if hyperparams_file.exists(): + with open(hyperparams_file, "rb") as f: + SessionStateManager.set_multiple(**pickle.load(f)) + else: + logger.warning("Hyperparameters file not found.") + ##################################################################################################### # FIXME: WIP # def merge_models2(self): diff --git a/bertrend/__init__.py b/bertrend/__init__.py index e2e5e3b..b20c9b8 100644 --- a/bertrend/__init__.py +++ b/bertrend/__init__.py @@ -11,6 +11,7 @@ # Read config BERTREND_CONFIG = load_toml_config(BERTREND_CONFIG_PATH) +PARAMETERS_CONFIG = BERTREND_CONFIG["parameters"] EMBEDDING_CONFIG = BERTREND_CONFIG["embedding_service"] LLM_CONFIG = BERTREND_CONFIG["llm_service"] @@ -48,4 +49,5 @@ # Create directories if they do not exist DATA_PATH.mkdir(parents=True, exist_ok=True) +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) CACHE_PATH.mkdir(parents=True, exist_ok=True) diff --git a/bertrend/bertrend.toml b/bertrend/bertrend.toml index 59cbb2d..056030d 100644 --- a/bertrend/bertrend.toml +++ b/bertrend/bertrend.toml @@ -4,6 +4,27 @@ # This file is part of BERTrend. # Configuration for BERTrend +[parameters] +# BERTopic Hyperparameters +default_umap_n_components = 5 +default_umap_n_neighbors = 5 +default_hdbscan_min_cluster_size = 5 +default_hdbscan_min_samples = 5 +default_top_n_words = 10 +default_min_df = 1 +default_granularity = 2 +default_min_similarity = 0.7 +default_zeroshot_min_similarity = 0.5 +bertopic_serialization = "safetensors" # or pickle +default_mmr_diversity = 0.3 +default_umap_min_dist = 0.0 +outlier_reduction_strategy = "c-tf-idf" # or "embeddings" +# signal classification settings +signal_classif_lower_bound = 10 +signal_classif_upper_bound = 75 +# other constants +default_zeroshot_topics = "" # empty string or a default list of topics + [embedding_service] # Indicates if BERTrend shall use local embedding service (run by BERTrend) or if the embedding service @@ -25,3 +46,6 @@ port = 6464 api_key = "$AZURE_SE_WATTELSE_OPENAI_API_KEY_DEV" endpoint = "$AZURE_SE_WATTELSE_OPENAI_ENDPOINT_DEV" model = "$AZURE_SE_WATTELSE_OPENAI_DEFAULT_MODEL_NAME_DEV" +temperature = 0.1 +max_tokens = 2048 +system_prompt = "You are a helpful assistant, skilled in detailing topic evolution over time for the detection of emerging trends and signals." diff --git a/bertrend/demos/topic_analysis/Main_page.py b/bertrend/demos/topic_analysis/Main_page.py index ed6769e..bcbc926 100644 --- a/bertrend/demos/topic_analysis/Main_page.py +++ b/bertrend/demos/topic_analysis/Main_page.py @@ -34,6 +34,7 @@ ) from bertrend.metrics.topic_metrics import get_coherence_value, get_diversity_value +from bertrend.parameters import BERTOPIC_SERIALIZATION from bertrend.train import train_BERTopic from bertrend.utils.data_loading import ( split_df_by_paragraphs, @@ -120,7 +121,7 @@ def save_model_interface(): try: st.session_state["topic_model"].save( model_save_path, - serialization="safetensors", + serialization=BERTOPIC_SERIALIZATION, save_ctfidf=True, save_embedding_model=True, ) diff --git a/bertrend/demos/weak_signals/app.py b/bertrend/demos/weak_signals/app.py index fa54a8a..ea5bb21 100644 --- a/bertrend/demos/weak_signals/app.py +++ b/bertrend/demos/weak_signals/app.py @@ -10,12 +10,10 @@ import numpy as np import pandas as pd import plotly.graph_objects as go -from bertopic import BERTopic from loguru import logger from bertrend import ( DATA_PATH, - MODELS_DIR, ZEROSHOT_TOPICS_DATA_DIR, SIGNAL_EVOLUTION_DATA_DIR, CACHE_PATH, @@ -34,9 +32,8 @@ STATE_SAVED_MESSAGE, STATE_RESTORED_MESSAGE, MODELS_SAVED_MESSAGE, - NO_MODELS_WARNING, - NO_GRANULARITY_WARNING, NO_DATASET_WARNING, + NO_MODELS_WARNING, ) from bertrend.trend_analysis.weak_signals import detect_weak_signals_zeroshot @@ -121,63 +118,6 @@ def restore_state(): st.warning("No saved state found.") -def restore_models(): - if not MODELS_DIR.exists(): - st.warning(NO_MODELS_WARNING) - return - - topic_models = {} - for period_dir in MODELS_DIR.iterdir(): - if period_dir.is_dir(): - topic_model = BERTopic.load(period_dir) - - doc_info_df_file = period_dir / DOC_INFO_DF_FILE - topic_info_df_file = period_dir / TOPIC_INFO_DF_FILE - if doc_info_df_file.exists() and topic_info_df_file.exists(): - topic_model.doc_info_df = pd.read_pickle(doc_info_df_file) - topic_model.topic_info_df = pd.read_pickle(topic_info_df_file) - else: - logger.warning( - f"doc_info_df or topic_info_df not found for period {period_dir.name}" - ) - - period = pd.Timestamp(period_dir.name.replace("_", ":")) - topic_models[period] = topic_model - - SessionStateManager.set("topic_models", topic_models) - - for file, key in [(DOC_GROUPS_FILE, "doc_groups"), (EMB_GROUPS_FILE, "emb_groups")]: - file_path = CACHE_PATH / file - if file_path.exists(): - with open(file_path, "rb") as f: - SessionStateManager.set(key, pickle.load(f)) - else: - logger.warning(f"{file} not found.") - - granularity_file = CACHE_PATH / GRANULARITY_FILE - if granularity_file.exists(): - with open(granularity_file, "rb") as f: - SessionStateManager.set("granularity_select", pickle.load(f)) - else: - logger.warning(NO_GRANULARITY_WARNING) - - # Restore the models_trained flag - models_trained_file = CACHE_PATH / MODELS_TRAINED_FILE - if models_trained_file.exists(): - with open(models_trained_file, "rb") as f: - # FIXME! set bertrend first! - SessionStateManager.set("models_trained", pickle.load(f)) - else: - logger.warning("Models trained flag not found.") - - hyperparams_file = CACHE_PATH / HYPERPARAMS_FILE - if hyperparams_file.exists(): - with open(hyperparams_file, "rb") as f: - SessionStateManager.set_multiple(**pickle.load(f)) - else: - logger.warning("Hyperparameters file not found.") - - def purge_cache(): if CACHE_PATH.exists(): shutil.rmtree(CACHE_PATH) @@ -206,8 +146,11 @@ def main(): if st.button("Restore Previous Run", use_container_width=True): restore_state() - restore_models() - st.success(MODELS_RESTORED_MESSAGE) + try: + BERTrend.restore_models() + st.success(MODELS_RESTORED_MESSAGE) + except Exception as e: + st.warning(NO_MODELS_WARNING) if st.button("Purge Cache", use_container_width=True): purge_cache() diff --git a/bertrend/demos/weak_signals/messages.py b/bertrend/demos/weak_signals/messages.py index c59ea93..6cc9b25 100644 --- a/bertrend/demos/weak_signals/messages.py +++ b/bertrend/demos/weak_signals/messages.py @@ -25,3 +25,4 @@ ) NO_GRANULARITY_WARNING = "Granularity value not found." NO_DATASET_WARNING = "Please select at least one dataset to proceed." +HTML_GENERATION_FAILED_WARNING = "HTML generation failed. Displaying markdown instead." diff --git a/bertrend/demos/weak_signals/visualizations_utils.py b/bertrend/demos/weak_signals/visualizations_utils.py index a3d36ad..5385158 100644 --- a/bertrend/demos/weak_signals/visualizations_utils.py +++ b/bertrend/demos/weak_signals/visualizations_utils.py @@ -11,6 +11,8 @@ from pandas import Timestamp from plotly import graph_objects as go +from bertrend import OUTPUT_PATH +from bertrend.demos.weak_signals.messages import HTML_GENERATION_FAILED_WARNING from bertrend.demos.weak_signals.session_state_manager import SessionStateManager from bertrend.parameters import MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE from bertrend.trend_analysis.visualizations import ( @@ -144,6 +146,7 @@ def display_popularity_evolution(): (with the smallest possible value being the earliest timestamp in the provided data). The latest selectable date corresponds to the most recent topic merges, which is at most equal to the latest timestamp in the data minus the provided granularity.""", + key="current_date", ) granularity = SessionStateManager.get("granularity_select") @@ -272,7 +275,7 @@ def display_topics_per_timestamp(topic_models: Dict[pd.Timestamp, BERTopic]) -> st.dataframe(selected_model.topic_info_df, use_container_width=True) -def display_signal_analysis(topic_number): +def display_signal_analysis(topic_number, output_file="signal_llm.html"): language = SessionStateManager.get("language") bertrend = SessionStateManager.get("bertrend") all_merge_histories_df = bertrend.all_merge_histories_df @@ -288,8 +291,7 @@ def display_signal_analysis(topic_number): ) # Check if the HTML file was created successfully - # FIXME: output path - output_file_path = Path(__file__).parent / "signal_llm.html" + output_file_path = OUTPUT_PATH / output_file if output_file_path.exists(): # Read the HTML file with open(output_file_path, "r", encoding="utf-8") as file: @@ -297,7 +299,7 @@ def display_signal_analysis(topic_number): # Display the HTML content st.html(html_content) else: - st.warning("HTML generation failed. Displaying markdown instead.") + st.warning(HTML_GENERATION_FAILED_WARNING) # Fallback to displaying markdown if HTML generation fails col1, col2 = st.columns(spec=[0.5, 0.5], gap="medium") with col1: diff --git a/bertrend/metrics/temporal_metrics.py b/bertrend/metrics/temporal_metrics.py index 7f5dd9a..87aa8c4 100644 --- a/bertrend/metrics/temporal_metrics.py +++ b/bertrend/metrics/temporal_metrics.py @@ -138,9 +138,9 @@ def __init__( self.tts_scores_df = None self.ttc_scores_df = None - def _topics_over_time(self) -> pd.DataFrame: + def _topics_over_time(self): """ - Calculates and returns a DataFrame containing topics over time with their respective words and frequencies. + Calculates and sets as a property a DataFrame containing topics over time with their respective words and frequencies. Returns: - pd.DataFrame: Topics, their top words, frequencies, and timestamps. diff --git a/bertrend/metrics/temporal_metrics_embedding.py b/bertrend/metrics/temporal_metrics_embedding.py index 3ac4857..73d1460 100644 --- a/bertrend/metrics/temporal_metrics_embedding.py +++ b/bertrend/metrics/temporal_metrics_embedding.py @@ -447,50 +447,6 @@ def _log_failed_match( f.write(" ".join(doc) + "\n\n") f.write(f"{'#'*50}\n\n") - # def _calculate_representation_embeddings(self, double_agg: bool = True, doc_agg: str = "mean", global_agg: str = "max", window_size: int = 10): - # """ - # Calculates embeddings for topic representations using fuzzy matching. - - # Parameters: - # - double_agg: Boolean to apply double aggregation. - # - doc_agg: Aggregation method for document embeddings. - # - global_agg: Aggregation method for global embeddings. - # - window_size: The size of the window for fuzzy matching. - # """ - # representation_embeddings = [] - - # for _, row in self.final_df.iterrows(): - # topic_id = row['Topic'] - # timestamp = row['Timestamp'] - # representation = [phrase.lower() for phrase in row['Words'].split(', ')] - - # token_strings = row['Token_Strings'] - # token_embeddings = row['Token_Embeddings'] - - # embedding_list = [] - # updated_representation = [] - - # for phrase in representation: - # matched_phrase, embedding = self._fuzzy_match_and_embed( - # phrase, token_strings, token_embeddings, topic_id, timestamp, window_size - # ) - # if embedding is not None: - # embedding_list.append(embedding) - # updated_representation.append(matched_phrase) - # else: - # logger.warning(f"No embedding found for '{phrase}' in topic {topic_id} at timestamp {timestamp}") - - # representation_embeddings.append({ - # 'Topic ID': topic_id, - # 'Timestamp': timestamp, - # 'Representation': ', '.join(updated_representation), - # 'Representation Embeddings': embedding_list - # }) - - # self.representation_embeddings_df = pd.DataFrame(representation_embeddings) - # logger.info(f"Created representation_embeddings_df with shape {self.representation_embeddings_df.shape}") - # logger.info(f"Detailed debugging information for failed matches has been written to {self.debug_file}") - def _calculate_representation_embeddings( self, double_agg: bool = True, @@ -576,81 +532,6 @@ def _calculate_representation_embeddings( logger.warning(f"Missing topics: {missing_topics}") logger.warning(f"Extra topics: {extra_topics}") - # def calculate_temporal_representation_stability(self, window_size: int = 2, k: int = 1) -> Tuple[pd.DataFrame, float]: - # """ - # Calculates the Temporal Representation Stability (TRS) scores for each topic. - - # Parameters: - # - window_size: Size of the window for temporal analysis. - # - k: Number of nearest neighbors for stability calculation. - - # Returns: - # - Tuple containing a DataFrame with TRS scores and the average TRS score. - # """ - # if window_size < 2: - # raise ValueError("window_size must be 2 or above.") - - # stability_scores = [] - # grouped_topics = self.representation_embeddings_df.groupby('Topic ID') - # all_topics = set(self.final_df['Topic'].unique()) - # processed_topics = set() - - # for topic_id, group in grouped_topics: - # processed_topics.add(topic_id) - # sorted_group = group.sort_values('Timestamp') - - # for i in range(len(sorted_group) - window_size + 1): - # start_row = sorted_group.iloc[i] - # end_row = sorted_group.iloc[i + window_size - 1] - - # start_embeddings = start_row['Representation Embeddings'] - # end_embeddings = end_row['Representation Embeddings'] - - # if len(start_embeddings) == 0 or len(end_embeddings) == 0: - # logger.warning(f"Skipping topic {topic_id} for timestamps {start_row['Timestamp']} to {end_row['Timestamp']} due to empty embeddings.") - # continue - - # similarity_scores = [] - # for start_embedding in start_embeddings: - # start_embedding = np.array(start_embedding).reshape(1, -1) - # end_embeddings_2d = np.array(end_embeddings).reshape(len(end_embeddings), -1) - - # cosine_similarities = cosine_similarity(start_embedding, end_embeddings_2d)[0] - # top_k_indices = cosine_similarities.argsort()[-k:][::-1] - # top_k_similarities = cosine_similarities[top_k_indices] - # similarity_scores.extend(top_k_similarities) - - # avg_similarity = np.mean(similarity_scores) - - # stability_scores.append({ - # 'Topic ID': topic_id, - # 'Start Timestamp': start_row['Timestamp'], - # 'End Timestamp': end_row['Timestamp'], - # 'Start Representation': start_row['Representation'], - # 'End Representation': end_row['Representation'], - # 'Representation Stability Score': avg_similarity - # }) - - # # Add empty entries for topics with no representation stability scores - # missing_topics = all_topics - processed_topics - # for topic_id in missing_topics: - # logger.warning(f"Topic {topic_id} has no representation stability scores. Adding empty entry.") - # stability_scores.append({ - # 'Topic ID': topic_id, - # 'Start Timestamp': None, - # 'End Timestamp': None, - # 'Start Representation': None, - # 'End Representation': None, - # 'Representation Stability Score': 0.0 # or np.nan if you prefer - # }) - - # self.representation_stability_scores_df = pd.DataFrame(stability_scores) - # self.avg_representation_stability_score = self.representation_stability_scores_df['Representation Stability Score'].mean() - - # logger.info(f"Calculated representation stability for {len(processed_topics)} topics. {len(missing_topics)} topics had no scores.") - - # return self.representation_stability_scores_df, self.avg_representation_stability_score - def calculate_temporal_representation_stability( self, window_size: int = 2, k: int = 1 ) -> Tuple[pd.DataFrame, float]: @@ -753,46 +634,6 @@ def calculate_temporal_representation_stability( self.avg_representation_stability_score, ) - # def calculate_topic_embedding_stability(self, window_size: int = 2) -> Tuple[pd.DataFrame, float]: - # """ - # Calculates the Temporal Topic Embedding Stability (TTES) scores for each topic. - - # Parameters: - # - window_size: Size of the window for temporal analysis. - - # Returns: - # - Tuple containing a DataFrame with TTES scores and the average TTES score. - # """ - # if window_size < 2: - # raise ValueError("window_size must be 2 or above.") - - # stability_scores = [] - # grouped_topics = self.final_df.groupby('Topic') - - # for topic_id, group in grouped_topics: - # sorted_group = group.sort_values('Timestamp') - - # for i in range(len(sorted_group) - window_size + 1): - # start_row = sorted_group.iloc[i] - # end_row = sorted_group.iloc[i + window_size - 1] - - # start_embedding = start_row['Embedding'] - # end_embedding = end_row['Embedding'] - - # similarity = cosine_similarity([start_embedding], [end_embedding])[0][0] - - # stability_scores.append({ - # 'Topic ID': topic_id, - # 'Start Timestamp': start_row['Timestamp'], - # 'End Timestamp': end_row['Timestamp'], - # 'Topic Stability Score': similarity - # }) - - # self.topic_stability_scores_df = pd.DataFrame(stability_scores) - # self.avg_topic_stability_score = self.topic_stability_scores_df['Topic Stability Score'].mean() - - # return self.topic_stability_scores_df, self.avg_topic_stability_score - def calculate_topic_embedding_stability( self, window_size: int = 2 ) -> Tuple[pd.DataFrame, float]: @@ -874,18 +715,15 @@ def calculate_topic_embedding_stability( def calculate_overall_topic_stability( self, window_size: int = 2, k: int = 1, alpha: float = 0.5 ) -> pd.DataFrame: - """ - Calculates the Overall Topic Stability (OTS) score by combining representation stability and embedding stability. - - Parameters: - - window_size: Size of the window for temporal analysis. - - k: Number of nearest neighbors for stability calculation. - - alpha: Weight for combining representation and + """Calculates the Overall Topic Stability (OTS) score by combining representation stability and embedding stability. - embedding stability. + Parameters: + - window_size: Size of the window for temporal analysis. + - k: Number of nearest neighbors for stability calculation. + - alpha: Weight for combining representation and embedding stability. - Returns: - - DataFrame containing the overall stability scores. + Returns: + - DataFrame containing the overall stability scores. """ ( representation_stability_df, diff --git a/bertrend/parameters.py b/bertrend/parameters.py index 247bb54..cbf3e32 100644 --- a/bertrend/parameters.py +++ b/bertrend/parameters.py @@ -8,6 +8,8 @@ import torch +from bertrend import PARAMETERS_CONFIG + stopwords_fr_file = Path(__file__).parent / "resources" / "stopwords-fr.json" stopwords_rte_file = Path(__file__).parent / "resources" / "stopwords-rte.json" common_ngrams_file = Path(__file__).parent / "resources" / "common_ngrams.json" @@ -50,19 +52,27 @@ ] # BERTopic Hyperparameters -DEFAULT_UMAP_N_COMPONENTS = 5 -DEFAULT_UMAP_N_NEIGHBORS = 5 -DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE = 5 -DEFAULT_HDBSCAN_MIN_SAMPLES = 5 -DEFAULT_TOP_N_WORDS = 10 -DEFAULT_MIN_DF = 1 -DEFAULT_GRANULARITY = 2 -DEFAULT_MIN_SIMILARITY = 0.7 -DEFAULT_ZEROSHOT_MIN_SIMILARITY = 0.5 -BERTOPIC_SERIALIZATION = "safetensors" # or pickle -DEFAULT_MMR_DIVERSITY = 0.3 -DEFAULT_UMAP_MIN_DIST = 0.0 -OUTLIER_REDUCTION_STRATEGY = "c-tf-idf" # or "embeddings" +DEFAULT_UMAP_N_COMPONENTS = PARAMETERS_CONFIG["default_umap_n_components"] +DEFAULT_UMAP_N_NEIGHBORS = PARAMETERS_CONFIG["default_umap_n_neighbors"] +DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE = PARAMETERS_CONFIG["default_hdbscan_min_cluster_size"] +DEFAULT_HDBSCAN_MIN_SAMPLES = PARAMETERS_CONFIG["default_hdbscan_min_samples"] +DEFAULT_TOP_N_WORDS = PARAMETERS_CONFIG["default_top_n_words"] +DEFAULT_MIN_DF = PARAMETERS_CONFIG["default_min_df"] +DEFAULT_GRANULARITY = PARAMETERS_CONFIG["default_granularity"] +DEFAULT_MIN_SIMILARITY = PARAMETERS_CONFIG["default_min_similarity"] +DEFAULT_ZEROSHOT_MIN_SIMILARITY = PARAMETERS_CONFIG["default_zeroshot_min_similarity"] +BERTOPIC_SERIALIZATION = PARAMETERS_CONFIG["bertopic_serialization"] +DEFAULT_MMR_DIVERSITY = PARAMETERS_CONFIG["default_mmr_diversity"] +DEFAULT_UMAP_MIN_DIST = PARAMETERS_CONFIG["default_umap_min_dist"] +OUTLIER_REDUCTION_STRATEGY = PARAMETERS_CONFIG["outlier_reduction_strategy"] + +# Signal classification Settings +SIGNAL_CLASSIF_LOWER_BOUND = PARAMETERS_CONFIG["signal_classif_lower_bound"] +SIGNAL_CLASSIF_UPPER_BOUND = PARAMETERS_CONFIG["signal_classif_upper_bound"] + +# Other Constants +DEFAULT_ZEROSHOT_TOPICS = PARAMETERS_CONFIG["default_zeroshot_topics"] + # Embedding Settings EMBEDDING_DTYPES = ["float32", "float16", "bfloat16"] @@ -75,10 +85,6 @@ HDBSCAN_CLUSTER_SELECTION_METHODS = ["eom", "leaf"] VECTORIZER_NGRAM_RANGES = [(1, 2), (1, 1), (2, 2)] -# GPT Model Settings -GPT_TEMPERATURE = 0.1 -GPT_SYSTEM_MESSAGE = "You are a helpful assistant, skilled in detailing topic evolution over time for the detection of emerging trends and signals." -GPT_MAX_TOKENS = 2048 # Data Processing MIN_CHARS_DEFAULT = 100 @@ -90,10 +96,3 @@ # Data Analysis Settings POPULARITY_THRESHOLD = 0.1 # for weak signal detection, if applicable - -# Signal classification Settings -SIGNAL_CLASSIF_LOWER_BOUND = 10 -SIGNAL_CLASSIF_UPPER_BOUND = 75 - -# Other Constants -DEFAULT_ZEROSHOT_TOPICS = "" # Empty string or a default list of topics diff --git a/bertrend/services/embedding_service.py b/bertrend/services/embedding_service.py index 3707264..0d698dc 100644 --- a/bertrend/services/embedding_service.py +++ b/bertrend/services/embedding_service.py @@ -13,6 +13,11 @@ from tqdm import tqdm from bertrend import EMBEDDING_CONFIG +from bertrend.parameters import ( + EMBEDDING_DEVICE, + EMBEDDING_BATCH_SIZE, + EMBEDDING_MAX_SEQ_LENGTH, +) class EmbeddingService: @@ -74,9 +79,9 @@ def _local_embed_documents( texts: List[str], embedding_model_name: str, embedding_dtype: str, - embedding_device: str = "cuda" if torch.cuda.is_available() else "cpu", - batch_size: int = 5000, - max_seq_length: int = 512, + embedding_device: str = EMBEDDING_DEVICE, + batch_size: int = EMBEDDING_BATCH_SIZE, + max_seq_length: int = EMBEDDING_MAX_SEQ_LENGTH, ) -> Tuple[SentenceTransformer, np.ndarray]: """ Embed a list of documents using a Sentence Transformer model. diff --git a/bertrend/trend_analysis/prompts.py b/bertrend/trend_analysis/prompts.py index 5e07a1e..f5bca2c 100644 --- a/bertrend/trend_analysis/prompts.py +++ b/bertrend/trend_analysis/prompts.py @@ -5,6 +5,8 @@ from pathlib import Path +from bertrend import OUTPUT_PATH + # Global variables for prompts SIGNAL_INTRO = { "en": """As an elite strategic foresight analyst with extensive expertise across multiple domains and industries, your task is to conduct a comprehensive evaluation of a potential signal derived from the following topic summary: @@ -234,9 +236,8 @@ def get_prompt( return prompt -# Function to parse the model's output and save as HTML -# FIXME: default path of file! def save_html_output(model_output, output_file="signal_llm.html"): + """Function to parse the model's output and save as HTML""" # Clean the HTML content cleaned_html = model_output.strip() # Remove leading/trailing whitespace @@ -252,7 +253,7 @@ def save_html_output(model_output, output_file="signal_llm.html"): # Final strip to remove any remaining whitespace cleaned_html = cleaned_html.strip() - output_path = Path(__file__).parent / output_file + output_path = OUTPUT_PATH / output_file # Save the cleaned HTML with open(output_path, "w", encoding="utf-8") as file: diff --git a/bertrend/trend_analysis/weak_signals.py b/bertrend/trend_analysis/weak_signals.py index 8331b0d..ae36e99 100644 --- a/bertrend/trend_analysis/weak_signals.py +++ b/bertrend/trend_analysis/weak_signals.py @@ -17,9 +17,6 @@ from bertrend.llm_utils.openai_client import OpenAI_Client from bertrend.parameters import ( - GPT_TEMPERATURE, - GPT_SYSTEM_MESSAGE, - GPT_MAX_TOKENS, SIGNAL_CLASSIF_LOWER_BOUND, SIGNAL_CLASSIF_UPPER_BOUND, ) @@ -547,10 +544,10 @@ def analyze_signal( content_summary=content_summary, ) summary = openai_client.generate( - system_prompt=GPT_SYSTEM_MESSAGE, + system_prompt=LLM_CONFIG["system_prompt"], user_prompt=summary_prompt, - temperature=GPT_TEMPERATURE, - max_tokens=GPT_MAX_TOKENS, + temperature=LLM_CONFIG["temperature"], + max_tokens=LLM_CONFIG["max_tokens"], ) # Second prompt: Analyze weak signal @@ -559,10 +556,10 @@ def analyze_signal( language, "weak_signal", summary_from_first_prompt=summary ) weak_signal_analysis = openai_client.generate( - system_prompt=GPT_SYSTEM_MESSAGE, + system_prompt=LLM_CONFIG["system_prompt"], user_prompt=weak_signal_prompt, - temperature=GPT_TEMPERATURE, - max_tokens=GPT_MAX_TOKENS, + temperature=LLM_CONFIG["temperature"], + max_tokens=LLM_CONFIG["max_tokens"], ) # Third prompt: Generate HTML format @@ -574,10 +571,10 @@ def analyze_signal( weak_signal_analysis=weak_signal_analysis, ) formatted_html = openai_client.generate( - system_prompt=GPT_SYSTEM_MESSAGE, + system_prompt=LLM_CONFIG["system_prompt"], user_prompt=html_format_prompt, - temperature=GPT_TEMPERATURE, - max_tokens=GPT_MAX_TOKENS, + temperature=LLM_CONFIG["temperature"], + max_tokens=LLM_CONFIG["max_tokens"], ) # Save the formatted HTML diff --git a/bertrend_apps/newsletters/__main__.py b/bertrend_apps/newsletters/__main__.py index b06916e..a883eca 100644 --- a/bertrend_apps/newsletters/__main__.py +++ b/bertrend_apps/newsletters/__main__.py @@ -23,6 +23,7 @@ from umap import UMAP from bertrend import FEED_BASE_PATH, BEST_CUDA_DEVICE, OUTPUT_PATH +from bertrend.parameters import BERTOPIC_SERIALIZATION from bertrend.utils.config_utils import load_toml_config from bertrend.utils.data_loading import ( split_df_by_paragraphs, @@ -266,7 +267,7 @@ def _save_topic_model( # Serialization using safetensors topic_model.save( full_model_path_dir, - serialization="safetensors", + serialization=BERTOPIC_SERIALIZATION, save_ctfidf=True, save_embedding_model=embedding_model, )