From da773a95d872bd8b3de8470717c09441bdda3d8b Mon Sep 17 00:00:00 2001 From: "j.aschauer" Date: Wed, 15 Nov 2023 18:52:38 +0100 Subject: [PATCH] move plot functions to individual module --- app/plot_functions.py | 216 +++++++++++++++++++++++++++++ app/ptxboa_functions.py | 286 ++++++--------------------------------- app/tab_compare_costs.py | 2 +- app/tab_dashboard.py | 6 +- 4 files changed, 256 insertions(+), 254 deletions(-) create mode 100644 app/plot_functions.py diff --git a/app/plot_functions.py b/app/plot_functions.py new file mode 100644 index 00000000..951a3277 --- /dev/null +++ b/app/plot_functions.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +"""Functions for plotting input data and results (cost_data).""" +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import streamlit as st + +from app.ptxboa_functions import remove_subregions +from ptxboa.api import PtxboaAPI + + +def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): + """Create world map.""" + parameter_to_show_on_map = "Total" + + # define title: + title_string = ( + f"{parameter_to_show_on_map} cost of exporting" + f"{st.session_state['chain']} to " + f"{st.session_state['country']}" + ) + # define color scale: + color_scale = [ + (0, st.session_state["colors"][0]), # Starting color at the minimum data value + (0.5, st.session_state["colors"][6]), + (1, st.session_state["colors"][9]), # Ending color at the maximum data value + ] + + # remove subregions from deep dive countries (otherwise colorscale is not correct) + res_costs = remove_subregions(api, res_costs, st.session_state["country"]) + + # Create custom hover text: + custom_hover_data = res_costs.apply( + lambda x: f"{x.name}

" + + "
".join( + [ + f"{col}: {x[col]:.1f}" f"{st.session_state['output_unit']}" + for col in res_costs.columns[:-1] + ] + + [ + f"──────────
{res_costs.columns[-1]}: " + f"{x[res_costs.columns[-1]]:.1f}" + f"{st.session_state['output_unit']}" + ] + ), + axis=1, + ) + + # Create a choropleth world map: + fig = px.choropleth( + locations=res_costs.index, # List of country codes or names + locationmode="country names", # Use country names as locations + color=res_costs[parameter_to_show_on_map], # Color values for the countries + custom_data=[custom_hover_data], # Pass custom data for hover information + color_continuous_scale=color_scale, # Choose a color scale + title=title_string, # set title + ) + + # update layout: + fig.update_geos( + showcountries=True, # Show country borders + showcoastlines=True, # Show coastlines + countrycolor="black", # Set default border color for other countries + countrywidth=0.2, # Set border width + coastlinewidth=0.2, # coastline width + coastlinecolor="black", # coastline color + showland=True, # show land areas + landcolor="#f3f4f5", # Set land color to light gray + oceancolor="#e3e4ea", # Optionally, set ocean color slightly darker gray + showocean=True, # show ocean areas + framewidth=0.2, # width of frame around map + ) + + fig.update_layout( + coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar + height=600, # height of figure + margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure + ) + + # Set the hover template to use the custom data + fig.update_traces(hovertemplate="%{customdata}") # Custom data + + # Display the map: + st.plotly_chart(fig, use_container_width=True) + return + + +def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None): + """Create bar plot for costs by components, and dots for total costs. + + Parameters + ---------- + res_costs : pd.DataFrame + data for plotting + settings : dict + settings dictionary, like output from create_sidebar() + current_selection : str + bar to highlight with an arrow. must be an element of res_costs.index + + Output + ------ + fig : plotly.graph_objects.Figure + Figure object + """ + if res_costs.empty: # nodata to plot (FIXME: migth not be required later) + return go.Figure() + + fig = px.bar( + res_costs, + x=res_costs.index, + y=res_costs.columns[:-1], + height=500, + color_discrete_sequence=st.session_state["colors"], + ) + + # Add the dot markers for the "total" column using plotly.graph_objects + scatter_trace = go.Scatter( + x=res_costs.index, + y=res_costs["Total"], + mode="markers+text", # Display markers and text + marker={"size": 10, "color": "black"}, + name="Total", + text=res_costs["Total"].apply( + lambda x: f"{x:.2f}" + ), # Use 'total' column values as text labels + textposition="top center", # Position of the text label above the marker + ) + + fig.add_trace(scatter_trace) + + # add highlight for current selection: + if current_selection is not None and current_selection in res_costs.index: + fig.add_annotation( + x=current_selection, + y=1.2 * res_costs.at[current_selection, "Total"], + text="current selection", + showarrow=True, + arrowhead=2, + arrowsize=1, + arrowwidth=2, + ax=0, + ay=-50, + ) + fig.update_layout( + yaxis_title=st.session_state["output_unit"], + ) + return fig + + +def create_box_plot(res_costs: pd.DataFrame): + """Create a subplot with one row and one column. + + Parameters + ---------- + res_costs : pd.DataFrame + data for plotting + settings : dict + settings dictionary, like output from create_sidebar() + + Output + ------ + fig : plotly.graph_objects.Figure + Figure object + """ + fig = go.Figure() + + # Specify the row index of the data point you want to highlight + highlighted_row_index = st.session_state["region"] + # Extract the value from the specified row and column + + if highlighted_row_index: + highlighted_value = res_costs.at[highlighted_row_index, "Total"] + else: + highlighted_value = 0 + + # Add the box plot to the subplot + fig.add_trace(go.Box(y=res_costs["Total"], name="Cost distribution")) + + # Add a scatter marker for the highlighted data point + fig.add_trace( + go.Scatter( + x=["Cost distribution"], + y=[highlighted_value], + mode="markers", + marker={"size": 10, "color": "black"}, + name=highlighted_row_index, + text=f"Value: {highlighted_value}", # Add a text label + ) + ) + + # Customize the layout as needed + fig.update_layout( + title="Cost distribution for all supply countries", + xaxis={"title": ""}, + yaxis={"title": st.session_state["output_unit"]}, + height=500, + ) + + return fig + + +def create_scatter_plot(df_res, settings: dict): + df_res["Country"] = "Other countries" + df_res.at[st.session_state["region"], "Country"] = st.session_state["region"] + + fig = px.scatter( + df_res, + y="Total", + x="tr_dst_sd", + color="Country", + text=df_res.index, + color_discrete_sequence=["blue", "red"], + ) + fig.update_traces(texttemplate="%{text}", textposition="top center") + st.plotly_chart(fig) + st.write(df_res) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index d7147b0b..14ddb74d 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -2,8 +2,6 @@ """Utility functions for streamlit app.""" import pandas as pd -import plotly.express as px -import plotly.graph_objects as go import streamlit as st from ptxboa.api import PtxboaAPI @@ -105,211 +103,52 @@ def aggregate_costs(res_details: pd.DataFrame) -> pd.DataFrame: return res -def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): - """Create world map.""" - parameter_to_show_on_map = "Total" - - # define title: - title_string = ( - f"{parameter_to_show_on_map} cost of exporting" - f"{st.session_state['chain']} to " - f"{st.session_state['country']}" - ) - # define color scale: - color_scale = [ - (0, st.session_state["colors"][0]), # Starting color at the minimum data value - (0.5, st.session_state["colors"][6]), - (1, st.session_state["colors"][9]), # Ending color at the maximum data value - ] - - # remove subregions from deep dive countries (otherwise colorscale is not correct) - res_costs = remove_subregions(api, res_costs, st.session_state["country"]) - - # Create custom hover text: - custom_hover_data = res_costs.apply( - lambda x: f"{x.name}

" - + "
".join( - [ - f"{col}: {x[col]:.1f}" f"{st.session_state['output_unit']}" - for col in res_costs.columns[:-1] - ] - + [ - f"──────────
{res_costs.columns[-1]}: " - f"{x[res_costs.columns[-1]]:.1f}" - f"{st.session_state['output_unit']}" - ] - ), - axis=1, - ) - - # Create a choropleth world map: - fig = px.choropleth( - locations=res_costs.index, # List of country codes or names - locationmode="country names", # Use country names as locations - color=res_costs[parameter_to_show_on_map], # Color values for the countries - custom_data=[custom_hover_data], # Pass custom data for hover information - color_continuous_scale=color_scale, # Choose a color scale - title=title_string, # set title - ) - - # update layout: - fig.update_geos( - showcountries=True, # Show country borders - showcoastlines=True, # Show coastlines - countrycolor="black", # Set default border color for other countries - countrywidth=0.2, # Set border width - coastlinewidth=0.2, # coastline width - coastlinecolor="black", # coastline color - showland=True, # show land areas - landcolor="#f3f4f5", # Set land color to light gray - oceancolor="#e3e4ea", # Optionally, set ocean color slightly darker gray - showocean=True, # show ocean areas - framewidth=0.2, # width of frame around map - ) - - fig.update_layout( - coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar - height=600, # height of figure - margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure - ) - - # Set the hover template to use the custom data - fig.update_traces(hovertemplate="%{customdata}") # Custom data - - # Display the map: - st.plotly_chart(fig, use_container_width=True) - return - - -def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None): - """Create bar plot for costs by components, and dots for total costs. - - Parameters - ---------- - res_costs : pd.DataFrame - data for plotting - settings : dict - settings dictionary, like output from create_sidebar() - current_selection : str - bar to highlight with an arrow. must be an element of res_costs.index - - Output - ------ - fig : plotly.graph_objects.Figure - Figure object +def subset_and_pivot_input_data( + input_data: pd.DataFrame, + source_region_code: list = None, + parameter_code: list = None, + process_code: list = None, + index: str = "source_region_code", + columns: str = "process_code", + values: str = "value", +): """ - if res_costs.empty: # nodata to plot (FIXME: migth not be required later) - return go.Figure() - - fig = px.bar( - res_costs, - x=res_costs.index, - y=res_costs.columns[:-1], - height=500, - color_discrete_sequence=st.session_state["colors"], - ) - - # Add the dot markers for the "total" column using plotly.graph_objects - scatter_trace = go.Scatter( - x=res_costs.index, - y=res_costs["Total"], - mode="markers+text", # Display markers and text - marker={"size": 10, "color": "black"}, - name="Total", - text=res_costs["Total"].apply( - lambda x: f"{x:.2f}" - ), # Use 'total' column values as text labels - textposition="top center", # Position of the text label above the marker - ) - - fig.add_trace(scatter_trace) - - # add highlight for current selection: - if current_selection is not None and current_selection in res_costs.index: - fig.add_annotation( - x=current_selection, - y=1.2 * res_costs.at[current_selection, "Total"], - text="current selection", - showarrow=True, - arrowhead=2, - arrowsize=1, - arrowwidth=2, - ax=0, - ay=-50, - ) - fig.update_layout( - yaxis_title=st.session_state["output_unit"], - ) - return fig - - -def create_box_plot(res_costs: pd.DataFrame): - """Create a subplot with one row and one column. + Reshapes and subsets input data. Parameters ---------- - res_costs : pd.DataFrame - data for plotting - settings : dict - settings dictionary, like output from create_sidebar() + input_data : pd.DataFrame + obtained with :meth:`~ptxboa.api.PtxboaAPI.get_input_data` + source_region_code : list, optional + list for subsetting source regions, by default None + parameter_code : list, optional + list for subsetting parameter_codes, by default None + process_code : list, optional + list for subsetting process_codes, by default None + index : str, optional + index for `pivot_table()`, by default "source_region_code" + columns : str, optional + column for generating new columns in pivot_table, by default "process_code" + values : str, optional + values for `pivot_table()` , by default "value" - Output - ------ - fig : plotly.graph_objects.Figure - Figure object + Returns + ------- + : pd.DataFrame """ - fig = go.Figure() - - # Specify the row index of the data point you want to highlight - highlighted_row_index = st.session_state["region"] - # Extract the value from the specified row and column - - if highlighted_row_index: - highlighted_value = res_costs.at[highlighted_row_index, "Total"] - else: - highlighted_value = 0 - - # Add the box plot to the subplot - fig.add_trace(go.Box(y=res_costs["Total"], name="Cost distribution")) - - # Add a scatter marker for the highlighted data point - fig.add_trace( - go.Scatter( - x=["Cost distribution"], - y=[highlighted_value], - mode="markers", - marker={"size": 10, "color": "black"}, - name=highlighted_row_index, - text=f"Value: {highlighted_value}", # Add a text label - ) - ) - - # Customize the layout as needed - fig.update_layout( - title="Cost distribution for all supply countries", - xaxis={"title": ""}, - yaxis={"title": st.session_state["output_unit"]}, - height=500, - ) - - return fig - - -def create_scatter_plot(df_res, settings: dict): - df_res["Country"] = "Other countries" - df_res.at[st.session_state["region"], "Country"] = st.session_state["region"] + if source_region_code is not None: + input_data = input_data.loc[ + input_data["source_region_code"].isin(source_region_code) + ] + if parameter_code is not None: + input_data = input_data.loc[input_data["parameter_code"].isin(parameter_code)] + if process_code is not None: + input_data = input_data.loc[input_data["process_code"].isin(process_code)] - fig = px.scatter( - df_res, - y="Total", - x="tr_dst_sd", - color="Country", - text=df_res.index, - color_discrete_sequence=["blue", "red"], + reshaped = input_data.pivot_table( + index=index, columns=columns, values=values, aggfunc="sum" ) - fig.update_traces(texttemplate="%{text}", textposition="top center") - st.plotly_chart(fig) - st.write(df_res) + return reshaped def remove_subregions(api: PtxboaAPI, df: pd.DataFrame, country_name: str): @@ -428,55 +267,6 @@ def display_and_edit_data_table( return df_tab -def subset_and_pivot_input_data( - input_data: pd.DataFrame, - source_region_code: list = None, - parameter_code: list = None, - process_code: list = None, - index: str = "source_region_code", - columns: str = "process_code", - values: str = "value", -): - """ - Reshapes and subsets input data. - - Parameters - ---------- - input_data : pd.DataFrame - obtained with :meth:`~ptxboa.api.PtxboaAPI.get_input_data` - source_region_code : list, optional - list for subsetting source regions, by default None - parameter_code : list, optional - list for subsetting parameter_codes, by default None - process_code : list, optional - list for subsetting process_codes, by default None - index : str, optional - index for `pivot_table()`, by default "source_region_code" - columns : str, optional - column for generating new columns in pivot_table, by default "process_code" - values : str, optional - values for `pivot_table()` , by default "value" - - Returns - ------- - _type_ - _description_ - """ - if source_region_code is not None: - input_data = input_data.loc[ - input_data["source_region_code"].isin(source_region_code) - ] - if parameter_code is not None: - input_data = input_data.loc[input_data["parameter_code"].isin(parameter_code)] - if process_code is not None: - input_data = input_data.loc[input_data["process_code"].isin(process_code)] - - reshaped = input_data.pivot_table( - index=index, columns=columns, values=values, aggfunc="sum" - ) - return reshaped - - def register_user_changes( missing_index_name: str, missing_index_value: str, diff --git a/app/tab_compare_costs.py b/app/tab_compare_costs.py index 5c64bdfa..34689d54 100644 --- a/app/tab_compare_costs.py +++ b/app/tab_compare_costs.py @@ -3,10 +3,10 @@ import pandas as pd import streamlit as st +from app.plot_functions import create_bar_chart_costs from app.ptxboa_functions import ( calculate_results_list, config_number_columns, - create_bar_chart_costs, remove_subregions, ) from ptxboa.api import PtxboaAPI diff --git a/app/tab_dashboard.py b/app/tab_dashboard.py index e0caa550..8da08689 100644 --- a/app/tab_dashboard.py +++ b/app/tab_dashboard.py @@ -3,11 +3,7 @@ import streamlit as st from plotly.subplots import make_subplots -from app.ptxboa_functions import ( - create_bar_chart_costs, - create_box_plot, - create_world_map, -) +from app.plot_functions import create_bar_chart_costs, create_box_plot, create_world_map def _create_infobox(context_data: dict):