Skip to content

Commit

Permalink
Merge pull request #168 from nismod/feature/per_sample_outputs
Browse files Browse the repository at this point in the history
Chunk analysis on sample
  • Loading branch information
thomas-fred authored Dec 12, 2023
2 parents d97d2d6 + 28b4171 commit ed0fc14
Show file tree
Hide file tree
Showing 57 changed files with 402 additions and 232 deletions.
3 changes: 2 additions & 1 deletion src/open_gira/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def forward_azimuth(
Reference: https://www.movable-type.co.uk/scripts/latlong.html
Args:
Δλ: Difference in longitudes (radians) φ1: Start latitudes (radians)
Δλ: Difference in longitudes (radians)
φ1: Start latitudes (radians)
φ2: End latitudes (radians)
Returns: Initial headings from start towards end points (radians)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
39 changes: 16 additions & 23 deletions tests/integration/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,18 @@ def compare_files(self, generated_file: Path, expected_file: Path) -> None:
"""

printerr(f">>> Compare files:\n{generated_file}\n{expected_file}")
file_ext = str(generated_file).split(".")[-1]

# PARQUET
if re.search(r"\.(geo)?parquet$", str(generated_file), re.IGNORECASE):
if re.search(r"\.geoparquet$", str(generated_file), re.IGNORECASE):
"""
NOTE: This test will **fail** if the geoparquet file does not contain geography data columns.
This can happen where the convert_to_geoparquet job does not find any roads to write.
We leave this failure in because it is usually unintentional that you're testing with a
dataset where _none_ of the slices have road data, and these tests should be targeted at
slices that _do_ have road data.
"""
read = gpd.read_parquet
elif re.search(r"\.parquet$", str(generated_file), re.IGNORECASE):
read = pd.read_parquet
else:
raise RuntimeError(f"couldn't identify read function for {generated_file}")

generated = read(generated_file)
expected = read(expected_file)
if any([ext == file_ext for ext in ("pq", "parq", "parquet")]):
generated = pd.read_parquet(generated_file)
expected = pd.read_parquet(expected_file)
self.compare_dataframes(generated, expected)

# GEOPARQUET
elif any([ext == file_ext for ext in ("gpq", "geoparq", "geoparquet")]):
generated = gpd.read_parquet(generated_file)
expected = gpd.read_parquet(expected_file)
self.compare_dataframes(generated, expected)

# JSON
Expand Down Expand Up @@ -283,21 +275,22 @@ def compare_dataframes(generated: pd.DataFrame, expected: pd.DataFrame) -> None:
printerr(f"{col=} {unequal_only_where_null=}")

# let's try and find failing rows by converting to str
MAX_FAILURES_TO_PRINT = 5
MAX_FAILURES_TO_PRINT = 20
failures = 0
for row in range(len(generated)):
gen_str = str(generated[col][row: row + 1].values)
exp_str = str(expected[col][row: row + 1].values)
if gen_str != exp_str:
failures += 1
if failures > MAX_FAILURES_TO_PRINT:
printerr(f"Failures truncated after {MAX_FAILURES_TO_PRINT}")
continue
else:
if failures < MAX_FAILURES_TO_PRINT:
printerr(f">>> FAILURE at {col=}, {row=}: {gen_str} != {exp_str}")
elif failures == MAX_FAILURES_TO_PRINT:
printerr(f"Failures truncated after {MAX_FAILURES_TO_PRINT}...")
else:
continue

if failures > 0:
raise ValueError(f"{failures} row mismatch(es) between tables")
raise ValueError(f"Found {failures} row mismatch(es) between tables")

else:
# None != None according to pandas, and this is responsible for the apparent mismatch
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_aggregate_exposure_within_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from . import runner


def test_aggregate_exposure_within_sample():
runner.run_snakemake_test(
"aggregate_exposure_within_sample",
(
"results/power/by_country/PRI/exposure/IBTrACS/0_length_m_by_event.pq",
"results/power/by_country/PRI/exposure/IBTrACS/0_length_m_by_edge.pq",
)
)
10 changes: 0 additions & 10 deletions tests/integration/test_concat_exposure_by_event.py

This file was deleted.

3 changes: 2 additions & 1 deletion tests/integration/test_electricity_grid_damages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ def test_electricity_grid_damages():
runner.run_snakemake_test(
"electricity_grid_damages",
(
"results/power/by_country/PRI/exposure/IBTrACS/by_storm/2017242N16333.nc",
"results/power/by_country/PRI/exposure/IBTrACS/0/2017242N16333.nc",
"results/power/by_country/PRI/exposure/IBTrACS/0/2017242N16333.nc",
)
)
2 changes: 1 addition & 1 deletion tests/integration/test_exposure_by_admin_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ def test_exposure_by_admin_region():
runner.run_snakemake_test(
"exposure_by_admin_region",
(
"results/power/by_country/PRI/exposure/IBTrACS/admin-level-1.geoparquet",
"results/power/by_country/PRI/exposure/IBTrACS/EAE_admin-level-1.gpq",
)
)
1 change: 0 additions & 1 deletion workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ include: "rules/download/IRIS.smk"
include: "rules/download/IBTrACS.smk"
include: "rules/download/gadm.smk"
include: "rules/download/gridfinder.smk"
include: "rules/download/worldpop-population.smk"
include: "rules/download/ghsl-pop.smk"
include: "rules/download/hazards.smk"
include: "rules/download/dryad-gdp.smk"
Expand Down
34 changes: 0 additions & 34 deletions workflow/rules/download/worldpop-population.smk

This file was deleted.

129 changes: 116 additions & 13 deletions workflow/rules/exposure/electricity_grid/disruption.smk
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,132 @@ def disruption_by_storm_for_country_for_storm_set(wildcards):
)


def disruption_by_storm_for_country_for_storm_set(wildcards):
"""
Given STORM_SET as a wildcard, lookup the storms in the set impacting given COUNTRY_ISO_A3.
Return a list of the relevant disruption netCDF file paths.
"""

json_file = checkpoints.countries_intersecting_storm_set.get(**wildcards).output.storm_set_by_country
storm_set_by_country = cached_json_file_read(json_file)

storms = storm_set_by_country[wildcards.COUNTRY_ISO_A3]

return expand(
"{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/{SAMPLE}/{STORM_ID}.nc",
OUTPUT_DIR=wildcards.OUTPUT_DIR, # str
COUNTRY_ISO_A3=wildcards.COUNTRY_ISO_A3, # str
STORM_SET=wildcards.STORM_SET, # str
SAMPLE=wildcards.SAMPLE, # str
STORM_ID=storms # list of str
)

rule aggregate_disruption_within_sample:
"""
Take per-event disruption files with per-target rows (for all of a storm set
sample) and aggregate into a per-target file and a per-event file.
"""
input:
disruption_by_event = disruption_by_storm_for_country_for_storm_set
params:
thresholds = config["transmission_windspeed_failure"]
output:
by_event = directory("{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/{SAMPLE}_pop_affected_by_event.pq"),
by_target = directory("{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/{SAMPLE}_pop_affected_by_target.pq"),
script:
"../../../scripts/exposure/aggregate_grid_disruption.py"

"""
Test with:
snakemake --cores 1 -- results/power/by_country/PRI/disruption/IBTrACS/0_pop_affected_by_event.pq
"""


def disruption_per_event_sample_files(wildcards) -> list[str]:
"""
Return a list of paths, one for each sample.
"""
dataset_name = wildcards.STORM_SET.split("-")[0]
return expand(
rules.aggregate_disruption_within_sample.output.by_event,
OUTPUT_DIR=wildcards.OUTPUT_DIR,
COUNTRY_ISO_A3=wildcards.COUNTRY_ISO_A3,
STORM_SET=wildcards.STORM_SET,
SAMPLE=range(0, SAMPLES_PER_TRACKSET[dataset_name]),
)

rule aggregate_per_event_disruption_across_samples:
"""
Take the per-sample customers affected files and combine them.
"""
input:
per_sample = disruption_per_event_sample_files,
output:
all_samples = "{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/pop_affected_by_event.pq",
run:
import pandas as pd

df = pd.concat([pd.read_parquet(file_path) for file_path in input.per_sample])
df.to_parquet(output.all_samples)

"""
Test with:
snakemake --cores 1 -- results/power/by_country/PRI/disruption/IBTrACS/pop_affected_by_event.pq
"""


def disruption_per_target_sample_files(wildcards) -> list[str]:
"""
Return a list of paths, one for each sample.
"""
dataset_name = wildcards.STORM_SET.split("-")[0]
return expand(
rules.aggregate_disruption_within_sample.output.by_target,
OUTPUT_DIR=wildcards.OUTPUT_DIR,
COUNTRY_ISO_A3=wildcards.COUNTRY_ISO_A3,
STORM_SET=wildcards.STORM_SET,
SAMPLE=range(0, SAMPLES_PER_TRACKSET[dataset_name]),
)

rule aggregate_per_target_disruption_across_samples:
"""
Take the per-sample customers affected files and combine them.
"""
input:
per_sample = disruption_per_target_sample_files,
output:
all_samples = "{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/pop_affected_by_target.pq",
run:
import pandas as pd

df = pd.concat([pd.read_parquet(file_path) for file_path in input.per_sample])
df.groupby("target").sum().to_parquet(output.all_samples)

"""
Test with:
snakemake --cores 1 -- results/power/by_country/PRI/disruption/IBTrACS/pop_affected_by_target.pq
"""


rule disruption_by_admin_region:
"""
Calculate expected annual population affected at given admin level.
"""
input:
tracks = storm_tracks_file_from_storm_set,
disruption = disruption_by_storm_for_country_for_storm_set,
disruption_by_target = rules.aggregate_per_target_disruption_across_samples.output.all_samples,
disruption_by_event = rules.aggregate_per_event_disruption_across_samples.output.all_samples,
targets = "{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/network/targets.geoparquet",
admin_areas = "{OUTPUT_DIR}/input/admin-boundaries/{ADMIN_SLUG}.geoparquet",
threads: 8 # read exposure files in parallel
output:
total_disruption_by_region = "{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/{ADMIN_SLUG}.geoparquet",
# TODO: per region event distributions
# disruption_event_distribution_by_region = dir("{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/{ADMIN_SLUG}/")
expected_annual_disruption = "{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/EAPA_{ADMIN_SLUG}.gpq",
script:
"../../../scripts/exposure/grid_disruption_by_admin_region.py"

"""
Test with:
snakemake -c1 -- results/power/by_country/PRI/disruption/IBTrACS/admin-level-1.geoparquet
snakemake -c1 -- results/power/by_country/PRI/disruption/IBTrACS/EAPA_admin-level-1.gpq
"""


Expand All @@ -176,7 +282,7 @@ def disruption_summaries_for_storm_set(wildcards):
country_set = cached_json_file_read(json_file)

return expand(
"{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/{ADMIN_LEVEL}.geoparquet",
"{OUTPUT_DIR}/power/by_country/{COUNTRY_ISO_A3}/disruption/{STORM_SET}/EAPA_{ADMIN_LEVEL}.gpq",
OUTPUT_DIR=wildcards.OUTPUT_DIR, # str
COUNTRY_ISO_A3=country_set, # list of str
STORM_SET=wildcards.STORM_SET, # str
Expand All @@ -186,15 +292,12 @@ def disruption_summaries_for_storm_set(wildcards):

rule disruption_by_admin_region_for_storm_set:
"""
A target rule to generate the exposure and disruption netCDFs for all
targets affected (across multiple countries) for each storm.
Concatenates the regional summaries for expected annual population disruption together.
Concatenate the regional summaries for expected annual population affected.
"""
input:
disruption = disruption_summaries_for_storm_set
output:
storm_set_disruption = "{OUTPUT_DIR}/power/by_storm_set/{STORM_SET}/disruption/{ADMIN_LEVEL}.geoparquet"
storm_set_disruption = "{OUTPUT_DIR}/power/by_storm_set/{STORM_SET}/disruption/EAPA_{ADMIN_LEVEL}.gpq"
run:
import geopandas as gpd
import pandas as pd
Expand All @@ -207,5 +310,5 @@ rule disruption_by_admin_region_for_storm_set:

"""
Test with:
snakemake --cores 1 -- results/power/by_storm_set/IBTrACS/disruption/admin-level-2.geoparquet
snakemake --cores 1 -- results/power/by_storm_set/IBTrACS/disruption/EAPA_admin-level-2.gpq
"""
Loading

0 comments on commit ed0fc14

Please sign in to comment.