Skip to content

Commit

Permalink
center_of_mass in float32
Browse files Browse the repository at this point in the history
  • Loading branch information
vovaf709 committed Aug 2, 2024
1 parent d3d378a commit cf9c7d1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions imops/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def center_of_mass(

src_center_of_mass = _fast_labeled_center_of_mass if backend.fast else _labeled_center_of_mass

if array.dtype != 'float64':
array = array.astype(float)
if array.dtype not in ('float32', 'float64'):
array = array.astype(np.float32)

n_dummy = 3 - ndim
if n_dummy:
Expand Down
18 changes: 10 additions & 8 deletions imops/src/_measure.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import numpy as np

cimport numpy as np
cimport cython

from cython.parallel import prange

ctypedef cython.floating FLOAT

ctypedef fused LABEL:
signed char
Expand All @@ -33,16 +35,16 @@ cdef inline Py_ssize_t _find(LABEL num, const LABEL[:] nums) noexcept nogil:
return -1


def _labeled_center_of_mass(const double[:, :, :] nums, const LABEL[:, :, :] labels,
def _labeled_center_of_mass(const FLOAT[:, :, :] nums, const LABEL[:, :, :] labels,
const LABEL[:] index) -> np.ndarray:
cdef const double[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums)
cdef const FLOAT[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums)
cdef const LABEL[:, :, ::1] contiguous_labels = np.ascontiguousarray(labels)
cdef const LABEL[:] contiguous_index = np.ascontiguousarray(index)

cdef Py_ssize_t index_len = len(index)

cdef double[:, ::1] output = np.zeros((index_len, 3))
cdef double[:] normalizers = np.zeros(index_len)
cdef FLOAT[:, ::1] output = np.zeros_like(nums, shape=(index_len, 3))
cdef FLOAT[:] normalizers = np.zeros_like(nums, shape=(index_len,))

cdef Py_ssize_t rows = nums.shape[0], cols = nums.shape[1], dims = nums.shape[2]
cdef Py_ssize_t i, j, k, pos
Expand All @@ -67,11 +69,11 @@ def _labeled_center_of_mass(const double[:, :, :] nums, const LABEL[:, :, :] lab
return np.asarray(output)


def _center_of_mass(const double[:, :, :] nums, Py_ssize_t num_threads) -> np.ndarray:
cdef const double[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums)
def _center_of_mass(const FLOAT[:, :, :] nums, Py_ssize_t num_threads) -> np.ndarray:
cdef const FLOAT[:, :, ::1] contiguous_nums = np.ascontiguousarray(nums)

cdef double output_x = 0, output_y = 0, output_z = 0
cdef double normalizer = 0
cdef FLOAT output_x = 0, output_y = 0, output_z = 0
cdef FLOAT normalizer = 0

cdef Py_ssize_t rows = nums.shape[0], cols = nums.shape[1], dims = nums.shape[2]
cdef Py_ssize_t i, j, k
Expand Down
2 changes: 1 addition & 1 deletion tests/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,4 @@ def test_labeled_center_of_mass(backend, dtype, label_dtype):
assert isinstance(x, tuple)
assert isinstance(y, tuple)

allclose(out, desired_out, err_msg=(inp, inp.shape), rtol=1e-5)
allclose(out, desired_out, err_msg=(inp, inp.shape), rtol=1e-4)

0 comments on commit cf9c7d1

Please sign in to comment.