Skip to content

Commit

Permalink
Minor layout changes
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 31, 2024
1 parent e52c943 commit 020a994
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 112 deletions.
58 changes: 31 additions & 27 deletions bertrend/demos/demos_utils/data_loading_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,34 +173,38 @@ def display_data_loading_component():
# Deduplicate using all columns
df = df.drop_duplicates()

# Select timeframe
min_date, max_date = df["timestamp"].dt.date.agg(["min", "max"])
register_widget("timeframe_slider")
start_date, end_date = st.slider(
"Select Timeframe",
min_value=min_date,
max_value=max_date,
value=(min_date, max_date),
key="timeframe_slider",
on_change=save_widget_state,
)

# Filter and sample the DataFrame
df_filtered = df[
(df["timestamp"].dt.date >= start_date)
& (df["timestamp"].dt.date <= end_date)
]
df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True)
col1, col2 = st.columns([0.8, 0.2])
with col1:
# Select timeframe
min_date, max_date = df["timestamp"].dt.date.agg(["min", "max"])
register_widget("timeframe_slider")
start_date, end_date = st.slider(
"Select Timeframe",
min_value=min_date,
max_value=max_date,
value=(min_date, max_date),
key="timeframe_slider",
on_change=save_widget_state,
)

# Filter and sample the DataFrame
df_filtered = df[
(df["timestamp"].dt.date >= start_date)
& (df["timestamp"].dt.date <= end_date)
]
df_filtered = df_filtered.sort_values(by="timestamp").reset_index(drop=True)

with col2:
register_widget("sample_size")
sample_size = st.number_input(
"Sample Size",
value=SAMPLE_SIZE_DEFAULT or len(df_filtered),
min_value=1,
max_value=len(df_filtered),
key="sample_size",
on_change=save_widget_state,
)

register_widget("sample_size")
sample_size = st.number_input(
"Sample Size",
value=SAMPLE_SIZE_DEFAULT or len(df_filtered),
min_value=1,
max_value=len(df_filtered),
key="sample_size",
on_change=save_widget_state,
)
if sample_size < len(df_filtered):
df_filtered = df_filtered.sample(n=sample_size, random_state=42)

Expand Down
5 changes: 5 additions & 0 deletions bertrend/demos/demos_utils/state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def register_widget(key):
st.session_state[STATE_KEYS].append(key)


def register_multiple_widget(*keys):
for key in keys:
register_widget(key)


def save_widget_state():
if STATE_KEYS in st.session_state.keys():
st.session_state[WIDGET_STATE] = {
Expand Down
20 changes: 7 additions & 13 deletions bertrend/demos/topic_analysis/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from bertrend.utils.data_loading import TEXT_COLUMN, TIMESTAMP_COLUMN


def data_overview(df: pd.DataFrame):
with st.container(border=True):
col1, col2 = st.columns([0.4, 0.6])
def data_distribution(df: pd.DataFrame):
with st.expander(
label="Data distribution",
expanded=False,
):
freq = st.select_slider(
"Time aggregation",
options=(
Expand All @@ -34,16 +36,8 @@ def data_overview(df: pd.DataFrame):
),
value="1M",
)
with col1:
fig = plot_docs_repartition_over_time(df, freq)
st.plotly_chart(
fig, config=PLOTLY_BUTTON_SAVE_CONFIG, use_container_width=True
)
with col2:
st.dataframe(
st.session_state["time_filtered_df"][[TEXT_COLUMN, TIMESTAMP_COLUMN]],
use_container_width=True,
)
fig = plot_docs_repartition_over_time(df, freq)
st.plotly_chart(fig, config=PLOTLY_BUTTON_SAVE_CONFIG, use_container_width=True)


def choose_data(base_dir: Path, filters: List[str]):
Expand Down
122 changes: 61 additions & 61 deletions bertrend/demos/topic_analysis/demo_pages/explore_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def display_topic_info():
]["Representation"]

st.write(
f"# Topic {st.session_state['selected_topic_number']} : {topic_docs_number} documents"
f"## Topic {st.session_state['selected_topic_number']} : {topic_docs_number} documents"
)
st.markdown(f"## #{' #'.join(topic_words)}")
st.markdown(f"### #{' #'.join(topic_words)}")


def plot_topic_over_time():
Expand Down Expand Up @@ -232,7 +232,6 @@ def display_representative_documents(filtered_df):
st.link_button(content, doc.url)
else:
st.markdown(content)
st.divider()


def display_new_documents():
Expand Down Expand Up @@ -283,57 +282,76 @@ def create_topic_documents(
return folder_name, documents


def _display_topic_description(filtered_df):
# GPT description button
if st.button(
"Generate a short description of the topic",
type="primary",
use_container_width=True,
):
with st.spinner("Génération de la description en cours..."):
language_code = (
"fr" if SessionStateManager.get("language") == "French" else "en"
)
gpt_description = generate_topic_description(
st.session_state["topic_model"],
st.session_state["selected_topic_number"],
filtered_df,
language_code=language_code,
)
with st.container(border=True):
st.markdown(gpt_description)


def main():
"""Main function to run the Streamlit topic_analysis."""
check_model_and_prepare_topics()

st.title("Topics exploration")

display_sidebar()

if "selected_topic_number" not in st.session_state:
st.stop()

display_topic_info()
plot_topic_over_time()

st.divider()

# Number of articles to display
top_n_docs = st.number_input(
"Number of articles to display",
min_value=1,
max_value=st.session_state["topics_info"].iloc[
st.session_state["selected_topic_number"]
]["Count"],
value=st.session_state["topics_info"].iloc[
st.session_state["selected_topic_number"]
]["Count"],
step=1,
)

representative_df = get_representative_documents(top_n_docs)
representative_df = representative_df.sort_values(by="timestamp", ascending=False)
col1, col2 = st.columns([0.3, 0.7])
with col1:
display_topic_info()
with col2:
plot_topic_over_time()

# Get unique sources
sources = representative_df[URL_COLUMN].apply(get_website_name).unique()
col1, col2 = st.columns(2)
with col1:
# Number of articles to display
top_n_docs = st.number_input(
"Number of articles to display",
min_value=1,
max_value=st.session_state["topics_info"].iloc[
st.session_state["selected_topic_number"]
]["Count"],
value=st.session_state["topics_info"].iloc[
st.session_state["selected_topic_number"]
]["Count"],
step=1,
)
with col2:
representative_df = get_representative_documents(top_n_docs)
representative_df = representative_df.sort_values(
by="timestamp", ascending=False
)

# Multi-select for sources
selected_sources = st.multiselect(
"Select the sources to display",
options=["All"] + list(sources),
default=["All"],
)
# Get unique sources
sources = representative_df[URL_COLUMN].apply(get_website_name).unique()

"""
# Filter the dataframe based on selected sources
if "All" not in selected_sources:
filtered_df = representative_df[
representative_df[URL_COLUMN].apply(get_website_name).isin(selected_sources)
]
else:
filtered_df = representative_df
"""
# Multi-select for sources
selected_sources = st.multiselect(
"Select the sources to display",
options=["All"] + list(sources),
default=["All"],
)

# Create two columns
col1, col2 = st.columns([0.5, 0.5])
col1, col2 = st.columns([0.3, 0.7])

with col1:
# Pass the full representative_df to display_source_distribution
Expand All @@ -353,26 +371,8 @@ def main():

display_new_documents()

st.divider()

# GPT description button
if st.button(
"Generate a short description of the topic",
type="primary",
use_container_width=True,
):
with st.spinner("Génération de la description en cours..."):
language_code = (
"fr" if SessionStateManager.get("language") == "French" else "en"
)
gpt_description = generate_topic_description(
st.session_state["topic_model"],
st.session_state["selected_topic_number"],
filtered_df,
language_code=language_code,
)
with st.container(border=True):
st.markdown(gpt_description)
# GPT generated topic description
_display_topic_description(filtered_df)

st.divider()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
register_widget,
save_widget_state,
restore_widget_state,
register_multiple_widget,
)
from bertrend.demos.weak_signals.visualizations_utils import PLOTLY_BUTTON_SAVE_CONFIG
from bertrend.utils.data_loading import TIMESTAMP_COLUMN, TEXT_COLUMN
Expand Down Expand Up @@ -411,30 +412,48 @@ def format_timedelta(td):

def select_time_granularity(max_granularity):
"""Allow user to select custom time granularity within limits."""
st.write("Select custom time granularity:")
st.write("Select custom time granularity")
col1, col2, col3, col4 = st.columns(4)

max_days = max_granularity.days

register_multiple_widget(
"granularity_days", "granularity_hours", "granularity_minutes"
)
with col1:
days = st.number_input(
days = st.slider(
"Days",
min_value=0,
max_value=max_days,
value=min(1, max_days),
key="granularity_days",
on_change=save_widget_state,
)
with col2:
hours = st.number_input(
"Hours", min_value=0, max_value=23, value=0, key="granularity_hours"
hours = st.slider(
"Hours",
min_value=0,
max_value=23,
value=0,
key="granularity_hours",
on_change=save_widget_state,
)
with col3:
minutes = st.number_input(
"Minutes", min_value=0, max_value=59, value=0, key="granularity_minutes"
minutes = st.slider(
"Minutes",
min_value=0,
max_value=59,
value=0,
key="granularity_minutes",
on_change=save_widget_state,
)
with col4:
seconds = st.number_input(
"Seconds", min_value=0, max_value=59, value=0, key="granularity_seconds"
seconds = st.slider(
"Seconds",
min_value=0,
max_value=59,
value=0,
key="granularity_seconds",
on_change=save_widget_state,
)

selected_granularity = timedelta(
Expand Down Expand Up @@ -743,6 +762,8 @@ def main():
# Check if model is trained
check_model_trained()

st.title("Temporal visualizations of topics")

# Display sidebar
display_sidebar()

Expand Down
4 changes: 2 additions & 2 deletions bertrend/demos/topic_analysis/demo_pages/training_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
display_bertopic_hyperparameters,
)
from bertrend.demos.demos_utils.session_state_manager import SessionStateManager
from bertrend.demos.topic_analysis.data_utils import data_overview
from bertrend.demos.topic_analysis.data_utils import data_distribution
from bertrend.metrics.topic_metrics import compute_cluster_metrics
from bertrend.parameters import BERTOPIC_SERIALIZATION
from bertrend.topic_model import TopicModel
Expand Down Expand Up @@ -172,7 +172,7 @@ def main():
# Data overview
if "time_filtered_df" not in st.session_state:
st.stop()
data_overview(st.session_state["time_filtered_df"])
data_distribution(st.session_state["time_filtered_df"])
SessionStateManager.set("split_type", st.session_state["split_by_paragraph"])

# Embed documents
Expand Down

0 comments on commit 020a994

Please sign in to comment.