Skip to content

Commit

Permalink
Merge pull request #118 from agoenergy/feat/deep_dive_country_maps
Browse files Browse the repository at this point in the history
deep dive country maps
  • Loading branch information
markushal authored Nov 20, 2023
2 parents 1dfadb4 + 02725cc commit 6f80d17
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 104 deletions.
122 changes: 102 additions & 20 deletions app/plot_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# -*- coding: utf-8 -*-
"""Functions for plotting input data and results (cost_data)."""
import json
from pathlib import Path
from typing import Literal

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
Expand All @@ -9,13 +13,33 @@
from ptxboa.api import PtxboaAPI


def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
"""Create world map."""
parameter_to_show_on_map = "Total"
def plot_costs_on_map(
api: PtxboaAPI,
res_costs: pd.DataFrame,
scope: Literal["world", "Argentina", "Morocco", "South Africa"] = "world",
cost_component: str = "Total",
) -> go.Figure:
"""
Create map for cost result data.
Parameters
----------
api : PtxboaAPI
res_costs : pd.DataFrame
result obtained with :func:`ptxboa_functions.calculate_results_list`
scope : Literal["world", "Argentina", "Morocco", "South Africa"], optional
either world or a deep dive country, by default "world"
cost_component : str, optional
one of the columns in 'res_costs', by default "Total"
Returns
-------
go.Figure
"""
# define title:
title_string = (
f"{parameter_to_show_on_map} cost of exporting"
f"{cost_component} cost of exporting "
f"{st.session_state['chain']} to "
f"{st.session_state['country']}"
)
Expand All @@ -26,8 +50,14 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
(1, st.session_state["colors"][9]), # Ending color at the maximum data value
]

# remove subregions from deep dive countries (otherwise colorscale is not correct)
res_costs = remove_subregions(api, res_costs, st.session_state["country"])
if scope == "world":
# remove subregions from deep dive countries (otherwise colorscale is not
# correct)
res_costs = remove_subregions(api, res_costs, st.session_state["country"])
else:
res_costs = res_costs.copy().loc[
res_costs.index.str.startswith(f"{scope} ("), :
]

# Create custom hover text:
custom_hover_data = res_costs.apply(
Expand All @@ -46,15 +76,30 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
axis=1,
)

# Create a choropleth world map:
fig = px.choropleth(
locations=res_costs.index, # List of country codes or names
locationmode="country names", # Use country names as locations
color=res_costs[parameter_to_show_on_map], # Color values for the countries
custom_data=[custom_hover_data], # Pass custom data for hover information
color_continuous_scale=color_scale, # Choose a color scale
title=title_string, # set title
)
if scope == "world":
# Create a choropleth world map:
fig = px.choropleth(
locations=res_costs.index, # List of country codes or names
locationmode="country names", # Use country names as locations
color=res_costs[cost_component], # Color values for the countries
custom_data=[custom_hover_data], # Pass custom data for hover information
color_continuous_scale=color_scale, # Choose a color scale
title=title_string,
)
else:
fig = _choropleth_map_deep_dive_country(
api,
res_costs,
scope,
color=cost_component,
custom_data=[custom_hover_data],
color_continuous_scale=color_scale,
title=title_string,
)
fig.update_geos(
fitbounds="locations",
visible=True,
)

# update layout:
fig.update_geos(
Expand All @@ -72,17 +117,54 @@ def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
)

fig.update_layout(
coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar
height=600, # height of figure
coloraxis_colorbar={
"title": st.session_state["output_unit"],
"len": 0.5,
}, # colorbar
margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure
)

# Set the hover template to use the custom data
fig.update_traces(hovertemplate="%{customdata}<extra></extra>") # Custom data

# Display the map:
st.plotly_chart(fig, use_container_width=True)
return
return fig


def _choropleth_map_deep_dive_country(
api,
res_costs_subset,
scope_country,
color,
custom_data,
color_continuous_scale,
title,
):
# get dataframe with info about iso 3166-2 codes and map them to res_costs
scope_info = api.get_dimension("region").loc[
api.get_dimension("region")["region_name"].str.startswith(f"{scope_country} (")
]
res_costs_subset["iso3166_code"] = res_costs_subset.index.map(
pd.Series(scope_info["iso3166_code"], index=scope_info["region_name"])
)

geojson_file = (
Path(__file__).parent.parent
/ "data"
/ f"{scope_country.lower().replace(' ', '_')}_subregions.geojson"
)
with geojson_file.open("r", encoding="utf-8") as f:
subregion_shapes = json.load(f)

fig = px.choropleth(
locations=res_costs_subset["iso3166_code"],
featureidkey="properties.iso_3166_2",
color=res_costs_subset[color],
geojson=subregion_shapes,
custom_data=custom_data,
color_continuous_scale=color_continuous_scale,
title=title,
)
return fig


def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None):
Expand Down
11 changes: 9 additions & 2 deletions app/tab_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import streamlit as st
from plotly.subplots import make_subplots

from app.plot_functions import create_bar_chart_costs, create_box_plot, create_world_map
from app.plot_functions import (
create_bar_chart_costs,
create_box_plot,
plot_costs_on_map,
)


def _create_infobox(context_data: dict):
Expand Down Expand Up @@ -43,7 +47,10 @@ def content_dashboard(api, res_costs: dict, context_data: dict):
c_1, c_2 = st.columns([2, 1])

with c_1:
create_world_map(api, res_costs)
fig_map = plot_costs_on_map(
api, res_costs, scope="world", cost_component="Total"
)
st.plotly_chart(fig_map, use_container_width=True)

with c_2:
# create box plot and bar plot:
Expand Down
6 changes: 4 additions & 2 deletions app/tab_deep_dive_countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import plotly.express as px
import streamlit as st

from app.plot_functions import plot_costs_on_map
from app.ptxboa_functions import display_and_edit_data_table
from ptxboa.api import PtxboaAPI

Expand Down Expand Up @@ -36,12 +37,13 @@ def content_deep_dive_countries(api: PtxboaAPI, res_costs: pd.DataFrame) -> None
"""
)

st.markdown("TODO: add country map")

ddc = st.radio(
"Select country:", ["Argentina", "Morocco", "South Africa"], horizontal=True
)

fig_map = plot_costs_on_map(api, res_costs, scope=ddc, cost_component="Total")
st.plotly_chart(fig_map, use_container_width=True)

# get input data:

input_data = api.get_input_data(st.session_state["scenario"])
Expand Down
31 changes: 31 additions & 0 deletions data/argentina_subregions.geojson

Large diffs are not rendered by default.

Loading

0 comments on commit 6f80d17

Please sign in to comment.