Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added "nearest", "wrap", and "truncate" modes to ndfilters.generic_filter(). #17

Merged
merged 9 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 69 additions & 23 deletions ndfilters/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def generic_filter(
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror"] = "mirror",
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
args: tuple = (),
) -> np.ndarray:
"""
Expand Down Expand Up @@ -42,7 +42,8 @@ def generic_filter(
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "reflect" mode is supported.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
args
Extra arguments to pass to function.

Expand Down Expand Up @@ -98,9 +99,6 @@ def function(a: np.ndarray, args: tuple) -> float:
f"{size=} should have the same number of elements as {axis=}."
)

if mode != "mirror": # pragma: nocover
raise ValueError(f"Only mode='reflected' is supported, got {mode=}")

axis_numba = ~np.arange(len(axis))[::-1]

shape = array.shape
Expand Down Expand Up @@ -138,6 +136,30 @@ def function(a: np.ndarray, args: tuple) -> float:
return result


@numba.njit
def _rectify_index_lower(index: int, size: int, mode: str) -> int:
if mode == "mirror":
return -index
elif mode == "nearest":
return 0
elif mode == "wrap":
return index % size
else: # pragma: nocover
raise ValueError


@numba.njit
def _rectify_index_upper(index: int, size: int, mode: str) -> int:
if mode == "mirror":
return ~(index % size + 1)
elif mode == "nearest":
return size - 1
elif mode == "wrap":
return index % size
else: # pragma: nocover
raise ValueError


@numba.njit(parallel=True)
def _generic_filter_1d(
array: np.ndarray,
Expand All @@ -157,18 +179,22 @@ def _generic_filter_1d(

for ix in numba.prange(array_shape_x):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue
jx = _rectify_index_upper(jx, array_shape_x, mode)

values[kx] = array[it, jx]
mask[kx] = where[it, jx]
Expand Down Expand Up @@ -198,28 +224,36 @@ def _generic_filter_2d(
for ix in numba.prange(array_shape_x):
for iy in numba.prange(array_shape_y):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue
jx = _rectify_index_upper(jx, array_shape_x, mode)

for ky in range(kernel_shape_y):

py = ky - kernel_shape_y // 2
jy = iy + py

if jy < 0:
jy = -jy
if mode == "truncate":
continue
jy = _rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
jy = ~(jy % array_shape_y + 1)
if mode == "truncate":
continue
jy = _rectify_index_upper(jy, array_shape_y, mode)

values[kx, ky] = array[it, jx, jy]
mask[kx, ky] = where[it, jx, jy]
Expand Down Expand Up @@ -253,38 +287,50 @@ def _generic_filter_3d(
for iy in numba.prange(array_shape_y):
for iz in numba.prange(array_shape_z):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue
jx = _rectify_index_upper(jx, array_shape_x, mode)

for ky in range(kernel_shape_y):

py = ky - kernel_shape_y // 2
jy = iy + py

if jy < 0:
jy = -jy
if mode == "truncate":
continue
jy = _rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
jy = ~(jy % array_shape_y + 1)
if mode == "truncate":
continue
jy = _rectify_index_upper(jy, array_shape_y, mode)

for kz in range(kernel_shape_z):

pz = kz - kernel_shape_z // 2
jz = iz + pz

if jz < 0:
jz = -jz
if mode == "truncate":
continue
jz = _rectify_index_lower(jz, array_shape_z, mode)
elif jz >= array_shape_z:
jz = ~(jz % array_shape_z + 1)
if mode == "truncate":
continue
jz = _rectify_index_upper(jz, array_shape_z, mode)

values[kx, ky, kz] = array[it, jx, jy, jz]
mask[kx, ky, kz] = where[it, jx, jy, jz]
Expand Down
32 changes: 19 additions & 13 deletions ndfilters/_tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,37 @@ def _mean(a: np.ndarray, args: tuple = ()) -> float:
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
"truncate",
pytest.param("foo", marks=pytest.mark.xfail),
],
)
def test_generic_filter(
array: np.ndarray | u.Quantity,
function: Callable[[np.ndarray], float],
size: int | tuple[int, ...],
mode: Literal["mirror"],
mode: Literal["mirror", "nearest", "wrap", "truncate"],
):
result = ndfilters.generic_filter(
array=array,
function=function,
size=size,
mode=mode,
)
assert result.shape == array.shape
assert result.sum() != 0

result_expected = scipy.ndimage.generic_filter(
input=array,
function=function,
size=size,
mode=mode,
)
if mode != "truncate":
result_expected = scipy.ndimage.generic_filter(
input=array,
function=function,
size=size,
mode=mode,
)

assert result.shape == array.shape
if isinstance(array, u.Quantity):
assert np.all(result.value == result_expected)
assert result.unit == array.unit
else:
assert np.all(result == result_expected)
if isinstance(array, u.Quantity):
assert np.all(result.value == result_expected)
assert result.unit == array.unit
else:
assert np.all(result == result_expected)
35 changes: 34 additions & 1 deletion ndfilters/_tests/test_trimmed_mean.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
import pytest
import numpy as np
import scipy.ndimage
Expand All @@ -20,6 +21,7 @@
@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
0,
-1,
(0,),
Expand All @@ -30,11 +32,29 @@
(2, 1, 0),
],
)
@pytest.mark.parametrize(
argnames="where",
argvalues=[
True,
False,
],
)
@pytest.mark.parametrize("proportion", [0.25, 0.45])
@pytest.mark.parametrize(
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
"truncate",
],
)
def test_trimmed_mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
where: bool | np.ndarray,
mode: Literal["mirror", "nearest", "wrap", "truncate"],
proportion: float,
):
if axis is None:
Expand All @@ -51,6 +71,7 @@ def test_trimmed_mean_filter(
size=size,
proportion=proportion,
axis=axis,
where=where,
)
return

Expand All @@ -66,6 +87,7 @@ def test_trimmed_mean_filter(
size=size,
proportion=proportion,
axis=axis,
where=where,
)
return

Expand All @@ -74,8 +96,19 @@ def test_trimmed_mean_filter(
size=size,
proportion=proportion,
axis=axis,
where=where,
mode=mode,
)

assert result.shape == array.shape
assert result.sum() != 0

if mode == "truncate":
return

if not np.all(where):
return

size_scipy = [1] * array.ndim
for i, ax in enumerate(axis_normalized):
size_scipy[ax] = size_normalized[i]
Expand All @@ -84,7 +117,7 @@ def test_trimmed_mean_filter(
input=array,
function=scipy.stats.trim_mean,
size=size_scipy,
mode="mirror",
mode=mode,
extra_keywords=dict(proportiontocut=proportion),
)

Expand Down
7 changes: 5 additions & 2 deletions ndfilters/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def trimmed_mean_filter(
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror"] = "mirror",
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
proportion: float = 0.25,
) -> np.ndarray:
"""
Expand All @@ -36,7 +36,8 @@ def trimmed_mean_filter(
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "reflect" mode is supported.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
proportion
The proportion to cut from the top and bottom of the distribution.

Expand Down Expand Up @@ -83,6 +84,8 @@ def _trimmed_mean(
(proportion,) = args

nobs = array.size
if nobs == 0:
return np.nan
lowercut = int(proportion * nobs)
uppercut = nobs - lowercut
if lowercut > uppercut: # pragma: nocover
Expand Down
Loading