diff --git a/src/depiction/image/image_channel_stats.py b/src/depiction/image/image_channel_stats.py index edbb076..ec62948 100644 --- a/src/depiction/image/image_channel_stats.py +++ b/src/depiction/image/image_channel_stats.py @@ -5,6 +5,7 @@ import numpy as np import polars as pl +import xarray if TYPE_CHECKING: from depiction.image.multi_channel_image import MultiChannelImage @@ -65,28 +66,24 @@ def interquartile_range(self) -> pl.DataFrame: return pl.DataFrame({"c": self._image.channel_names, "iqr": data}).fill_nan(None) @cached_property - def mean(self) -> pl.DataFrame: - """Returns a DataFrame with the mean for each channel, columns 'c', and 'mean'.""" - data = np.zeros(self._image.n_channels) - for i_channel in range(self._image.n_channels): - values = self._get_channel_values(i_channel=i_channel, drop_missing=True) - if len(values) == 0: - data[i_channel] = np.nan - continue - data[i_channel] = np.mean(values) - return pl.DataFrame({"c": self._image.channel_names, "mean": data}).fill_nan(None) + def mean(self) -> xarray.DataArray: + """Returns a DataArray with the mean for each channel.""" + return self._compute_scalar_metric(fn=np.mean, min_values=1) @cached_property - def std(self) -> pl.DataFrame: + def std(self) -> xarray.DataArray: """Returns a DataFrame with the standard deviation for each channel, columns 'c', and 'std'.""" + return self._compute_scalar_metric(fn=np.std, min_values=2) + + def _compute_scalar_metric(self, fn, min_values: int): data = np.zeros(self._image.n_channels) for i_channel in range(self._image.n_channels): values = self._get_channel_values(i_channel=i_channel, drop_missing=True) - if len(values) == 0: + if min_values <= len(values): + data[i_channel] = fn(values) + else: data[i_channel] = np.nan - continue - data[i_channel] = np.std(values) - return pl.DataFrame({"c": self._image.channel_names, "std": data}).fill_nan(None) + return xarray.DataArray(data, dims="c", coords={"c": self._image.channel_names}) def _get_channel_values(self, i_channel: int, drop_missing: bool) -> np.ndarray: """Returns the values of the given channel.""" diff --git a/tests/unit/image/test_image_channel_stats.py b/tests/unit/image/test_image_channel_stats.py index e711c17..8c0942e 100644 --- a/tests/unit/image/test_image_channel_stats.py +++ b/tests/unit/image/test_image_channel_stats.py @@ -2,6 +2,8 @@ import polars as pl import pytest import polars.testing +import xarray.testing + from depiction.image.image_channel_stats import ImageChannelStats @@ -63,6 +65,29 @@ def test_interquartile_range(mocker, image_channel_stats, mock_multi_channel_ima assert result.equals(expected) +def test_mean(mocker, image_channel_stats, mock_multi_channel_image): + mock_data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]]) + mocker.patch.object( + image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]] + ) + result = image_channel_stats.mean + xarray.testing.assert_allclose( + result, xarray.DataArray([3, 8, 13], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c") + ) + + +def test_std(mocker, image_channel_stats, mock_multi_channel_image): + mock_data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]]) + mocker.patch.object( + image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]] + ) + result = image_channel_stats.std + xarray.testing.assert_allclose( + result, + xarray.DataArray([1.414214, 1.414214, 1.414214], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c"), + ) + + def test_get_channel_values(image_channel_stats, mock_multi_channel_image): mock_data = np.array([1, 2, 3, 0, 5]) mock_multi_channel_image.data_flat.isel.return_value.values = mock_data