Skip to content

Commit

Permalink
Rework interpn to support all methods, propagate NaNs consistently an…
Browse files Browse the repository at this point in the history
…d add a lot of tests
  • Loading branch information
rhugonnet committed Jun 8, 2024
1 parent 658e3b2 commit 4e29c5e
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 41 deletions.
113 changes: 82 additions & 31 deletions geoutils/raster/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import numpy as np
import rasterio as rio
from scipy.interpolate import RectBivariateSpline, RegularGridInterpolator
from scipy.ndimage import map_coordinates
from scipy.ndimage import map_coordinates, binary_dilation

from geoutils._typing import NDArrayNum, Number
from geoutils.raster.georeferencing import _coords, _outside_image, _res, _xy2ij

method_to_order = {"nearest": 0, "linear": 1, "cubic": 3, "quintic": 5, "slinear": 1, "pchip": 3, "splinef2d": 3}

def _interpn_interpolator(
coords: tuple[NDArrayNum, NDArrayNum],
Expand All @@ -19,45 +20,95 @@ def _interpn_interpolator(
method: Literal["nearest", "linear", "cubic", "quintic", "slinear", "pchip", "splinef2d"] = "linear",
) -> Callable[[tuple[NDArrayNum, NDArrayNum]], NDArrayNum]:
"""
Mirroring scipy.interpn function but returning interpolator directly: either a RegularGridInterpolator or
a RectBivariateSpline object. (required for speed when interpolating multiple times)
Create SciPy interpolator with nodata spreading at distance of half the method order rounded up (i.e., linear
spreads 1 nodata in each direction, cubic spreads 2, quintic 3).
From: https://github.com/scipy/scipy/blob/44e4ebaac992fde33f04638b99629d23973cb9b2/scipy/interpolate/_rgi.py#L743
Gives the exact same result as scipy.interpolate.interpn, and allows interpolator to be re-used if required (
for speed).
In practice, returns either a NaN-modified RegularGridInterpolator or a NaN-modified RectBivariateSpline object,
both expecting a tuple of X/Y coordinates to be evaluated.
Adapted from:
https://github.com/scipy/scipy/blob/44e4ebaac992fde33f04638b99629d23973cb9b2/scipy/interpolate/_rgi.py#L743.
"""

# Easy for the RegularGridInterpolator
# Adding masking of NaNs for methods not supporting it
method_support_nan = method in ["nearest"]
order = method_to_order[method]
dist_nodata_spread = int(np.ceil(order/2))

# If NaNs are not supported
if not method_support_nan:
# We compute the mask and dilate it to the order of interpolation (propagating NaNs)
mask_nan = ~np.isfinite(values)
new_mask = binary_dilation(mask_nan, iterations=dist_nodata_spread).astype("uint8")

# We create an interpolator for the mask too, using nearest
interp_mask = RegularGridInterpolator(
coords, new_mask, method="nearest", bounds_error=bounds_error, fill_value=1
)

# Replace NaN values by nearest neighbour to avoid biasing interpolation near NaNs with placeholder value
values[mask_nan] = 0

# For the RegularGridInterpolator
if method in RegularGridInterpolator._ALL_METHODS:

# We create the interpolator
interp = RegularGridInterpolator(
coords, values, method=method, bounds_error=bounds_error, fill_value=fill_value
)
return interp

# Otherwise need to wrap the fill value around RectBivariateSpline
interp = RectBivariateSpline(np.flip(coords[0]), coords[1], np.flip(values[:], axis=0))

def rectbivariate_interpolator_with_fillvalue(xi: tuple[NDArrayNum, NDArrayNum]) -> NDArrayNum:

# RectBivariateSpline doesn't support fill_value; we need to wrap here
xi_arr = np.array(xi)
xi_shape = xi_arr.shape
xi_arr = xi_arr.reshape(-1, xi_arr.shape[-1])
idx_valid = np.all(
(
coords[0][-1] <= xi_arr[:, 0],
xi_arr[:, 0] <= coords[0][0],
coords[1][0] <= xi_arr[:, 1],
xi_arr[:, 1] <= coords[1][-1],
),
axis=0,
)
# Make a copy of values for RectBivariateSpline
result = np.empty_like(xi_arr[:, 0])
result[idx_valid] = interp.ev(xi_arr[idx_valid, 0], xi_arr[idx_valid, 1])
result[np.logical_not(idx_valid)] = fill_value

return result.reshape(xi_shape[:-1])
# We create a new interpolator callable
def regulargrid_interpolator_with_nan(xi: tuple[NDArrayNum, NDArrayNum]) -> NDArrayNum:

results = interp(xi)

if not method_support_nan:
invalids = interp_mask(xi)
results[invalids.astype(bool)] = np.nan

return results

return regulargrid_interpolator_with_nan

# For the RectBivariateSpline
else:

return rectbivariate_interpolator_with_fillvalue
# The coordinates must be in ascending order, which requires flipping the array too (more costly)
interp = RectBivariateSpline(np.flip(coords[0]), coords[1], np.flip(values[:], axis=0))

# We create a new interpolator callable
def rectbivariate_interpolator_with_fillvalue(xi: tuple[NDArrayNum, NDArrayNum]) -> NDArrayNum:

# Get invalids
invalids = interp_mask(xi)

# RectBivariateSpline doesn't support fill_value, so we need to wrap here to add them
xi_arr = np.array(xi).T
xi_shape = xi_arr.shape
xi_arr = xi_arr.reshape(-1, xi_arr.shape[-1])
idx_valid = np.all(
(
coords[0][-1] <= xi_arr[:, 0],
xi_arr[:, 0] <= coords[0][0],
coords[1][0] <= xi_arr[:, 1],
xi_arr[:, 1] <= coords[1][-1],
),
axis=0,
)
# Make a copy of values for RectBivariateSpline
result = np.empty_like(xi_arr[:, 0])
result[idx_valid] = interp.ev(xi_arr[idx_valid, 0], xi_arr[idx_valid, 1])
result[np.logical_not(idx_valid)] = fill_value

# Add back NaNs from dilated mask
results = np.atleast_1d(result.reshape(xi_shape[:-1]))
results[invalids.astype(bool)] = np.nan

return results

return rectbivariate_interpolator_with_fillvalue


@overload
Expand Down
4 changes: 2 additions & 2 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3601,8 +3601,8 @@ def interp_points(
"""
Interpolate raster values at a set of points.
Uses scipy.ndimage.map_coordinates if the Raster is on an equal grid, otherwise uses scipy.interpn
on a regular grid.
Uses scipy.ndimage.map_coordinates if the Raster is on an equal grid using "nearest" or "linear" (for speed),
otherwise uses scipy.interpn on a regular grid.
Optionally, user can enforce the interpretation of pixel coordinates in self.tags['AREA_OR_POINT']
to ensure that the interpolation of points is done at the right location. See parameter description
Expand Down
79 changes: 71 additions & 8 deletions tests/test_raster/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import xarray as xr
from pylint.lint import Run
from pylint.reporters.text import TextReporter
from scipy.ndimage import distance_transform_edt

import geoutils as gu
import geoutils.projtools as pt
Expand All @@ -27,6 +28,7 @@
from geoutils.misc import resampling_method_from_str
from geoutils.projtools import reproject_to_latlon
from geoutils.raster.raster import _default_nodata, _default_rio_attrs
from geoutils.raster.interpolate import method_to_order

DO_PLOT = False

Expand Down Expand Up @@ -2271,35 +2273,96 @@ def test_interp_points__synthetic(self, tag_aop: str | None, shift_aop: bool) ->
assert all(~np.isfinite(raster_points_mapcoords_edge))
assert all(~np.isfinite(raster_points_interpn_edge))

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path])
@pytest.mark.parametrize(
"method", ["nearest", "linear", "cubic", "quintic", "slinear", "pchip", "splinef2d"]
) # type: ignore
def test_interp_points__real(
self, method: Literal["nearest", "linear", "cubic", "quintic", "slinear", "pchip", "splinef2d"]
self, example: str, method: Literal["nearest", "linear", "cubic", "quintic", "slinear", "pchip", "splinef2d"]
) -> None:
"""Test interp_points for real data."""

r = gu.Raster(self.landsat_b4_path)
# 1/ Check the accuracy of the interpolation at an exact point, and between methods

r = gu.Raster(example)
r.set_area_or_point("Area", shift_area_or_point=False)

# Test for an invidiual point (shape can be tricky at 1 dimension)
x = 493120.0
y = 3101000.0
i, j = r.xy2ij(x, y)
# Test for an individual point (shape can be tricky at 1 dimension)
itest = 100
jtest = 100
x, y = r.ij2xy(itest, jtest)
val = r.interp_points((x, y), method=method, force_scipy_function="map_coordinates")[0]
val_img = r.data[int(i[0]), int(j[0])]
val_img = r.data[itest, jtest]
if "nearest" in method or "linear" in method:
assert val_img == val

# Check the result is exactly the same for both methods
val2 = r.interp_points((x, y), method=method, force_scipy_function="interpn")[0]
assert val2 == pytest.approx(val)

# Finally, check that interp convert to latlon
# Check that interp convert to latlon
lat, lon = gu.projtools.reproject_to_latlon([x, y], in_crs=r.crs)
val_latlon = r.interp_points((lat, lon), method=method, input_latlon=True)[0]
assert val == pytest.approx(val_latlon, abs=0.0001)

# 2/ Check the propagation of NaNs

# Convert raster to float
r = r.astype(np.float32)

# Create a NaN at a given pixel (we know the landsat example has no NaNs to begin with)
i0, j0 = (10, 10)
r[i0, j0] = np.nan

# All surrounding pixels with distance half the method order rounded up should be NaNs
order = method_to_order[method]
if method == "linear":
order = 1
d = int(np.ceil(order / 2))
# No NaN propagation for linear
indices_nan = [(i0 + i, j0 + j) for i in np.arange(-d, d+1) for j in np.arange(-d, d+1) if (np.abs(i) + np.abs(j)) <= d]
i,j = list(zip(*indices_nan))
x, y = r.ij2xy(i, j)
vals = r.interp_points((x, y), method=method, force_scipy_function="map_coordinates")[0]
vals2 = r.interp_points((x, y), method=method, force_scipy_function="interpn")[0]

assert all(np.isnan(np.atleast_1d(vals))) and all(np.isnan(np.atleast_1d(vals2)))

# 3/ Check that valid interpolated values at the edge of NaNs are free of errors

# We compare values interpolated right at the edge of valid values near a NaN between
# 1/ Implementation of interp_points (that replaces NaNs by a placeholder value of 0 during interpolation)
# 2/ Raster filled with nearest neighbour at NaN coordinates, then running interp_points
# If the interpolated value are free of the influence of the placeholder value of 0, the interpolated value
# should be exactly the same

# We get the indexes of valid pixels just at the edge of NaNs
indices_edge = [(i0 + i, j0 + j) for i in np.arange(-d-1, d+2) for j in np.arange(-d-1, d+2) if (np.abs(i) + np.abs(j)) == d+1]
i, j = list(zip(*indices_edge))
x, y = r.ij2xy(i, j)
# And get their interpolated value
vals = r.interp_points((x, y), method=method, force_scipy_function="map_coordinates")[0]
vals2 = r.interp_points((x, y), method=method, force_scipy_function="interpn")[0]

# Then we fill the NaNs in the raster with the nearest neighbour
r_arr = r.get_nanarray()
# Elegant solution from: https://stackoverflow.com/questions/5551286/filling-gaps-in-a-numpy-array
indices = distance_transform_edt(~np.isfinite(r_arr), return_distances=False, return_indices=True)
r_arr = r_arr[tuple(indices)]
r.data = r_arr

# All raster values should be valid now
assert np.all(np.isfinite(r_arr))

# And get the interpolated values
vals_near = r.interp_points((x, y), method=method, force_scipy_function="map_coordinates")[0]
vals2_near = r.interp_points((x, y), method=method, force_scipy_function="interpn")[0]

# Both sets of values should be exactly the same, without any NaNs
assert np.allclose(vals, vals_near, equal_nan=False)
assert np.allclose(vals2, vals2_near, equal_nan=False)


def test_value_at_coords(self) -> None:
"""
Test that value at coords works as intended
Expand Down

0 comments on commit 4e29c5e

Please sign in to comment.