Skip to content

Commit

Permalink
add normalize_image method and fix attributes not being preserved
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 7, 2024
1 parent 5d9447f commit f83ba1c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 38 deletions.
20 changes: 14 additions & 6 deletions src/depiction/image/image_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import xarray

from depiction.image.multi_channel_image import MultiChannelImage


# TODO experimental code, untested etc
# TODO in principle the interface this method would require is to apply over the pixels of the image
Expand All @@ -28,13 +30,19 @@ def normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationV
else:
raise NotImplementedError("Multiple index columns are not supported yet.")

def normalize_image(self, image: MultiChannelImage, variant: ImageNormalizationVariant) -> MultiChannelImage:
return MultiChannelImage(self.normalize_xarray(image.data_spatial, variant=variant))

def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray:
if variant == ImageNormalizationVariant.VEC_NORM:
return image / (((image**2).sum(["c"])) ** 0.5)
elif variant == ImageNormalizationVariant.STD:
return (image - image.mean("c")) / image.std("c")
else:
raise NotImplementedError(f"Unknown variant: {variant}")
with xarray.set_options(keep_attrs=True):
if variant == ImageNormalizationVariant.VEC_NORM:
norm = ((image**2).sum(["c"])) ** 0.5
return xarray.where(norm != 0, image / norm, 0)
elif variant == ImageNormalizationVariant.STD:
std = image.std("c")
return xarray.where(std != 0, (image - image.mean("c")) / std, 0)
else:
raise NotImplementedError(f"Unknown variant: {variant}")

def _normalize_multiple_xarray(
self, image: xarray.DataArray, index_dim: str, variant: ImageNormalizationVariant
Expand Down
125 changes: 93 additions & 32 deletions tests/unit/image/test_image_normalization.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,100 @@
import unittest

import numpy as np
import xarray
import pytest
import xarray as xr

from depiction.image.image_normalization import ImageNormalizationVariant, ImageNormalization
from depiction.image.multi_channel_image import MultiChannelImage


@pytest.fixture
def image_normalizer():
return ImageNormalization()


@pytest.fixture
def single_image():
return xr.DataArray(
data=[[[2, 0], [0, 2]], [[1, 1], [4, 1]], [[0, 0], [0, 0]]], dims=["y", "x", "c"], attrs={"bg_value": 0}
)


@pytest.fixture
def multiple_images():
return xr.DataArray(data=[[[[2, 0]]], [[[0, 3]]]], dims=["whatever", "y", "x", "c"], attrs={"bg_value": 0})


def test_normalize_xarray_single_vec_norm(image_normalizer, single_image):
norm_vec = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM)
expected = xr.DataArray(
data=[[[1, 0], [0, 1]], [[0.707107, 0.707107], [0.970143, 0.242536]], [[0, 0], [0, 0]]],
dims=["y", "x", "c"],
attrs={"bg_value": 0},
)
xr.testing.assert_allclose(expected, norm_vec)
assert norm_vec.attrs["bg_value"] == single_image.attrs["bg_value"]


def test_normalize_xarray_single_std(image_normalizer, single_image):
norm_std = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.STD)
expected = xr.DataArray(
data=[[[1.0, -1.0], [-1.0, 1.0]], [[0.0, 0.0], [1.0, -1.0]], [[0.0, 0.0], [0.0, 0.0]]],
dims=["y", "x", "c"],
attrs={"bg_value": 0},
)
xr.testing.assert_allclose(norm_std, expected, rtol=1e-5)
assert norm_std.attrs["bg_value"] == single_image.attrs["bg_value"]


def test_normalize_xarray_multiple_vec_norm(image_normalizer, multiple_images):
norm_vec = image_normalizer.normalize_xarray(multiple_images, variant=ImageNormalizationVariant.VEC_NORM)
expected = xr.DataArray(
data=[[[[1, 0]]], [[[0, 1]]]],
dims=["whatever", "y", "x", "c"],
coords={"whatever": [0, 1]},
attrs={"bg_value": 0},
)
xr.testing.assert_allclose(expected, norm_vec)
assert norm_vec.attrs["bg_value"] == multiple_images.attrs["bg_value"]


def test_normalize_xarray_multiple_std(image_normalizer, multiple_images):
norm_std = image_normalizer.normalize_xarray(multiple_images, variant=ImageNormalizationVariant.STD)
expected = xr.DataArray(
data=[[[[1, -1]]], [[[-1, 1]]]],
dims=["whatever", "y", "x", "c"],
coords={"whatever": [0, 1]},
attrs={"bg_value": 0},
)
xr.testing.assert_allclose(expected, norm_std)
assert norm_std.attrs["bg_value"] == multiple_images.attrs["bg_value"]


def test_normalize_image(image_normalizer, single_image):
multi_channel_image = MultiChannelImage(single_image)
normalized_image = image_normalizer.normalize_image(multi_channel_image, variant=ImageNormalizationVariant.VEC_NORM)
assert isinstance(normalized_image, MultiChannelImage)
xr.testing.assert_allclose(
normalized_image.data_spatial,
image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM),
)
assert normalized_image.data_spatial.attrs["bg_value"] == single_image.attrs["bg_value"]


def test_missing_dimensions(image_normalizer):
invalid_image = xr.DataArray(data=[[2, 0], [0, 2]], dims=["y", "x"])
with pytest.raises(ValueError, match="Missing required dimensions: {'c'}"):
image_normalizer.normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM)


def test_multiple_index_dimensions(image_normalizer):
invalid_image = xr.DataArray(data=[[[[[2, 0]]], [[[0, 3]]]]], dims=["dim1", "dim2", "y", "x", "c"])
with pytest.raises(NotImplementedError, match="Multiple index columns are not supported yet."):
image_normalizer.normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM)


class TestImageNormalization(unittest.TestCase):
def test_normalize_xarray_single_vec_norm(self) -> None:
images = xarray.DataArray(
data=[[[2, 0], [0, 2]], [[1, 1], [4, 1]], [[0, 0], [0, 0]]],
dims=["y", "x", "c"],
)
norm_vec = ImageNormalization().normalize_xarray(images, variant=ImageNormalizationVariant.VEC_NORM)
self.assertEqual(norm_vec.shape, (3, 2, 2))
expected = xarray.DataArray(
data=[[[1, 0], [0, 1]], [[0.707107, 0.707107], [0.970143, 0.242536]], [[np.nan, np.nan], [np.nan, np.nan]]],
dims=["y", "x", "c"],
)
xarray.testing.assert_allclose(expected, norm_vec)

def test_normalize_xarray_multiple_vec_norm(self) -> None:
images = xarray.DataArray(
data=[[[[2, 0]]], [[[0, 3]]]],
dims=["whatever", "y", "x", "c"],
)
norm_vec = ImageNormalization().normalize_xarray(images, variant=ImageNormalizationVariant.VEC_NORM)
self.assertEqual(norm_vec.shape, (2, 1, 1, 2))
expected = xarray.DataArray(
data=[[[[1, 0]]], [[[0, 1]]]],
dims=["whatever", "y", "x", "c"],
coords={"whatever": [0, 1]},
)
xarray.testing.assert_allclose(expected, norm_vec)
def test_unknown_variant(image_normalizer, single_image):
with pytest.raises(NotImplementedError, match="Unknown variant: unknown"):
image_normalizer.normalize_xarray(single_image, variant="unknown")


if __name__ == "__main__":
unittest.main()
pytest.main()

0 comments on commit f83ba1c

Please sign in to comment.