Skip to content

Commit

Permalink
return xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 27, 2024
1 parent 7cf79e5 commit 7ce8db4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
27 changes: 12 additions & 15 deletions src/depiction/image/image_channel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import polars as pl
import xarray

if TYPE_CHECKING:
from depiction.image.multi_channel_image import MultiChannelImage
Expand Down Expand Up @@ -65,28 +66,24 @@ def interquartile_range(self) -> pl.DataFrame:
return pl.DataFrame({"c": self._image.channel_names, "iqr": data}).fill_nan(None)

@cached_property
def mean(self) -> pl.DataFrame:
"""Returns a DataFrame with the mean for each channel, columns 'c', and 'mean'."""
data = np.zeros(self._image.n_channels)
for i_channel in range(self._image.n_channels):
values = self._get_channel_values(i_channel=i_channel, drop_missing=True)
if len(values) == 0:
data[i_channel] = np.nan
continue
data[i_channel] = np.mean(values)
return pl.DataFrame({"c": self._image.channel_names, "mean": data}).fill_nan(None)
def mean(self) -> xarray.DataArray:
"""Returns a DataArray with the mean for each channel."""
return self._compute_scalar_metric(fn=np.mean, min_values=1)

@cached_property
def std(self) -> pl.DataFrame:
def std(self) -> xarray.DataArray:
"""Returns a DataFrame with the standard deviation for each channel, columns 'c', and 'std'."""
return self._compute_scalar_metric(fn=np.std, min_values=2)

def _compute_scalar_metric(self, fn, min_values: int):
data = np.zeros(self._image.n_channels)
for i_channel in range(self._image.n_channels):
values = self._get_channel_values(i_channel=i_channel, drop_missing=True)
if len(values) == 0:
if min_values <= len(values):
data[i_channel] = fn(values)
else:
data[i_channel] = np.nan
continue
data[i_channel] = np.std(values)
return pl.DataFrame({"c": self._image.channel_names, "std": data}).fill_nan(None)
return xarray.DataArray(data, dims="c", coords={"c": self._image.channel_names})

def _get_channel_values(self, i_channel: int, drop_missing: bool) -> np.ndarray:
"""Returns the values of the given channel."""
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/image/test_image_channel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import polars as pl
import pytest
import polars.testing
import xarray.testing

from depiction.image.image_channel_stats import ImageChannelStats


Expand Down Expand Up @@ -63,6 +65,29 @@ def test_interquartile_range(mocker, image_channel_stats, mock_multi_channel_ima
assert result.equals(expected)


def test_mean(mocker, image_channel_stats, mock_multi_channel_image):
mock_data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
mocker.patch.object(
image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]]
)
result = image_channel_stats.mean
xarray.testing.assert_allclose(
result, xarray.DataArray([3, 8, 13], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c")
)


def test_std(mocker, image_channel_stats, mock_multi_channel_image):
mock_data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
mocker.patch.object(
image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]]
)
result = image_channel_stats.std
xarray.testing.assert_allclose(
result,
xarray.DataArray([1.414214, 1.414214, 1.414214], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c"),
)


def test_get_channel_values(image_channel_stats, mock_multi_channel_image):
mock_data = np.array([1, 2, 3, 0, 5])
mock_multi_channel_image.data_flat.isel.return_value.values = mock_data
Expand Down

0 comments on commit 7ce8db4

Please sign in to comment.