diff --git a/imops/measure.py b/imops/measure.py index 58c0372c..15c3b04d 100644 --- a/imops/measure.py +++ b/imops/measure.py @@ -210,8 +210,8 @@ def center_of_mass( src_center_of_mass = _fast_labeled_center_of_mass if backend.fast else _labeled_center_of_mass - if array.dtype != 'float64': - array = array.astype(float) + if array.dtype not in ('float32', 'float64'): + array = array.astype(np.float32) n_dummy = 3 - ndim if n_dummy: diff --git a/imops/src/_measure.pyx b/imops/src/_measure.pyx index e842be0b..8c1057ce 100644 --- a/imops/src/_measure.pyx +++ b/imops/src/_measure.pyx @@ -8,9 +8,11 @@ import numpy as np cimport numpy as np +cimport cython from cython.parallel import prange +ctypedef cython.floating FLOAT ctypedef fused LABEL: signed char @@ -33,16 +35,16 @@ cdef inline Py_ssize_t _find(LABEL num, const LABEL[:] nums) noexcept nogil: return -1 -def _labeled_center_of_mass(const double[:, :, :] nums, const LABEL[:, :, :] labels, +def _labeled_center_of_mass(const FLOAT[:, :, :] nums, const LABEL[:, :, :] labels, const LABEL[:] index) -> np.ndarray: - cdef const double[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) + cdef const FLOAT[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) cdef const LABEL[:, :, ::1] contiguous_labels = np.ascontiguousarray(labels) cdef const LABEL[:] contiguous_index = np.ascontiguousarray(index) cdef Py_ssize_t index_len = len(index) - cdef double[:, ::1] output = np.zeros((index_len, 3)) - cdef double[:] normalizers = np.zeros(index_len) + cdef FLOAT[:, ::1] output = np.zeros_like(nums, shape=(index_len, 3)) + cdef FLOAT[:] normalizers = np.zeros_like(nums, shape=(index_len,)) cdef Py_ssize_t rows = nums.shape[0], cols = nums.shape[1], dims = nums.shape[2] cdef Py_ssize_t i, j, k, pos @@ -67,11 +69,11 @@ def _labeled_center_of_mass(const double[:, :, :] nums, const LABEL[:, :, :] lab return np.asarray(output) -def _center_of_mass(const double[:, :, :] nums, Py_ssize_t num_threads) -> np.ndarray: - cdef const double[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) +def _center_of_mass(const FLOAT[:, :, :] nums, Py_ssize_t num_threads) -> np.ndarray: + cdef const FLOAT[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) - cdef double output_x = 0, output_y = 0, output_z = 0 - cdef double normalizer = 0 + cdef FLOAT output_x = 0, output_y = 0, output_z = 0 + cdef FLOAT normalizer = 0 cdef Py_ssize_t rows = nums.shape[0], cols = nums.shape[1], dims = nums.shape[2] cdef Py_ssize_t i, j, k diff --git a/tests/test_measure.py b/tests/test_measure.py index c34a4928..07ae4262 100644 --- a/tests/test_measure.py +++ b/tests/test_measure.py @@ -353,4 +353,4 @@ def test_labeled_center_of_mass(backend, dtype, label_dtype): assert isinstance(x, tuple) assert isinstance(y, tuple) - allclose(out, desired_out, err_msg=(inp, inp.shape), rtol=1e-5) + allclose(out, desired_out, err_msg=(inp, inp.shape), rtol=1e-4)