Skip to content

Commit

Permalink
rename to sel_channels (xarray naming)
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Oct 25, 2024
1 parent 3f789bc commit d81a850
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/depiction/image/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def select_features(feature_selection: FeatureSelection, image: MultiChannelImag
def retain_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> MultiChannelImage:
"""Returns a new ``MultiChannelImage`` that is a copy of ``image`` with only the selected features remaining."""
selected_features = select_features(feature_selection=feature_selection, image=image)
return image.retain_channels(coords=selected_features)
return image.sel_channels(coords=selected_features)


def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]:
Expand Down
3 changes: 1 addition & 2 deletions src/depiction/image/multi_channel_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def recompute_is_foreground(self, bg_value: float = 0.0) -> MultiChannelImage:
data=self._data, is_foreground=is_foreground, is_foreground_label=self._is_foreground_label
)

# TODO rename to sel_channels
def retain_channels(
def sel_channels(
self, indices: Sequence[int] | None = None, coords: Sequence[Any] | None = None
) -> MultiChannelImage:
"""Returns a copy with only the specified channels retained."""
Expand Down
6 changes: 3 additions & 3 deletions src/depiction/image/multi_channel_image_concatenation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def __init__(self, data: MultiChannelImage) -> None:
@cached_property
def n_individual_images(self) -> int:
"""Number of individual images."""
return int(self._data.retain_channels(coords=["image_index"]).data_flat.max().values + 1)
return int(self._data.sel_channels(coords=["image_index"]).data_flat.max().values + 1)

def get_combined_image(self) -> MultiChannelImage:
return self._data.drop_channels(coords=["image_index"], allow_missing=False)

def get_combined_image_index(self) -> MultiChannelImage:
return self._data.retain_channels(coords=["image_index"])
return self._data.sel_channels(coords=["image_index"])

def get_single_image(self, index: int, min_coords: tuple[int, int] = (0, 0)) -> MultiChannelImage:
# perform the selection in flat representation for sanity
Expand Down Expand Up @@ -65,7 +65,7 @@ def relabel_combined_image(self, image: MultiChannelImage) -> MultiChannelImageC
original_combined = self.get_combined_image()
if image.dimensions != original_combined.dimensions:
raise ValueError("The new image must have the same shape as the original combined image")
labeled = image.append_channels(self._data.retain_channels(coords=["image_index"]))
labeled = image.append_channels(self._data.sel_channels(coords=["image_index"]))
return MultiChannelImageConcatenation(data=labeled)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/depiction/tools/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _preprocess(input_image, n_best_features):
image=image_full_features, variant=ImageNormalizationVariant.STD
)
if "image_index" in input_image.channel_names:
image_full_image_index = input_image.retain_channels(coords=["image_index"])
image_full_image_index = input_image.sel_channels(coords=["image_index"])
else:
with xarray.set_options(keep_attrs=True):
image_full_image_index = MultiChannelImage(
Expand Down
2 changes: 1 addition & 1 deletion src/depiction/visualize/visualize_mass_shift_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def plot_test_mass_maps_and_histograms(
self._plot_row(
ax_map=axs[i_mass, 0],
ax_hist=axs[i_mass, 1],
correction_image=correction_image.retain_channels(indices=[i_mass]),
correction_image=correction_image.sel_channels(indices=[i_mass]),
hist_bins=hist_bins,
same_scale=same_scale,
scale_percentile=scale_percentile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def render_single_channel_png(
channel_index: int = 0,
) -> None:
image = MultiChannelImage.read_hdf5(input_hdf5)
image = image.retain_channels(indices=[channel_index])
image = image.sel_channels(indices=[channel_index])

plt.figure()
image.data_spatial.squeeze().plot.imshow(yincrease=False, ax=plt.gca(), x="x", y="y", cmap="tab10")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def visualize_cluster_umap_coefs(
channel: str = "cluster",
) -> None:
# load the input data
umap_image = MultiChannelImage.read_hdf5(path=input_umap_hdf5_path).retain_channels(coords=["umap_x", "umap_y"])
umap_image = MultiChannelImage.read_hdf5(path=input_umap_hdf5_path).sel_channels(coords=["umap_x", "umap_y"])
cluster_image = MultiChannelImage.read_hdf5(path=input_cluster_hdf5_path)
combined_image = umap_image.append_channels(cluster_image)

Expand Down
20 changes: 10 additions & 10 deletions tests/unit/image/test_multi_channel_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,28 +184,28 @@ def test_recompute_is_foreground(mocker: MockerFixture, mock_image: MultiChannel
xarray.testing.assert_equal(new_image.data_spatial, mock_image.data_spatial)


def test_retain_channels_when_both_none(mock_image: MultiChannelImage) -> None:
def test_sel_channels_when_both_none(mock_image: MultiChannelImage) -> None:
with pytest.raises(ValueError):
mock_image.retain_channels(None, None)
mock_image.sel_channels(None, None)


def test_retain_channels_by_indices(mock_image: MultiChannelImage) -> None:
def test_sel_channels_by_indices(mock_image: MultiChannelImage) -> None:
indices = [1]
result = mock_image.retain_channels(indices=indices)
result = mock_image.sel_channels(indices=indices)
assert result.channel_names == ["Channel B"]
np.testing.assert_array_equal(result.data_spatial.values, mock_image.data_spatial.values[:, :, [1]])


def test_retain_channels_by_coords(mock_image: MultiChannelImage) -> None:
def test_sel_channels_by_coords(mock_image: MultiChannelImage) -> None:
coords = ["Channel B"]
result = mock_image.retain_channels(coords=coords)
result = mock_image.sel_channels(coords=coords)
assert result.channel_names == coords
np.testing.assert_array_equal(result.data_spatial.values, mock_image.data_spatial.values[:, :, [1]])


def test_retain_channels_when_both_provided(mock_image: MultiChannelImage) -> None:
def test_sel_channels_when_both_provided(mock_image: MultiChannelImage) -> None:
with pytest.raises(ValueError):
mock_image.retain_channels(indices=[0, 1], coords=["red", "blue"])
mock_image.sel_channels(indices=[0, 1], coords=["red", "blue"])


def test_drop_channels_when_coords_and_allow_missing(mock_image: MultiChannelImage) -> None:
Expand Down Expand Up @@ -263,8 +263,8 @@ def test_append_channels(mock_image: MultiChannelImage) -> None:
extra_image = MultiChannelImage(data=extra_image_data, is_foreground=mock_image.fg_mask)
result = mock_image.append_channels(extra_image)
assert result.channel_names == ["Channel A", "Channel B", "Channel X", "Channel Y"]
assert result.retain_channels(coords=["Channel A", "Channel B"]).data_spatial.identical(mock_image.data_spatial)
assert result.retain_channels(coords=["Channel X", "Channel Y"]).data_spatial.identical(extra_image.data_spatial)
assert result.sel_channels(coords=["Channel A", "Channel B"]).data_spatial.identical(mock_image.data_spatial)
assert result.sel_channels(coords=["Channel X", "Channel Y"]).data_spatial.identical(extra_image.data_spatial)


def test_get_z_scaled(mock_image: MultiChannelImage) -> None:
Expand Down

0 comments on commit d81a850

Please sign in to comment.