Skip to content

Commit

Permalink
Merge pull request #56 from neuro-ml/dev
Browse files Browse the repository at this point in the history
Forgot to pass `num_threads`
  • Loading branch information
vovaf709 authored Jun 5, 2024
2 parents d9eaabd + 162eb66 commit d15ec60
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
29 changes: 16 additions & 13 deletions imops/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def binary_opening(


def distance_transform_edt(
input: np.ndarray,
image: np.ndarray,
sampling: Tuple[float] = None,
return_distances: bool = True,
return_indices: bool = False,
Expand All @@ -384,7 +384,7 @@ def distance_transform_edt(
"""
Fast parallelizable Euclidean distance transform for <= 3D inputs
This function calculates the distance transform of the `input`, by
This function calculates the distance transform of the `image`, by
replacing each foreground (non-zero) element, with its
shortest distance to the background (any zero-valued element).
Expand All @@ -394,10 +394,10 @@ def distance_transform_edt(
Parameters
----------
input : array_like
image : array_like
input data to transform. Can be any type but will be converted
into binary: 1 wherever input equates to True, 0 elsewhere
sampling : tuple of `input.ndim` floats, optional
sampling : tuple of `image.ndim` floats, optional
spacing of elements along each dimension. If a sequence, must be of
length equal to the input rank; if a single number, this is used for
all axes. If not specified, a grid spacing of unity is implied
Expand Down Expand Up @@ -484,28 +484,31 @@ def distance_transform_edt(
num_threads = normalize_num_threads(num_threads, backend, warn_stacklevel=3)

if backend.name == 'Scipy':
return scipy_distance_transform_edt(input, sampling, return_distances, return_indices)
return scipy_distance_transform_edt(image, sampling, return_distances, return_indices)

if input.ndim > 3:
if image.ndim > 3:
warn("Fast Euclidean Distance Transform is only supported for ndim<=3. Falling back to scipy's implementation.")
return scipy_distance_transform_edt(input, sampling, return_distances, return_indices)
return scipy_distance_transform_edt(image, sampling, return_distances, return_indices)

if (not return_distances) and (not return_indices):
raise RuntimeError('At least one of `return_distances`/`return_indices` must be True')

input = np.atleast_1d(np.where(input, 1, 0).astype(np.int8))
if image.dtype != bool:
image = np.atleast_1d(np.where(image, 1, 0))
if sampling is not None:
sampling = _ni_support._normalize_sequence(sampling, input.ndim)
sampling = _ni_support._normalize_sequence(sampling, image.ndim)
sampling = np.asarray(sampling, dtype=np.float64)
if not sampling.flags.contiguous:
sampling = sampling.copy()

if return_indices:
ft = np.zeros((input.ndim,) + input.shape, dtype=np.int32)
euclidean_feature_transform(input, sampling, ft)
ft = np.zeros((image.ndim,) + image.shape, dtype=np.int32)
euclidean_feature_transform(image, sampling, ft)

if return_distances:
dt = edt(input, anisotropy=sampling.astype(np.float32)) if sampling is not None else edt(input)
if sampling is not None:
dt = edt(image, anisotropy=sampling.astype(np.float32), parallel=num_threads)
else:
dt = edt(image, parallel=num_threads)

result = []
if return_distances:
Expand Down
24 changes: 12 additions & 12 deletions imops/zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def zoom_to_shape(


def _zoom(
input: np.ndarray,
image: np.ndarray,
zoom: Sequence[float],
output: np.ndarray = None,
order: int = 1,
Expand All @@ -216,15 +216,15 @@ def _zoom(
if backend.name not in ('Scipy', 'Numba', 'Cython'):
raise ValueError(f'Unsupported backend "{backend.name}".')

ndim = input.ndim
dtype = input.dtype
ndim = image.ndim
dtype = image.dtype
cval = np.dtype(dtype).type(cval)
zoom = fill_by_indices(np.ones(input.ndim, 'float64'), zoom, range(input.ndim))
zoom = fill_by_indices(np.ones(image.ndim, 'float64'), zoom, range(image.ndim))
num_threads = normalize_num_threads(num_threads, backend, warn_stacklevel=4)

if backend.name == 'Scipy':
return scipy_zoom(
input, zoom, output=output, order=order, mode=mode, cval=cval, prefilter=prefilter, grid_mode=grid_mode
image, zoom, output=output, order=order, mode=mode, cval=cval, prefilter=prefilter, grid_mode=grid_mode
)

if (
Expand All @@ -246,7 +246,7 @@ def _zoom(
stacklevel=3,
)
return scipy_zoom(
input, zoom, output=output, order=order, mode=mode, cval=cval, prefilter=prefilter, grid_mode=grid_mode
image, zoom, output=output, order=order, mode=mode, cval=cval, prefilter=prefilter, grid_mode=grid_mode
)

if backend.name == 'Cython':
Expand All @@ -264,28 +264,28 @@ def _zoom(
n_dummy = 3 - ndim if ndim <= 3 else 0

if n_dummy:
input = input[(None,) * n_dummy]
image = image[(None,) * n_dummy]
zoom = [*(1,) * n_dummy, *zoom]

zoom = np.array(zoom, dtype=np.float64)
is_contiguous = input.data.c_contiguous
is_contiguous = image.data.c_contiguous
c_contiguous_permutaion = None
args = () if backend.name in ('Numba',) else (num_threads,)

if not is_contiguous:
c_contiguous_permutaion = get_c_contiguous_permutaion(input)
c_contiguous_permutaion = get_c_contiguous_permutaion(image)
if c_contiguous_permutaion is not None:
out = src_zoom(
np.transpose(input, c_contiguous_permutaion),
np.transpose(image, c_contiguous_permutaion),
zoom[c_contiguous_permutaion],
cval,
*args,
)
else:
warn("Input array can't be represented as C-contiguous, performance can drop a lot.", stacklevel=3)
out = src_zoom(input, zoom, cval, *args)
out = src_zoom(image, zoom, cval, *args)
else:
out = src_zoom(input, zoom, cval, *args)
out = src_zoom(image, zoom, cval, *args)

if c_contiguous_permutaion is not None:
out = np.transpose(out, inverse_permutation(c_contiguous_permutaion))
Expand Down

0 comments on commit d15ec60

Please sign in to comment.