Skip to content

Commit

Permalink
simplify coefficient_of_variation
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 27, 2024
1 parent 7870f72 commit 4d8f674
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
5 changes: 2 additions & 3 deletions src/depiction/image/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def retain_features(feature_selection: FeatureSelection, image: MultiChannelImag


def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]:
cv = image.channel_stats.coefficient_of_variation
n_channels = len(cv)
return cv.drop_nulls().sort("cv").tail(min(n_features, n_channels))["c"].to_list()
cv = image.channel_stats.coefficient_of_variation.dropna("c")
return list(cv.sortby(cv, ascending=False).c.values[:n_features])


def _select_features_iqr(image: MultiChannelImage, n_features: int) -> list[str]:
Expand Down
13 changes: 3 additions & 10 deletions src/depiction/image/image_channel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,13 @@ def five_number_summary(self) -> pl.DataFrame:
).fill_nan(None)

@cached_property
def coefficient_of_variation(self) -> pl.DataFrame:
def coefficient_of_variation(self) -> xarray.DataArray:
"""Returns a DataFrame with the cv for each channel, columns 'c', and 'cv'."""
data = np.zeros(self._image.n_channels)
for i_channel in range(self._image.n_channels):
values = self._get_channel_values(i_channel=i_channel, drop_missing=True)
if len(values) == 0:
data[i_channel] = np.nan
continue
data[i_channel] = np.std(values) / np.mean(values)
return pl.DataFrame({"c": self._image.channel_names, "cv": data}).fill_nan(None)
return self._compute_scalar_metric(fn=lambda x: np.std(x) / np.mean(x), min_values=2)

@cached_property
def interquartile_range(self) -> xarray.DataArray:
"""Returns a DataFrame with the iqr for each channel c, columns 'c', and 'iqr'."""
"""Returns a DataArray with the iqr for each channel c, columns 'c', and 'iqr'."""
return self._compute_scalar_metric(fn=lambda x: np.percentile(x, 75) - np.percentile(x, 25), min_values=2)

@cached_property
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/image/test_feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from xarray import DataArray

from depiction.image.feature_selection import FeatureSelectionIQR, select_features
from depiction.image.feature_selection import FeatureSelectionIQR, select_features, FeatureSelectionCV
from depiction.image.multi_channel_image import MultiChannelImage


Expand All @@ -18,6 +18,12 @@ def image() -> MultiChannelImage:
)


def test_select_features_cv(image):
fs = FeatureSelectionCV.model_validate(dict(n_features=2))
selection = select_features(feature_selection=fs, image=image)
assert selection == ["channel3", "channel2"]


def test_select_features_iqr(image):
fs = FeatureSelectionIQR.model_validate(dict(n_features=2))
selection = select_features(feature_selection=fs, image=image)
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/image/test_image_channel_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import polars as pl
import polars.testing
import pytest
import xarray.testing

Expand Down Expand Up @@ -49,10 +48,9 @@ def test_coefficient_of_variation(mocker, image_channel_stats, mock_multi_channe
image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]]
)
result = image_channel_stats.coefficient_of_variation
expected = pl.DataFrame(
{"c": ["channel1", "channel2", "channel3"], "cv": np.array([0.471404, 0.0, np.nan])}
).fill_nan(None)
pl.testing.assert_frame_equal(result, expected, check_dtype=False, atol=1e-5)
xarray.testing.assert_allclose(
result, xarray.DataArray([0.471404, 0.0, np.nan], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c")
)


def test_interquartile_range(mocker, image_channel_stats, mock_multi_channel_image):
Expand Down

0 comments on commit 4d8f674

Please sign in to comment.