-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add multi_channel_image_concatenation.py
- Loading branch information
1 parent
d550ec7
commit 2f11603
Showing
2 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from __future__ import annotations | ||
|
||
from functools import cached_property | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
||
from depiction.image.horizontal_concat import horizontal_concat | ||
from depiction.image.multi_channel_image import MultiChannelImage | ||
|
||
|
||
# TODO properly document (y, x) vs (x, y) and the min_coords in get_single_image and get_single_images | ||
|
||
|
||
class MultiChannelImageConcatenation: | ||
"""Represents a concatenation of multiple multi-channel images, with potentially different shapes | ||
(but same number of channels per image and same background value). | ||
This is done by concatenating the images in the spatial domain, so it is possible to obtain a single | ||
multi-channel image as well as a list of individual multi-channel images. | ||
""" | ||
|
||
def __init__(self, data: MultiChannelImage) -> None: | ||
self._data = data | ||
|
||
@cached_property | ||
def n_individual_images(self) -> int: | ||
"""Number of individual images.""" | ||
return int(self._data.retain_channels(coords=["image_index"]).data_flat.max().values + 1) | ||
|
||
def get_combined_image(self) -> MultiChannelImage: | ||
return self._data.drop_channels(coords=["image_index"], allow_missing=False) | ||
|
||
def get_combined_image_index(self) -> MultiChannelImage: | ||
return self._data.retain_channels(coords=["image_index"]) | ||
|
||
def get_single_image(self, index: int, min_coords: tuple[int, int] = (0, 0)) -> MultiChannelImage: | ||
# perform the selection in flat representation for sanity | ||
# all_values = self._data.data_flat.drop_sel(c="image_index", allow_missing=False) | ||
all_values = self._data.drop_channels(coords=["image_index"], allow_missing=False).data_flat | ||
all_coords = self._data.coordinates_flat | ||
|
||
# determine the indices in flat representation, corresponding to the requested image | ||
sel_indices = np.where(self._data.data_flat.sel(c="image_index").values == index)[0] | ||
|
||
# select the values and coordinates | ||
sel_values = all_values.isel(i=sel_indices) | ||
sel_coords = all_coords.isel(i=sel_indices) | ||
|
||
# readjust the coordinates | ||
sel_coords = sel_coords - sel_coords.min(axis=1) + np.array(min_coords)[:, None] | ||
|
||
# create the individual image | ||
return MultiChannelImage.from_sparse( | ||
values=sel_values, | ||
coordinates=sel_coords, | ||
channel_names=sel_values.coords["c"].values.tolist(), | ||
bg_value=sel_values.bg_value, | ||
) | ||
|
||
def get_single_images(self) -> list[MultiChannelImage]: | ||
return [self.get_single_image(index=index) for index in range(self.n_individual_images)] | ||
|
||
@classmethod | ||
def read_hdf5(cls, path: Path) -> MultiChannelImageConcatenation: | ||
return cls(data=MultiChannelImage.read_hdf5(path=path)) | ||
|
||
def write_hdf5(self, path: Path) -> None: | ||
self._data.write_hdf5(path=path) | ||
|
||
@classmethod | ||
def concat_images(cls, images: list[MultiChannelImage]) -> MultiChannelImageConcatenation: | ||
"""Returns the horizontal concatenation of the provided images.""" | ||
# TODO consider introducing a padding step as it would be more precise | ||
data = horizontal_concat(images=images, add_index=True, index_channel="image_index") | ||
return cls(data=data) |
116 changes: 116 additions & 0 deletions
116
tests/unit/image/test_multi_channel_image_concatenation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pytest | ||
import xarray | ||
from pytest_mock import MockerFixture | ||
from xarray import DataArray | ||
|
||
from depiction.image.multi_channel_image import MultiChannelImage | ||
from depiction.image.multi_channel_image_concatenation import MultiChannelImageConcatenation | ||
|
||
|
||
@pytest.fixture() | ||
def channel_names() -> list[str]: | ||
return ["Channel A", "Channel B"] | ||
|
||
|
||
def _construct_single_image(data: np.ndarray, channel_names: list[str]) -> MultiChannelImage: | ||
return MultiChannelImage( | ||
data=DataArray( | ||
data=data, | ||
dims=("y", "x", "c"), | ||
coords={"c": channel_names}, | ||
attrs={"bg_value": np.nan}, | ||
) | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def image_0(channel_names: list[str]) -> MultiChannelImage: | ||
"""shape (y=2, x=3, c=2)""" | ||
data = np.arange(12).reshape(2, 3, 2).astype(float) | ||
return _construct_single_image(data=data, channel_names=channel_names) | ||
|
||
|
||
@pytest.fixture() | ||
def image_1(channel_names: list[str]) -> MultiChannelImage: | ||
"""shape (y=3, x=4, c=2)""" | ||
data = np.arange(24).reshape(3, 4, 2).astype(float) * 3.0 | ||
return _construct_single_image(data=data, channel_names=channel_names) | ||
|
||
|
||
@pytest.fixture() | ||
def concat_image(image_0: MultiChannelImage, image_1: MultiChannelImage) -> MultiChannelImageConcatenation: | ||
# TODO not sure if this is the nicest way for testing, in general it would probably be nicer if the fixture would | ||
# not use the method for construction but rather do it directly (maybe tbd later) | ||
return MultiChannelImageConcatenation.concat_images([image_0, image_1]) | ||
|
||
|
||
def test_concat_images(concat_image: MultiChannelImageConcatenation) -> None: | ||
assert isinstance(concat_image, MultiChannelImageConcatenation) | ||
|
||
|
||
def test_n_individual_images(concat_image: MultiChannelImageConcatenation) -> None: | ||
assert concat_image.n_individual_images == 2 | ||
|
||
|
||
def test_get_combined_image( | ||
concat_image: MultiChannelImageConcatenation, image_0: MultiChannelImage, image_1: MultiChannelImage | ||
) -> None: | ||
combined_image = concat_image.get_combined_image() | ||
assert combined_image.data_spatial.shape == (3, 7, 2) | ||
assert combined_image.channel_names == image_0.channel_names == image_1.channel_names | ||
assert np.isnan(combined_image.bg_value) | ||
# image 0 | ||
assert combined_image.data_spatial[0, 0, 1] == 1.0 | ||
# image 1 | ||
assert combined_image.data_spatial[0, 3, 1] == 3.0 | ||
# check nan (because of shape differences) | ||
assert np.isnan(combined_image.data_spatial[2, 0, 0]) | ||
|
||
|
||
def test_get_combined_image_index(concat_image: MultiChannelImageConcatenation) -> None: | ||
image_index = concat_image.get_combined_image_index() | ||
assert image_index.data_spatial.shape == (3, 7, 1) | ||
assert image_index.channel_names == ["image_index"] | ||
expected_indices = np.zeros((3, 7, 1), dtype=int) | ||
expected_indices[:, 3:, :] = 1 | ||
np.testing.assert_array_equal(image_index.data_spatial, expected_indices) | ||
|
||
|
||
@pytest.mark.parametrize(["image_index", "image_fixture"], [(0, "image_0"), (1, "image_1")]) | ||
def test_get_single_image(request, concat_image: MultiChannelImageConcatenation, image_index: int, image_fixture: str): | ||
expected_image = request.getfixturevalue(image_fixture) | ||
result_image = concat_image.get_single_image(index=image_index) | ||
assert result_image.dimensions == expected_image.dimensions | ||
xarray.testing.assert_equal(result_image.coordinates_flat, expected_image.coordinates_flat) | ||
xarray.testing.assert_equal(result_image.data_flat, expected_image.data_flat) | ||
|
||
|
||
def test_get_single_images(concat_image: MultiChannelImageConcatenation, mocker: MockerFixture) -> None: | ||
mock_get_single_image = mocker.patch.object( | ||
MultiChannelImageConcatenation, "get_single_image", side_effect=["img1", "img2"] | ||
) | ||
result = concat_image.get_single_images() | ||
assert result == ["img1", "img2"] | ||
assert mock_get_single_image.mock_calls == [mocker.call(index=0), mocker.call(index=1)] | ||
|
||
|
||
def test_read_hdf5(mocker: MockerFixture) -> None: | ||
mock_read_hdf5 = mocker.patch.object(MultiChannelImage, "read_hdf5") | ||
mock_path = mocker.Mock(name="path", spec=[]) | ||
# TODO nicer assertion | ||
assert MultiChannelImageConcatenation.read_hdf5(path=mock_path)._data == mock_read_hdf5.return_value | ||
mock_read_hdf5.assert_called_once_with(path=mock_path) | ||
|
||
|
||
def test_write_hdf5(mocker: MockerFixture, concat_image: MultiChannelImageConcatenation) -> None: | ||
mock_write_hdf5 = mocker.patch.object(MultiChannelImage, "write_hdf5") | ||
mock_path = mocker.Mock(name="path", spec=[]) | ||
concat_image.write_hdf5(path=mock_path) | ||
mock_write_hdf5.assert_called_once_with(path=mock_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |