From f83ba1cead4878c36841b37009c4b3d605a159e2 Mon Sep 17 00:00:00 2001 From: Leonardo Schwarz Date: Wed, 7 Aug 2024 10:46:49 +0200 Subject: [PATCH] add normalize_image method and fix attributes not being preserved --- src/depiction/image/image_normalization.py | 20 ++- tests/unit/image/test_image_normalization.py | 125 ++++++++++++++----- 2 files changed, 107 insertions(+), 38 deletions(-) diff --git a/src/depiction/image/image_normalization.py b/src/depiction/image/image_normalization.py index 0d66664..8fdff1b 100644 --- a/src/depiction/image/image_normalization.py +++ b/src/depiction/image/image_normalization.py @@ -2,6 +2,8 @@ import xarray +from depiction.image.multi_channel_image import MultiChannelImage + # TODO experimental code, untested etc # TODO in principle the interface this method would require is to apply over the pixels of the image @@ -28,13 +30,19 @@ def normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationV else: raise NotImplementedError("Multiple index columns are not supported yet.") + def normalize_image(self, image: MultiChannelImage, variant: ImageNormalizationVariant) -> MultiChannelImage: + return MultiChannelImage(self.normalize_xarray(image.data_spatial, variant=variant)) + def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray: - if variant == ImageNormalizationVariant.VEC_NORM: - return image / (((image**2).sum(["c"])) ** 0.5) - elif variant == ImageNormalizationVariant.STD: - return (image - image.mean("c")) / image.std("c") - else: - raise NotImplementedError(f"Unknown variant: {variant}") + with xarray.set_options(keep_attrs=True): + if variant == ImageNormalizationVariant.VEC_NORM: + norm = ((image**2).sum(["c"])) ** 0.5 + return xarray.where(norm != 0, image / norm, 0) + elif variant == ImageNormalizationVariant.STD: + std = image.std("c") + return xarray.where(std != 0, (image - image.mean("c")) / std, 0) + else: + raise NotImplementedError(f"Unknown variant: {variant}") def _normalize_multiple_xarray( self, image: xarray.DataArray, index_dim: str, variant: ImageNormalizationVariant diff --git a/tests/unit/image/test_image_normalization.py b/tests/unit/image/test_image_normalization.py index e4ffc78..015759f 100644 --- a/tests/unit/image/test_image_normalization.py +++ b/tests/unit/image/test_image_normalization.py @@ -1,39 +1,100 @@ -import unittest - -import numpy as np -import xarray +import pytest +import xarray as xr from depiction.image.image_normalization import ImageNormalizationVariant, ImageNormalization +from depiction.image.multi_channel_image import MultiChannelImage + + +@pytest.fixture +def image_normalizer(): + return ImageNormalization() + + +@pytest.fixture +def single_image(): + return xr.DataArray( + data=[[[2, 0], [0, 2]], [[1, 1], [4, 1]], [[0, 0], [0, 0]]], dims=["y", "x", "c"], attrs={"bg_value": 0} + ) + + +@pytest.fixture +def multiple_images(): + return xr.DataArray(data=[[[[2, 0]]], [[[0, 3]]]], dims=["whatever", "y", "x", "c"], attrs={"bg_value": 0}) + + +def test_normalize_xarray_single_vec_norm(image_normalizer, single_image): + norm_vec = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM) + expected = xr.DataArray( + data=[[[1, 0], [0, 1]], [[0.707107, 0.707107], [0.970143, 0.242536]], [[0, 0], [0, 0]]], + dims=["y", "x", "c"], + attrs={"bg_value": 0}, + ) + xr.testing.assert_allclose(expected, norm_vec) + assert norm_vec.attrs["bg_value"] == single_image.attrs["bg_value"] + + +def test_normalize_xarray_single_std(image_normalizer, single_image): + norm_std = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.STD) + expected = xr.DataArray( + data=[[[1.0, -1.0], [-1.0, 1.0]], [[0.0, 0.0], [1.0, -1.0]], [[0.0, 0.0], [0.0, 0.0]]], + dims=["y", "x", "c"], + attrs={"bg_value": 0}, + ) + xr.testing.assert_allclose(norm_std, expected, rtol=1e-5) + assert norm_std.attrs["bg_value"] == single_image.attrs["bg_value"] + + +def test_normalize_xarray_multiple_vec_norm(image_normalizer, multiple_images): + norm_vec = image_normalizer.normalize_xarray(multiple_images, variant=ImageNormalizationVariant.VEC_NORM) + expected = xr.DataArray( + data=[[[[1, 0]]], [[[0, 1]]]], + dims=["whatever", "y", "x", "c"], + coords={"whatever": [0, 1]}, + attrs={"bg_value": 0}, + ) + xr.testing.assert_allclose(expected, norm_vec) + assert norm_vec.attrs["bg_value"] == multiple_images.attrs["bg_value"] + + +def test_normalize_xarray_multiple_std(image_normalizer, multiple_images): + norm_std = image_normalizer.normalize_xarray(multiple_images, variant=ImageNormalizationVariant.STD) + expected = xr.DataArray( + data=[[[[1, -1]]], [[[-1, 1]]]], + dims=["whatever", "y", "x", "c"], + coords={"whatever": [0, 1]}, + attrs={"bg_value": 0}, + ) + xr.testing.assert_allclose(expected, norm_std) + assert norm_std.attrs["bg_value"] == multiple_images.attrs["bg_value"] + + +def test_normalize_image(image_normalizer, single_image): + multi_channel_image = MultiChannelImage(single_image) + normalized_image = image_normalizer.normalize_image(multi_channel_image, variant=ImageNormalizationVariant.VEC_NORM) + assert isinstance(normalized_image, MultiChannelImage) + xr.testing.assert_allclose( + normalized_image.data_spatial, + image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM), + ) + assert normalized_image.data_spatial.attrs["bg_value"] == single_image.attrs["bg_value"] + + +def test_missing_dimensions(image_normalizer): + invalid_image = xr.DataArray(data=[[2, 0], [0, 2]], dims=["y", "x"]) + with pytest.raises(ValueError, match="Missing required dimensions: {'c'}"): + image_normalizer.normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM) + + +def test_multiple_index_dimensions(image_normalizer): + invalid_image = xr.DataArray(data=[[[[[2, 0]]], [[[0, 3]]]]], dims=["dim1", "dim2", "y", "x", "c"]) + with pytest.raises(NotImplementedError, match="Multiple index columns are not supported yet."): + image_normalizer.normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM) -class TestImageNormalization(unittest.TestCase): - def test_normalize_xarray_single_vec_norm(self) -> None: - images = xarray.DataArray( - data=[[[2, 0], [0, 2]], [[1, 1], [4, 1]], [[0, 0], [0, 0]]], - dims=["y", "x", "c"], - ) - norm_vec = ImageNormalization().normalize_xarray(images, variant=ImageNormalizationVariant.VEC_NORM) - self.assertEqual(norm_vec.shape, (3, 2, 2)) - expected = xarray.DataArray( - data=[[[1, 0], [0, 1]], [[0.707107, 0.707107], [0.970143, 0.242536]], [[np.nan, np.nan], [np.nan, np.nan]]], - dims=["y", "x", "c"], - ) - xarray.testing.assert_allclose(expected, norm_vec) - - def test_normalize_xarray_multiple_vec_norm(self) -> None: - images = xarray.DataArray( - data=[[[[2, 0]]], [[[0, 3]]]], - dims=["whatever", "y", "x", "c"], - ) - norm_vec = ImageNormalization().normalize_xarray(images, variant=ImageNormalizationVariant.VEC_NORM) - self.assertEqual(norm_vec.shape, (2, 1, 1, 2)) - expected = xarray.DataArray( - data=[[[[1, 0]]], [[[0, 1]]]], - dims=["whatever", "y", "x", "c"], - coords={"whatever": [0, 1]}, - ) - xarray.testing.assert_allclose(expected, norm_vec) +def test_unknown_variant(image_normalizer, single_image): + with pytest.raises(NotImplementedError, match="Unknown variant: unknown"): + image_normalizer.normalize_xarray(single_image, variant="unknown") if __name__ == "__main__": - unittest.main() + pytest.main()