Skip to content

Commit

Permalink
Modified ndfilters.mean_filter() to use `ndfilters.generic_filter()…
Browse files Browse the repository at this point in the history
…`. (#18)
  • Loading branch information
byrdie authored Sep 17, 2024
1 parent 3ae5925 commit f6535d6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 91 deletions.
113 changes: 27 additions & 86 deletions ndfilters/_mean.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Literal
import numpy as np
import numba
import astropy.units as u
import ndfilters

__all__ = [
"mean_filter",
]


def mean_filter(
array: np.ndarray,
array: np.ndarray | u.Quantity,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
) -> np.ndarray:
"""
Calculate a multidimensional rolling mean.
The kernel is truncated at the edges of the array.
Parameters
----------
Expand All @@ -23,14 +26,21 @@ def mean_filter(
size
The shape of the kernel over which the mean will be calculated.
axis
The axes over which to apply the kernel. If :obj:`None` the kernel
is applied to every axis.
The axes over which to apply the kernel.
Should either be a scalar or have the same number of items as `size`.
If :obj:`None` (the default) the kernel spans every axis of the array.
where
A boolean mask used to select which elements of the input array to filter.
An optional mask that can be used to exclude parts of the array during
filtering.
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
Returns
-------
A copy of the array with a mean filter applied.
A copy of the array with the mean filter applied.
Examples
--------
Expand All @@ -47,92 +57,23 @@ def mean_filter(
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
axs[0].set_title("original image");
axs[0].imshow(img, cmap="gray");
axs[1].set_title("mean filtered image");
axs[1].set_title("filtered image");
axs[1].imshow(img_filtered, cmap="gray");
"""
array, where = np.broadcast_arrays(array, where, subok=True)

if axis is None:
axis = tuple(range(array.ndim))
else:
axis = np.core.numeric.normalize_axis_tuple(axis=axis, ndim=array.ndim)

if isinstance(size, int):
size = (size,) * len(axis)

result = array
for sz, ax in zip(size, axis, strict=True):
result = _mean_filter_1d(
array=result,
size=sz,
axis=ax,
where=where,
)

return result

def _mean_filter_1d(
array: np.ndarray,
size: int,
axis: int,
where: np.ndarray,
) -> np.ndarray:

array = np.moveaxis(array, axis, ~0)
where = np.moveaxis(where, axis, ~0)

shape = array.shape

array = array.reshape(-1, shape[~0])
where = where.reshape(-1, shape[~0])

result = _mean_filter_1d_numba(
"""
return ndfilters.generic_filter(
array=array,
function=_mean,
size=size,
axis=axis,
where=where,
out=np.empty_like(array),
mode=mode,
)

result = result.reshape(shape)

result = np.moveaxis(result, ~0, axis)

return result


@numba.njit(parallel=True, cache=True)
def _mean_filter_1d_numba(
@numba.njit
def _mean(
array: np.ndarray,
size: int,
where: np.ndarray,
out: np.ndarray,
) -> np.ndarray:

num_t, num_x = array.shape

halfsize = size // 2

for t in numba.prange(num_t):

for i in range(num_x):

sum = 0
count = 0
for j in range(size):

j2 = j - halfsize

k = i + j2
if k < 0:
continue
elif k >= num_x:
continue

if where[t, k]:
sum += array[t, k]
count += 1

out[t, i] = sum / count

return out
args: tuple[float],
) -> float:
return np.mean(array)
17 changes: 12 additions & 5 deletions ndfilters/_tests/test_mean.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
import pytest
import numpy as np
import scipy.ndimage
Expand Down Expand Up @@ -32,15 +33,25 @@
(2, 1, 0),
],
)
@pytest.mark.parametrize(
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
],
)
def test_mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
mode: Literal["mirror", "nearest", "wrap", "truncate"],
):
kwargs = dict(
array=array,
size=size,
axis=axis,
mode=mode,
)

if axis is None:
Expand Down Expand Up @@ -74,11 +85,7 @@ def test_mean_filter(
expected = scipy.ndimage.uniform_filter(
input=array,
size=size_scipy,
mode="constant",
) / scipy.ndimage.uniform_filter(
input=np.ones(array.shape),
size=size_scipy,
mode="constant",
mode=mode,
)

if isinstance(result, u.Quantity):
Expand Down

0 comments on commit f6535d6

Please sign in to comment.