Skip to content

Commit

Permalink
Updated save_models() and restore_model()
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 22, 2024
1 parent 12eed54 commit 8a916af
Show file tree
Hide file tree
Showing 24 changed files with 396 additions and 381 deletions.
156 changes: 51 additions & 105 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
import shutil
from collections import defaultdict
from pathlib import Path
from typing import Dict, Tuple, List, Any

import numpy as np
Expand All @@ -14,10 +15,7 @@
from sentence_transformers import SentenceTransformer

from bertrend import MODELS_DIR, CACHE_PATH
from bertrend.demos.weak_signals.messages import (
NO_GRANULARITY_WARNING,
)
from bertrend.demos.weak_signals.session_state_manager import SessionStateManager

from bertrend.topic_model import TopicModel
from bertrend.parameters import (
DEFAULT_MIN_SIMILARITY,
Expand All @@ -29,6 +27,7 @@
EMB_GROUPS_FILE,
GRANULARITY_FILE,
HYPERPARAMS_FILE,
BERTOPIC_SERIALIZATION,
)
from bertrend.trend_analysis.topic_modeling import preprocess_model, merge_models
from bertrend.trend_analysis.weak_signals import (
Expand Down Expand Up @@ -62,6 +61,7 @@ def __init__(
)
self.zeroshot_topic_list = zeroshot_topic_list
self.zeroshot_min_similarity = zeroshot_min_similarity
self.granularity = DEFAULT_GRANULARITY

# State variables of BERTrend
self._is_fitted = False
Expand Down Expand Up @@ -429,17 +429,16 @@ def calculate_signal_popularity(
self.topic_last_popularity = topic_last_popularity
self.topic_last_update = topic_last_update

def save_models(self):
if MODELS_DIR.exists():
shutil.rmtree(MODELS_DIR)
MODELS_DIR.mkdir(parents=True, exist_ok=True)
def save_models(self, models_path: Path = MODELS_DIR):
if models_path.exists():
shutil.rmtree(models_path)
models_path.mkdir(parents=True, exist_ok=True)

# TODO
"""
# Save topic models using the selected serialization type
for period, topic_model in self.topic_models.items():
model_dir = MODELS_DIR / period.strftime("%Y-%m-%d")
model_dir = models_path / period.strftime("%Y-%m-%d")
model_dir.mkdir(exist_ok=True)
embedding_model = SessionStateManager.get("embedding_model")
embedding_model = topic_model.embedding_model
topic_model.save(
model_dir,
serialization=BERTOPIC_SERIALIZATION,
Expand All @@ -449,44 +448,54 @@ def save_models(self):

topic_model.doc_info_df.to_pickle(model_dir / DOC_INFO_DF_FILE)
topic_model.topic_info_df.to_pickle(model_dir / TOPIC_INFO_DF_FILE)
"""

# Save topic model parameters
with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f:
pickle.dump(self.topic_model_parameters, f)
# Save granularity file
with open(CACHE_PATH / GRANULARITY_FILE, "wb") as f:
pickle.dump(self.granularity, f)
# Save doc_groups file
with open(CACHE_PATH / DOC_GROUPS_FILE, "wb") as f:
pickle.dump(self.doc_groups, f)
# Save emb_groups file
with open(CACHE_PATH / EMB_GROUPS_FILE, "wb") as f:
pickle.dump(self.emb_groups, f)

# FIXME: granularity currently not set at this stage
# with open(CACHE_PATH / GRANULARITY_FILE, "wb") as f:
# pickle.dump(self.granularity)

# Save the models_trained flag
with open(CACHE_PATH / MODELS_TRAINED_FILE, "wb") as f:
pickle.dump(self._is_fitted, f)

# TODO!
"""
hyperparams = SessionStateManager.get_multiple(
"umap_n_components",
"umap_n_neighbors",
"hdbscan_min_cluster_size",
"hdbscan_min_samples",
"hdbscan_cluster_selection_method",
"top_n_words",
"vectorizer_ngram_range",
"min_df",
)
with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f:
pickle.dump(hyperparams, f)
"""
logger.info(f"Models saved to: {models_path}")

@classmethod
def restore_models(cls):
if not MODELS_DIR.exists():
raise FileNotFoundError(f"MODELS_DIR={MODELS_DIR} does not exist")

def restore_models(cls, models_path: Path = MODELS_DIR) -> "BERTrend":
if not models_path.exists():
raise FileNotFoundError(f"models_path={models_path} does not exist")

logger.info(f"Loading models from: {models_path}")

# Create BERTrend object
bertrend = cls()

# load topic model parameters
with open(CACHE_PATH / HYPERPARAMS_FILE, "rb") as f:
bertrend.topic_model_parameters = pickle.load(f)
# load granularity file
with open(CACHE_PATH / GRANULARITY_FILE, "rb") as f:
bertrend.granularity = pickle.load(f)
# load doc_groups file
with open(CACHE_PATH / DOC_GROUPS_FILE, "rb") as f:
bertrend.doc_groups = pickle.load(f)
# load emb_groups file
with open(CACHE_PATH / EMB_GROUPS_FILE, "rb") as f:
bertrend.emb_groups = pickle.load(f)
# load the models_trained flag
with open(CACHE_PATH / MODELS_TRAINED_FILE, "rb") as f:
bertrend._is_fitted = pickle.load(f)

# Restore topic models using the selected serialization type
topic_models = {}
for period_dir in MODELS_DIR.iterdir():
for period_dir in models_path.iterdir():
if period_dir.is_dir():
topic_model = BERTopic.load(period_dir)

Expand All @@ -502,72 +511,9 @@ def restore_models(cls):

period = pd.Timestamp(period_dir.name.replace("_", ":"))
topic_models[period] = topic_model
bertrend.topic_models = topic_models

SessionStateManager.set("topic_models", topic_models)
return bertrend

for file, key in [
(DOC_GROUPS_FILE, "doc_groups"),
(EMB_GROUPS_FILE, "emb_groups"),
]:
file_path = CACHE_PATH / file
if file_path.exists():
with open(file_path, "rb") as f:
SessionStateManager.set(key, pickle.load(f))
else:
logger.warning(f"{file} not found.")

granularity_file = CACHE_PATH / GRANULARITY_FILE
if granularity_file.exists():
with open(granularity_file, "rb") as f:
SessionStateManager.set("granularity_select", pickle.load(f))
else:
logger.warning(NO_GRANULARITY_WARNING)

# Restore the models_trained flag
models_trained_file = CACHE_PATH / MODELS_TRAINED_FILE
if models_trained_file.exists():
with open(models_trained_file, "rb") as f:
# FIXME! set bertrend first!
SessionStateManager.set("models_trained", pickle.load(f))
else:
logger.warning("Models trained flag not found.")

hyperparams_file = CACHE_PATH / HYPERPARAMS_FILE
if hyperparams_file.exists():
with open(hyperparams_file, "rb") as f:
SessionStateManager.set_multiple(**pickle.load(f))
else:
logger.warning("Hyperparameters file not found.")

#####################################################################################################
# FIXME: WIP
# def merge_models2(self):
# if not self._is_fitted:
# raise RuntimeError("You must fit the BERTrend model before merging models.")
#
# merged_data = self._initialize_merge_data()
#
# logger.info("Merging models...")
# for timestamp, model in self.topic_models.items():
# if not merged_data:
# merged_data = self._process_first_model(model)
# else:
# merged_data = self._merge_with_existing_data(
# merged_data, model, timestamp
# )
#
# self.merged_topics = merged_data
#
# def _merge_with_existing_data(
# self, merged_data: Dict, model: BERTopic, timestamp: pd.Timestamp
# ) -> Dict:
# # Extract topics and embeddings
#
# # Compute similarity between current model's topics and the merged ones
#
# # Update merged_data with this model's data based on computed similarities
# # Implement business logic to handle merging decisions
# # This can involve thresholding, updating topic IDs, and merging document and metadata entries
#
# # return merged_data # Return the updated merged data
# pass

# TODO: methods for prospective analysis (handle topic generation step by step)
4 changes: 4 additions & 0 deletions bertrend/demos/demos_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
109 changes: 109 additions & 0 deletions bertrend/demos/demos_utils/data_loading_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
import pandas as pd
import streamlit as st

from bertrend import DATA_PATH
from bertrend.demos.demos_utils.session_state_manager import SessionStateManager
from bertrend.parameters import MIN_CHARS_DEFAULT, SAMPLE_SIZE_DEFAULT
from bertrend.utils.data_loading import (
find_compatible_files,
load_and_preprocess_data,
TEXT_COLUMN,
)

NO_DATASET_WARNING = "Please select at least one dataset to proceed."


def display_data_loading_component():
# Find files in the current directory and subdirectories
compatible_extensions = ["csv", "parquet", "json", "jsonl"]
selected_files = st.multiselect(
"Select one or more datasets",
find_compatible_files(DATA_PATH, compatible_extensions),
default=SessionStateManager.get("selected_files", []),
key="selected_files",
)

if not selected_files:
st.warning(NO_DATASET_WARNING)
return

# Display number input and checkbox for preprocessing options
col1, col2 = st.columns(2)
with col1:
min_chars = st.number_input(
"Minimum Characters",
value=MIN_CHARS_DEFAULT,
min_value=0,
max_value=1000,
key="min_chars",
)
with col2:
split_by_paragraph = st.checkbox(
"Split text by paragraphs", value=False, key="split_by_paragraph"
)

# Load and preprocess each selected file, then concatenate them
dfs = []
for selected_file, ext in selected_files:
file_path = DATA_PATH / selected_file
df = load_and_preprocess_data(
(file_path, ext),
st.session_state["language"],
min_chars,
split_by_paragraph,
)
dfs.append(df)

if not dfs:
st.warning(
"No data available after preprocessing. Please check the selected files and preprocessing options."
)
else:
df = pd.concat(dfs, ignore_index=True)

# Deduplicate using all columns
df = df.drop_duplicates()

# Select timeframe
min_date, max_date = df["timestamp"].dt.date.agg(["min", "max"])
start_date, end_date = st.slider(
"Select Timeframe",
min_value=min_date,
max_value=max_date,
value=(min_date, max_date),
key="timeframe_slider",
)

# Filter and sample the DataFrame
df_filtered = df[
(df["timestamp"].dt.date >= start_date)
& (df["timestamp"].dt.date <= end_date)
]
df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True)

sample_size = st.number_input(
"Sample Size",
value=SAMPLE_SIZE_DEFAULT or len(df_filtered),
min_value=1,
max_value=len(df_filtered),
key="sample_size",
)
if sample_size < len(df_filtered):
df_filtered = df_filtered.sample(n=sample_size, random_state=42)

df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True)

SessionStateManager.set("timefiltered_df", df_filtered)
st.write(
f"Number of documents in selected timeframe: {len(SessionStateManager.get_dataframe('timefiltered_df'))}"
)
st.dataframe(
SessionStateManager.get_dataframe("timefiltered_df")[
[TEXT_COLUMN, "timestamp"]
],
use_container_width=True,
)
Loading

0 comments on commit 8a916af

Please sign in to comment.