Skip to content

Commit

Permalink
merged develop and fixed conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
wingechr committed Nov 21, 2023
2 parents 3e5314f + 7335c97 commit 8ab1600
Show file tree
Hide file tree
Showing 47 changed files with 41,922 additions and 1,428 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: tests

on:
push:
branches: [ "main", "develop" ]
pull_request:
branches: [ "main", "develop" ]

permissions:
contents: read

jobs:

test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
pip install flake8-bandit flake8-bugbear flake8-builtins flake8-comprehensions flake8-docstrings flake8-eradicate flake8-isort
- name: Test with pytest
run: |
# run tests with pytest (unittest tests will be collected as well)
pytest -vv
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:

# ==================== BLACK ====================
- repo: https://github.com/psf/black
rev: 23.10.1
rev: 23.11.0
hooks:
- id: black
args: ["--line-length", "88"]
Expand Down
6 changes: 6 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[theme]
primaryColor="#1A667B"
backgroundColor="#f8f8f8"
secondaryBackgroundColor="#F0F2F6"
textColor="#262730"
font="sans serif"
30 changes: 30 additions & 0 deletions app/context_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""Module for loading context data."""
import pandas as pd
import streamlit as st


@st.cache_data()
def load_context_data():
"""Import context data from excel file."""
filename = "data/context_data.xlsx"
cd = {}
cd["demand_countries"] = pd.read_excel(
filename, sheet_name="demand_countries", skiprows=1
)
cd["certification_schemes_countries"] = pd.read_excel(
filename, sheet_name="certification_schemes_countries"
)
cd["certification_schemes"] = pd.read_excel(
filename, sheet_name="certification_schemes", skiprows=1
)
cd["sustainability"] = pd.read_excel(filename, sheet_name="sustainability")
cd["supply"] = pd.read_excel(filename, sheet_name="supply", skiprows=1)
cd["literature"] = pd.read_excel(filename, sheet_name="literature")
cd["infobox"] = pd.read_excel(
filename,
sheet_name="infobox",
usecols="A:F",
skiprows=1,
).set_index("country_name")
return cd
298 changes: 298 additions & 0 deletions app/plot_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# -*- coding: utf-8 -*-
"""Functions for plotting input data and results (cost_data)."""
import json
from pathlib import Path
from typing import Literal

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

from app.ptxboa_functions import remove_subregions
from ptxboa.api import PtxboaAPI


def plot_costs_on_map(
api: PtxboaAPI,
res_costs: pd.DataFrame,
scope: Literal["world", "Argentina", "Morocco", "South Africa"] = "world",
cost_component: str = "Total",
) -> go.Figure:
"""
Create map for cost result data.
Parameters
----------
api : PtxboaAPI
res_costs : pd.DataFrame
result obtained with :func:`ptxboa_functions.calculate_results_list`
scope : Literal["world", "Argentina", "Morocco", "South Africa"], optional
either world or a deep dive country, by default "world"
cost_component : str, optional
one of the columns in 'res_costs', by default "Total"
Returns
-------
go.Figure
"""
# define title:
title_string = (
f"{cost_component} cost of exporting "
f"{st.session_state['chain']} to "
f"{st.session_state['country']}"
)
# define color scale:
color_scale = [
(0, st.session_state["colors"][0]), # Starting color at the minimum data value
(0.5, st.session_state["colors"][6]),
(1, st.session_state["colors"][9]), # Ending color at the maximum data value
]

if scope == "world":
# remove subregions from deep dive countries (otherwise colorscale is not
# correct)
res_costs = remove_subregions(api, res_costs, st.session_state["country"])
else:
res_costs = res_costs.copy().loc[
res_costs.index.str.startswith(f"{scope} ("), :
]

# Create custom hover text:
custom_hover_data = res_costs.apply(
lambda x: f"<b>{x.name}</b><br><br>"
+ "<br>".join(
[
f"<b>{col}</b>: {x[col]:.1f}" f"{st.session_state['output_unit']}"
for col in res_costs.columns[:-1]
]
+ [
f"──────────<br><b>{res_costs.columns[-1]}</b>: "
f"{x[res_costs.columns[-1]]:.1f}"
f"{st.session_state['output_unit']}"
]
),
axis=1,
)

if scope == "world":
# Create a choropleth world map:
fig = px.choropleth(
locations=res_costs.index, # List of country codes or names
locationmode="country names", # Use country names as locations
color=res_costs[cost_component], # Color values for the countries
custom_data=[custom_hover_data], # Pass custom data for hover information
color_continuous_scale=color_scale, # Choose a color scale
title=title_string,
)
else:
fig = _choropleth_map_deep_dive_country(
api,
res_costs,
scope,
color=cost_component,
custom_data=[custom_hover_data],
color_continuous_scale=color_scale,
title=title_string,
)
fig.update_geos(
fitbounds="locations",
visible=True,
)

# update layout:
fig.update_geos(
showcountries=True, # Show country borders
showcoastlines=True, # Show coastlines
countrycolor="black", # Set default border color for other countries
countrywidth=0.2, # Set border width
coastlinewidth=0.2, # coastline width
coastlinecolor="black", # coastline color
showland=True, # show land areas
landcolor="#f3f4f5", # Set land color to light gray
oceancolor="#e3e4ea", # Optionally, set ocean color slightly darker gray
showocean=True, # show ocean areas
framewidth=0.2, # width of frame around map
)

fig.update_layout(
coloraxis_colorbar={
"title": st.session_state["output_unit"],
"len": 0.5,
}, # colorbar
margin={"t": 20, "b": 20, "l": 20, "r": 20}, # reduce margin around figure
)

# Set the hover template to use the custom data
fig.update_traces(hovertemplate="%{customdata}<extra></extra>") # Custom data

return fig


def _choropleth_map_deep_dive_country(
api,
res_costs_subset,
scope_country,
color,
custom_data,
color_continuous_scale,
title,
):
# get dataframe with info about iso 3166-2 codes and map them to res_costs
scope_info = api.get_dimension("region").loc[
api.get_dimension("region")["region_name"].str.startswith(f"{scope_country} (")
]
res_costs_subset["iso3166_code"] = res_costs_subset.index.map(
pd.Series(scope_info["iso3166_code"], index=scope_info["region_name"])
)

geojson_file = (
Path(__file__).parent.parent
/ "data"
/ f"{scope_country.lower().replace(' ', '_')}_subregions.geojson"
)
with geojson_file.open("r", encoding="utf-8") as f:
subregion_shapes = json.load(f)

fig = px.choropleth(
locations=res_costs_subset["iso3166_code"],
featureidkey="properties.iso_3166_2",
color=res_costs_subset[color],
geojson=subregion_shapes,
custom_data=custom_data,
color_continuous_scale=color_continuous_scale,
title=title,
)
return fig


def create_bar_chart_costs(res_costs: pd.DataFrame, current_selection: str = None):
"""Create bar plot for costs by components, and dots for total costs.
Parameters
----------
res_costs : pd.DataFrame
data for plotting
settings : dict
settings dictionary, like output from create_sidebar()
current_selection : str
bar to highlight with an arrow. must be an element of res_costs.index
Output
------
fig : plotly.graph_objects.Figure
Figure object
"""
if res_costs.empty: # nodata to plot (FIXME: migth not be required later)
return go.Figure()

fig = px.bar(
res_costs,
x=res_costs.index,
y=res_costs.columns[:-1],
height=500,
color_discrete_sequence=st.session_state["colors"],
)

# Add the dot markers for the "total" column using plotly.graph_objects
scatter_trace = go.Scatter(
x=res_costs.index,
y=res_costs["Total"],
mode="markers+text", # Display markers and text
marker={"size": 10, "color": "black"},
name="Total",
text=res_costs["Total"].apply(
lambda x: f"{x:.2f}"
), # Use 'total' column values as text labels
textposition="top center", # Position of the text label above the marker
)

fig.add_trace(scatter_trace)

# add highlight for current selection:
if current_selection is not None and current_selection in res_costs.index:
fig.add_annotation(
x=current_selection,
y=1.2 * res_costs.at[current_selection, "Total"],
text="current selection",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
ax=0,
ay=-50,
)
fig.update_layout(
yaxis_title=st.session_state["output_unit"],
)
return fig


def create_box_plot(res_costs: pd.DataFrame):
"""Create a subplot with one row and one column.
Parameters
----------
res_costs : pd.DataFrame
data for plotting
settings : dict
settings dictionary, like output from create_sidebar()
Output
------
fig : plotly.graph_objects.Figure
Figure object
"""
fig = go.Figure()

# Specify the row index of the data point you want to highlight
highlighted_row_index = st.session_state["region"]
# Extract the value from the specified row and column

if highlighted_row_index:
highlighted_value = res_costs.at[highlighted_row_index, "Total"]
else:
highlighted_value = 0

# Add the box plot to the subplot
fig.add_trace(go.Box(y=res_costs["Total"], name="Cost distribution"))

# Add a scatter marker for the highlighted data point
fig.add_trace(
go.Scatter(
x=["Cost distribution"],
y=[highlighted_value],
mode="markers",
marker={"size": 10, "color": "black"},
name=highlighted_row_index,
text=f"Value: {highlighted_value}", # Add a text label
)
)

# Customize the layout as needed
fig.update_layout(
title="Cost distribution for all supply countries",
xaxis={"title": ""},
yaxis={"title": st.session_state["output_unit"]},
height=500,
)

return fig


def create_scatter_plot(df_res, settings: dict):
df_res["Country"] = "Other countries"
df_res.at[st.session_state["region"], "Country"] = st.session_state["region"]

fig = px.scatter(
df_res,
y="Total",
x="tr_dst_sd",
color="Country",
text=df_res.index,
color_discrete_sequence=["blue", "red"],
)
fig.update_traces(texttemplate="%{text}", textposition="top center")
st.plotly_chart(fig)
st.write(df_res)
Loading

0 comments on commit 8ab1600

Please sign in to comment.