Skip to content

Commit

Permalink
refactor iqr computation
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 27, 2024
1 parent 60c8f02 commit 7870f72
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 23 deletions.
5 changes: 2 additions & 3 deletions src/depiction/image/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,5 @@ def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]:


def _select_features_iqr(image: MultiChannelImage, n_features: int) -> list[str]:
iqr = image.channel_stats.interquartile_range
n_channels = len(iqr)
return iqr.drop_nulls().sort("iqr").tail(min(n_features, n_channels))["c"].to_list()
iqr = image.channel_stats.interquartile_range.dropna("c")
return list(iqr.sortby(iqr, ascending=False).c.values[:n_features])
15 changes: 3 additions & 12 deletions src/depiction/image/image_channel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,9 @@ def coefficient_of_variation(self) -> pl.DataFrame:
return pl.DataFrame({"c": self._image.channel_names, "cv": data}).fill_nan(None)

@cached_property
def interquartile_range(self) -> pl.DataFrame:
def interquartile_range(self) -> xarray.DataArray:
"""Returns a DataFrame with the iqr for each channel c, columns 'c', and 'iqr'."""
data = np.zeros(self._image.n_channels)
for i_channel in range(self._image.n_channels):
values = self._get_channel_values(i_channel=i_channel, drop_missing=True)
if len(values) == 0:
data[i_channel] = np.nan
continue
q1 = np.percentile(values, 25)
q3 = np.percentile(values, 75)
data[i_channel] = q3 - q1
return pl.DataFrame({"c": self._image.channel_names, "iqr": data}).fill_nan(None)
return self._compute_scalar_metric(fn=lambda x: np.percentile(x, 75) - np.percentile(x, 25), min_values=2)

@cached_property
def mean(self) -> xarray.DataArray:
Expand All @@ -72,7 +63,7 @@ def mean(self) -> xarray.DataArray:

@cached_property
def std(self) -> xarray.DataArray:
"""Returns a DataFrame with the standard deviation for each channel, columns 'c', and 'std'."""
"""Returns a DataArray with the standard deviation for each channel, columns 'c', and 'std'."""
return self._compute_scalar_metric(fn=np.std, min_values=2)

def _compute_scalar_metric(self, fn, min_values: int):
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/image/test_feature_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import pytest
from xarray import DataArray

from depiction.image.feature_selection import FeatureSelectionIQR, select_features
from depiction.image.multi_channel_image import MultiChannelImage


@pytest.fixture()
def image() -> MultiChannelImage:
return MultiChannelImage(
DataArray(
[[[1.0, 2, 0.0], [1.0, 5, 0.5], [1.0, 10, 0.0], [1.0, 20, 0.5], [1.0, 30, 0.0], [1.0, 40, 0.5]]],
dims=("y", "x", "c"),
coords={"c": ["channel1", "channel2", "channel3"]},
attrs={"bg_value": np.nan},
)
)


def test_select_features_iqr(image):
fs = FeatureSelectionIQR.model_validate(dict(n_features=2))
selection = select_features(feature_selection=fs, image=image)
assert selection == ["channel2", "channel3"]
23 changes: 15 additions & 8 deletions tests/unit/image/test_image_channel_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import polars as pl
import pytest
import polars.testing
import pytest
import xarray.testing

from depiction.image.image_channel_stats import ImageChannelStats
Expand Down Expand Up @@ -61,8 +61,9 @@ def test_interquartile_range(mocker, image_channel_stats, mock_multi_channel_ima
image_channel_stats, "_get_channel_values", side_effect=[mock_data[0], mock_data[1], mock_data[2]]
)
result = image_channel_stats.interquartile_range
expected = pl.DataFrame({"c": ["channel1", "channel2", "channel3"], "iqr": [2, 2, 2]})
assert result.equals(expected)
xarray.testing.assert_allclose(
result, xarray.DataArray([2, 2, 2], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c")
)


def test_mean(mocker, image_channel_stats, mock_multi_channel_image):
Expand Down Expand Up @@ -115,19 +116,25 @@ def test_get_channel_values_with_nan(image_channel_stats, mock_multi_channel_ima
np.testing.assert_array_equal(result, np.array([1, 2, 3, 5]))


def test_empty_channel(mocker, image_channel_stats, mock_multi_channel_image):
def test_five_number_summary_when_empty_channel(mocker, image_channel_stats, mock_multi_channel_image):
mock_data = np.array([])

mocker.patch.object(image_channel_stats, "_get_channel_values", return_value=mock_data)
five_number_summary = image_channel_stats.five_number_summary
interquartile_range = image_channel_stats.interquartile_range

assert five_number_summary["min"][0] is None
assert five_number_summary["q1"][0] is None
assert five_number_summary["median"][0] is None
assert five_number_summary["q3"][0] is None
assert five_number_summary["max"][0] is None
assert interquartile_range["iqr"][0] is None


def test_interquartile_range_when_emtpy_channel(mocker, image_channel_stats, mock_multi_channel_image):
mock_data = np.array([])
mocker.patch.object(image_channel_stats, "_get_channel_values", return_value=mock_data)
interquartile_range = image_channel_stats.interquartile_range
xarray.testing.assert_equal(
interquartile_range,
xarray.DataArray([None, None, None], coords={"c": ["channel1", "channel2", "channel3"]}, dims="c"),
)


if __name__ == "__main__":
Expand Down

0 comments on commit 7870f72

Please sign in to comment.