diff --git a/ndfilters/_mean.py b/ndfilters/_mean.py index 1d9559c..42dff9c 100644 --- a/ndfilters/_mean.py +++ b/ndfilters/_mean.py @@ -1,5 +1,8 @@ +from typing import Literal import numpy as np import numba +import astropy.units as u +import ndfilters __all__ = [ "mean_filter", @@ -7,14 +10,14 @@ def mean_filter( - array: np.ndarray, + array: np.ndarray | u.Quantity, size: int | tuple[int, ...], axis: None | int | tuple[int, ...] = None, where: bool | np.ndarray = True, + mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", ) -> np.ndarray: """ Calculate a multidimensional rolling mean. - The kernel is truncated at the edges of the array. Parameters ---------- @@ -23,14 +26,21 @@ def mean_filter( size The shape of the kernel over which the mean will be calculated. axis - The axes over which to apply the kernel. If :obj:`None` the kernel - is applied to every axis. + The axes over which to apply the kernel. + Should either be a scalar or have the same number of items as `size`. + If :obj:`None` (the default) the kernel spans every axis of the array. where - A boolean mask used to select which elements of the input array to filter. + An optional mask that can be used to exclude parts of the array during + filtering. + mode + The method used to extend the input array beyond its boundaries. + See :func:`scipy.ndimage.generic_filter` for the definitions. + Currently, only "mirror", "nearest", "wrap", and "truncate" modes are + supported. Returns ------- - A copy of the array with a mean filter applied. + A copy of the array with the mean filter applied. Examples -------- @@ -47,92 +57,23 @@ def mean_filter( fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True) axs[0].set_title("original image"); axs[0].imshow(img, cmap="gray"); - axs[1].set_title("mean filtered image"); + axs[1].set_title("filtered image"); axs[1].imshow(img_filtered, cmap="gray"); - """ - array, where = np.broadcast_arrays(array, where, subok=True) - - if axis is None: - axis = tuple(range(array.ndim)) - else: - axis = np.core.numeric.normalize_axis_tuple(axis=axis, ndim=array.ndim) - - if isinstance(size, int): - size = (size,) * len(axis) - - result = array - for sz, ax in zip(size, axis, strict=True): - result = _mean_filter_1d( - array=result, - size=sz, - axis=ax, - where=where, - ) - - return result - -def _mean_filter_1d( - array: np.ndarray, - size: int, - axis: int, - where: np.ndarray, -) -> np.ndarray: - - array = np.moveaxis(array, axis, ~0) - where = np.moveaxis(where, axis, ~0) - - shape = array.shape - - array = array.reshape(-1, shape[~0]) - where = where.reshape(-1, shape[~0]) - - result = _mean_filter_1d_numba( + """ + return ndfilters.generic_filter( array=array, + function=_mean, size=size, + axis=axis, where=where, - out=np.empty_like(array), + mode=mode, ) - result = result.reshape(shape) - result = np.moveaxis(result, ~0, axis) - - return result - - -@numba.njit(parallel=True, cache=True) -def _mean_filter_1d_numba( +@numba.njit +def _mean( array: np.ndarray, - size: int, - where: np.ndarray, - out: np.ndarray, -) -> np.ndarray: - - num_t, num_x = array.shape - - halfsize = size // 2 - - for t in numba.prange(num_t): - - for i in range(num_x): - - sum = 0 - count = 0 - for j in range(size): - - j2 = j - halfsize - - k = i + j2 - if k < 0: - continue - elif k >= num_x: - continue - - if where[t, k]: - sum += array[t, k] - count += 1 - - out[t, i] = sum / count - - return out + args: tuple[float], +) -> float: + return np.mean(array) diff --git a/ndfilters/_tests/test_mean.py b/ndfilters/_tests/test_mean.py index e13c64a..072221b 100644 --- a/ndfilters/_tests/test_mean.py +++ b/ndfilters/_tests/test_mean.py @@ -1,3 +1,4 @@ +from typing import Literal import pytest import numpy as np import scipy.ndimage @@ -32,15 +33,25 @@ (2, 1, 0), ], ) +@pytest.mark.parametrize( + argnames="mode", + argvalues=[ + "mirror", + "nearest", + "wrap", + ], +) def test_mean_filter( array: np.ndarray, size: int | tuple[int, ...], axis: None | int | tuple[int, ...], + mode: Literal["mirror", "nearest", "wrap", "truncate"], ): kwargs = dict( array=array, size=size, axis=axis, + mode=mode, ) if axis is None: @@ -74,11 +85,7 @@ def test_mean_filter( expected = scipy.ndimage.uniform_filter( input=array, size=size_scipy, - mode="constant", - ) / scipy.ndimage.uniform_filter( - input=np.ones(array.shape), - size=size_scipy, - mode="constant", + mode=mode, ) if isinstance(result, u.Quantity):