From cf9c7d1ae03bc14c97d6989b13d056244ca218b6 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Fri, 2 Aug 2024 16:38:57 +0300 Subject: [PATCH] `center_of_mass` in float32 --- imops/measure.py | 4 ++-- imops/src/_measure.pyx | 18 ++++++++++-------- tests/test_measure.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/imops/measure.py b/imops/measure.py index 58c0372c..15c3b04d 100644 --- a/imops/measure.py +++ b/imops/measure.py @@ -210,8 +210,8 @@ def center_of_mass( src_center_of_mass = _fast_labeled_center_of_mass if backend.fast else _labeled_center_of_mass - if array.dtype != 'float64': - array = array.astype(float) + if array.dtype not in ('float32', 'float64'): + array = array.astype(np.float32) n_dummy = 3 - ndim if n_dummy: diff --git a/imops/src/_measure.pyx b/imops/src/_measure.pyx index e842be0b..8c1057ce 100644 --- a/imops/src/_measure.pyx +++ b/imops/src/_measure.pyx @@ -8,9 +8,11 @@ import numpy as np cimport numpy as np +cimport cython from cython.parallel import prange +ctypedef cython.floating FLOAT ctypedef fused LABEL: signed char @@ -33,16 +35,16 @@ cdef inline Py_ssize_t _find(LABEL num, const LABEL[:] nums) noexcept nogil: return -1 -def _labeled_center_of_mass(const double[:, :, :] nums, const LABEL[:, :, :] labels, +def _labeled_center_of_mass(const FLOAT[:, :, :] nums, const LABEL[:, :, :] labels, const LABEL[:] index) -> np.ndarray: - cdef const double[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) + cdef const FLOAT[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) cdef const LABEL[:, :, ::1] contiguous_labels = np.ascontiguousarray(labels) cdef const LABEL[:] contiguous_index = np.ascontiguousarray(index) cdef Py_ssize_t index_len = len(index) - cdef double[:, ::1] output = np.zeros((index_len, 3)) - cdef double[:] normalizers = np.zeros(index_len) + cdef FLOAT[:, ::1] output = np.zeros_like(nums, shape=(index_len, 3)) + cdef FLOAT[:] normalizers = np.zeros_like(nums, shape=(index_len,)) cdef Py_ssize_t rows = nums.shape[0], cols = nums.shape[1], dims = nums.shape[2] cdef Py_ssize_t i, j, k, pos @@ -67,11 +69,11 @@ def _labeled_center_of_mass(const double[:, :, :] nums, const LABEL[:, :, :] lab return np.asarray(output) -def _center_of_mass(const double[:, :, :] nums, Py_ssize_t num_threads) -> np.ndarray: - cdef const double[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) +def _center_of_mass(const FLOAT[:, :, :] nums, Py_ssize_t num_threads) -> np.ndarray: + cdef const FLOAT[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums) - cdef double output_x = 0, output_y = 0, output_z = 0 - cdef double normalizer = 0 + cdef FLOAT output_x = 0, output_y = 0, output_z = 0 + cdef FLOAT normalizer = 0 cdef Py_ssize_t rows = nums.shape[0], cols = nums.shape[1], dims = nums.shape[2] cdef Py_ssize_t i, j, k diff --git a/tests/test_measure.py b/tests/test_measure.py index c34a4928..07ae4262 100644 --- a/tests/test_measure.py +++ b/tests/test_measure.py @@ -353,4 +353,4 @@ def test_labeled_center_of_mass(backend, dtype, label_dtype): assert isinstance(x, tuple) assert isinstance(y, tuple) - allclose(out, desired_out, err_msg=(inp, inp.shape), rtol=1e-5) + allclose(out, desired_out, err_msg=(inp, inp.shape), rtol=1e-4)