diff --git a/app/plot_functions.py b/app/plot_functions.py
new file mode 100644
index 00000000..951a3277
--- /dev/null
+++ b/app/plot_functions.py
@@ -0,0 +1,216 @@
+# -*- coding: utf-8 -*-
+"""Functions for plotting input data and results (cost_data)."""
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+import streamlit as st
+
+from app.ptxboa_functions import remove_subregions
+from ptxboa.api import PtxboaAPI
+
+
+def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
+ """Create world map."""
+ parameter_to_show_on_map = "Total"
+
+ # define title:
+ title_string = (
+ f"{parameter_to_show_on_map} cost of exporting"
+ f"{st.session_state['chain']} to "
+ f"{st.session_state['country']}"
+ )
+ # define color scale:
+ color_scale = [
+ (0, st.session_state["colors"][0]), # Starting color at the minimum data value
+ (0.5, st.session_state["colors"][6]),
+ (1, st.session_state["colors"][9]), # Ending color at the maximum data value
+ ]
+
+ # remove subregions from deep dive countries (otherwise colorscale is not correct)
+ res_costs = remove_subregions(api, res_costs, st.session_state["country"])
+
+ # Create custom hover text:
+ custom_hover_data = res_costs.apply(
+ lambda x: f"{x.name}
"
+ + "
".join(
+ [
+ f"{col}: {x[col]:.1f}" f"{st.session_state['output_unit']}"
+ for col in res_costs.columns[:-1]
+ ]
+ + [
+ f"──────────
{res_costs.columns[-1]}: "
+ f"{x[res_costs.columns[-1]]:.1f}"
+ f"{st.session_state['output_unit']}"
+ ]
+ ),
+ axis=1,
+ )
+
+ # Create a choropleth world map:
+ fig = px.choropleth(
+ locations=res_costs.index, # List of country codes or names
+ locationmode="country names", # Use country names as locations
+ color=res_costs[parameter_to_show_on_map], # Color values for the countries
+ custom_data=[custom_hover_data], # Pass custom data for hover information
+ color_continuous_scale=color_scale, # Choose a color scale
+ title=title_string, # set title
+ )
+
+ # update layout:
+ fig.update_geos(
+ showcountries=True, # Show country borders
+ showcoastlines=True, # Show coastlines
+ countrycolor="black", # Set default border color for other countries
+ countrywidth=0.2, # Set border width
+ coastlinewidth=0.2, # coastline width
+ coastlinecolor="black", # coastline color
+ showland=True, # show land areas
+ landcolor="#f3f4f5", # Set land color to light gray
+ oceancolor="#e3e4ea", # Optionally, set ocean color slightly darker gray
+ showocean=True, # show ocean areas
+ framewidth=0.2, # width of frame around map
+ )
+
+ fig.update_layout(
+ coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar
+ height=600, # height of figure
+ margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure
+ )
+
+ # Set the hover template to use the custom data
+ fig.update_traces(hovertemplate="%{customdata}") # Custom data
+
+ # Display the map:
+ st.plotly_chart(fig, use_container_width=True)
+ return
+
+
+def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None):
+ """Create bar plot for costs by components, and dots for total costs.
+
+ Parameters
+ ----------
+ res_costs : pd.DataFrame
+ data for plotting
+ settings : dict
+ settings dictionary, like output from create_sidebar()
+ current_selection : str
+ bar to highlight with an arrow. must be an element of res_costs.index
+
+ Output
+ ------
+ fig : plotly.graph_objects.Figure
+ Figure object
+ """
+ if res_costs.empty: # nodata to plot (FIXME: migth not be required later)
+ return go.Figure()
+
+ fig = px.bar(
+ res_costs,
+ x=res_costs.index,
+ y=res_costs.columns[:-1],
+ height=500,
+ color_discrete_sequence=st.session_state["colors"],
+ )
+
+ # Add the dot markers for the "total" column using plotly.graph_objects
+ scatter_trace = go.Scatter(
+ x=res_costs.index,
+ y=res_costs["Total"],
+ mode="markers+text", # Display markers and text
+ marker={"size": 10, "color": "black"},
+ name="Total",
+ text=res_costs["Total"].apply(
+ lambda x: f"{x:.2f}"
+ ), # Use 'total' column values as text labels
+ textposition="top center", # Position of the text label above the marker
+ )
+
+ fig.add_trace(scatter_trace)
+
+ # add highlight for current selection:
+ if current_selection is not None and current_selection in res_costs.index:
+ fig.add_annotation(
+ x=current_selection,
+ y=1.2 * res_costs.at[current_selection, "Total"],
+ text="current selection",
+ showarrow=True,
+ arrowhead=2,
+ arrowsize=1,
+ arrowwidth=2,
+ ax=0,
+ ay=-50,
+ )
+ fig.update_layout(
+ yaxis_title=st.session_state["output_unit"],
+ )
+ return fig
+
+
+def create_box_plot(res_costs: pd.DataFrame):
+ """Create a subplot with one row and one column.
+
+ Parameters
+ ----------
+ res_costs : pd.DataFrame
+ data for plotting
+ settings : dict
+ settings dictionary, like output from create_sidebar()
+
+ Output
+ ------
+ fig : plotly.graph_objects.Figure
+ Figure object
+ """
+ fig = go.Figure()
+
+ # Specify the row index of the data point you want to highlight
+ highlighted_row_index = st.session_state["region"]
+ # Extract the value from the specified row and column
+
+ if highlighted_row_index:
+ highlighted_value = res_costs.at[highlighted_row_index, "Total"]
+ else:
+ highlighted_value = 0
+
+ # Add the box plot to the subplot
+ fig.add_trace(go.Box(y=res_costs["Total"], name="Cost distribution"))
+
+ # Add a scatter marker for the highlighted data point
+ fig.add_trace(
+ go.Scatter(
+ x=["Cost distribution"],
+ y=[highlighted_value],
+ mode="markers",
+ marker={"size": 10, "color": "black"},
+ name=highlighted_row_index,
+ text=f"Value: {highlighted_value}", # Add a text label
+ )
+ )
+
+ # Customize the layout as needed
+ fig.update_layout(
+ title="Cost distribution for all supply countries",
+ xaxis={"title": ""},
+ yaxis={"title": st.session_state["output_unit"]},
+ height=500,
+ )
+
+ return fig
+
+
+def create_scatter_plot(df_res, settings: dict):
+ df_res["Country"] = "Other countries"
+ df_res.at[st.session_state["region"], "Country"] = st.session_state["region"]
+
+ fig = px.scatter(
+ df_res,
+ y="Total",
+ x="tr_dst_sd",
+ color="Country",
+ text=df_res.index,
+ color_discrete_sequence=["blue", "red"],
+ )
+ fig.update_traces(texttemplate="%{text}", textposition="top center")
+ st.plotly_chart(fig)
+ st.write(df_res)
diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py
index d7147b0b..14ddb74d 100644
--- a/app/ptxboa_functions.py
+++ b/app/ptxboa_functions.py
@@ -2,8 +2,6 @@
"""Utility functions for streamlit app."""
import pandas as pd
-import plotly.express as px
-import plotly.graph_objects as go
import streamlit as st
from ptxboa.api import PtxboaAPI
@@ -105,211 +103,52 @@ def aggregate_costs(res_details: pd.DataFrame) -> pd.DataFrame:
return res
-def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
- """Create world map."""
- parameter_to_show_on_map = "Total"
-
- # define title:
- title_string = (
- f"{parameter_to_show_on_map} cost of exporting"
- f"{st.session_state['chain']} to "
- f"{st.session_state['country']}"
- )
- # define color scale:
- color_scale = [
- (0, st.session_state["colors"][0]), # Starting color at the minimum data value
- (0.5, st.session_state["colors"][6]),
- (1, st.session_state["colors"][9]), # Ending color at the maximum data value
- ]
-
- # remove subregions from deep dive countries (otherwise colorscale is not correct)
- res_costs = remove_subregions(api, res_costs, st.session_state["country"])
-
- # Create custom hover text:
- custom_hover_data = res_costs.apply(
- lambda x: f"{x.name}
"
- + "
".join(
- [
- f"{col}: {x[col]:.1f}" f"{st.session_state['output_unit']}"
- for col in res_costs.columns[:-1]
- ]
- + [
- f"──────────
{res_costs.columns[-1]}: "
- f"{x[res_costs.columns[-1]]:.1f}"
- f"{st.session_state['output_unit']}"
- ]
- ),
- axis=1,
- )
-
- # Create a choropleth world map:
- fig = px.choropleth(
- locations=res_costs.index, # List of country codes or names
- locationmode="country names", # Use country names as locations
- color=res_costs[parameter_to_show_on_map], # Color values for the countries
- custom_data=[custom_hover_data], # Pass custom data for hover information
- color_continuous_scale=color_scale, # Choose a color scale
- title=title_string, # set title
- )
-
- # update layout:
- fig.update_geos(
- showcountries=True, # Show country borders
- showcoastlines=True, # Show coastlines
- countrycolor="black", # Set default border color for other countries
- countrywidth=0.2, # Set border width
- coastlinewidth=0.2, # coastline width
- coastlinecolor="black", # coastline color
- showland=True, # show land areas
- landcolor="#f3f4f5", # Set land color to light gray
- oceancolor="#e3e4ea", # Optionally, set ocean color slightly darker gray
- showocean=True, # show ocean areas
- framewidth=0.2, # width of frame around map
- )
-
- fig.update_layout(
- coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar
- height=600, # height of figure
- margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure
- )
-
- # Set the hover template to use the custom data
- fig.update_traces(hovertemplate="%{customdata}") # Custom data
-
- # Display the map:
- st.plotly_chart(fig, use_container_width=True)
- return
-
-
-def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None):
- """Create bar plot for costs by components, and dots for total costs.
-
- Parameters
- ----------
- res_costs : pd.DataFrame
- data for plotting
- settings : dict
- settings dictionary, like output from create_sidebar()
- current_selection : str
- bar to highlight with an arrow. must be an element of res_costs.index
-
- Output
- ------
- fig : plotly.graph_objects.Figure
- Figure object
+def subset_and_pivot_input_data(
+ input_data: pd.DataFrame,
+ source_region_code: list = None,
+ parameter_code: list = None,
+ process_code: list = None,
+ index: str = "source_region_code",
+ columns: str = "process_code",
+ values: str = "value",
+):
"""
- if res_costs.empty: # nodata to plot (FIXME: migth not be required later)
- return go.Figure()
-
- fig = px.bar(
- res_costs,
- x=res_costs.index,
- y=res_costs.columns[:-1],
- height=500,
- color_discrete_sequence=st.session_state["colors"],
- )
-
- # Add the dot markers for the "total" column using plotly.graph_objects
- scatter_trace = go.Scatter(
- x=res_costs.index,
- y=res_costs["Total"],
- mode="markers+text", # Display markers and text
- marker={"size": 10, "color": "black"},
- name="Total",
- text=res_costs["Total"].apply(
- lambda x: f"{x:.2f}"
- ), # Use 'total' column values as text labels
- textposition="top center", # Position of the text label above the marker
- )
-
- fig.add_trace(scatter_trace)
-
- # add highlight for current selection:
- if current_selection is not None and current_selection in res_costs.index:
- fig.add_annotation(
- x=current_selection,
- y=1.2 * res_costs.at[current_selection, "Total"],
- text="current selection",
- showarrow=True,
- arrowhead=2,
- arrowsize=1,
- arrowwidth=2,
- ax=0,
- ay=-50,
- )
- fig.update_layout(
- yaxis_title=st.session_state["output_unit"],
- )
- return fig
-
-
-def create_box_plot(res_costs: pd.DataFrame):
- """Create a subplot with one row and one column.
+ Reshapes and subsets input data.
Parameters
----------
- res_costs : pd.DataFrame
- data for plotting
- settings : dict
- settings dictionary, like output from create_sidebar()
+ input_data : pd.DataFrame
+ obtained with :meth:`~ptxboa.api.PtxboaAPI.get_input_data`
+ source_region_code : list, optional
+ list for subsetting source regions, by default None
+ parameter_code : list, optional
+ list for subsetting parameter_codes, by default None
+ process_code : list, optional
+ list for subsetting process_codes, by default None
+ index : str, optional
+ index for `pivot_table()`, by default "source_region_code"
+ columns : str, optional
+ column for generating new columns in pivot_table, by default "process_code"
+ values : str, optional
+ values for `pivot_table()` , by default "value"
- Output
- ------
- fig : plotly.graph_objects.Figure
- Figure object
+ Returns
+ -------
+ : pd.DataFrame
"""
- fig = go.Figure()
-
- # Specify the row index of the data point you want to highlight
- highlighted_row_index = st.session_state["region"]
- # Extract the value from the specified row and column
-
- if highlighted_row_index:
- highlighted_value = res_costs.at[highlighted_row_index, "Total"]
- else:
- highlighted_value = 0
-
- # Add the box plot to the subplot
- fig.add_trace(go.Box(y=res_costs["Total"], name="Cost distribution"))
-
- # Add a scatter marker for the highlighted data point
- fig.add_trace(
- go.Scatter(
- x=["Cost distribution"],
- y=[highlighted_value],
- mode="markers",
- marker={"size": 10, "color": "black"},
- name=highlighted_row_index,
- text=f"Value: {highlighted_value}", # Add a text label
- )
- )
-
- # Customize the layout as needed
- fig.update_layout(
- title="Cost distribution for all supply countries",
- xaxis={"title": ""},
- yaxis={"title": st.session_state["output_unit"]},
- height=500,
- )
-
- return fig
-
-
-def create_scatter_plot(df_res, settings: dict):
- df_res["Country"] = "Other countries"
- df_res.at[st.session_state["region"], "Country"] = st.session_state["region"]
+ if source_region_code is not None:
+ input_data = input_data.loc[
+ input_data["source_region_code"].isin(source_region_code)
+ ]
+ if parameter_code is not None:
+ input_data = input_data.loc[input_data["parameter_code"].isin(parameter_code)]
+ if process_code is not None:
+ input_data = input_data.loc[input_data["process_code"].isin(process_code)]
- fig = px.scatter(
- df_res,
- y="Total",
- x="tr_dst_sd",
- color="Country",
- text=df_res.index,
- color_discrete_sequence=["blue", "red"],
+ reshaped = input_data.pivot_table(
+ index=index, columns=columns, values=values, aggfunc="sum"
)
- fig.update_traces(texttemplate="%{text}", textposition="top center")
- st.plotly_chart(fig)
- st.write(df_res)
+ return reshaped
def remove_subregions(api: PtxboaAPI, df: pd.DataFrame, country_name: str):
@@ -428,55 +267,6 @@ def display_and_edit_data_table(
return df_tab
-def subset_and_pivot_input_data(
- input_data: pd.DataFrame,
- source_region_code: list = None,
- parameter_code: list = None,
- process_code: list = None,
- index: str = "source_region_code",
- columns: str = "process_code",
- values: str = "value",
-):
- """
- Reshapes and subsets input data.
-
- Parameters
- ----------
- input_data : pd.DataFrame
- obtained with :meth:`~ptxboa.api.PtxboaAPI.get_input_data`
- source_region_code : list, optional
- list for subsetting source regions, by default None
- parameter_code : list, optional
- list for subsetting parameter_codes, by default None
- process_code : list, optional
- list for subsetting process_codes, by default None
- index : str, optional
- index for `pivot_table()`, by default "source_region_code"
- columns : str, optional
- column for generating new columns in pivot_table, by default "process_code"
- values : str, optional
- values for `pivot_table()` , by default "value"
-
- Returns
- -------
- _type_
- _description_
- """
- if source_region_code is not None:
- input_data = input_data.loc[
- input_data["source_region_code"].isin(source_region_code)
- ]
- if parameter_code is not None:
- input_data = input_data.loc[input_data["parameter_code"].isin(parameter_code)]
- if process_code is not None:
- input_data = input_data.loc[input_data["process_code"].isin(process_code)]
-
- reshaped = input_data.pivot_table(
- index=index, columns=columns, values=values, aggfunc="sum"
- )
- return reshaped
-
-
def register_user_changes(
missing_index_name: str,
missing_index_value: str,
diff --git a/app/tab_compare_costs.py b/app/tab_compare_costs.py
index 5c64bdfa..34689d54 100644
--- a/app/tab_compare_costs.py
+++ b/app/tab_compare_costs.py
@@ -3,10 +3,10 @@
import pandas as pd
import streamlit as st
+from app.plot_functions import create_bar_chart_costs
from app.ptxboa_functions import (
calculate_results_list,
config_number_columns,
- create_bar_chart_costs,
remove_subregions,
)
from ptxboa.api import PtxboaAPI
diff --git a/app/tab_dashboard.py b/app/tab_dashboard.py
index e0caa550..8da08689 100644
--- a/app/tab_dashboard.py
+++ b/app/tab_dashboard.py
@@ -3,11 +3,7 @@
import streamlit as st
from plotly.subplots import make_subplots
-from app.ptxboa_functions import (
- create_bar_chart_costs,
- create_box_plot,
- create_world_map,
-)
+from app.plot_functions import create_bar_chart_costs, create_box_plot, create_world_map
def _create_infobox(context_data: dict):