Skip to content

Commit

Permalink
Refactor: temptopic visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 31, 2024
1 parent 020a994 commit d9c1664
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 265 deletions.
31 changes: 25 additions & 6 deletions bertrend/demos/demos_utils/data_loading_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TEXT_COLUMN,
load_data,
split_data,
TIMESTAMP_COLUMN,
)

NO_DATASET_WARNING = "Please select at least one dataset to proceed."
Expand Down Expand Up @@ -82,7 +83,18 @@ def _load_files(
def display_data_loading_component():
"""
Component for a streamlit app about topic modelling. It allows to choose data to load and preprocess data.
The final dataframe is stored inside the Streamlit state variable "time_filtered_df"
Preprocessing of data includes:
- concatenation of data from different files
- adding potentially missing columns to make all datasets homogeneous
- removal of duplicates
- splitting the dataset by paragraphs to avoid too long textes
- filtering based on timestamp range
- filtering based on a minimum number of characters
The initial dataframe (one line per document) is stored after filtering of bad data inside a Streamlit state
variable "initial_df". After split by paragraph, it is stored in "df_split". The final dataframe (possibly
split by paragraph from initial documents and in all cases filtered by dates) is stored inside the Streamlit
state variable "time_filtered_df".
"""
# Find files in the current directory and subdirectories
tab1, tab2 = st.tabs(["Data from local storage", "Data from server data"])
Expand Down Expand Up @@ -173,10 +185,13 @@ def display_data_loading_component():
# Deduplicate using all columns
df = df.drop_duplicates()

# Save state of split dataframe (before time-based filtering)
st.session_state["split_df"] = df.copy()

col1, col2 = st.columns([0.8, 0.2])
with col1:
# Select timeframe
min_date, max_date = df["timestamp"].dt.date.agg(["min", "max"])
min_date, max_date = df[TIMESTAMP_COLUMN].dt.date.agg(["min", "max"])
register_widget("timeframe_slider")
start_date, end_date = st.slider(
"Select Timeframe",
Expand All @@ -189,10 +204,12 @@ def display_data_loading_component():

# Filter and sample the DataFrame
df_filtered = df[
(df["timestamp"].dt.date >= start_date)
& (df["timestamp"].dt.date <= end_date)
(df[TIMESTAMP_COLUMN].dt.date >= start_date)
& (df[TIMESTAMP_COLUMN].dt.date <= end_date)
]
df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True)
df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index(
drop=True
)

with col2:
register_widget("sample_size")
Expand All @@ -208,7 +225,9 @@ def display_data_loading_component():
if sample_size < len(df_filtered):
df_filtered = df_filtered.sample(n=sample_size, random_state=42)

df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True)
df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index(
drop=True
)

SessionStateManager.set("time_filtered_df", df_filtered)
st.write(
Expand Down
Loading

0 comments on commit d9c1664

Please sign in to comment.