Skip to content

Commit

Permalink
add multi_channel_image_concatenation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 6, 2024
1 parent d550ec7 commit 2f11603
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/depiction/image/multi_channel_image_concatenation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from functools import cached_property
from pathlib import Path

import numpy as np

from depiction.image.horizontal_concat import horizontal_concat
from depiction.image.multi_channel_image import MultiChannelImage


# TODO properly document (y, x) vs (x, y) and the min_coords in get_single_image and get_single_images


class MultiChannelImageConcatenation:
"""Represents a concatenation of multiple multi-channel images, with potentially different shapes
(but same number of channels per image and same background value).
This is done by concatenating the images in the spatial domain, so it is possible to obtain a single
multi-channel image as well as a list of individual multi-channel images.
"""

def __init__(self, data: MultiChannelImage) -> None:
self._data = data

@cached_property
def n_individual_images(self) -> int:
"""Number of individual images."""
return int(self._data.retain_channels(coords=["image_index"]).data_flat.max().values + 1)

def get_combined_image(self) -> MultiChannelImage:
return self._data.drop_channels(coords=["image_index"], allow_missing=False)

def get_combined_image_index(self) -> MultiChannelImage:
return self._data.retain_channels(coords=["image_index"])

def get_single_image(self, index: int, min_coords: tuple[int, int] = (0, 0)) -> MultiChannelImage:
# perform the selection in flat representation for sanity
# all_values = self._data.data_flat.drop_sel(c="image_index", allow_missing=False)
all_values = self._data.drop_channels(coords=["image_index"], allow_missing=False).data_flat
all_coords = self._data.coordinates_flat

# determine the indices in flat representation, corresponding to the requested image
sel_indices = np.where(self._data.data_flat.sel(c="image_index").values == index)[0]

# select the values and coordinates
sel_values = all_values.isel(i=sel_indices)
sel_coords = all_coords.isel(i=sel_indices)

# readjust the coordinates
sel_coords = sel_coords - sel_coords.min(axis=1) + np.array(min_coords)[:, None]

# create the individual image
return MultiChannelImage.from_sparse(
values=sel_values,
coordinates=sel_coords,
channel_names=sel_values.coords["c"].values.tolist(),
bg_value=sel_values.bg_value,
)

def get_single_images(self) -> list[MultiChannelImage]:
return [self.get_single_image(index=index) for index in range(self.n_individual_images)]

@classmethod
def read_hdf5(cls, path: Path) -> MultiChannelImageConcatenation:
return cls(data=MultiChannelImage.read_hdf5(path=path))

def write_hdf5(self, path: Path) -> None:
self._data.write_hdf5(path=path)

@classmethod
def concat_images(cls, images: list[MultiChannelImage]) -> MultiChannelImageConcatenation:
"""Returns the horizontal concatenation of the provided images."""
# TODO consider introducing a padding step as it would be more precise
data = horizontal_concat(images=images, add_index=True, index_channel="image_index")
return cls(data=data)
116 changes: 116 additions & 0 deletions tests/unit/image/test_multi_channel_image_concatenation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

import numpy as np
import pytest
import xarray
from pytest_mock import MockerFixture
from xarray import DataArray

from depiction.image.multi_channel_image import MultiChannelImage
from depiction.image.multi_channel_image_concatenation import MultiChannelImageConcatenation


@pytest.fixture()
def channel_names() -> list[str]:
return ["Channel A", "Channel B"]


def _construct_single_image(data: np.ndarray, channel_names: list[str]) -> MultiChannelImage:
return MultiChannelImage(
data=DataArray(
data=data,
dims=("y", "x", "c"),
coords={"c": channel_names},
attrs={"bg_value": np.nan},
)
)


@pytest.fixture()
def image_0(channel_names: list[str]) -> MultiChannelImage:
"""shape (y=2, x=3, c=2)"""
data = np.arange(12).reshape(2, 3, 2).astype(float)
return _construct_single_image(data=data, channel_names=channel_names)


@pytest.fixture()
def image_1(channel_names: list[str]) -> MultiChannelImage:
"""shape (y=3, x=4, c=2)"""
data = np.arange(24).reshape(3, 4, 2).astype(float) * 3.0
return _construct_single_image(data=data, channel_names=channel_names)


@pytest.fixture()
def concat_image(image_0: MultiChannelImage, image_1: MultiChannelImage) -> MultiChannelImageConcatenation:
# TODO not sure if this is the nicest way for testing, in general it would probably be nicer if the fixture would
# not use the method for construction but rather do it directly (maybe tbd later)
return MultiChannelImageConcatenation.concat_images([image_0, image_1])


def test_concat_images(concat_image: MultiChannelImageConcatenation) -> None:
assert isinstance(concat_image, MultiChannelImageConcatenation)


def test_n_individual_images(concat_image: MultiChannelImageConcatenation) -> None:
assert concat_image.n_individual_images == 2


def test_get_combined_image(
concat_image: MultiChannelImageConcatenation, image_0: MultiChannelImage, image_1: MultiChannelImage
) -> None:
combined_image = concat_image.get_combined_image()
assert combined_image.data_spatial.shape == (3, 7, 2)
assert combined_image.channel_names == image_0.channel_names == image_1.channel_names
assert np.isnan(combined_image.bg_value)
# image 0
assert combined_image.data_spatial[0, 0, 1] == 1.0
# image 1
assert combined_image.data_spatial[0, 3, 1] == 3.0
# check nan (because of shape differences)
assert np.isnan(combined_image.data_spatial[2, 0, 0])


def test_get_combined_image_index(concat_image: MultiChannelImageConcatenation) -> None:
image_index = concat_image.get_combined_image_index()
assert image_index.data_spatial.shape == (3, 7, 1)
assert image_index.channel_names == ["image_index"]
expected_indices = np.zeros((3, 7, 1), dtype=int)
expected_indices[:, 3:, :] = 1
np.testing.assert_array_equal(image_index.data_spatial, expected_indices)


@pytest.mark.parametrize(["image_index", "image_fixture"], [(0, "image_0"), (1, "image_1")])
def test_get_single_image(request, concat_image: MultiChannelImageConcatenation, image_index: int, image_fixture: str):
expected_image = request.getfixturevalue(image_fixture)
result_image = concat_image.get_single_image(index=image_index)
assert result_image.dimensions == expected_image.dimensions
xarray.testing.assert_equal(result_image.coordinates_flat, expected_image.coordinates_flat)
xarray.testing.assert_equal(result_image.data_flat, expected_image.data_flat)


def test_get_single_images(concat_image: MultiChannelImageConcatenation, mocker: MockerFixture) -> None:
mock_get_single_image = mocker.patch.object(
MultiChannelImageConcatenation, "get_single_image", side_effect=["img1", "img2"]
)
result = concat_image.get_single_images()
assert result == ["img1", "img2"]
assert mock_get_single_image.mock_calls == [mocker.call(index=0), mocker.call(index=1)]


def test_read_hdf5(mocker: MockerFixture) -> None:
mock_read_hdf5 = mocker.patch.object(MultiChannelImage, "read_hdf5")
mock_path = mocker.Mock(name="path", spec=[])
# TODO nicer assertion
assert MultiChannelImageConcatenation.read_hdf5(path=mock_path)._data == mock_read_hdf5.return_value
mock_read_hdf5.assert_called_once_with(path=mock_path)


def test_write_hdf5(mocker: MockerFixture, concat_image: MultiChannelImageConcatenation) -> None:
mock_write_hdf5 = mocker.patch.object(MultiChannelImage, "write_hdf5")
mock_path = mocker.Mock(name="path", spec=[])
concat_image.write_hdf5(path=mock_path)
mock_write_hdf5.assert_called_once_with(path=mock_path)


if __name__ == "__main__":
pytest.main()

0 comments on commit 2f11603

Please sign in to comment.