diff --git a/src/depiction/image/image_normalization.py b/src/depiction/image/image_normalization.py index 8fdff1b..6d3d90f 100644 --- a/src/depiction/image/image_normalization.py +++ b/src/depiction/image/image_normalization.py @@ -37,10 +37,10 @@ def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormal with xarray.set_options(keep_attrs=True): if variant == ImageNormalizationVariant.VEC_NORM: norm = ((image**2).sum(["c"])) ** 0.5 - return xarray.where(norm != 0, image / norm, 0) + return xarray.where(norm != 0, image / norm, image.attrs.get("bg_value", 0)) elif variant == ImageNormalizationVariant.STD: std = image.std("c") - return xarray.where(std != 0, (image - image.mean("c")) / std, 0) + return xarray.where(std != 0, (image - image.mean("c")) / std, image.attrs.get("bg_value", 0)) else: raise NotImplementedError(f"Unknown variant: {variant}") diff --git a/src/depiction/tools/clustering.py b/src/depiction/tools/clustering.py index c0b7ff9..8be4291 100644 --- a/src/depiction/tools/clustering.py +++ b/src/depiction/tools/clustering.py @@ -100,7 +100,10 @@ def clustering( full_features=image_features.data_flat.values.T, ) label_image = MultiChannelImage.from_sparse( - values=full_labels[:, np.newaxis], coordinates=image_full_features.coordinates_flat, channel_names=["cluster"] + values=full_labels[:, np.newaxis], + coordinates=image_full_features.coordinates_flat, + channel_names=["cluster"], + bg_value=np.nan, ) # write the result of the operation diff --git a/tests/unit/image/test_image_normalization.py b/tests/unit/image/test_image_normalization.py index 015759f..ea2e56a 100644 --- a/tests/unit/image/test_image_normalization.py +++ b/tests/unit/image/test_image_normalization.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import xarray as xr @@ -33,6 +34,22 @@ def test_normalize_xarray_single_vec_norm(image_normalizer, single_image): assert norm_vec.attrs["bg_value"] == single_image.attrs["bg_value"] +def test_normalize_xarray_single_vec_norm_with_nans(image_normalizer): + image_with_nans = xr.DataArray( + data=[[[2, np.nan], [0, 2]], [[1, 1], [4, np.nan]], [[np.nan, 0], [0, 0]]], + dims=["y", "x", "c"], + attrs={"bg_value": np.nan}, + ) + norm_vec = image_normalizer.normalize_xarray(image_with_nans, variant=ImageNormalizationVariant.VEC_NORM) + expected = xr.DataArray( + data=[[[1, np.nan], [0, 1]], [[0.707107, 0.707107], [1, np.nan]], [[np.nan, np.nan], [np.nan, np.nan]]], + dims=["y", "x", "c"], + attrs={"bg_value": np.nan}, + ) + xr.testing.assert_allclose(expected, norm_vec) + assert np.isnan(norm_vec.attrs["bg_value"]) + + def test_normalize_xarray_single_std(image_normalizer, single_image): norm_std = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.STD) expected = xr.DataArray(