Skip to content

Commit

Permalink
improve MultiChannelIMage.retain_channels
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 18, 2024
1 parent ac36268 commit 7d4759a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 18 deletions.
11 changes: 8 additions & 3 deletions src/depiction/image/multi_channel_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xarray import DataArray

from depiction.image.sparse_representation import SparseRepresentation
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -92,9 +92,14 @@ def data_flat(self) -> DataArray:

# TODO get_single_channel_dense_array

def retain_channels(self, channel_indices: Sequence[int]) -> MultiChannelImage:
def retain_channels(
self, indices: Sequence[int] | None = None, coords: Sequence[Any] | None = None
) -> MultiChannelImage:
"""Returns a copy with only the specified channels retained."""
return MultiChannelImage(data=self._data.isel(c=channel_indices))
if (indices is not None) == (coords is not None):
raise ValueError("Exactly one of indices and coords must be specified.")
data = self._data.isel(c=indices) if indices is not None else self._data.sel(c=coords)
return MultiChannelImage(data=data)

# TODO save_single_channel_image... does it belong here or into plotter?

Expand Down
2 changes: 1 addition & 1 deletion src/depiction/visualize/visualize_mass_shift_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def plot_test_mass_maps_and_histograms(
self._plot_row(
ax_map=axs[i_mass, 0],
ax_hist=axs[i_mass, 1],
correction_image=correction_image.retain_channels(channel_indices=[i_mass]),
correction_image=correction_image.retain_channels(indices=[i_mass]),
hist_bins=hist_bins,
same_scale=same_scale,
scale_percentile=scale_percentile,
Expand Down
52 changes: 38 additions & 14 deletions tests/unit/image/test_multi_channel_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

@pytest.fixture
def mock_coords() -> dict[str, list[str]]:
return {"c": ["Channel A"]}
return {"c": ["Channel A", "Channel B"]}


@pytest.fixture
def mock_data(mock_coords) -> DataArray:
"""Dense mock data without any missing values."""
return DataArray(
[[[2.0], [4]], [[6], [8]], [[10], [12]]],
[[[2.0, 5], [4, 5]], [[6, 5], [8, 5]], [[10, 5], [12, 5]]],
dims=("y", "x", "c"),
coords=mock_coords,
attrs={"bg_value": 0},
Expand All @@ -41,15 +41,15 @@ def test_from_numpy_sparse() -> None:


def test_n_channels(mock_image: MultiChannelImage) -> None:
assert mock_image.n_channels == 1
assert mock_image.n_channels == 2


def test_n_nonzero(mock_image: MultiChannelImage) -> None:
assert mock_image.n_nonzero == 6


def test_n_nonzero_when_sparse(mock_image: MultiChannelImage) -> None:
mock_image.data_spatial[1, 0, 0] = 0
mock_image.data_spatial[1, 0, :] = 0
assert mock_image.n_nonzero == 5


Expand Down Expand Up @@ -81,33 +81,57 @@ def test_dimensions(mock_image: MultiChannelImage) -> None:


def test_channel_names_when_set(mock_image: MultiChannelImage) -> None:
assert mock_image.channel_names == ["Channel A"]
assert mock_image.channel_names == ["Channel A", "Channel B"]


@pytest.mark.parametrize("mock_coords", [{}])
def test_channel_names_when_not_set(mock_image: MultiChannelImage) -> None:
assert mock_image.channel_names == ["0"]
assert mock_image.channel_names == ["0", "1"]


def test_data_spatial(mock_data: DataArray, mock_image: MultiChannelImage) -> None:
xarray.testing.assert_identical(mock_data, mock_image.data_spatial)


def test_data_flat(mock_data: DataArray, mock_image: MultiChannelImage) -> None:
mock_data[0, 0, 0] = 0
mock_data[0, 0, :] = 0
mock_data[1, 0, 0] = np.nan
expected = DataArray(
[[4.0, 8, 10, 12]],
[[4.0, 8, 10, 12], [5, 5, 5, 5]],
dims=("c", "i"),
coords={
"c": ["Channel A"],
"c": ["Channel A", "Channel B"],
"i": pd.MultiIndex.from_tuples([(0, 1), (1, 1), (2, 0), (2, 1)], names=("y", "x")),
},
attrs={"bg_value": 0},
)
xarray.testing.assert_identical(expected, mock_image.data_flat)


def test_retain_channels_when_both_none(mock_image: MultiChannelImage) -> None:
with pytest.raises(ValueError):
mock_image.retain_channels(None, None)


def test_retain_channels_by_indices(mock_image: MultiChannelImage) -> None:
indices = [1]
result = mock_image.retain_channels(indices=indices)
assert result.channel_names == ["Channel B"]
np.testing.assert_array_equal(result.data_spatial.values, mock_image.data_spatial.values[:, :, [1]])


def test_retain_channels_by_coords(mock_image: MultiChannelImage) -> None:
coords = ["Channel B"]
result = mock_image.retain_channels(coords=coords)
assert result.channel_names == coords
np.testing.assert_array_equal(result.data_spatial.values, mock_image.data_spatial.values[:, :, [1]])


def test_retain_channels_when_both_provided(mock_image: MultiChannelImage) -> None:
with pytest.raises(ValueError):
mock_image.retain_channels(indices=[0, 1], coords=["red", "blue"])


def test_write_hdf5(mocker: MockerFixture, mock_image: MultiChannelImage) -> None:
mocker.patch("xarray.DataArray.to_netcdf")
mock_image.write_hdf5(Path("test.h5"))
Expand All @@ -122,15 +146,15 @@ def test_read_hdf5(mocker: MockerFixture, mock_data: DataArray) -> None:


def test_with_channel_names(mock_image: MultiChannelImage) -> None:
image = mock_image.with_channel_names(channel_names=["New Channel Name"])
assert image.channel_names == ["New Channel Name"]
image = mock_image.with_channel_names(channel_names=["New Channel Name", "B"])
assert image.channel_names == ["New Channel Name", "B"]
assert image.dimensions == mock_image.dimensions
assert image.n_channels == mock_image.n_channels == 1
# TODO check values
assert image.n_channels == mock_image.n_channels == 2
np.testing.assert_array_equal(image.data_spatial.values, mock_image.data_spatial.values)


def test_str(mock_image: MultiChannelImage) -> None:
assert str(mock_image) == "MultiChannelImage(size_y=3, size_x=2, n_channels=1)"
assert str(mock_image) == "MultiChannelImage(size_y=3, size_x=2, n_channels=2)"


def test_repr(mocker: MockerFixture, mock_image: MultiChannelImage) -> None:
Expand Down

0 comments on commit 7d4759a

Please sign in to comment.