diff --git a/src/depiction/image/feature_selection.py b/src/depiction/image/feature_selection.py index 1d39218..6473113 100644 --- a/src/depiction/image/feature_selection.py +++ b/src/depiction/image/feature_selection.py @@ -35,9 +35,8 @@ def retain_features(feature_selection: FeatureSelection, image: MultiChannelImag def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]: - cv = image.channel_stats.coefficient_of_variation - n_channels = len(cv) - return cv.drop_nulls().sort("cv").tail(min(n_features, n_channels))["c"].to_list() + cv = image.channel_stats.coefficient_of_variation.dropna("c") + return list(cv.sortby(cv, ascending=False).c.values[:n_features]) def _select_features_iqr(image: MultiChannelImage, n_features: int) -> list[str]: diff --git a/src/depiction/image/image_channel_stats.py b/src/depiction/image/image_channel_stats.py index 13d57ae..b4473f3 100644 --- a/src/depiction/image/image_channel_stats.py +++ b/src/depiction/image/image_channel_stats.py @@ -40,20 +40,13 @@ def five_number_summary(self) -> pl.DataFrame: ).fill_nan(None) @cached_property - def coefficient_of_variation(self) -> pl.DataFrame: + def coefficient_of_variation(self) -> xarray.DataArray: """Returns a DataFrame with the cv for each channel, columns 'c', and 'cv'.""" - data = np.zeros(self._image.n_channels) - for i_channel in range(self._image.n_channels): - values = self._get_channel_values(i_channel=i_channel, drop_missing=True) - if len(values) == 0: - data[i_channel] = np.nan - continue - data[i_channel] = np.std(values) / np.mean(values) - return pl.DataFrame({"c": self._image.channel_names, "cv": data}).fill_nan(None) + return self._compute_scalar_metric(fn=lambda x: np.std(x) / np.mean(x), min_values=2) @cached_property def interquartile_range(self) -> xarray.DataArray: - """Returns a DataFrame with the iqr for each channel c, columns 'c', and 'iqr'.""" + """Returns a DataArray with the iqr for each channel c, columns 'c', and 'iqr'.""" return self._compute_scalar_metric(fn=lambda x: np.percentile(x, 75) - np.percentile(x, 25), min_values=2) @cached_property diff --git a/tests/unit/image/test_feature_selection.py b/tests/unit/image/test_feature_selection.py index a2a617d..ab463ac 100644 --- a/tests/unit/image/test_feature_selection.py +++ b/tests/unit/image/test_feature_selection.py @@ -2,7 +2,7 @@ import pytest from xarray import DataArray -from depiction.image.feature_selection import FeatureSelectionIQR, select_features +from depiction.image.feature_selection import FeatureSelectionIQR, select_features, FeatureSelectionCV from depiction.image.multi_channel_image import MultiChannelImage @@ -18,6 +18,12 @@ def image() -> MultiChannelImage: ) +def test_select_features_cv(image): + fs = FeatureSelectionCV.model_validate(dict(n_features=2)) + selection = select_features(feature_selection=fs, image=image) + assert selection == ["channel3", "channel2"] + + def test_select_features_iqr(image): fs = FeatureSelectionIQR.model_validate(dict(n_features=2)) selection = select_features(feature_selection=fs, image=image) diff --git a/tests/unit/image/test_image_channel_stats.py b/tests/unit/image/test_image_channel_stats.py index f9dffb0..735c95c 100644 --- a/tests/unit/image/test_image_channel_stats.py +++ b/tests/unit/image/test_image_channel_stats.py @@ -1,6 +1,5 @@ import numpy as np import polars as pl -import polars.testing import pytest import xarray.testing @@ -49,10 +48,9 @@ def test_coefficient_of_variation(mocker, image_channel_stats, mock_multi_channe image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]] ) result = image_channel_stats.coefficient_of_variation - expected = pl.DataFrame( - {"c": ["channel1", "channel2", "channel3"], "cv": np.array([0.471404, 0.0, np.nan])} - ).fill_nan(None) - pl.testing.assert_frame_equal(result, expected, check_dtype=False, atol=1e-5) + xarray.testing.assert_allclose( + result, xarray.DataArray([0.471404, 0.0, np.nan], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c") + ) def test_interquartile_range(mocker, image_channel_stats, mock_multi_channel_image):