From 8d6e1adce71e63c5aef91bca8f08ef4facd462d0 Mon Sep 17 00:00:00 2001 From: Markus Haller Date: Wed, 15 Nov 2023 12:20:54 +0100 Subject: [PATCH 1/6] move ``settings`` to ``st.session_state[settings]`` #105 --- app/ptxboa_functions.py | 180 ++++++++++++++++----------------- ptxboa_streamlit.py | 18 ++-- tests/test_ptxboa_functions.py | 18 ++-- 3 files changed, 106 insertions(+), 110 deletions(-) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index a31b1bb7..36e7500a 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -38,9 +38,7 @@ def calculate_results_single( return res -def calculate_results( - api: PtxboaAPI, settings: dict, region_list: list = None -) -> pd.DataFrame: +def calculate_results(api: PtxboaAPI, region_list: list = None) -> pd.DataFrame: """Calculate results for source regions and one selected target country. TODO: This function will eventually be replaced by ``calculate_results_list()``. @@ -49,9 +47,6 @@ def calculate_results( ---------- api : :class:`~ptxboa.api.PtxboaAPI` an instance of the api class - settings : dict - settings from the streamlit app. An example can be obtained with the - return value from :func:`ptxboa_functions.create_sidebar`. region_list : list or None The regions for which the results are calculated. If None, all regions available in the API will be used. @@ -67,7 +62,7 @@ def calculate_results( region_list = api.get_dimension("region")["region_name"] for region in region_list: - settings2 = settings.copy() + settings2 = st.session_state["settings"].copy() settings2["region"] = region res_single = calculate_results_single(api, settings2) res_list.append(res_single) @@ -77,7 +72,6 @@ def calculate_results( def calculate_results_list( api: PtxboaAPI, - settings: dict, parameter_to_change: str, parameter_list: list = None, user_data: pd.DataFrame | None = None, @@ -108,7 +102,7 @@ def calculate_results_list( parameter_list = api.get_dimension(parameter_to_change).index for parameter in parameter_list: - settings2 = settings.copy() + settings2 = st.session_state["settings"].copy() settings2[parameter_to_change] = parameter res_single = calculate_results_single(api, settings2, user_data=user_data) res_list.append(res_single) @@ -147,8 +141,10 @@ def aggregate_costs(res_details: pd.DataFrame) -> pd.DataFrame: # Settings: def create_sidebar(api: PtxboaAPI): + if "settings" not in st.session_state: + st.session_state["settings"] = {} + st.sidebar.subheader("Main settings:") - settings = {} include_subregions = False if include_subregions: region_list = api.get_dimension("region").index @@ -159,7 +155,7 @@ def create_sidebar(api: PtxboaAPI): .index ) - settings["region"] = st.sidebar.selectbox( + st.session_state["settings"]["region"] = st.sidebar.selectbox( "Supply country / region:", region_list, help=( @@ -177,7 +173,7 @@ def create_sidebar(api: PtxboaAPI): "if you want to chose one of these subregions as a supply region. " ), ) - settings["country"] = st.sidebar.selectbox( + st.session_state["settings"]["country"] = st.sidebar.selectbox( "Demand country:", api.get_dimension("country").index, help=( @@ -222,11 +218,13 @@ def create_sidebar(api: PtxboaAPI): else: use_reconversion = False - settings["chain"] = f"{product} ({ely})" + st.session_state["settings"]["chain"] = f"{product} ({ely})" if use_reconversion: - settings["chain"] = f"{settings['chain']} + reconv. to H2" + st.session_state["settings"][ + "chain" + ] = f"{st.session_state['settings']['chain']} + reconv. to H2" - settings["res_gen"] = st.sidebar.selectbox( + st.session_state["settings"]["res_gen"] = st.sidebar.selectbox( "Renewable electricity source (for selected supply region):", api.get_dimension("res_gen").index, help=( @@ -260,33 +258,33 @@ def create_sidebar(api: PtxboaAPI): ), horizontal=True, ) - settings["scenario"] = f"{data_year} ({cost_scenario})" + st.session_state["settings"]["scenario"] = f"{data_year} ({cost_scenario})" st.sidebar.subheader("Additional settings:") - settings["secproc_co2"] = st.sidebar.radio( + st.session_state["settings"]["secproc_co2"] = st.sidebar.radio( "Carbon source:", api.get_dimension("secproc_co2").index, horizontal=True, help="Help text", ) - settings["secproc_water"] = st.sidebar.radio( + st.session_state["settings"]["secproc_water"] = st.sidebar.radio( "Water source:", api.get_dimension("secproc_water").index, horizontal=True, help="Help text", ) - settings["transport"] = st.sidebar.radio( + st.session_state["settings"]["transport"] = st.sidebar.radio( "Mode of transportation (for selected supply country):", api.get_dimension("transport").index, horizontal=True, help="Help text", ) - if settings["transport"] == "Ship": - settings["ship_own_fuel"] = st.sidebar.toggle( + if st.session_state["settings"]["transport"] == "Ship": + st.session_state["settings"]["ship_own_fuel"] = st.sidebar.toggle( "For shipping option: Use the product as own fuel?", help="Help text", ) - settings["output_unit"] = st.sidebar.radio( + st.session_state["settings"]["output_unit"] = st.sidebar.radio( "Unit for delivered costs:", api.get_dimension("output_unit").index, horizontal=True, @@ -311,17 +309,18 @@ def create_sidebar(api: PtxboaAPI): if "colors" not in st.session_state: colors = pd.read_csv("data/Agora_Industry_Colours.csv") st.session_state["colors"] = colors["Hex Code"].to_list() - return settings + return -def create_world_map(api: PtxboaAPI, settings: dict, res_costs: pd.DataFrame): +def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): """Create world map.""" parameter_to_show_on_map = "Total" # define title: title_string = ( - f"{parameter_to_show_on_map} cost of exporting {settings['chain']} to " - f"{settings['country']}" + f"{parameter_to_show_on_map} cost of exporting" + f"{st.session_state['settings']['chain']} to " + f"{st.session_state['settings']['country']}" ) # define color scale: color_scale = [ @@ -331,19 +330,21 @@ def create_world_map(api: PtxboaAPI, settings: dict, res_costs: pd.DataFrame): ] # remove subregions from deep dive countries (otherwise colorscale is not correct) - res_costs = remove_subregions(api, res_costs, settings) + res_costs = remove_subregions(api, res_costs) # Create custom hover text: custom_hover_data = res_costs.apply( lambda x: f"{x.name}

" + "
".join( [ - f"{col}: {x[col]:.1f} {settings['output_unit']}" + f"{col}: {x[col]:.1f}" + f"{st.session_state['settings']['output_unit']}" for col in res_costs.columns[:-1] ] + [ f"──────────
{res_costs.columns[-1]}: " - f"{x[res_costs.columns[-1]]:.1f} {settings['output_unit']}" + f"{x[res_costs.columns[-1]]:.1f}" + f"{st.session_state['settings']['output_unit']}" ] ), axis=1, @@ -375,7 +376,9 @@ def create_world_map(api: PtxboaAPI, settings: dict, res_costs: pd.DataFrame): ) fig.update_layout( - coloraxis_colorbar={"title": settings["output_unit"]}, # colorbar + coloraxis_colorbar={ + "title": st.session_state["settings"]["output_unit"] + }, # colorbar height=600, # height of figure margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure ) @@ -388,9 +391,7 @@ def create_world_map(api: PtxboaAPI, settings: dict, res_costs: pd.DataFrame): return -def create_bar_chart_costs( - res_costs: pd.DataFrame, settings: dict, current_selection: str = None -): +def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None): """Create bar plot for costs by components, and dots for total costs. Parameters @@ -447,12 +448,12 @@ def create_bar_chart_costs( ay=-50, ) fig.update_layout( - yaxis_title=settings["output_unit"], + yaxis_title=st.session_state["settings"]["output_unit"], ) return fig -def create_box_plot(res_costs: pd.DataFrame, settings: dict): +def create_box_plot(res_costs: pd.DataFrame): """Create a subplot with one row and one column. Parameters @@ -470,7 +471,7 @@ def create_box_plot(res_costs: pd.DataFrame, settings: dict): fig = go.Figure() # Specify the row index of the data point you want to highlight - highlighted_row_index = settings["region"] + highlighted_row_index = st.session_state["settings"]["region"] # Extract the value from the specified row and column if highlighted_row_index: @@ -497,7 +498,7 @@ def create_box_plot(res_costs: pd.DataFrame, settings: dict): fig.update_layout( title="Cost distribution for all supply countries", xaxis={"title": ""}, - yaxis={"title": settings["output_unit"]}, + yaxis={"title": st.session_state["settings"]["output_unit"]}, height=500, ) @@ -506,7 +507,9 @@ def create_box_plot(res_costs: pd.DataFrame, settings: dict): def create_scatter_plot(df_res, settings: dict): df_res["Country"] = "Other countries" - df_res.at[settings["region"], "Country"] = settings["region"] + df_res.at[st.session_state["settings"]["region"], "Country"] = st.session_state[ + "settings" + ]["region"] fig = px.scatter( df_res, @@ -521,7 +524,7 @@ def create_scatter_plot(df_res, settings: dict): st.write(df_res) -def content_dashboard(api, res_costs: dict, context_data: dict, settings: pd.DataFrame): +def content_dashboard(api, res_costs: dict, context_data: dict): with st.expander("What is this?"): st.markdown( """ @@ -538,13 +541,15 @@ def content_dashboard(api, res_costs: dict, context_data: dict, settings: pd.Dat c_1, c_2 = st.columns([2, 1]) with c_1: - create_world_map(api, settings, res_costs) + create_world_map(api, res_costs) with c_2: # create box plot and bar plot: - fig1 = create_box_plot(res_costs, settings) - filtered_data = res_costs[res_costs.index == settings["region"]] - fig2 = create_bar_chart_costs(filtered_data, settings) + fig1 = create_box_plot(res_costs) + filtered_data = res_costs[ + res_costs.index == st.session_state["settings"]["region"] + ] + fig2 = create_bar_chart_costs(filtered_data) doublefig = make_subplots(rows=1, cols=2, shared_yaxes=True) for trace in fig1.data: @@ -557,27 +562,22 @@ def content_dashboard(api, res_costs: dict, context_data: dict, settings: pd.Dat doublefig.update_layout(title_text="Cost distribution and details:") st.plotly_chart(doublefig, use_container_width=True) - create_infobox(context_data, settings) + create_infobox(context_data) st.write("Chosen settings:") - st.write(settings) + st.write(st.session_state["settings"]) st.write("res_cost") st.write(res_costs) -def content_market_scanning( - api: PtxboaAPI, res_costs: pd.DataFrame, settings: dict -) -> None: +def content_market_scanning(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: """Create content for the "market scanning" sheet. Parameters ---------- api : :class:`~ptxboa.api.PtxboaAPI` an instance of the api class - settings : dict - settings from the streamlit app. An example can be obtained with the - return value from :func:`ptxboa_functions.create_sidebar`. res_costs : pd.DataFrame Results. """ @@ -594,12 +594,14 @@ def content_market_scanning( ) # get input data: - input_data = api.get_input_data(settings["scenario"]) + input_data = api.get_input_data(st.session_state["settings"]["scenario"]) # filter shipping and pipeline distances: distances = input_data.loc[ (input_data["parameter_code"].isin(["shipping distance", "pipeline distance"])) - & (input_data["target_country_code"] == settings["country"]), + & ( + input_data["target_country_code"] == st.session_state["settings"]["country"] + ), ["source_region_code", "parameter_code", "value"], ] distances = distances.pivot_table( @@ -615,7 +617,7 @@ def content_market_scanning( df_plot = df_plot.merge(distances, left_index=True, right_index=True) # do not show subregions: - df_plot = remove_subregions(api, df_plot, settings) + df_plot = remove_subregions(api, df_plot) # create plot: [c1, c2] = st.columns([1, 5]) @@ -648,7 +650,7 @@ def content_market_scanning( st.dataframe(df_plot, use_container_width=True, column_config=column_config) -def remove_subregions(api: PtxboaAPI, df: pd.DataFrame, settings: dict): +def remove_subregions(api: PtxboaAPI, df: pd.DataFrame): """Remove subregions from a dataframe. Parameters @@ -670,17 +672,15 @@ def remove_subregions(api: PtxboaAPI, df: pd.DataFrame, settings: dict): ) # ensure that target country is not in list of regions: - if settings["country"] in region_list_without_subregions: - region_list_without_subregions.remove(settings["country"]) + if st.session_state["settings"]["country"] in region_list_without_subregions: + region_list_without_subregions.remove(st.session_state["settings"]["country"]) df = df.loc[region_list_without_subregions] return df -def content_compare_costs( - api: PtxboaAPI, res_costs: pd.DataFrame, settings: dict -) -> None: +def content_compare_costs(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: """Create content for the "compare costs" sheet. Parameters @@ -703,9 +703,7 @@ def content_compare_costs( """ ) - def display_costs( - df_costs: pd.DataFrame, key: str, titlestring: str, settings: dict - ): + def display_costs(df_costs: pd.DataFrame, key: str, titlestring: str): """Display costs as table and bar chart.""" st.subheader(titlestring) c1, c2 = st.columns([1, 5]) @@ -740,24 +738,25 @@ def display_costs( with c2: # create graph: fig = create_bar_chart_costs( - df_res, settings, current_selection=settings[key] + df_res, + current_selection=st.session_state["settings"][key], ) st.plotly_chart(fig, use_container_width=True) with st.expander("**Data**"): column_config = config_number_columns( - df_res, format=f"%.1f {settings['output_unit']}" + df_res, format=f"%.1f {st.session_state['settings']['output_unit']}" ) st.dataframe(df_res, use_container_width=True, column_config=column_config) - res_costs_without_subregions = remove_subregions(api, res_costs, settings) - display_costs(res_costs_without_subregions, "region", "Costs by region:", settings) + res_costs_without_subregions = remove_subregions(api, res_costs) + display_costs(res_costs_without_subregions, "region", "Costs by region:") # Display costs by scenario: res_scenario = calculate_results_list( - api, settings, "scenario", user_data=st.session_state["user_changes_df"] + api, "scenario", user_data=st.session_state["user_changes_df"] ) - display_costs(res_scenario, "scenario", "Costs by data scenario:", settings) + display_costs(res_scenario, "scenario", "Costs by data scenario:") # Display costs by RE generation: # TODO: remove PV tracking manually, this needs to be fixed in data @@ -765,21 +764,16 @@ def display_costs( list_res_gen.remove("PV tracking") res_res_gen = calculate_results_list( api, - settings, "res_gen", parameter_list=list_res_gen, user_data=st.session_state["user_changes_df"], ) - display_costs( - res_res_gen, "res_gen", "Costs by renewable electricity source:", settings - ) + display_costs(res_res_gen, "res_gen", "Costs by renewable electricity source:") # TODO: display costs by chain -def content_deep_dive_countries( - api: PtxboaAPI, res_costs: pd.DataFrame, settings: dict -) -> None: +def content_deep_dive_countries(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: """Create content for the "costs by region" sheet. Parameters @@ -815,7 +809,7 @@ def content_deep_dive_countries( # get input data: - input_data = api.get_input_data(settings["scenario"]) + input_data = api.get_input_data(st.session_state["settings"]["scenario"]) # filter data: # get list of subregions: @@ -875,16 +869,13 @@ def content_deep_dive_countries( st.plotly_chart(fig, use_container_width=True) -def content_input_data(api: PtxboaAPI, settings: dict) -> None: +def content_input_data(api: PtxboaAPI) -> None: """Create content for the "input data" sheet. Parameters ---------- api : :class:`~ptxboa.api.PtxboaAPI` an instance of the api class - settings : dict - settings from the streamlit app. An example can be obtained with the - return value from :func:`ptxboa_functions.create_sidebar`. Output ------ @@ -907,7 +898,8 @@ def content_input_data(api: PtxboaAPI, settings: dict) -> None: st.subheader("Region specific data:") # get input data: input_data = api.get_input_data( - settings["scenario"], user_data=st.session_state["user_changes_df"] + st.session_state["settings"]["scenario"], + user_data=st.session_state["user_changes_df"], ) # filter data: @@ -1160,14 +1152,16 @@ def register_user_changes( ) -def create_infobox(context_data: dict, settings: dict): +def create_infobox(context_data: dict): data = context_data["infobox"] - st.markdown(f"**Key information on {settings['country']}:**") - demand = data.at[settings["country"], "Projected H2 demand [2030]"] - info1 = data.at[settings["country"], "key_info_1"] - info2 = data.at[settings["country"], "key_info_2"] - info3 = data.at[settings["country"], "key_info_3"] - info4 = data.at[settings["country"], "key_info_4"] + st.markdown(f"**Key information on {st.session_state['settings']['country']}:**") + demand = data.at[ + st.session_state["settings"]["country"], "Projected H2 demand [2030]" + ] + info1 = data.at[st.session_state["settings"]["country"], "key_info_1"] + info2 = data.at[st.session_state["settings"]["country"], "key_info_2"] + info3 = data.at[st.session_state["settings"]["country"], "key_info_3"] + info4 = data.at[st.session_state["settings"]["country"], "key_info_4"] st.markdown(f"* Projected H2 demand in 2030: {demand}") def write_info(info): @@ -1205,7 +1199,9 @@ def import_context_data(): return cd -def create_fact_sheet_demand_country(context_data: dict, country_name: str): +def create_fact_sheet_demand_country(context_data: dict): + # select country: + country_name = st.session_state["settings"]["country"] with st.expander("What is this?"): st.markdown( """ @@ -1291,8 +1287,10 @@ def create_fact_sheet_demand_country(context_data: dict, country_name: str): st.markdown(f"*Source: {data['source_certification_info']}*") -def create_fact_sheet_supply_country(context_data: dict, country_name: str): +def create_fact_sheet_supply_country(context_data: dict): """Display information on a chosen supply country.""" + # select country: + country_name = st.session_state["settings"]["region"] df = context_data["supply"] data = df.loc[df["country_name"] == country_name].iloc[0].to_dict() diff --git a/ptxboa_streamlit.py b/ptxboa_streamlit.py index bbddec99..72ca181c 100644 --- a/ptxboa_streamlit.py +++ b/ptxboa_streamlit.py @@ -46,11 +46,11 @@ api = st.cache_resource(PtxboaAPI)() # create sidebar: -settings = pf.create_sidebar(api) +pf.create_sidebar(api) # calculate results: res_costs = pf.calculate_results_list( - api, settings, "region", user_data=st.session_state["user_changes_df"] + api, "region", user_data=st.session_state["user_changes_df"] ) # import context data: @@ -58,24 +58,24 @@ # dashboard: with t_dashboard: - pf.content_dashboard(api, res_costs, cd, settings) + pf.content_dashboard(api, res_costs, cd) with t_market_scanning: - pf.content_market_scanning(api, res_costs, settings) + pf.content_market_scanning(api, res_costs) with t_compare_costs: - pf.content_compare_costs(api, res_costs, settings) + pf.content_compare_costs(api, res_costs) with t_input_data: - pf.content_input_data(api, settings) + pf.content_input_data(api) with t_deep_dive_countries: - pf.content_deep_dive_countries(api, res_costs, settings) + pf.content_deep_dive_countries(api, res_costs) with t_country_fact_sheets: - pf.create_fact_sheet_demand_country(cd, settings["country"]) + pf.create_fact_sheet_demand_country(cd) st.divider() - pf.create_fact_sheet_supply_country(cd, settings["region"]) + pf.create_fact_sheet_supply_country(cd) with t_certification_schemes: pf.create_fact_sheet_certification_schemes(cd) diff --git a/tests/test_ptxboa_functions.py b/tests/test_ptxboa_functions.py index ec395552..bb68c783 100644 --- a/tests/test_ptxboa_functions.py +++ b/tests/test_ptxboa_functions.py @@ -5,6 +5,7 @@ import unittest import pandas as pd +import streamlit as st import app.ptxboa_functions as pf from ptxboa.api import PtxboaAPI @@ -19,7 +20,7 @@ class TestPtxboaFunctions(unittest.TestCase): def test_remove_subregions(self): """Test remove_subregions function.""" - settings = { + st.session_state["settings"] = { "region": "United Arab Emirates", "country": "Germany", "chain": "Methane (AEL)", @@ -37,7 +38,7 @@ def test_remove_subregions(self): # regions including subregions: 79 self.assertEqual(len(df_in), 79) - df_out = pf.remove_subregions(api, df_in, settings) + df_out = pf.remove_subregions(api, df_in) # output is dataframe: self.assertIsInstance(df_out, pd.DataFrame) @@ -52,14 +53,14 @@ def test_remove_subregions(self): # if target country is also a source region, it needs to be removed # from the source region list: - settings["country"] = "China" - df_out = pf.remove_subregions(api, df_in, settings) + st.session_state["settings"]["country"] = "China" + df_out = pf.remove_subregions(api, df_in) self.assertEqual(len(df_out), 33) self.assertFalse("China" in df_out["region_name"]) def test_calculate_results_list(self): """Test calculate_results_list function.""" - settings = { + st.session_state["settings"] = { "region": "United Arab Emirates", "country": "Germany", "chain": "Methane (AEL)", @@ -74,14 +75,11 @@ def test_calculate_results_list(self): api = PtxboaAPI() # old way of calculating results: - res_details = pf.calculate_results( - api, - settings, - ) + res_details = pf.calculate_results(api) res_costs = pf.aggregate_costs(res_details) # new way of calculating results: - res_by_region = pf.calculate_results_list(api, settings, "region") + res_by_region = pf.calculate_results_list(api, "region") # assert that both ways yield identical results: pd.testing.assert_frame_equal(res_costs, res_by_region) From ff1c76f3f8ce11f4d817b3e8b91f8a4b2fc93f0e Mon Sep 17 00:00:00 2001 From: Markus Haller Date: Wed, 15 Nov 2023 12:28:55 +0100 Subject: [PATCH 2/6] do not use ``st.session_state`` in ``remove_subregions`` to make it testable --- app/ptxboa_functions.py | 22 +++++++++++++++------- tests/test_ptxboa_functions.py | 8 ++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index 36e7500a..2387a67a 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -330,7 +330,9 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): ] # remove subregions from deep dive countries (otherwise colorscale is not correct) - res_costs = remove_subregions(api, res_costs) + res_costs = remove_subregions( + api, res_costs, st.session_state["settings"]["country"] + ) # Create custom hover text: custom_hover_data = res_costs.apply( @@ -617,7 +619,7 @@ def content_market_scanning(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: df_plot = df_plot.merge(distances, left_index=True, right_index=True) # do not show subregions: - df_plot = remove_subregions(api, df_plot) + df_plot = remove_subregions(api, df_plot, st.session_state["settings"]["country"]) # create plot: [c1, c2] = st.columns([1, 5]) @@ -650,7 +652,7 @@ def content_market_scanning(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: st.dataframe(df_plot, use_container_width=True, column_config=column_config) -def remove_subregions(api: PtxboaAPI, df: pd.DataFrame): +def remove_subregions(api: PtxboaAPI, df: pd.DataFrame, country_name: str): """Remove subregions from a dataframe. Parameters @@ -658,7 +660,11 @@ def remove_subregions(api: PtxboaAPI, df: pd.DataFrame): api : :class:`~ptxboa.api.PtxboaAPI` an instance of the api class - df : pandas DataFrame with list of regions as index. + df : pd.DataFrame + pandas DataFrame with list of regions as index. + + country_name : str + name of target country. Is removed from region list if it is also in there. Returns ------- @@ -672,8 +678,8 @@ def remove_subregions(api: PtxboaAPI, df: pd.DataFrame): ) # ensure that target country is not in list of regions: - if st.session_state["settings"]["country"] in region_list_without_subregions: - region_list_without_subregions.remove(st.session_state["settings"]["country"]) + if country_name in region_list_without_subregions: + region_list_without_subregions.remove(country_name) df = df.loc[region_list_without_subregions] @@ -749,7 +755,9 @@ def display_costs(df_costs: pd.DataFrame, key: str, titlestring: str): ) st.dataframe(df_res, use_container_width=True, column_config=column_config) - res_costs_without_subregions = remove_subregions(api, res_costs) + res_costs_without_subregions = remove_subregions( + api, res_costs, st.session_state["settings"]["country"] + ) display_costs(res_costs_without_subregions, "region", "Costs by region:") # Display costs by scenario: diff --git a/tests/test_ptxboa_functions.py b/tests/test_ptxboa_functions.py index bb68c783..6d99ace9 100644 --- a/tests/test_ptxboa_functions.py +++ b/tests/test_ptxboa_functions.py @@ -20,7 +20,7 @@ class TestPtxboaFunctions(unittest.TestCase): def test_remove_subregions(self): """Test remove_subregions function.""" - st.session_state["settings"] = { + settings = { "region": "United Arab Emirates", "country": "Germany", "chain": "Methane (AEL)", @@ -38,7 +38,7 @@ def test_remove_subregions(self): # regions including subregions: 79 self.assertEqual(len(df_in), 79) - df_out = pf.remove_subregions(api, df_in) + df_out = pf.remove_subregions(api, df_in, settings["country"]) # output is dataframe: self.assertIsInstance(df_out, pd.DataFrame) @@ -53,8 +53,8 @@ def test_remove_subregions(self): # if target country is also a source region, it needs to be removed # from the source region list: - st.session_state["settings"]["country"] = "China" - df_out = pf.remove_subregions(api, df_in) + settings["country"] = "China" + df_out = pf.remove_subregions(api, df_in, settings["country"]) self.assertEqual(len(df_out), 33) self.assertFalse("China" in df_out["region_name"]) From 257a38881185344506cd628b7bd43ede7cddd4a4 Mon Sep 17 00:00:00 2001 From: Markus Haller Date: Wed, 15 Nov 2023 12:31:56 +0100 Subject: [PATCH 3/6] remove obsolete function ``calculate_results`` and associated test --- app/ptxboa_functions.py | 32 -------------------------------- tests/test_ptxboa_functions.py | 27 --------------------------- 2 files changed, 59 deletions(-) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index 2387a67a..7d5f5cf7 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -38,38 +38,6 @@ def calculate_results_single( return res -def calculate_results(api: PtxboaAPI, region_list: list = None) -> pd.DataFrame: - """Calculate results for source regions and one selected target country. - - TODO: This function will eventually be replaced by ``calculate_results_list()``. - - Parameters - ---------- - api : :class:`~ptxboa.api.PtxboaAPI` - an instance of the api class - region_list : list or None - The regions for which the results are calculated. If None, all regions - available in the API will be used. - - Returns - ------- - pd.DataFrame - same format as for :meth:`~ptxboa.api.PtxboaAPI.calculate()` - """ - res_list = [] - - if region_list is None: - region_list = api.get_dimension("region")["region_name"] - - for region in region_list: - settings2 = st.session_state["settings"].copy() - settings2["region"] = region - res_single = calculate_results_single(api, settings2) - res_list.append(res_single) - res = pd.concat(res_list) - return res - - def calculate_results_list( api: PtxboaAPI, parameter_to_change: str, diff --git a/tests/test_ptxboa_functions.py b/tests/test_ptxboa_functions.py index 6d99ace9..1ea64f51 100644 --- a/tests/test_ptxboa_functions.py +++ b/tests/test_ptxboa_functions.py @@ -5,7 +5,6 @@ import unittest import pandas as pd -import streamlit as st import app.ptxboa_functions as pf from ptxboa.api import PtxboaAPI @@ -57,29 +56,3 @@ def test_remove_subregions(self): df_out = pf.remove_subregions(api, df_in, settings["country"]) self.assertEqual(len(df_out), 33) self.assertFalse("China" in df_out["region_name"]) - - def test_calculate_results_list(self): - """Test calculate_results_list function.""" - st.session_state["settings"] = { - "region": "United Arab Emirates", - "country": "Germany", - "chain": "Methane (AEL)", - "res_gen": "PV tilted", - "scenario": "2040 (medium)", - "secproc_co2": "Direct Air Capture", - "secproc_water": "Sea Water desalination", - "transport": "Ship", - "ship_own_fuel": False, - "output_unit": "USD/t", - } - api = PtxboaAPI() - - # old way of calculating results: - res_details = pf.calculate_results(api) - res_costs = pf.aggregate_costs(res_details) - - # new way of calculating results: - res_by_region = pf.calculate_results_list(api, "region") - - # assert that both ways yield identical results: - pd.testing.assert_frame_equal(res_costs, res_by_region) From dd86f9b638bdb1971ddf38af0eda962913d8f985 Mon Sep 17 00:00:00 2001 From: Markus Haller Date: Wed, 15 Nov 2023 15:20:10 +0100 Subject: [PATCH 4/6] wrote settings directliy into st.session_state https://github.com/agoenergy/ptx-boa/pull/111#discussion_r1394136506 --- app/ptxboa_functions.py | 99 ++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 51 deletions(-) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index 7d5f5cf7..e939e4ff 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -33,7 +33,19 @@ def calculate_results_single( pd.DataFrame same format as for :meth:`~ptxboa.api.PtxboaAPI.calculate()` """ - res = _api.calculate(user_data=user_data, **settings) + res = _api.calculate( + user_data=user_data, + chain=settings["chain"], + country=settings["country"], + output_unit=settings["output_unit"], + region=settings["region"], + res_gen=settings["res_gen"], + scenario=settings["scenario"], + secproc_co2=settings["secproc_co2"], + secproc_water=settings["secproc_water"], + ship_own_fuel=settings["ship_own_fuel"], + transport=settings["transport"], + ) return res @@ -109,9 +121,6 @@ def aggregate_costs(res_details: pd.DataFrame) -> pd.DataFrame: # Settings: def create_sidebar(api: PtxboaAPI): - if "settings" not in st.session_state: - st.session_state["settings"] = {} - st.sidebar.subheader("Main settings:") include_subregions = False if include_subregions: @@ -123,7 +132,7 @@ def create_sidebar(api: PtxboaAPI): .index ) - st.session_state["settings"]["region"] = st.sidebar.selectbox( + st.session_state["region"] = st.sidebar.selectbox( "Supply country / region:", region_list, help=( @@ -141,7 +150,7 @@ def create_sidebar(api: PtxboaAPI): "if you want to chose one of these subregions as a supply region. " ), ) - st.session_state["settings"]["country"] = st.sidebar.selectbox( + st.session_state["country"] = st.sidebar.selectbox( "Demand country:", api.get_dimension("country").index, help=( @@ -186,13 +195,13 @@ def create_sidebar(api: PtxboaAPI): else: use_reconversion = False - st.session_state["settings"]["chain"] = f"{product} ({ely})" + st.session_state["chain"] = f"{product} ({ely})" if use_reconversion: - st.session_state["settings"][ + st.session_state[ "chain" ] = f"{st.session_state['settings']['chain']} + reconv. to H2" - st.session_state["settings"]["res_gen"] = st.sidebar.selectbox( + st.session_state["res_gen"] = st.sidebar.selectbox( "Renewable electricity source (for selected supply region):", api.get_dimension("res_gen").index, help=( @@ -226,33 +235,33 @@ def create_sidebar(api: PtxboaAPI): ), horizontal=True, ) - st.session_state["settings"]["scenario"] = f"{data_year} ({cost_scenario})" + st.session_state["scenario"] = f"{data_year} ({cost_scenario})" st.sidebar.subheader("Additional settings:") - st.session_state["settings"]["secproc_co2"] = st.sidebar.radio( + st.session_state["secproc_co2"] = st.sidebar.radio( "Carbon source:", api.get_dimension("secproc_co2").index, horizontal=True, help="Help text", ) - st.session_state["settings"]["secproc_water"] = st.sidebar.radio( + st.session_state["secproc_water"] = st.sidebar.radio( "Water source:", api.get_dimension("secproc_water").index, horizontal=True, help="Help text", ) - st.session_state["settings"]["transport"] = st.sidebar.radio( + st.session_state["transport"] = st.sidebar.radio( "Mode of transportation (for selected supply country):", api.get_dimension("transport").index, horizontal=True, help="Help text", ) - if st.session_state["settings"]["transport"] == "Ship": - st.session_state["settings"]["ship_own_fuel"] = st.sidebar.toggle( + if st.session_state["transport"] == "Ship": + st.session_state["ship_own_fuel"] = st.sidebar.toggle( "For shipping option: Use the product as own fuel?", help="Help text", ) - st.session_state["settings"]["output_unit"] = st.sidebar.radio( + st.session_state["output_unit"] = st.sidebar.radio( "Unit for delivered costs:", api.get_dimension("output_unit").index, horizontal=True, @@ -298,9 +307,7 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): ] # remove subregions from deep dive countries (otherwise colorscale is not correct) - res_costs = remove_subregions( - api, res_costs, st.session_state["settings"]["country"] - ) + res_costs = remove_subregions(api, res_costs, st.session_state["country"]) # Create custom hover text: custom_hover_data = res_costs.apply( @@ -346,9 +353,7 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): ) fig.update_layout( - coloraxis_colorbar={ - "title": st.session_state["settings"]["output_unit"] - }, # colorbar + coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar height=600, # height of figure margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure ) @@ -418,7 +423,7 @@ def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = Non ay=-50, ) fig.update_layout( - yaxis_title=st.session_state["settings"]["output_unit"], + yaxis_title=st.session_state["output_unit"], ) return fig @@ -441,7 +446,7 @@ def create_box_plot(res_costs: pd.DataFrame): fig = go.Figure() # Specify the row index of the data point you want to highlight - highlighted_row_index = st.session_state["settings"]["region"] + highlighted_row_index = st.session_state["region"] # Extract the value from the specified row and column if highlighted_row_index: @@ -468,7 +473,7 @@ def create_box_plot(res_costs: pd.DataFrame): fig.update_layout( title="Cost distribution for all supply countries", xaxis={"title": ""}, - yaxis={"title": st.session_state["settings"]["output_unit"]}, + yaxis={"title": st.session_state["output_unit"]}, height=500, ) @@ -477,9 +482,7 @@ def create_box_plot(res_costs: pd.DataFrame): def create_scatter_plot(df_res, settings: dict): df_res["Country"] = "Other countries" - df_res.at[st.session_state["settings"]["region"], "Country"] = st.session_state[ - "settings" - ]["region"] + df_res.at[st.session_state["region"], "Country"] = st.session_state["region"] fig = px.scatter( df_res, @@ -516,9 +519,7 @@ def content_dashboard(api, res_costs: dict, context_data: dict): with c_2: # create box plot and bar plot: fig1 = create_box_plot(res_costs) - filtered_data = res_costs[ - res_costs.index == st.session_state["settings"]["region"] - ] + filtered_data = res_costs[res_costs.index == st.session_state["region"]] fig2 = create_bar_chart_costs(filtered_data) doublefig = make_subplots(rows=1, cols=2, shared_yaxes=True) @@ -535,7 +536,7 @@ def content_dashboard(api, res_costs: dict, context_data: dict): create_infobox(context_data) st.write("Chosen settings:") - st.write(st.session_state["settings"]) + st.write(st.session_state) st.write("res_cost") st.write(res_costs) @@ -564,14 +565,12 @@ def content_market_scanning(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: ) # get input data: - input_data = api.get_input_data(st.session_state["settings"]["scenario"]) + input_data = api.get_input_data(st.session_state["scenario"]) # filter shipping and pipeline distances: distances = input_data.loc[ (input_data["parameter_code"].isin(["shipping distance", "pipeline distance"])) - & ( - input_data["target_country_code"] == st.session_state["settings"]["country"] - ), + & (input_data["target_country_code"] == st.session_state["country"]), ["source_region_code", "parameter_code", "value"], ] distances = distances.pivot_table( @@ -587,9 +586,9 @@ def content_market_scanning(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: df_plot = df_plot.merge(distances, left_index=True, right_index=True) # do not show subregions: - df_plot = remove_subregions(api, df_plot, st.session_state["settings"]["country"]) + df_plot = remove_subregions(api, df_plot, st.session_state["country"]) - # create plot: + # create plot:st.session_state [c1, c2] = st.columns([1, 5]) with c1: # select which distance to show: @@ -713,7 +712,7 @@ def display_costs(df_costs: pd.DataFrame, key: str, titlestring: str): # create graph: fig = create_bar_chart_costs( df_res, - current_selection=st.session_state["settings"][key], + current_selection=st.session_state[key], ) st.plotly_chart(fig, use_container_width=True) @@ -724,7 +723,7 @@ def display_costs(df_costs: pd.DataFrame, key: str, titlestring: str): st.dataframe(df_res, use_container_width=True, column_config=column_config) res_costs_without_subregions = remove_subregions( - api, res_costs, st.session_state["settings"]["country"] + api, res_costs, st.session_state["country"] ) display_costs(res_costs_without_subregions, "region", "Costs by region:") @@ -785,7 +784,7 @@ def content_deep_dive_countries(api: PtxboaAPI, res_costs: pd.DataFrame) -> None # get input data: - input_data = api.get_input_data(st.session_state["settings"]["scenario"]) + input_data = api.get_input_data(st.session_state["scenario"]) # filter data: # get list of subregions: @@ -874,7 +873,7 @@ def content_input_data(api: PtxboaAPI) -> None: st.subheader("Region specific data:") # get input data: input_data = api.get_input_data( - st.session_state["settings"]["scenario"], + st.session_state["scenario"], user_data=st.session_state["user_changes_df"], ) @@ -1131,13 +1130,11 @@ def register_user_changes( def create_infobox(context_data: dict): data = context_data["infobox"] st.markdown(f"**Key information on {st.session_state['settings']['country']}:**") - demand = data.at[ - st.session_state["settings"]["country"], "Projected H2 demand [2030]" - ] - info1 = data.at[st.session_state["settings"]["country"], "key_info_1"] - info2 = data.at[st.session_state["settings"]["country"], "key_info_2"] - info3 = data.at[st.session_state["settings"]["country"], "key_info_3"] - info4 = data.at[st.session_state["settings"]["country"], "key_info_4"] + demand = data.at[st.session_state["country"], "Projected H2 demand [2030]"] + info1 = data.at[st.session_state["country"], "key_info_1"] + info2 = data.at[st.session_state["country"], "key_info_2"] + info3 = data.at[st.session_state["country"], "key_info_3"] + info4 = data.at[st.session_state["country"], "key_info_4"] st.markdown(f"* Projected H2 demand in 2030: {demand}") def write_info(info): @@ -1177,7 +1174,7 @@ def import_context_data(): def create_fact_sheet_demand_country(context_data: dict): # select country: - country_name = st.session_state["settings"]["country"] + country_name = st.session_state["country"] with st.expander("What is this?"): st.markdown( """ @@ -1266,7 +1263,7 @@ def create_fact_sheet_demand_country(context_data: dict): def create_fact_sheet_supply_country(context_data: dict): """Display information on a chosen supply country.""" # select country: - country_name = st.session_state["settings"]["region"] + country_name = st.session_state["region"] df = context_data["supply"] data = df.loc[df["country_name"] == country_name].iloc[0].to_dict() From 8d80a971b0a963e7296772e1330e34f39f3eb786 Mon Sep 17 00:00:00 2001 From: Markus Haller Date: Wed, 15 Nov 2023 15:59:18 +0100 Subject: [PATCH 5/6] fixed last occurences of st.session_state[settings] --- app/ptxboa_functions.py | 49 +++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index e939e4ff..1a7cef7e 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -33,19 +33,7 @@ def calculate_results_single( pd.DataFrame same format as for :meth:`~ptxboa.api.PtxboaAPI.calculate()` """ - res = _api.calculate( - user_data=user_data, - chain=settings["chain"], - country=settings["country"], - output_unit=settings["output_unit"], - region=settings["region"], - res_gen=settings["res_gen"], - scenario=settings["scenario"], - secproc_co2=settings["secproc_co2"], - secproc_water=settings["secproc_water"], - ship_own_fuel=settings["ship_own_fuel"], - transport=settings["transport"], - ) + res = _api.calculate(user_data=user_data, **settings) return res @@ -81,8 +69,24 @@ def calculate_results_list( if parameter_list is None: parameter_list = api.get_dimension(parameter_to_change).index + # copy settings from session_state: + settings = {} + for key in [ + "chain", + "country", + "output_unit", + "region", + "res_gen", + "scenario", + "secproc_co2", + "secproc_water", + "ship_own_fuel", + "transport", + ]: + settings[key] = st.session_state[key] + for parameter in parameter_list: - settings2 = st.session_state["settings"].copy() + settings2 = settings.copy() settings2[parameter_to_change] = parameter res_single = calculate_results_single(api, settings2, user_data=user_data) res_list.append(res_single) @@ -197,9 +201,7 @@ def create_sidebar(api: PtxboaAPI): st.session_state["chain"] = f"{product} ({ely})" if use_reconversion: - st.session_state[ - "chain" - ] = f"{st.session_state['settings']['chain']} + reconv. to H2" + st.session_state["chain"] = f"{st.session_state['chain']} + reconv. to H2" st.session_state["res_gen"] = st.sidebar.selectbox( "Renewable electricity source (for selected supply region):", @@ -296,8 +298,8 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): # define title: title_string = ( f"{parameter_to_show_on_map} cost of exporting" - f"{st.session_state['settings']['chain']} to " - f"{st.session_state['settings']['country']}" + f"{st.session_state['chain']} to " + f"{st.session_state['country']}" ) # define color scale: color_scale = [ @@ -314,14 +316,13 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame): lambda x: f"{x.name}

" + "
".join( [ - f"{col}: {x[col]:.1f}" - f"{st.session_state['settings']['output_unit']}" + f"{col}: {x[col]:.1f}" f"{st.session_state['output_unit']}" for col in res_costs.columns[:-1] ] + [ f"──────────
{res_costs.columns[-1]}: " f"{x[res_costs.columns[-1]]:.1f}" - f"{st.session_state['settings']['output_unit']}" + f"{st.session_state['output_unit']}" ] ), axis=1, @@ -718,7 +719,7 @@ def display_costs(df_costs: pd.DataFrame, key: str, titlestring: str): with st.expander("**Data**"): column_config = config_number_columns( - df_res, format=f"%.1f {st.session_state['settings']['output_unit']}" + df_res, format=f"%.1f {st.session_state['output_unit']}" ) st.dataframe(df_res, use_container_width=True, column_config=column_config) @@ -1129,7 +1130,7 @@ def register_user_changes( def create_infobox(context_data: dict): data = context_data["infobox"] - st.markdown(f"**Key information on {st.session_state['settings']['country']}:**") + st.markdown(f"**Key information on {st.session_state['country']}:**") demand = data.at[st.session_state["country"], "Projected H2 demand [2030]"] info1 = data.at[st.session_state["country"], "key_info_1"] info2 = data.at[st.session_state["country"], "key_info_2"] From 0a5004a3b25bc3267fe81400de15d3dd838c8e92 Mon Sep 17 00:00:00 2001 From: Markus Haller Date: Wed, 15 Nov 2023 16:01:28 +0100 Subject: [PATCH 6/6] removed debugging output --- app/ptxboa_functions.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/app/ptxboa_functions.py b/app/ptxboa_functions.py index 1a7cef7e..c583a087 100644 --- a/app/ptxboa_functions.py +++ b/app/ptxboa_functions.py @@ -536,12 +536,6 @@ def content_dashboard(api, res_costs: dict, context_data: dict): create_infobox(context_data) - st.write("Chosen settings:") - st.write(st.session_state) - - st.write("res_cost") - st.write(res_costs) - def content_market_scanning(api: PtxboaAPI, res_costs: pd.DataFrame) -> None: """Create content for the "market scanning" sheet. @@ -995,11 +989,6 @@ def content_input_data(api: PtxboaAPI) -> None: # If there are user changes, display them: display_user_changes() - st.write("**Debug: Session state**") - st.write( - st.session_state - ) # FIXME this is debugging output, remove when it is not needed anymore - def reset_user_changes(): """Reset all user changes."""