From 8a0d8b7c19d0f94bbfea5364470c5ab84058e67d Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 10:19:14 -0600 Subject: [PATCH 1/9] Added "nearest", "wrap", and "truncate" modes to `ndfilters.generic_filter()`. --- ndfilters/_generic.py | 92 ++++++++++++++++++++++++++++---------- ndfilters/_trimmed_mean.py | 7 ++- 2 files changed, 74 insertions(+), 25 deletions(-) diff --git a/ndfilters/_generic.py b/ndfilters/_generic.py index de22186..fb28ee3 100644 --- a/ndfilters/_generic.py +++ b/ndfilters/_generic.py @@ -14,7 +14,7 @@ def generic_filter( size: int | tuple[int, ...], axis: None | int | tuple[int, ...] = None, where: bool | np.ndarray = True, - mode: Literal["mirror"] = "mirror", + mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", args: tuple = (), ) -> np.ndarray: """ @@ -42,7 +42,8 @@ def generic_filter( mode The method used to extend the input array beyond its boundaries. See :func:`scipy.ndimage.generic_filter` for the definitions. - Currently, only "reflect" mode is supported. + Currently, only "mirror", "nearest", "wrap", and "truncate" modes are + supported. args Extra arguments to pass to function. @@ -98,9 +99,6 @@ def function(a: np.ndarray, args: tuple) -> float: f"{size=} should have the same number of elements as {axis=}." ) - if mode != "mirror": # pragma: nocover - raise ValueError(f"Only mode='reflected' is supported, got {mode=}") - axis_numba = ~np.arange(len(axis))[::-1] shape = array.shape @@ -138,6 +136,30 @@ def function(a: np.ndarray, args: tuple) -> float: return result +@numba.njit +def _rectify_index_lower(index: int, size: int, mode: str) -> int: + if mode == "mirror": + return -index + elif mode == "nearest": + return 0 + elif mode == "wrap": + return index % size + else: + raise ValueError + + +@numba.njit +def _rectify_index_upper(index: int, size: int, mode: str) -> int: + if mode == "mirror": + return ~(index % size + 1) + elif mode == "nearest": + return size - 1 + elif mode == "wrap": + return index % size + else: + raise ValueError + + @numba.njit(parallel=True) def _generic_filter_1d( array: np.ndarray, @@ -157,8 +179,8 @@ def _generic_filter_1d( for ix in numba.prange(array_shape_x): - values = np.empty(shape=size) - mask = np.empty(shape=size, dtype=np.bool_) + values = np.zeros(shape=size) + mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): @@ -166,9 +188,13 @@ def _generic_filter_1d( jx = ix + px if jx < 0: - jx = -jx + if mode == "truncate": + continue + jx = _rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: - jx = ~(jx % array_shape_x + 1) + if mode == "truncate": + continue + jx = _rectify_index_upper(jx, array_shape_x, mode) values[kx] = array[it, jx] mask[kx] = where[it, jx] @@ -198,8 +224,8 @@ def _generic_filter_2d( for ix in numba.prange(array_shape_x): for iy in numba.prange(array_shape_y): - values = np.empty(shape=size) - mask = np.empty(shape=size, dtype=np.bool_) + values = np.zeros(shape=size) + mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): @@ -207,9 +233,13 @@ def _generic_filter_2d( jx = ix + px if jx < 0: - jx = -jx + if mode == "truncate": + continue + jx = _rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: - jx = ~(jx % array_shape_x + 1) + if mode == "truncate": + continue + jx = _rectify_index_upper(jx, array_shape_x, mode) for ky in range(kernel_shape_y): @@ -217,9 +247,13 @@ def _generic_filter_2d( jy = iy + py if jy < 0: - jy = -jy + if mode == "truncate": + continue + jy = _rectify_index_lower(jy, array_shape_y, mode) elif jy >= array_shape_y: - jy = ~(jy % array_shape_y + 1) + if mode == "truncate": + continue + jy = _rectify_index_upper(jy, array_shape_y, mode) values[kx, ky] = array[it, jx, jy] mask[kx, ky] = where[it, jx, jy] @@ -253,8 +287,8 @@ def _generic_filter_3d( for iy in numba.prange(array_shape_y): for iz in numba.prange(array_shape_z): - values = np.empty(shape=size) - mask = np.empty(shape=size, dtype=np.bool_) + values = np.zeros(shape=size) + mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): @@ -262,9 +296,13 @@ def _generic_filter_3d( jx = ix + px if jx < 0: - jx = -jx + if mode == "truncate": + continue + jx = _rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: - jx = ~(jx % array_shape_x + 1) + if mode == "truncate": + continue + jx = _rectify_index_upper(jx, array_shape_x, mode) for ky in range(kernel_shape_y): @@ -272,9 +310,13 @@ def _generic_filter_3d( jy = iy + py if jy < 0: - jy = -jy + if mode == "truncate": + continue + jy = _rectify_index_lower(jy, array_shape_y, mode) elif jy >= array_shape_y: - jy = ~(jy % array_shape_y + 1) + if mode == "truncate": + continue + jy = _rectify_index_upper(jy, array_shape_y, mode) for kz in range(kernel_shape_z): @@ -282,9 +324,13 @@ def _generic_filter_3d( jz = iz + pz if jz < 0: - jz = -jz + if mode == "truncate": + continue + jz = _rectify_index_lower(jz, array_shape_z, mode) elif jz >= array_shape_z: - jz = ~(jz % array_shape_z + 1) + if mode == "truncate": + continue + jz = _rectify_index_upper(jz, array_shape_z, mode) values[kx, ky, kz] = array[it, jx, jy, jz] mask[kx, ky, kz] = where[it, jx, jy, jz] diff --git a/ndfilters/_trimmed_mean.py b/ndfilters/_trimmed_mean.py index 1bd011a..54eceb4 100644 --- a/ndfilters/_trimmed_mean.py +++ b/ndfilters/_trimmed_mean.py @@ -14,7 +14,7 @@ def trimmed_mean_filter( size: int | tuple[int, ...], axis: None | int | tuple[int, ...] = None, where: bool | np.ndarray = True, - mode: Literal["mirror"] = "mirror", + mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", proportion: float = 0.25, ) -> np.ndarray: """ @@ -36,7 +36,8 @@ def trimmed_mean_filter( mode The method used to extend the input array beyond its boundaries. See :func:`scipy.ndimage.generic_filter` for the definitions. - Currently, only "reflect" mode is supported. + Currently, only "mirror", "nearest", "wrap", and "truncate" modes are + supported. proportion The proportion to cut from the top and bottom of the distribution. @@ -83,6 +84,8 @@ def _trimmed_mean( (proportion,) = args nobs = array.size + if nobs == 0: + return np.nan lowercut = int(proportion * nobs) uppercut = nobs - lowercut if lowercut > uppercut: # pragma: nocover From f1fdadf4c6dce8406520e31ca2e4f99a881b4fe5 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 10:40:17 -0600 Subject: [PATCH 2/9] coverage --- ndfilters/_tests/test_generic.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/ndfilters/_tests/test_generic.py b/ndfilters/_tests/test_generic.py index 4506429..ed15ceb 100644 --- a/ndfilters/_tests/test_generic.py +++ b/ndfilters/_tests/test_generic.py @@ -36,13 +36,17 @@ def _mean(a: np.ndarray, args: tuple = ()) -> float: argnames="mode", argvalues=[ "mirror", + "nearest", + "wrap", + "truncate", + pytest.param("foo", marks=pytest.mark.xfail), ], ) def test_generic_filter( array: np.ndarray | u.Quantity, function: Callable[[np.ndarray], float], size: int | tuple[int, ...], - mode: Literal["mirror"], + mode: Literal["mirror", "nearest", "wrap", "truncate"], ): result = ndfilters.generic_filter( array=array, @@ -50,17 +54,19 @@ def test_generic_filter( size=size, mode=mode, ) + assert result.shape == array.shape + assert result.sum() != 0 - result_expected = scipy.ndimage.generic_filter( - input=array, - function=function, - size=size, - mode=mode, - ) + if mode != "truncate": + result_expected = scipy.ndimage.generic_filter( + input=array, + function=function, + size=size, + mode=mode, + ) - assert result.shape == array.shape - if isinstance(array, u.Quantity): - assert np.all(result.value == result_expected) - assert result.unit == array.unit - else: - assert np.all(result == result_expected) + if isinstance(array, u.Quantity): + assert np.all(result.value == result_expected) + assert result.unit == array.unit + else: + assert np.all(result == result_expected) From 72c358b31c8db5897414f106eb2690f0560fd3c5 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 10:51:24 -0600 Subject: [PATCH 3/9] More coverage --- ndfilters/_generic.py | 4 +-- ndfilters/_tests/test_trimmed_mean.py | 39 +++++++++++++++++++-------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/ndfilters/_generic.py b/ndfilters/_generic.py index fb28ee3..73cd4f3 100644 --- a/ndfilters/_generic.py +++ b/ndfilters/_generic.py @@ -144,7 +144,7 @@ def _rectify_index_lower(index: int, size: int, mode: str) -> int: return 0 elif mode == "wrap": return index % size - else: + else: # pragma: nocover raise ValueError @@ -156,7 +156,7 @@ def _rectify_index_upper(index: int, size: int, mode: str) -> int: return size - 1 elif mode == "wrap": return index % size - else: + else: # pragma: nocover raise ValueError diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 989c0b2..1d213a3 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -1,3 +1,4 @@ +from typing import Literal import pytest import numpy as np import scipy.ndimage @@ -31,10 +32,21 @@ ], ) @pytest.mark.parametrize("proportion", [0.25, 0.45]) +@pytest.mark.parametrize( + argnames="mode", + argvalues=[ + "mirror", + "nearest", + "wrap", + "truncate", + pytest.param("foo", marks=pytest.mark.xfail), + ], +) def test_trimmed_mean_filter( array: np.ndarray, size: int | tuple[int, ...], axis: None | int | tuple[int, ...], + mode: Literal["mirror", "nearest", "wrap", "truncate"], proportion: float, ): if axis is None: @@ -76,16 +88,21 @@ def test_trimmed_mean_filter( axis=axis, ) - size_scipy = [1] * array.ndim - for i, ax in enumerate(axis_normalized): - size_scipy[ax] = size_normalized[i] + assert result.shape == array.shape + assert result.sum() != 0 - expected = scipy.ndimage.generic_filter( - input=array, - function=scipy.stats.trim_mean, - size=size_scipy, - mode="mirror", - extra_keywords=dict(proportiontocut=proportion), - ) + if mode != "truncate": + + size_scipy = [1] * array.ndim + for i, ax in enumerate(axis_normalized): + size_scipy[ax] = size_normalized[i] + + expected = scipy.ndimage.generic_filter( + input=array, + function=scipy.stats.trim_mean, + size=size_scipy, + mode="mirror", + extra_keywords=dict(proportiontocut=proportion), + ) - assert np.allclose(result, expected) + assert np.allclose(result, expected) From cc35ac1b4052016f318e1f2125e0d2516b66745f Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 10:52:08 -0600 Subject: [PATCH 4/9] trim number of tests --- ndfilters/_tests/test_trimmed_mean.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 1d213a3..5617d4a 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -39,7 +39,6 @@ "nearest", "wrap", "truncate", - pytest.param("foo", marks=pytest.mark.xfail), ], ) def test_trimmed_mean_filter( From de39a828bd172c91e2f3f0b1222b7ce9096a7790 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 11:19:38 -0600 Subject: [PATCH 5/9] test fixes --- ndfilters/_tests/test_trimmed_mean.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 5617d4a..9abc514 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -85,6 +85,7 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + mode=mode, ) assert result.shape == array.shape @@ -100,7 +101,7 @@ def test_trimmed_mean_filter( input=array, function=scipy.stats.trim_mean, size=size_scipy, - mode="mirror", + mode=mode, extra_keywords=dict(proportiontocut=proportion), ) From fb0390acf0f251a419a2285a4c98a0daf7e743fe Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 12:50:36 -0600 Subject: [PATCH 6/9] coverage --- ndfilters/_tests/test_trimmed_mean.py | 36 ++++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 9abc514..5f12b05 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -31,6 +31,13 @@ (2, 1, 0), ], ) +@pytest.mark.parametrize( + argnames="where", + argvalues=[ + True, + False, + ] +) @pytest.mark.parametrize("proportion", [0.25, 0.45]) @pytest.mark.parametrize( argnames="mode", @@ -45,6 +52,7 @@ def test_trimmed_mean_filter( array: np.ndarray, size: int | tuple[int, ...], axis: None | int | tuple[int, ...], + where: bool | np.ndarray, mode: Literal["mirror", "nearest", "wrap", "truncate"], proportion: float, ): @@ -91,18 +99,22 @@ def test_trimmed_mean_filter( assert result.shape == array.shape assert result.sum() != 0 - if mode != "truncate": + if mode == "truncate": + return + + if not np.all(where): + return - size_scipy = [1] * array.ndim - for i, ax in enumerate(axis_normalized): - size_scipy[ax] = size_normalized[i] + size_scipy = [1] * array.ndim + for i, ax in enumerate(axis_normalized): + size_scipy[ax] = size_normalized[i] - expected = scipy.ndimage.generic_filter( - input=array, - function=scipy.stats.trim_mean, - size=size_scipy, - mode=mode, - extra_keywords=dict(proportiontocut=proportion), - ) + expected = scipy.ndimage.generic_filter( + input=array, + function=scipy.stats.trim_mean, + size=size_scipy, + mode=mode, + extra_keywords=dict(proportiontocut=proportion), + ) - assert np.allclose(result, expected) + assert np.allclose(result, expected) From 52581847a26a8603c71994afb0f59192642cde5d Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 12:51:29 -0600 Subject: [PATCH 7/9] black --- ndfilters/_tests/test_trimmed_mean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 5f12b05..0556648 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -36,7 +36,7 @@ argvalues=[ True, False, - ] + ], ) @pytest.mark.parametrize("proportion", [0.25, 0.45]) @pytest.mark.parametrize( From 4a4dfd5bff1c3f1bb967ce130aa0e3b5b3c86725 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 13:03:43 -0600 Subject: [PATCH 8/9] test fixes --- ndfilters/_tests/test_trimmed_mean.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 0556648..a18605c 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -70,6 +70,7 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + where=where, ) return @@ -85,6 +86,7 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + where=where, ) return @@ -93,6 +95,7 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + where=where, mode=mode, ) From 8addfd9fa8fe801caa652f11c59d5598c048121c Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 23 Aug 2024 13:07:46 -0600 Subject: [PATCH 9/9] coverage --- ndfilters/_tests/test_trimmed_mean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index a18605c..50b3441 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -21,6 +21,7 @@ @pytest.mark.parametrize( argnames="axis", argvalues=[ + None, 0, -1, (0,),