diff --git a/bertrend/demos/topic_analysis/demo_pages/explore_topics.py b/bertrend/demos/topic_analysis/demo_pages/explore_topics.py index 5af2993..dc716e1 100644 --- a/bertrend/demos/topic_analysis/demo_pages/explore_topics.py +++ b/bertrend/demos/topic_analysis/demo_pages/explore_topics.py @@ -283,6 +283,27 @@ def create_topic_documents( return folder_name, documents +def _display_topic_description(filtered_df): + # GPT description button + if st.button( + "Generate a short description of the topic", + type="primary", + use_container_width=True, + ): + with st.spinner("Génération de la description en cours..."): + language_code = ( + "fr" if SessionStateManager.get("language") == "French" else "en" + ) + gpt_description = generate_topic_description( + st.session_state["topic_model"], + st.session_state["selected_topic_number"], + filtered_df, + language_code=language_code, + ) + with st.container(border=True): + st.markdown(gpt_description) + + def main(): """Main function to run the Streamlit topic_analysis.""" check_model_and_prepare_topics() @@ -291,8 +312,11 @@ def main(): if "selected_topic_number" not in st.session_state: st.stop() - display_topic_info() - plot_topic_over_time() + col1, col2 = st.columns([0.4, 0.6]) + with col1: + display_topic_info() + with col2: + plot_topic_over_time() st.divider() @@ -322,18 +346,8 @@ def main(): default=["All"], ) - """ - # Filter the dataframe based on selected sources - if "All" not in selected_sources: - filtered_df = representative_df[ - representative_df[URL_COLUMN].apply(get_website_name).isin(selected_sources) - ] - else: - filtered_df = representative_df - """ - # Create two columns - col1, col2 = st.columns([0.5, 0.5]) + col1, col2 = st.columns([0.4, 0.6]) with col1: # Pass the full representative_df to display_source_distribution @@ -355,24 +369,8 @@ def main(): st.divider() - # GPT description button - if st.button( - "Generate a short description of the topic", - type="primary", - use_container_width=True, - ): - with st.spinner("Génération de la description en cours..."): - language_code = ( - "fr" if SessionStateManager.get("language") == "French" else "en" - ) - gpt_description = generate_topic_description( - st.session_state["topic_model"], - st.session_state["selected_topic_number"], - filtered_df, - language_code=language_code, - ) - with st.container(border=True): - st.markdown(gpt_description) + # GPT generated topic description + _display_topic_description(filtered_df) st.divider()