From 7870f7256f6fee6b6c6a7cfed77ac0504ec3bd2e Mon Sep 17 00:00:00 2001 From: Leonardo Schwarz Date: Tue, 27 Aug 2024 14:12:42 +0200 Subject: [PATCH] refactor iqr computation --- src/depiction/image/feature_selection.py | 5 ++-- src/depiction/image/image_channel_stats.py | 15 +++--------- tests/unit/image/test_feature_selection.py | 24 ++++++++++++++++++++ tests/unit/image/test_image_channel_stats.py | 23 ++++++++++++------- 4 files changed, 44 insertions(+), 23 deletions(-) create mode 100644 tests/unit/image/test_feature_selection.py diff --git a/src/depiction/image/feature_selection.py b/src/depiction/image/feature_selection.py index 661190a..1d39218 100644 --- a/src/depiction/image/feature_selection.py +++ b/src/depiction/image/feature_selection.py @@ -41,6 +41,5 @@ def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]: def _select_features_iqr(image: MultiChannelImage, n_features: int) -> list[str]: - iqr = image.channel_stats.interquartile_range - n_channels = len(iqr) - return iqr.drop_nulls().sort("iqr").tail(min(n_features, n_channels))["c"].to_list() + iqr = image.channel_stats.interquartile_range.dropna("c") + return list(iqr.sortby(iqr, ascending=False).c.values[:n_features]) diff --git a/src/depiction/image/image_channel_stats.py b/src/depiction/image/image_channel_stats.py index ec62948..13d57ae 100644 --- a/src/depiction/image/image_channel_stats.py +++ b/src/depiction/image/image_channel_stats.py @@ -52,18 +52,9 @@ def coefficient_of_variation(self) -> pl.DataFrame: return pl.DataFrame({"c": self._image.channel_names, "cv": data}).fill_nan(None) @cached_property - def interquartile_range(self) -> pl.DataFrame: + def interquartile_range(self) -> xarray.DataArray: """Returns a DataFrame with the iqr for each channel c, columns 'c', and 'iqr'.""" - data = np.zeros(self._image.n_channels) - for i_channel in range(self._image.n_channels): - values = self._get_channel_values(i_channel=i_channel, drop_missing=True) - if len(values) == 0: - data[i_channel] = np.nan - continue - q1 = np.percentile(values, 25) - q3 = np.percentile(values, 75) - data[i_channel] = q3 - q1 - return pl.DataFrame({"c": self._image.channel_names, "iqr": data}).fill_nan(None) + return self._compute_scalar_metric(fn=lambda x: np.percentile(x, 75) - np.percentile(x, 25), min_values=2) @cached_property def mean(self) -> xarray.DataArray: @@ -72,7 +63,7 @@ def mean(self) -> xarray.DataArray: @cached_property def std(self) -> xarray.DataArray: - """Returns a DataFrame with the standard deviation for each channel, columns 'c', and 'std'.""" + """Returns a DataArray with the standard deviation for each channel, columns 'c', and 'std'.""" return self._compute_scalar_metric(fn=np.std, min_values=2) def _compute_scalar_metric(self, fn, min_values: int): diff --git a/tests/unit/image/test_feature_selection.py b/tests/unit/image/test_feature_selection.py new file mode 100644 index 0000000..a2a617d --- /dev/null +++ b/tests/unit/image/test_feature_selection.py @@ -0,0 +1,24 @@ +import numpy as np +import pytest +from xarray import DataArray + +from depiction.image.feature_selection import FeatureSelectionIQR, select_features +from depiction.image.multi_channel_image import MultiChannelImage + + +@pytest.fixture() +def image() -> MultiChannelImage: + return MultiChannelImage( + DataArray( + [[[1.0, 2, 0.0], [1.0, 5, 0.5], [1.0, 10, 0.0], [1.0, 20, 0.5], [1.0, 30, 0.0], [1.0, 40, 0.5]]], + dims=("y", "x", "c"), + coords={"c": ["channel1", "channel2", "channel3"]}, + attrs={"bg_value": np.nan}, + ) + ) + + +def test_select_features_iqr(image): + fs = FeatureSelectionIQR.model_validate(dict(n_features=2)) + selection = select_features(feature_selection=fs, image=image) + assert selection == ["channel2", "channel3"] diff --git a/tests/unit/image/test_image_channel_stats.py b/tests/unit/image/test_image_channel_stats.py index 8c0942e..f9dffb0 100644 --- a/tests/unit/image/test_image_channel_stats.py +++ b/tests/unit/image/test_image_channel_stats.py @@ -1,7 +1,7 @@ import numpy as np import polars as pl -import pytest import polars.testing +import pytest import xarray.testing from depiction.image.image_channel_stats import ImageChannelStats @@ -61,8 +61,9 @@ def test_interquartile_range(mocker, image_channel_stats, mock_multi_channel_ima image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]] ) result = image_channel_stats.interquartile_range - expected = pl.DataFrame({"c": ["channel1", "channel2", "channel3"], "iqr": [2, 2, 2]}) - assert result.equals(expected) + xarray.testing.assert_allclose( + result, xarray.DataArray([2, 2, 2], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c") + ) def test_mean(mocker, image_channel_stats, mock_multi_channel_image): @@ -115,19 +116,25 @@ def test_get_channel_values_with_nan(image_channel_stats, mock_multi_channel_ima np.testing.assert_array_equal(result, np.array([1, 2, 3, 5])) -def test_empty_channel(mocker, image_channel_stats, mock_multi_channel_image): +def test_five_number_summary_when_empty_channel(mocker, image_channel_stats, mock_multi_channel_image): mock_data = np.array([]) - mocker.patch.object(image_channel_stats, "_get_channel_values", return_value=mock_data) five_number_summary = image_channel_stats.five_number_summary - interquartile_range = image_channel_stats.interquartile_range - assert five_number_summary["min"][0] is None assert five_number_summary["q1"][0] is None assert five_number_summary["median"][0] is None assert five_number_summary["q3"][0] is None assert five_number_summary["max"][0] is None - assert interquartile_range["iqr"][0] is None + + +def test_interquartile_range_when_emtpy_channel(mocker, image_channel_stats, mock_multi_channel_image): + mock_data = np.array([]) + mocker.patch.object(image_channel_stats, "_get_channel_values", return_value=mock_data) + interquartile_range = image_channel_stats.interquartile_range + xarray.testing.assert_equal( + interquartile_range, + xarray.DataArray([None, None, None], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c"), + ) if __name__ == "__main__":