From 60c8f026737e5948432da692505fd0160e073f68 Mon Sep 17 00:00:00 2001 From: Leonardo Schwarz Date: Tue, 27 Aug 2024 13:59:21 +0200 Subject: [PATCH] simplify get_z_scaled by use of xarray --- src/depiction/image/multi_channel_image.py | 5 ++--- tests/unit/image/test_multi_channel_image.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/depiction/image/multi_channel_image.py b/src/depiction/image/multi_channel_image.py index 1e021c5..5342765 100644 --- a/src/depiction/image/multi_channel_image.py +++ b/src/depiction/image/multi_channel_image.py @@ -164,11 +164,10 @@ def append_channels(self, other: MultiChannelImage) -> MultiChannelImage: return MultiChannelImage(data=data) def get_z_scaled(self) -> MultiChannelImage: - means = xarray.DataArray(self.channel_stats.mean["mean"], dims="c", coords={"c": self.channel_names}) - stds = xarray.DataArray(self.channel_stats.std["std"], dims="c", coords={"c": self.channel_names}) + """Returns a copy of self with each feature z-scaled.""" eps = 1e-12 with xarray.set_options(keep_attrs=True): - return MultiChannelImage(data=(self._data - means + eps) / (stds + eps)) + return MultiChannelImage(data=(self._data - self.channel_stats.mean + eps) / (self.channel_stats.std + eps)) # TODO reconsider:there is actually a problem, whether it should use bg_mask only or also replace individual values # since both could be necessary it should be implemented in a sane and maintainable manner diff --git a/tests/unit/image/test_multi_channel_image.py b/tests/unit/image/test_multi_channel_image.py index 938aefa..3e3b388 100644 --- a/tests/unit/image/test_multi_channel_image.py +++ b/tests/unit/image/test_multi_channel_image.py @@ -212,7 +212,6 @@ def test_append_channels(mock_image: MultiChannelImage) -> None: def test_get_z_scaled(mock_image: MultiChannelImage) -> None: result = mock_image.get_z_scaled() - # TODO I am not fully sure this is correct yet np.testing.assert_almost_equal( np.array([[-1.46385011, -0.87831007], [-0.29277002, 0.29277002], [0.87831007, 1.46385011]]), result.data_spatial[:, :, 0].values,