From 2f11603d5803f467de95f6e7b51e2a4b0cc8aa68 Mon Sep 17 00:00:00 2001 From: Leonardo Schwarz Date: Tue, 6 Aug 2024 17:05:30 +0200 Subject: [PATCH] add multi_channel_image_concatenation.py --- .../multi_channel_image_concatenation.py | 76 ++++++++++++ .../test_multi_channel_image_concatenation.py | 116 ++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 src/depiction/image/multi_channel_image_concatenation.py create mode 100644 tests/unit/image/test_multi_channel_image_concatenation.py diff --git a/src/depiction/image/multi_channel_image_concatenation.py b/src/depiction/image/multi_channel_image_concatenation.py new file mode 100644 index 0000000..367913b --- /dev/null +++ b/src/depiction/image/multi_channel_image_concatenation.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from functools import cached_property +from pathlib import Path + +import numpy as np + +from depiction.image.horizontal_concat import horizontal_concat +from depiction.image.multi_channel_image import MultiChannelImage + + +# TODO properly document (y, x) vs (x, y) and the min_coords in get_single_image and get_single_images + + +class MultiChannelImageConcatenation: + """Represents a concatenation of multiple multi-channel images, with potentially different shapes + (but same number of channels per image and same background value). + + This is done by concatenating the images in the spatial domain, so it is possible to obtain a single + multi-channel image as well as a list of individual multi-channel images. + """ + + def __init__(self, data: MultiChannelImage) -> None: + self._data = data + + @cached_property + def n_individual_images(self) -> int: + """Number of individual images.""" + return int(self._data.retain_channels(coords=["image_index"]).data_flat.max().values + 1) + + def get_combined_image(self) -> MultiChannelImage: + return self._data.drop_channels(coords=["image_index"], allow_missing=False) + + def get_combined_image_index(self) -> MultiChannelImage: + return self._data.retain_channels(coords=["image_index"]) + + def get_single_image(self, index: int, min_coords: tuple[int, int] = (0, 0)) -> MultiChannelImage: + # perform the selection in flat representation for sanity + # all_values = self._data.data_flat.drop_sel(c="image_index", allow_missing=False) + all_values = self._data.drop_channels(coords=["image_index"], allow_missing=False).data_flat + all_coords = self._data.coordinates_flat + + # determine the indices in flat representation, corresponding to the requested image + sel_indices = np.where(self._data.data_flat.sel(c="image_index").values == index)[0] + + # select the values and coordinates + sel_values = all_values.isel(i=sel_indices) + sel_coords = all_coords.isel(i=sel_indices) + + # readjust the coordinates + sel_coords = sel_coords - sel_coords.min(axis=1) + np.array(min_coords)[:, None] + + # create the individual image + return MultiChannelImage.from_sparse( + values=sel_values, + coordinates=sel_coords, + channel_names=sel_values.coords["c"].values.tolist(), + bg_value=sel_values.bg_value, + ) + + def get_single_images(self) -> list[MultiChannelImage]: + return [self.get_single_image(index=index) for index in range(self.n_individual_images)] + + @classmethod + def read_hdf5(cls, path: Path) -> MultiChannelImageConcatenation: + return cls(data=MultiChannelImage.read_hdf5(path=path)) + + def write_hdf5(self, path: Path) -> None: + self._data.write_hdf5(path=path) + + @classmethod + def concat_images(cls, images: list[MultiChannelImage]) -> MultiChannelImageConcatenation: + """Returns the horizontal concatenation of the provided images.""" + # TODO consider introducing a padding step as it would be more precise + data = horizontal_concat(images=images, add_index=True, index_channel="image_index") + return cls(data=data) diff --git a/tests/unit/image/test_multi_channel_image_concatenation.py b/tests/unit/image/test_multi_channel_image_concatenation.py new file mode 100644 index 0000000..54ac2eb --- /dev/null +++ b/tests/unit/image/test_multi_channel_image_concatenation.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import numpy as np +import pytest +import xarray +from pytest_mock import MockerFixture +from xarray import DataArray + +from depiction.image.multi_channel_image import MultiChannelImage +from depiction.image.multi_channel_image_concatenation import MultiChannelImageConcatenation + + +@pytest.fixture() +def channel_names() -> list[str]: + return ["Channel A", "Channel B"] + + +def _construct_single_image(data: np.ndarray, channel_names: list[str]) -> MultiChannelImage: + return MultiChannelImage( + data=DataArray( + data=data, + dims=("y", "x", "c"), + coords={"c": channel_names}, + attrs={"bg_value": np.nan}, + ) + ) + + +@pytest.fixture() +def image_0(channel_names: list[str]) -> MultiChannelImage: + """shape (y=2, x=3, c=2)""" + data = np.arange(12).reshape(2, 3, 2).astype(float) + return _construct_single_image(data=data, channel_names=channel_names) + + +@pytest.fixture() +def image_1(channel_names: list[str]) -> MultiChannelImage: + """shape (y=3, x=4, c=2)""" + data = np.arange(24).reshape(3, 4, 2).astype(float) * 3.0 + return _construct_single_image(data=data, channel_names=channel_names) + + +@pytest.fixture() +def concat_image(image_0: MultiChannelImage, image_1: MultiChannelImage) -> MultiChannelImageConcatenation: + # TODO not sure if this is the nicest way for testing, in general it would probably be nicer if the fixture would + # not use the method for construction but rather do it directly (maybe tbd later) + return MultiChannelImageConcatenation.concat_images([image_0, image_1]) + + +def test_concat_images(concat_image: MultiChannelImageConcatenation) -> None: + assert isinstance(concat_image, MultiChannelImageConcatenation) + + +def test_n_individual_images(concat_image: MultiChannelImageConcatenation) -> None: + assert concat_image.n_individual_images == 2 + + +def test_get_combined_image( + concat_image: MultiChannelImageConcatenation, image_0: MultiChannelImage, image_1: MultiChannelImage +) -> None: + combined_image = concat_image.get_combined_image() + assert combined_image.data_spatial.shape == (3, 7, 2) + assert combined_image.channel_names == image_0.channel_names == image_1.channel_names + assert np.isnan(combined_image.bg_value) + # image 0 + assert combined_image.data_spatial[0, 0, 1] == 1.0 + # image 1 + assert combined_image.data_spatial[0, 3, 1] == 3.0 + # check nan (because of shape differences) + assert np.isnan(combined_image.data_spatial[2, 0, 0]) + + +def test_get_combined_image_index(concat_image: MultiChannelImageConcatenation) -> None: + image_index = concat_image.get_combined_image_index() + assert image_index.data_spatial.shape == (3, 7, 1) + assert image_index.channel_names == ["image_index"] + expected_indices = np.zeros((3, 7, 1), dtype=int) + expected_indices[:, 3:, :] = 1 + np.testing.assert_array_equal(image_index.data_spatial, expected_indices) + + +@pytest.mark.parametrize(["image_index", "image_fixture"], [(0, "image_0"), (1, "image_1")]) +def test_get_single_image(request, concat_image: MultiChannelImageConcatenation, image_index: int, image_fixture: str): + expected_image = request.getfixturevalue(image_fixture) + result_image = concat_image.get_single_image(index=image_index) + assert result_image.dimensions == expected_image.dimensions + xarray.testing.assert_equal(result_image.coordinates_flat, expected_image.coordinates_flat) + xarray.testing.assert_equal(result_image.data_flat, expected_image.data_flat) + + +def test_get_single_images(concat_image: MultiChannelImageConcatenation, mocker: MockerFixture) -> None: + mock_get_single_image = mocker.patch.object( + MultiChannelImageConcatenation, "get_single_image", side_effect=["img1", "img2"] + ) + result = concat_image.get_single_images() + assert result == ["img1", "img2"] + assert mock_get_single_image.mock_calls == [mocker.call(index=0), mocker.call(index=1)] + + +def test_read_hdf5(mocker: MockerFixture) -> None: + mock_read_hdf5 = mocker.patch.object(MultiChannelImage, "read_hdf5") + mock_path = mocker.Mock(name="path", spec=[]) + # TODO nicer assertion + assert MultiChannelImageConcatenation.read_hdf5(path=mock_path)._data == mock_read_hdf5.return_value + mock_read_hdf5.assert_called_once_with(path=mock_path) + + +def test_write_hdf5(mocker: MockerFixture, concat_image: MultiChannelImageConcatenation) -> None: + mock_write_hdf5 = mocker.patch.object(MultiChannelImage, "write_hdf5") + mock_path = mocker.Mock(name="path", spec=[]) + concat_image.write_hdf5(path=mock_path) + mock_write_hdf5.assert_called_once_with(path=mock_path) + + +if __name__ == "__main__": + pytest.main()