Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor app functions and tabs #113

Merged
merged 10 commits into from
Nov 16, 2023
30 changes: 30 additions & 0 deletions app/context_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""Module for loading context data."""
import pandas as pd
import streamlit as st


@st.cache_data()
def load_context_data():
"""Import context data from excel file."""
filename = "data/context_data.xlsx"
cd = {}
cd["demand_countries"] = pd.read_excel(
filename, sheet_name="demand_countries", skiprows=1
)
cd["certification_schemes_countries"] = pd.read_excel(
filename, sheet_name="certification_schemes_countries"
)
cd["certification_schemes"] = pd.read_excel(
filename, sheet_name="certification_schemes", skiprows=1
)
cd["sustainability"] = pd.read_excel(filename, sheet_name="sustainability")
cd["supply"] = pd.read_excel(filename, sheet_name="supply", skiprows=1)
cd["literature"] = pd.read_excel(filename, sheet_name="literature")
cd["infobox"] = pd.read_excel(
filename,
sheet_name="infobox",
usecols="A:F",
skiprows=1,
).set_index("country_name")
return cd
216 changes: 216 additions & 0 deletions app/plot_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
"""Functions for plotting input data and results (cost_data)."""
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

from app.ptxboa_functions import remove_subregions
from ptxboa.api import PtxboaAPI


def create_world_map(api: PtxboaAPI, res_costs: pd.DataFrame):
"""Create world map."""
parameter_to_show_on_map = "Total"

# define title:
title_string = (
f"{parameter_to_show_on_map} cost of exporting"
f"{st.session_state['chain']} to "
f"{st.session_state['country']}"
)
# define color scale:
color_scale = [
(0, st.session_state["colors"][0]), # Starting color at the minimum data value
(0.5, st.session_state["colors"][6]),
(1, st.session_state["colors"][9]), # Ending color at the maximum data value
]

# remove subregions from deep dive countries (otherwise colorscale is not correct)
res_costs = remove_subregions(api, res_costs, st.session_state["country"])

# Create custom hover text:
custom_hover_data = res_costs.apply(
lambda x: f"<b>{x.name}</b><br><br>"
+ "<br>".join(
[
f"<b>{col}</b>: {x[col]:.1f}" f"{st.session_state['output_unit']}"
for col in res_costs.columns[:-1]
]
+ [
f"──────────<br><b>{res_costs.columns[-1]}</b>: "
f"{x[res_costs.columns[-1]]:.1f}"
f"{st.session_state['output_unit']}"
]
),
axis=1,
)

# Create a choropleth world map:
fig = px.choropleth(
locations=res_costs.index, # List of country codes or names
locationmode="country names", # Use country names as locations
color=res_costs[parameter_to_show_on_map], # Color values for the countries
custom_data=[custom_hover_data], # Pass custom data for hover information
color_continuous_scale=color_scale, # Choose a color scale
title=title_string, # set title
)

# update layout:
fig.update_geos(
showcountries=True, # Show country borders
showcoastlines=True, # Show coastlines
countrycolor="black", # Set default border color for other countries
countrywidth=0.2, # Set border width
coastlinewidth=0.2, # coastline width
coastlinecolor="black", # coastline color
showland=True, # show land areas
landcolor="#f3f4f5", # Set land color to light gray
oceancolor="#e3e4ea", # Optionally, set ocean color slightly darker gray
showocean=True, # show ocean areas
framewidth=0.2, # width of frame around map
)

fig.update_layout(
coloraxis_colorbar={"title": st.session_state["output_unit"]}, # colorbar
height=600, # height of figure
margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure
)

# Set the hover template to use the custom data
fig.update_traces(hovertemplate="%{customdata}<extra></extra>") # Custom data

# Display the map:
st.plotly_chart(fig, use_container_width=True)
return


def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None):
"""Create bar plot for costs by components, and dots for total costs.

Parameters
----------
res_costs : pd.DataFrame
data for plotting
settings : dict
settings dictionary, like output from create_sidebar()
current_selection : str
bar to highlight with an arrow. must be an element of res_costs.index

Output
------
fig : plotly.graph_objects.Figure
Figure object
"""
if res_costs.empty: # nodata to plot (FIXME: migth not be required later)
return go.Figure()

fig = px.bar(
res_costs,
x=res_costs.index,
y=res_costs.columns[:-1],
height=500,
color_discrete_sequence=st.session_state["colors"],
)

# Add the dot markers for the "total" column using plotly.graph_objects
scatter_trace = go.Scatter(
x=res_costs.index,
y=res_costs["Total"],
mode="markers+text", # Display markers and text
marker={"size": 10, "color": "black"},
name="Total",
text=res_costs["Total"].apply(
lambda x: f"{x:.2f}"
), # Use 'total' column values as text labels
textposition="top center", # Position of the text label above the marker
)

fig.add_trace(scatter_trace)

# add highlight for current selection:
if current_selection is not None and current_selection in res_costs.index:
fig.add_annotation(
x=current_selection,
y=1.2 * res_costs.at[current_selection, "Total"],
text="current selection",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
ax=0,
ay=-50,
)
fig.update_layout(
yaxis_title=st.session_state["output_unit"],
)
return fig


def create_box_plot(res_costs: pd.DataFrame):
"""Create a subplot with one row and one column.

Parameters
----------
res_costs : pd.DataFrame
data for plotting
settings : dict
settings dictionary, like output from create_sidebar()

Output
------
fig : plotly.graph_objects.Figure
Figure object
"""
fig = go.Figure()

# Specify the row index of the data point you want to highlight
highlighted_row_index = st.session_state["region"]
# Extract the value from the specified row and column

if highlighted_row_index:
highlighted_value = res_costs.at[highlighted_row_index, "Total"]
else:
highlighted_value = 0

# Add the box plot to the subplot
fig.add_trace(go.Box(y=res_costs["Total"], name="Cost distribution"))

# Add a scatter marker for the highlighted data point
fig.add_trace(
go.Scatter(
x=["Cost distribution"],
y=[highlighted_value],
mode="markers",
marker={"size": 10, "color": "black"},
name=highlighted_row_index,
text=f"Value: {highlighted_value}", # Add a text label
)
)

# Customize the layout as needed
fig.update_layout(
title="Cost distribution for all supply countries",
xaxis={"title": ""},
yaxis={"title": st.session_state["output_unit"]},
height=500,
)

return fig


def create_scatter_plot(df_res, settings: dict):
df_res["Country"] = "Other countries"
df_res.at[st.session_state["region"], "Country"] = st.session_state["region"]

fig = px.scatter(
df_res,
y="Total",
x="tr_dst_sd",
color="Country",
text=df_res.index,
color_discrete_sequence=["blue", "red"],
)
fig.update_traces(texttemplate="%{text}", textposition="top center")
st.plotly_chart(fig)
st.write(df_res)
Loading