Skip to content

Commit

Permalink
Propagate raster nodata to missing results in output
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Dec 20, 2023
1 parent 0d5e887 commit 74bdf2e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
35 changes: 26 additions & 9 deletions src/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,25 @@ class Operation
virtual void set_result(const StatsRegistry& reg, const Feature& f_in, Feature& f_out) const
{
static const StatsRegistry::RasterStatsVariant empty_stats = RasterStats<double>();

constexpr bool write_if_missing = true; // should we set attribute values if the feature did not intersect the raster?
if (!write_if_missing && !reg.contains(f_in, *this)) {
return;
}

const auto& stats = reg.contains(f_in, *this) ? reg.stats(f_in, *this) : empty_stats;

auto missing = std::numeric_limits<double>::quiet_NaN();
// FIXME don't read an empty box every time, maybe cache this in the source?
auto empty_rast = values->read_box(Box::make_empty());

auto missing = std::visit([](const auto& r) {
std::variant<std::int32_t, std::int64_t, float, double> ret = std::numeric_limits<double>::quiet_NaN();
if (r->has_nodata()) {
ret = r->nodata();
}
return ret;
},
empty_rast);

if (stat == "mean") {
std::visit([&f_out, this](const auto& x) { f_out.set(m_field_names[0], x.mean()); }, stats);
Expand All @@ -125,13 +141,13 @@ class Operation
} else if (stat == "weighted_sum") {
std::visit([&f_out, this](const auto& x) { f_out.set(m_field_names[0], x.weighted_sum()); }, stats);
} else if (stat == "min") {
std::visit([&f_out, &missing, this](const auto& x) { f_out.set(m_field_names[0], x.min().value_or(missing)); }, stats);
std::visit([&f_out, this](const auto& x, const auto& m) { f_out.set(m_field_names[0], x.min().value_or(m)); }, stats, missing);
} else if (stat == "max") {
std::visit([&f_out, &missing, this](const auto& x) { f_out.set(m_field_names[0], x.max().value_or(missing)); }, stats);
std::visit([&f_out, this](const auto& x, const auto& m) { f_out.set(m_field_names[0], x.max().value_or(m)); }, stats, missing);
} else if (stat == "majority" || stat == "mode") {
std::visit([&f_out, &missing, this](const auto& x) { f_out.set(m_field_names[0], x.mode().value_or(missing)); }, stats);
std::visit([&f_out, this](const auto& x, const auto& m) { f_out.set(m_field_names[0], x.mode().value_or(m)); }, stats, missing);
} else if (stat == "minority") {
std::visit([&f_out, &missing, this](const auto& x) { f_out.set(m_field_names[0], x.minority().value_or(missing)); }, stats);
std::visit([&f_out, this](const auto& x, const auto& m) { f_out.set(m_field_names[0], x.minority().value_or(m)); }, stats, missing);
} else if (stat == "variety") {
std::visit([&f_out, this](const auto& x) { f_out.set(m_field_names[0], x.variety()); }, stats);
} else if (stat == "stdev") {
Expand All @@ -145,14 +161,15 @@ class Operation
} else if (stat == "coefficient_of_variation") {
std::visit([&f_out, this](const auto& x) { f_out.set(m_field_names[0], x.coefficient_of_variation()); }, stats);
} else if (stat == "median") {
std::visit([&f_out, &missing, this](const auto& x) { f_out.set(m_field_names[0], x.quantile(0.5).value_or(missing)); }, stats);
std::visit([&f_out, this](const auto& x, const auto& m) { f_out.set(m_field_names[0], x.quantile(0.5).value_or(m)); }, stats, missing);
} else if (stat == "quantile") {
std::visit([&f_out, &missing, this](const auto& x) {
std::visit([&f_out, this](const auto& x, const auto& m) {
for (std::size_t i = 0; i < m_quantiles.size(); i++) {
f_out.set(m_field_names[i], x.quantile(m_quantiles[i]).value_or(missing));
f_out.set(m_field_names[i], x.quantile(m_quantiles[i]).value_or(m));
}
},
stats);
stats,
missing);
} else if (stat == "frac") {
std::visit([&f_out](const auto& x) {
for (const auto& value : x) {
Expand Down
3 changes: 1 addition & 2 deletions src/raster_stats.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018-2022 ISciences, LLC.
// Copyright (c) 2018-2023 ISciences, LLC.
// All rights reserved.
//
// This software is licensed under the Apache License, Version 2.0 (the "License").
Expand All @@ -15,7 +15,6 @@
#define EXACTEXTRACT_RASTER_STATS_H

#include <algorithm>
#include <iostream>
#include <limits>
#include <optional>
#include <unordered_map>
Expand Down
35 changes: 32 additions & 3 deletions test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def runner(*args, **kwargs):

cmd = [str(x) for x in ["./exactextract", "-o", output_fname] + arglist]

# print(' '.join(cmd))

subprocess.run(cmd, check=True)

with open(output_fname, "r") as f:
Expand All @@ -45,7 +43,7 @@ def write_raster(tmp_path):

files = []

def writer(data):
def writer(data, nodata=None):

fname = str(tmp_path / f"raster{len(files)}.tif")

Expand All @@ -65,6 +63,9 @@ def writer(data):

ds.GetRasterBand(1).WriteArray(data)

if nodata:
ds.GetRasterBand(1).SetNoDataValue(nodata)

return fname

return writer
Expand Down Expand Up @@ -166,12 +167,40 @@ def test_feature_not_intersecting_raster(strategy, run, write_raster, write_feat
fid="id",
raster=f"value:{write_raster(data)}",
stat=["count(value)", "mean(value)"],
strategy=strategy,
)

assert len(rows) == 1
assert rows[0] == {"id": "1", "value_count": "0", "value_mean": "nan"}


@pytest.mark.parametrize("strategy", ("feature-sequential", "raster-sequential"))
@pytest.mark.parametrize("dtype,nodata", [(np.float32, None), (np.int32, -999)])
def test_feature_intersecting_nodata(
strategy, run, write_raster, write_features, dtype, nodata
):

data = np.full((4, 3), nodata or np.nan, dtype=dtype)

rows = run(
polygons=write_features(
{"id": 1, "geom": "POLYGON ((0.5 0.5, 2.5 0.5, 2.5 2, 0.5 2, 0.5 0.5))"}
),
fid="id",
raster=f"metric:{write_raster(data, nodata)}",
stat=["count(metric)", "mean(metric)", "mode(metric)"],
strategy=strategy,
)

assert len(rows) == 1
assert rows[0] == {
"id": "1",
"metric_count": "0",
"metric_mean": "nan",
"metric_mode": str(nodata) if nodata else "nan",
}


@pytest.mark.parametrize("strategy", ("feature-sequential", "raster-sequential"))
def test_include_cols(strategy, run, write_raster, write_features):

Expand Down

0 comments on commit 74bdf2e

Please sign in to comment.