From 8d0ac05d4192e9f2cd808498b63040ed6b8ac3e2 Mon Sep 17 00:00:00 2001 From: Leonardo Schwarz Date: Tue, 15 Oct 2024 12:47:35 +0200 Subject: [PATCH] Update MultiChannelImage definition to explicitly manage background/foreground information. This is a far-reaching change with many side effects probably not yet identified. --- noxfile.py | 1 - .../calibration/calibration_method.py | 19 +- .../calibration/perform_calibration.py | 94 ++----- ...libration_method_chemical_peptide_noise.py | 3 +- .../spectrum/calibration_method_dummy.py | 6 +- ...alibration_method_global_constant_shift.py | 14 +- .../spectrum/calibration_method_mcc.py | 12 +- .../calibration_method_regress_shift.py | 12 +- .../spectrum/calibration_smoothing.py | 30 --- src/depiction/clustering/remap_clusters.py | 9 +- src/depiction/image/container/__init__.py | 0 .../image/container/alpha_channel.py | 29 +++ src/depiction/image/horizontal_concat.py | 11 +- src/depiction/image/image_channel_stats.py | 4 +- src/depiction/image/image_normalization.py | 20 +- src/depiction/image/multi_channel_image.py | 230 +++++++++++++++--- .../multi_channel_image_concatenation.py | 7 +- .../image/multi_channel_image_persistence.py | 48 ++++ src/depiction/image/smoothing/__init__.py | 0 .../image/smoothing/bilateral_filter.py | 50 ++++ .../image/smoothing/median_filter.py | 45 ++++ .../{ => smoothing}/spatial_smoothing.py | 0 .../spatial_smoothing_sparse_aware.py | 40 ++- src/depiction/image/sparse_representation.py | 43 ++-- src/depiction/persistence/types.py | 8 + src/depiction/tools/generate_ion_image.py | 10 +- .../tools/simulate/generate_label_image.py | 8 +- .../workflow/vis/images_norm.py | 24 +- .../workflow/vis/test_mass_shifts.py | 14 +- tests/unit/clustering/test_remap_clusters.py | 5 +- tests/unit/image/test_feature_selection.py | 15 +- tests/unit/image/test_horizontal_concat.py | 4 +- tests/unit/image/test_image_channel_stats.py | 22 +- tests/unit/image/test_image_normalization.py | 111 +++++---- tests/unit/image/test_multi_channel_image.py | 198 ++++++++++++--- .../test_multi_channel_image_concatenation.py | 41 ++-- .../unit/image/test_sparse_representation.py | 229 +++++++++-------- tests/unit/image/test_spatial_smoothing.py | 2 +- .../test_spatial_smoothing_sparse_aware.py | 104 +++++--- .../simulate/test_generate_label_image.py | 2 + tests/unit/tools/test_generate_ion_image.py | 17 +- 41 files changed, 989 insertions(+), 552 deletions(-) delete mode 100644 src/depiction/calibration/spectrum/calibration_smoothing.py create mode 100644 src/depiction/image/container/__init__.py create mode 100644 src/depiction/image/container/alpha_channel.py create mode 100644 src/depiction/image/multi_channel_image_persistence.py create mode 100644 src/depiction/image/smoothing/__init__.py create mode 100644 src/depiction/image/smoothing/bilateral_filter.py create mode 100644 src/depiction/image/smoothing/median_filter.py rename src/depiction/image/{ => smoothing}/spatial_smoothing.py (100%) rename src/depiction/image/{ => smoothing}/spatial_smoothing_sparse_aware.py (64%) diff --git a/noxfile.py b/noxfile.py index 263c607..2e5a477 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,7 +3,6 @@ nox.options.default_venv_backend = "uv" -@nox.session(reuse_venv=True) def lint(session: nox.Session) -> None: """Runs the linter.""" session.install("pre-commit") diff --git a/src/depiction/calibration/calibration_method.py b/src/depiction/calibration/calibration_method.py index 5d3de50..8693f85 100644 --- a/src/depiction/calibration/calibration_method.py +++ b/src/depiction/calibration/calibration_method.py @@ -1,8 +1,12 @@ -from typing import Protocol +from __future__ import annotations +from typing import Protocol, TYPE_CHECKING from numpy.typing import NDArray from xarray import DataArray +if TYPE_CHECKING: + from depiction.image.multi_channel_image import MultiChannelImage + class CalibrationMethod(Protocol): """Defines the interface for a spectrum calibration method.""" @@ -16,14 +20,15 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N """ return DataArray([], dims=["c"]) - def preprocess_image_features(self, all_features: DataArray) -> DataArray: + # TODO update doc + # TODO update other methods + # TODO check if it works + def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: """Preprocesses the extracted features from all spectra in an image. For example, image-wide smoothing of the features could be applied here. - If no preprocessing is necessary, the input DataArray should be returned. - :param all_features: a DataArray with the extracted features, with dimensions ["i", "c"] - and coordinates ["i", "x", "y"] for dimension "i" - :return: a DataArray with the preprocessed features, with dimensions ["i", "c"] - and coordinates ["i", "x", "y"] for dimension "i" + If no preprocessing is necessary, the input MultiChannelImage should be returned. + :param all_features: a MultiChannelImage with the extracted features + :return: a MultiChannelImage with the preprocessed features """ return all_features diff --git a/src/depiction/calibration/perform_calibration.py b/src/depiction/calibration/perform_calibration.py index 8b5bf47..e67e1cd 100644 --- a/src/depiction/calibration/perform_calibration.py +++ b/src/depiction/calibration/perform_calibration.py @@ -10,6 +10,7 @@ from xarray import DataArray from depiction.calibration.calibration_method import CalibrationMethod +from depiction.image import MultiChannelImage from depiction.parallel_ops import ParallelConfig, ReadSpectraParallel, WriteSpectraParallel from depiction.parallel_ops.parallel_map import ParallelMap from depiction.persistence.types import GenericReadFile, GenericWriteFile, GenericReader, GenericWriter @@ -26,77 +27,27 @@ def __init__( self._parallel_config = parallel_config self._coefficient_output_file = coefficient_output_file - # def _reshape(self, pattern: str, data: DataArray, coordinates) -> DataArray: - # if pattern == "i,c->y,x,c": - # data = data.copy() - # # TODO fix the deprecation here! - # data["i"] = pd.MultiIndex.from_arrays((coordinates[:, 1], coordinates[:, 0]), names=("y", "x")) - # data = data.unstack("i") - # return data.transpose("y", "x", "c") - # elif pattern == "y,x,c->i,c": - # data = data.transpose("y", "x", "c").copy() - # data = data.stack(i=("y", "x")).drop_vars(["i", "x", "y"]) - # # convert to integers - # data["i"] = np.arange(len(data["i"])) - # return data.transpose("i", "c") - # else: - # raise ValueError(f"Unknown pattern={repr(pattern)}") - - def _validate_per_spectra_array(self, array: DataArray, coordinates_2d: NDArray[float]) -> None: - """Checks the DataArray has the correct shapes and dimensions. Used for debugging.""" - # TODO make it configurable in the future, whether this check is executed, during development it definitely - # should be here since it can safe a ton of time - expected_coords = {"i", "x", "y"} - if set(array.coords) != expected_coords: - raise ValueError(f"Expected coords={expected_coords}, got={set(array.coords)}") - expected_dims = {"i", "c"} - - errors = [] - if set(array.dims) != expected_dims: - logger.error(f"Expected dims={expected_dims}, got={set(array.dims)}") - errors.append("Mismatch in dimensions") - if not np.array_equal(array.x.values, coordinates_2d[:, 0]): - logger.error(f"Expected x: values={coordinates_2d[:, 0]} shape={coordinates_2d[:, 0].shape}") - logger.error(f"Actual x: values={array.x.values} shape={array.x.values.shape}") - logger.info(f"(Expected x values without offset: {coordinates_2d[:, 0] - coordinates_2d[:, 0].min()})") - errors.append("Mismatch in x values") - if not np.array_equal(array.y.values, coordinates_2d[:, 1]): - logger.error(f"Expected y: values={coordinates_2d[:, 1]} shape={coordinates_2d[:, 1].shape}") - logger.error(f"Actual y: values={array.y.values} shape={array.y.values.shape}") - logger.info(f"(Expected y values without offset: {coordinates_2d[:, 1] - coordinates_2d[:, 1].min()})") - errors.append("Mismatch in y values") - if not np.array_equal(array.i.values, np.arange(len(array.i))): - errors.append("Mismatch in i values") - logger.error(f"Expected i: values={np.arange(len(array.i))} shape={np.arange(len(array.i)).shape}") - logger.error(f"Actual i: values={array.i.values} shape={array.i.values.shape}") - if errors: - raise ValueError(errors) - def calibrate_image( self, read_peaks: GenericReadFile, write_file: GenericWriteFile, read_full: Optional[GenericReadFile] = None ) -> None: - if read_full is None: - read_full = read_peaks + read_full = read_full or read_peaks logger.info("Extracting all features...") - all_features = self._extract_all_features(read_peaks).transpose("i", "c") - self._validate_per_spectra_array(all_features, coordinates_2d=read_peaks.coordinates_2d) + all_features = self._extract_all_features(read_peaks) self._write_data_array(all_features, group="features_raw") logger.info("Preprocessing features...") - all_features = self._calibration.preprocess_image_features(all_features=all_features).transpose("i", "c") - self._validate_per_spectra_array(all_features, coordinates_2d=read_peaks.coordinates_2d) + all_features = self._calibration.preprocess_image_features(all_features=all_features) self._write_data_array(all_features, group="features_processed") logger.info("Fitting models...") - model_coefs = self._fit_all_models(all_features=all_features).transpose("i", "c") - self._validate_per_spectra_array(model_coefs, coordinates_2d=read_peaks.coordinates_2d) + model_coefs = self._fit_all_models(all_features=all_features) self._write_data_array(model_coefs, group="model_coefs") logger.info("Applying models...") self._apply_all_models(read_file=read_full, write_file=write_file, all_model_coefs=model_coefs) - def _extract_all_features(self, read_peaks: GenericReadFile) -> DataArray: + def _extract_all_features(self, read_peaks: GenericReadFile) -> MultiChannelImage: read_parallel = ReadSpectraParallel.from_config(self._parallel_config) all_features = read_parallel.map_chunked( read_file=read_peaks, @@ -106,12 +57,14 @@ def _extract_all_features(self, read_peaks: GenericReadFile) -> DataArray: ), reduce_fn=lambda chunks: xarray.concat(chunks, dim="i"), ) - return all_features.assign_coords( - x=("i", read_peaks.coordinates_2d[:, 0]), y=("i", read_peaks.coordinates_2d[:, 1]) + return MultiChannelImage.from_flat( + values=all_features, + coordinates=read_peaks.coordinates_array_2d, + channel_names="c" not in all_features.coords, ) def _apply_all_models( - self, read_file: GenericReadFile, write_file: GenericWriteFile, all_model_coefs: DataArray + self, read_file: GenericReadFile, write_file: GenericWriteFile, all_model_coefs: MultiChannelImage ) -> None: write_parallel = WriteSpectraParallel.from_config(self._parallel_config) write_parallel.map_chunked_to_file( @@ -124,15 +77,19 @@ def _apply_all_models( ), ) - def _fit_all_models(self, all_features: DataArray) -> DataArray: + def _fit_all_models(self, all_features: MultiChannelImage) -> MultiChannelImage: parallel_map = ParallelMap.from_config(self._parallel_config) + # TODO to be refactored + all_features_flat = all_features.data_flat result = parallel_map( operation=self._fit_chunk_models, - tasks=np.array_split(all_features.coords["i"], self._parallel_config.n_jobs), + tasks=np.array_split(all_features_flat.coords["i"], self._parallel_config.n_jobs), reduce_fn=lambda chunks: xarray.concat(chunks, dim="i"), - bind_kwargs={"all_features": all_features}, + bind_kwargs={"all_features": all_features_flat}, + ) + return MultiChannelImage.from_flat( + result, coordinates=all_features.coordinates_flat, channel_names="c" not in result.coords ) - return result def _fit_chunk_models(self, spectra_indices: NDArray[int], all_features: DataArray) -> DataArray: collect = [] @@ -140,15 +97,12 @@ def _fit_chunk_models(self, spectra_indices: NDArray[int], all_features: DataArr features = all_features.sel(i=spectrum_id) model_coef = self._calibration.fit_spectrum_model(features=features) collect.append(model_coef) - combined = xarray.concat(collect, dim="i") - combined.coords["i"] = spectra_indices - return combined + return xarray.concat(collect, dim="i") - def _write_data_array(self, array: DataArray, group: str) -> None: + def _write_data_array(self, image: MultiChannelImage, group: str) -> None: if not self._coefficient_output_file: return - # TODO engine should not be necessary, but using it for debugging - array.to_netcdf(path=self._coefficient_output_file, group=group, format="NETCDF4", engine="netcdf4", mode="a") + image.write_hdf5(path=self._coefficient_output_file, mode="a", group=group) @staticmethod def _extract_chunk_features( @@ -171,12 +125,12 @@ def _calibrate_spectra( spectra_indices: list[int], writer: GenericWriter, calibration: CalibrationMethod, - all_model_coefs: DataArray, + all_model_coefs: MultiChannelImage, ) -> None: for spectrum_id in spectra_indices: # TODO sanity check the usage of i as spectrum_id (i.e. check the coords!) mz_arr, int_arr, coords = reader.get_spectrum_with_coords(spectrum_id) - features = all_model_coefs.sel(i=spectrum_id) + features = all_model_coefs.data_flat.isel(i=spectrum_id) calib_mz_arr, calib_int_arr = calibration.apply_spectrum_model( spectrum_mz_arr=mz_arr, spectrum_int_arr=int_arr, model_coef=features ) diff --git a/src/depiction/calibration/spectrum/calibration_method_chemical_peptide_noise.py b/src/depiction/calibration/spectrum/calibration_method_chemical_peptide_noise.py index 0d010d0..3be9e03 100644 --- a/src/depiction/calibration/spectrum/calibration_method_chemical_peptide_noise.py +++ b/src/depiction/calibration/spectrum/calibration_method_chemical_peptide_noise.py @@ -6,6 +6,7 @@ from depiction.calibration.calibration_method import CalibrationMethod from depiction.calibration.chemical_noise_bg_2019_boskamp_v2 import ChemicalNoiseCalibration +from depiction.image import MultiChannelImage class CalibrationMethodChemicalPeptideNoise(CalibrationMethod): @@ -30,7 +31,7 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N # return DataArray(shifts_arr, dims=["c"]) return DataArray([], dims=["c"]) - def preprocess_image_features(self, all_features: DataArray) -> DataArray: + def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: # TODO no smoothing applied for now, but could be added (just, avoid duplication with the RegressShift) return all_features diff --git a/src/depiction/calibration/spectrum/calibration_method_dummy.py b/src/depiction/calibration/spectrum/calibration_method_dummy.py index c0f5181..df692c6 100644 --- a/src/depiction/calibration/spectrum/calibration_method_dummy.py +++ b/src/depiction/calibration/spectrum/calibration_method_dummy.py @@ -1,7 +1,9 @@ -from depiction.calibration.calibration_method import CalibrationMethod from numpy.typing import NDArray from xarray import DataArray +from depiction.calibration.calibration_method import CalibrationMethod +from depiction.image import MultiChannelImage + class CalibrationMethodDummy(CalibrationMethod): """Returns the input data and creates some dummy coefficients to ensure compatibility.""" @@ -9,7 +11,7 @@ class CalibrationMethodDummy(CalibrationMethod): def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: NDArray[float]) -> DataArray: return DataArray([0], dims=["c"]) - def preprocess_image_features(self, all_features: DataArray) -> DataArray: + def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: return all_features def fit_spectrum_model(self, features: DataArray) -> DataArray: diff --git a/src/depiction/calibration/spectrum/calibration_method_global_constant_shift.py b/src/depiction/calibration/spectrum/calibration_method_global_constant_shift.py index e71aa75..2b581e7 100644 --- a/src/depiction/calibration/spectrum/calibration_method_global_constant_shift.py +++ b/src/depiction/calibration/spectrum/calibration_method_global_constant_shift.py @@ -4,6 +4,7 @@ from depiction.calibration.calibration_method import CalibrationMethod from depiction.calibration.spectrum.reference_peak_distances import ReferencePeakDistances +from depiction.image import MultiChannelImage class CalibrationMethodGlobalConstantShift(CalibrationMethod): @@ -27,12 +28,17 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N ) return DataArray(distances_mz, dims=["c"]) - def preprocess_image_features(self, all_features: DataArray) -> DataArray: + def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: # we compute the actual global distance here - global_distance = np.nanmedian(all_features.values.ravel()) + global_distance = np.nanmedian(all_features.data_flat.ravel()) # create one copy per spectrum - n_spectra = all_features.sizes["i"] - return DataArray(np.full((n_spectra, 1), global_distance), dims=["i", "c"], coords=all_features.coords) + n_spectra = all_features.n_nonzero + return MultiChannelImage( + data=DataArray( + np.full((n_spectra, 1, 1), global_distance), dims=["y", "x", "c"], coords=all_features.coords + ), + is_foreground=DataArray(np.ones((n_spectra, 1), dtype=bool), dims=["y", "x"], coords=all_features.coords), + ) def fit_spectrum_model(self, features: DataArray) -> DataArray: return features diff --git a/src/depiction/calibration/spectrum/calibration_method_mcc.py b/src/depiction/calibration/spectrum/calibration_method_mcc.py index 918e952..3601f0f 100644 --- a/src/depiction/calibration/spectrum/calibration_method_mcc.py +++ b/src/depiction/calibration/spectrum/calibration_method_mcc.py @@ -7,7 +7,8 @@ from xarray import DataArray from depiction.calibration.calibration_method import CalibrationMethod -from depiction.calibration.spectrum.calibration_smoothing import smooth_image_features +from depiction.image import MultiChannelImage +from depiction.image.smoothing.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware class CalibrationMethodMassClusterCenterModel(CalibrationMethod): @@ -62,13 +63,12 @@ def compute_distance_from_MCC(self, delta: NDArray[float], l_none: float = 1.000 delta_lambda[i] = -1 + term1 return delta_lambda - def preprocess_image_features(self, all_features: DataArray) -> DataArray: + def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: if self._model_smoothing_activated: - return smooth_image_features( - all_features=all_features, - kernel_size=self._model_smoothing_kernel_size, - kernel_std=self._model_smoothing_kernel_std, + smoother = SpatialSmoothingSparseAware( + kernel_size=self._model_smoothing_kernel_size, kernel_std=self._model_smoothing_kernel_std ) + return smoother.smooth_image(all_features) else: return all_features diff --git a/src/depiction/calibration/spectrum/calibration_method_regress_shift.py b/src/depiction/calibration/spectrum/calibration_method_regress_shift.py index 53aa99b..a809370 100644 --- a/src/depiction/calibration/spectrum/calibration_method_regress_shift.py +++ b/src/depiction/calibration/spectrum/calibration_method_regress_shift.py @@ -5,8 +5,9 @@ from depiction.calibration.calibration_method import CalibrationMethod from depiction.calibration.models import LinearModel from depiction.calibration.models.fit_model import fit_model -from depiction.calibration.spectrum.calibration_smoothing import smooth_image_features from depiction.calibration.spectrum.reference_peak_distances import ReferencePeakDistances +from depiction.image import MultiChannelImage +from depiction.image.smoothing.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware class CalibrationMethodRegressShift(CalibrationMethod): @@ -80,13 +81,12 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N else: raise ValueError(f"Unknown unit={self._model_unit}") - def preprocess_image_features(self, all_features: DataArray) -> DataArray: + def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: if self._input_smoothing_activated: - return smooth_image_features( - all_features=all_features, - kernel_size=self._input_smoothing_kernel_size, - kernel_std=self._input_smoothing_kernel_std, + smoother = SpatialSmoothingSparseAware( + kernel_size=self._input_smoothing_kernel_size, kernel_std=self._input_smoothing_kernel_std ) + return smoother.smooth_image(all_features) else: return all_features diff --git a/src/depiction/calibration/spectrum/calibration_smoothing.py b/src/depiction/calibration/spectrum/calibration_smoothing.py deleted file mode 100644 index 0d146c2..0000000 --- a/src/depiction/calibration/spectrum/calibration_smoothing.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from depiction.image.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware -from depiction.image.xarray_helper import XarrayHelper -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from xarray import DataArray - - -# TODO should be refactored later - -# TODO test this case : a spectrum is all nan before, but present in the flat repr, -# it should not disappear like an actual background - - -def smooth_image_features(all_features: DataArray, kernel_size: int, kernel_std: float) -> DataArray: - """Smoothes the image features using a 2D Gaussian kernel, assuming the data is in a collapsed representation.""" - - def fn(array_2d: DataArray) -> DataArray: - smoother = SpatialSmoothingSparseAware( - kernel_size=kernel_size, - kernel_std=kernel_std, - ) - return smoother.smooth(array_2d, bg_value=np.nan) - - return XarrayHelper.apply_on_spatial_view(all_features, fn) - # return _apply_on_spatial_view(all_features, fn) diff --git a/src/depiction/clustering/remap_clusters.py b/src/depiction/clustering/remap_clusters.py index 8a1561d..31c427b 100644 --- a/src/depiction/clustering/remap_clusters.py +++ b/src/depiction/clustering/remap_clusters.py @@ -57,11 +57,10 @@ def remap_cluster_labels( :param mapping: a dictionary mapping the old cluster labels to the new ones :param cluster_channel: the name of the channel with the cluster labels """ - with xarray.set_options(keep_attrs=True): - relabeled = xarray.apply_ufunc( - lambda v: mapping.get(v, np.nan), image.data_spatial.sel(c=[cluster_channel]), vectorize=True - ) + relabeled = xarray.apply_ufunc( + lambda v: mapping.get(v, np.nan), image.data_spatial.sel(c=[cluster_channel]), vectorize=True + ) img_relabeled = image.drop_channels(coords=[cluster_channel], allow_missing=False).append_channels( - MultiChannelImage(relabeled) + MultiChannelImage(relabeled, is_foreground=image.fg_mask, is_foreground_label=image.is_foreground_label) ) return img_relabeled diff --git a/src/depiction/image/container/__init__.py b/src/depiction/image/container/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/depiction/image/container/alpha_channel.py b/src/depiction/image/container/alpha_channel.py new file mode 100644 index 0000000..62cc5ad --- /dev/null +++ b/src/depiction/image/container/alpha_channel.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import xarray + + +class AlphaChannel: + """Implements logic to stack an alpha channel on top of an arbitrary channel image and split it off again. + + The alpha channel is expected to be a boolean mask, where True indicates foreground and False background. + When stacked together it will take the numeric value of 1 for True and 0 for False. + + Additionally, while the image is supposed to have shape (y, x, c), the alpha channel will be stacked as a single + channel with shape (y, x). + """ + + def __init__(self, label: str) -> None: + self._alpha_label = label + + def stack(self, data_array: xarray.DataArray, is_fg_array: xarray.DataArray) -> xarray.DataArray: + """Stacks the alpha channel on top of the data array.""" + return xarray.concat( + [data_array, is_fg_array.expand_dims("c", axis=-1).assign_coords(c=[self._alpha_label])], dim="c" + ) + + def split(self, combined: xarray.DataArray) -> tuple[xarray.DataArray, xarray.DataArray]: + """Splits the alpha channel off the combined array and returns the data array and the alpha array.""" + data_array = combined.drop_sel(c=self._alpha_label) + is_fg_array = combined.sel(c=self._alpha_label).drop_vars("c").astype(bool) + return data_array, is_fg_array diff --git a/src/depiction/image/horizontal_concat.py b/src/depiction/image/horizontal_concat.py index b386a7d..2a99a5b 100644 --- a/src/depiction/image/horizontal_concat.py +++ b/src/depiction/image/horizontal_concat.py @@ -1,4 +1,6 @@ import xarray + +from depiction.image.container.alpha_channel import AlphaChannel from depiction.image.multi_channel_image import MultiChannelImage @@ -19,10 +21,10 @@ def horizontal_concat( # shift x coordinates iteratively xoffset = 0 concat = [] - bg_value = images[0].bg_value + alpha_channel = AlphaChannel(label=images[0].is_foreground_label) for i_image, image in enumerate(images): - data = image.data_spatial - data = data.pad(y=(0, ymax - data.y.values.max()), constant_values=bg_value) + data = alpha_channel.stack(image.data_spatial, image.fg_mask) + data = data.pad(y=(0, ymax - data.y.values.max()), constant_values=0) x_extent = data.x.values.max() - data.x.values.min() + 1 data_shifted = data.assign_coords(x=data.x - data.x.values.min() + xoffset) if add_index: @@ -31,4 +33,5 @@ def horizontal_concat( concat.append(data_shifted) xoffset += x_extent data = xarray.concat(concat, dim="x") - return MultiChannelImage(data) + array, is_foreground = alpha_channel.split(data) + return MultiChannelImage(array, is_foreground) diff --git a/src/depiction/image/image_channel_stats.py b/src/depiction/image/image_channel_stats.py index 67eb7e2..8639642 100644 --- a/src/depiction/image/image_channel_stats.py +++ b/src/depiction/image/image_channel_stats.py @@ -69,7 +69,5 @@ def _get_channel_values(self, i_channel: int, drop_missing: bool) -> np.ndarray: # TODO maybe caching data_flat would already make this faster, could be tested easily by temporarily adding the cache in the MultiChannelImage class data_channel = self._image.data_flat.isel(c=i_channel).values if drop_missing: - bg = self._image.bg_value - bg_mask = np.isnan(data_channel) if np.isnan(bg) else data_channel == bg - data_channel = data_channel[~bg_mask] + data_channel = data_channel[self._image.fg_mask_flat] return data_channel diff --git a/src/depiction/image/image_normalization.py b/src/depiction/image/image_normalization.py index 6d3d90f..48c4da4 100644 --- a/src/depiction/image/image_normalization.py +++ b/src/depiction/image/image_normalization.py @@ -15,8 +15,18 @@ class ImageNormalizationVariant(enum.Enum): # TODO maybe rename to ImageFeatureNormalization + + class ImageNormalization: - def normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray: + def normalize_image( + self, image: MultiChannelImage, variant: ImageNormalizationVariant, dim: str = "c" + ) -> MultiChannelImage: + normalized = self._normalize_single_xarray(image.data_spatial, variant=variant) + return MultiChannelImage( + data=normalized, is_foreground=image.fg_mask, is_foreground_label=image.is_foreground_label + ) + + def _normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray: # First, understand the dimensions of the image. known_dims = ["y", "x", "c"] missing_dims = set(known_dims) - set(image.dims) @@ -30,17 +40,15 @@ def normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationV else: raise NotImplementedError("Multiple index columns are not supported yet.") - def normalize_image(self, image: MultiChannelImage, variant: ImageNormalizationVariant) -> MultiChannelImage: - return MultiChannelImage(self.normalize_xarray(image.data_spatial, variant=variant)) - def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray: + bg_value = 0 with xarray.set_options(keep_attrs=True): if variant == ImageNormalizationVariant.VEC_NORM: norm = ((image**2).sum(["c"])) ** 0.5 - return xarray.where(norm != 0, image / norm, image.attrs.get("bg_value", 0)) + return xarray.where(norm != 0, image / norm, bg_value) elif variant == ImageNormalizationVariant.STD: std = image.std("c") - return xarray.where(std != 0, (image - image.mean("c")) / std, image.attrs.get("bg_value", 0)) + return xarray.where(std != 0, (image - image.mean("c")) / std, bg_value) else: raise NotImplementedError(f"Unknown variant: {variant}") diff --git a/src/depiction/image/multi_channel_image.py b/src/depiction/image/multi_channel_image.py index 349facc..87a41f3 100644 --- a/src/depiction/image/multi_channel_image.py +++ b/src/depiction/image/multi_channel_image.py @@ -1,16 +1,17 @@ from __future__ import annotations +import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np import xarray - -from depiction.image.image_channel_stats import ImageChannelStats -from depiction.image.sparse_representation import SparseRepresentation from numpy.typing import NDArray from xarray import DataArray +from depiction.image.image_channel_stats import ImageChannelStats +from depiction.image.multi_channel_image_persistence import MultiChannelImagePersistence +from depiction.image.sparse_representation import SparseRepresentation from depiction.persistence.format_ome_tiff import OmeTiff if TYPE_CHECKING: @@ -18,13 +19,100 @@ from pathlib import Path +# TODO would it be clever or stupid to call is_foreground "alpha" channel? + + class MultiChannelImage: - """Represents a multi-channel 2D image, internally backed by a `xarray.DataArray`.""" + """Represents a multi-channel 2D image, internally backed by a `xarray.DataArray`. + + The API is generally designed to be immutable, i.e. methods modifying the image return a new instance. + The image is internally represented in a dense layout, with the background/foreground being explicitly stored in + a `is_foreground` channel that is not part of the `n_channels` count but will be exported. + This is to make the conversion to and from sparse representation sane. + """ + + def __init__(self, data: DataArray, is_foreground: DataArray, is_foreground_label: str = "is_foreground") -> None: + if "bg_value" in data.attrs: + # TODO remove this warning at a later time + warnings.warn("bg_value is deprecated, use is_foreground instead", DeprecationWarning) + + # Assign the data + self._data = data.transpose("y", "x", "c").drop_attrs() + self._is_foreground = is_foreground.transpose("y", "x").drop_vars("c", errors="ignore").drop_attrs() + self._is_foreground_label = is_foreground_label + + # Validate the input + self._assert_data_and_foreground_dimensions() + self._assert_data_and_foreground_coords() + self._assert_foreground_is_boolean() + self._assert_data_channel_names_present() + + def _assert_data_channel_names_present(self) -> None: + """Asserts that the data has channel names and that they are strings.""" + if "c" not in self._data.coords: + raise ValueError("Data must have a 'c' coordinate for channel names.") + if self._data.sizes["c"] > 0 and not isinstance(self._data.c[0].item(), str): + raise ValueError(f"Channel names must be strings, but type is: {type(self._data.c[0].item())}.") + + def _assert_data_and_foreground_coords(self) -> None: + if np.not_equal(self._data.coords["y"], self._is_foreground.coords["y"]).any(): + raise ValueError("Inconsistent y coordinate values between data and is_foreground.") + if np.not_equal(self._data.coords["x"], self._is_foreground.coords["x"]).any(): + raise ValueError("Inconsistent x coordinate values between data and is_foreground.") + + def _assert_data_and_foreground_dimensions(self) -> None: + if ( + self._data.sizes["x"] != self._is_foreground.sizes["x"] + or self._data.sizes["y"] != self._is_foreground.sizes["y"] + ): + msg = ( + "'data' and 'is_foreground' must have the same dimensions, but " + f"data[y,x] = {self._data.sizes['y'], self._data.sizes['x']}, " + f"is_foreground[y,x] = {self._is_foreground.sizes['y'], self._is_foreground.sizes['x']}." + ) + raise ValueError(msg) + + def _assert_foreground_is_boolean(self) -> None: + if self._is_foreground.dtype != np.bool_: + raise ValueError(f"is_foreground must be a boolean array, but has dtype {self._is_foreground.dtype}.") - def __init__(self, data: DataArray) -> None: - self._data = data.transpose("y", "x", "c") - if "bg_value" not in self._data.attrs: - raise ValueError("The bg_value attribute must be set.") + @classmethod + def from_spatial( + cls, data: DataArray, bg_value: float = 0, is_foreground_label: str = "is_foreground" + ) -> MultiChannelImage: + # TODO improve this method + is_fg = cls._compute_is_foreground(data=data, bg_value=bg_value) + return cls(data=data, is_foreground=is_fg, is_foreground_label=is_foreground_label) + + @classmethod + def from_flat( + cls, + values: DataArray, + coordinates: DataArray | None, + channel_names: list[str] | bool = False, + bg_value: float = 0.0, + ): + coordinates = cls._extract_flat_coordinates(values) if coordinates is None else coordinates + if channel_names: + if "c" in values.coords: + msg = ( + "Either provide channel names as coordinate in values or as argument, but not both. " + "Use .drop_vars('c') to remove the `c` coordinate, " + "or use `.assign_coords(c=channel_names)` to add channel names directly." + ) + raise ValueError(msg) + elif isinstance(channel_names, bool): + channel_names = [str(i) for i in range(values.sizes["c"])] + else: + channel_names = [str(name) for name in channel_names] + values = values.assign_coords(c=channel_names) + + data, is_foreground = SparseRepresentation.flat_to_spatial( + sparse_values=values.transpose("i", "c"), + coordinates=cls._validate_coordinates(coordinates), + bg_value=bg_value, + ) + return cls(data=data, is_foreground=is_foreground) @classmethod def from_sparse( @@ -41,16 +129,17 @@ def from_sparse( :param channel_names: The names of the channels. :param bg_value: The background value. """ - data = SparseRepresentation.sparse_to_dense( + # TODO delete method + warnings.warn("from_sparse is deprecated, use from_flat instead", DeprecationWarning) + data, is_foreground = SparseRepresentation.flat_to_spatial( sparse_values=cls._validate_sparse_values(values), coordinates=cls._validate_coordinates(coordinates), bg_value=bg_value, ) - data.attrs["bg_value"] = bg_value channel_names = list(channel_names) if channel_names is not None else None if channel_names: data.coords["c"] = channel_names - return cls(data=data) + return cls(data=data, is_foreground=is_foreground) @property def n_channels(self) -> int: @@ -60,8 +149,7 @@ def n_channels(self) -> int: @property def n_nonzero(self) -> int: """Number of non-zero values.""" - # TODO efficient impl - return (~self.bg_mask).sum().item() + return self._is_foreground.sum().item() @property def dtype(self) -> np.dtype: @@ -69,14 +157,29 @@ def dtype(self) -> np.dtype: return self._data.dtype @property - def bg_value(self) -> int | float: - """The background value.""" - return self._data.attrs["bg_value"] + def is_foreground_label(self) -> str: + """The label for the is_foreground channel when persisting.""" + return self._is_foreground_label - @cached_property + @property + def fg_mask(self) -> DataArray: + """A boolean mask indicating the foreground values as `True` and non-foreground values as `False`.""" + return self._is_foreground + + @property def bg_mask(self) -> DataArray: """A boolean mask indicating the background values as `True` and non-background values as `False`.""" - return ((self._data == self.bg_value) | (self._data.isnull() & np.isnan(self.bg_value))).all(dim="c") + return ~self._is_foreground + + @property + def fg_mask_flat(self) -> DataArray: + """A boolean mask indicating the foreground values as `True` and non-foreground values as `False`.""" + return self._is_foreground.stack(i=("y", "x")).dropna(dim="i") + + @property + def bg_mask_flat(self) -> DataArray: + """A boolean mask indicating the background values as `True` and non-background values as `False`.""" + return ~self.fg_mask_flat @property def dimensions(self) -> tuple[int, int]: @@ -87,7 +190,6 @@ def dimensions(self) -> tuple[int, int]: @property def channel_names(self) -> list[str]: """Returns the names of the channels.""" - # TODO consider renaming to `channels` return [str(c) for c in self._data.coords["c"].values.tolist()] @property @@ -98,20 +200,32 @@ def data_spatial(self) -> DataArray: @property def data_flat(self) -> DataArray: """Returns the underlying data, in its flat form, i.e. dimensions (i, c), omitting any background values.""" - return self._data.where(~self.bg_mask).stack(i=("y", "x")).dropna(dim="i") + return self._data.stack(i=("y", "x")).isel(i=self.fg_mask_flat) @property def coordinates_flat(self) -> DataArray: """Returns the coordinates of the non-background values.""" orig_coords = self.data_flat.coords + # TODO make consistent + # return DataArray( + # np.stack((orig_coords["x"].values, orig_coords["y"].values), axis=0), + # dims=("d", "i"), + # coords={"d": ["x", "y"], "i": orig_coords["i"]}, + # ) return DataArray( np.stack((orig_coords["y"].values, orig_coords["x"].values), axis=0), dims=("d", "i"), coords={"d": ["y", "x"], "i": orig_coords["i"]}, ) - # TODO from_dense_array + def recompute_is_foreground(self, bg_value: float = 0.0) -> MultiChannelImage: + """Returns a copy of self with a recomputed is_foreground mask, based on the provided bg value.""" + is_foreground = self._compute_is_foreground(data=self._data, bg_value=bg_value) + return MultiChannelImage( + data=self._data, is_foreground=is_foreground, is_foreground_label=self._is_foreground_label + ) + # TODO rename to sel_channels def retain_channels( self, indices: Sequence[int] | None = None, coords: Sequence[Any] | None = None ) -> MultiChannelImage: @@ -119,39 +233,53 @@ def retain_channels( if (indices is not None) == (coords is not None): raise ValueError("Exactly one of indices and coords must be specified.") data = self._data.isel(c=indices) if indices is not None else self._data.sel(c=coords) - return MultiChannelImage(data=data) + return MultiChannelImage( + data=data, is_foreground=self._is_foreground, is_foreground_label=self._is_foreground_label + ) + # TODO rename to dropsel_channels def drop_channels(self, *, coords: Sequence[Any], allow_missing: bool) -> MultiChannelImage: """Returns a copy with the specified channels dropped.""" data = self._data.drop_sel(c=coords, errors="ignore" if allow_missing else "raise") - return MultiChannelImage(data=data) + return MultiChannelImage( + data=data, is_foreground=self._is_foreground, is_foreground_label=self._is_foreground_label + ) # TODO save_single_channel_image... does it belong here or into plotter? - def write_hdf5(self, path: Path) -> None: + def write_hdf5(self, path: Path, mode: Literal["a", "w"] = "w", group: str | None = None) -> None: """Writes the image to a HDF5 file (actually NETCDF4).""" - self._data.to_netcdf(path, format="NETCDF4") + return MultiChannelImagePersistence(image=self).write_hdf5(path=path, mode=mode, group=group) @classmethod - def read_hdf5(cls, path: Path, group: str | None = None) -> MultiChannelImage: + def read_hdf5( + cls, path: Path, group: str | None = None, is_foreground_label: str = "is_foreground" + ) -> MultiChannelImage: """Reads a MultiChannelImage from a HDF5 file (assuming it contains NETCDF data). :param path: The path to the HDF5 file. :param group: The group within the HDF5 file, by default None. + :param is_foreground_label: The label for the is_foreground channel, by default "is_foreground". """ - return cls(data=xarray.open_dataarray(path, group=group)) + return MultiChannelImagePersistence.read_hdf5(path=path, group=group, is_foreground_label=is_foreground_label) - # TODO is_valid_hdf5 # TODO combine_in_parallel, combine_sequentially: consider moving this somewhere else @classmethod - def read_ome_tiff(cls, path: Path) -> MultiChannelImage: + def read_ome_tiff(cls, path: Path, bg_value: float = 0.0) -> MultiChannelImage: """Reads a MultiChannelImage from a OME-TIFF file.""" - return MultiChannelImage(data=OmeTiff.read(path)) + data = OmeTiff.read(path) + return MultiChannelImage(data=data, is_foreground=cls._compute_is_foreground(data=data, bg_value=bg_value)) def with_channel_names(self, channel_names: Sequence[str]) -> MultiChannelImage: """Returns a copy with the specified channel names.""" - return MultiChannelImage(data=self._data.assign_coords(c=channel_names)) + # TODO too specific! it would be better to have a "rename_channels" method instead that allows specifying only some + # or, do a "select" like in polars + return MultiChannelImage( + data=self._data.assign_coords(c=channel_names), + is_foreground=self._is_foreground, + is_foreground_label=self._is_foreground_label, + ) @cached_property def channel_stats(self) -> ImageChannelStats: @@ -165,13 +293,19 @@ def append_channels(self, other: MultiChannelImage) -> MultiChannelImage: msg = f"Channels {common_channels} are present in both images." raise ValueError(msg) data = xarray.concat([self._data, other._data], dim="c") - return MultiChannelImage(data=data) + return MultiChannelImage( + data=data, is_foreground=self._is_foreground, is_foreground_label=self._is_foreground_label + ) def get_z_scaled(self) -> MultiChannelImage: """Returns a copy of self with each feature z-scaled.""" eps = 1e-12 with xarray.set_options(keep_attrs=True): - return MultiChannelImage(data=(self._data - self.channel_stats.mean + eps) / (self.channel_stats.std + eps)) + return MultiChannelImage( + data=(self._data - self.channel_stats.mean + eps) / (self.channel_stats.std + eps), + is_foreground=self._is_foreground, + is_foreground_label=self._is_foreground_label, + ) # TODO reconsider:there is actually a problem, whether it should use bg_mask only or also replace individual values # since both could be necessary it should be implemented in a sane and maintainable manner @@ -199,7 +333,14 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"MultiChannelImage(data={self._data!r})" - @staticmethod + @classmethod + def _compute_is_foreground(cls, data: DataArray, bg_value: float = np.nan) -> DataArray: + """Computes the foreground mask from the data.""" + if np.isnan(bg_value): + return ~data.isnull().all(dim="c") + else: + return (data != bg_value).any(dim="c") + def _validate_sparse_values(values: NDArray[float] | DataArray) -> DataArray: """Converts the sparse values to a DataArray, if necessary.""" if hasattr(values, "coords"): @@ -212,9 +353,18 @@ def _validate_sparse_values(values: NDArray[float] | DataArray) -> DataArray: @staticmethod def _validate_coordinates(coordinates: NDArray[int] | DataArray) -> DataArray: """Converts the coordinates to a DataArray, if necessary.""" - if hasattr(coordinates, "coords"): - return coordinates.transpose("i", "d") + if not hasattr(coordinates, "coords"): + return DataArray(coordinates, dims=("i", "d"), coords={"d": ["x", "y"]}) else: - if coordinates.ndim != 2: - raise ValueError("Coordinates must be a 2D array.") - return DataArray(coordinates, dims=("i", "d")) + coordinates = coordinates.transpose("i", "d").sortby("d") + if not coordinates.coords["d"].values.tolist() == ["x", "y"]: + raise ValueError("Coordinates must have dimensions 'x' and 'y'.") + return coordinates + + @classmethod + def _extract_flat_coordinates(cls, values: DataArray) -> DataArray: + return DataArray( + np.stack([values.coords["x"].values, values.coords["y"].values], axis=1), + dims=("i", "d"), + coords={"d": ["x", "y"]}, + ) diff --git a/src/depiction/image/multi_channel_image_concatenation.py b/src/depiction/image/multi_channel_image_concatenation.py index 32fed78..011f77d 100644 --- a/src/depiction/image/multi_channel_image_concatenation.py +++ b/src/depiction/image/multi_channel_image_concatenation.py @@ -51,12 +51,7 @@ def get_single_image(self, index: int, min_coords: tuple[int, int] = (0, 0)) -> sel_coords = sel_coords - sel_coords.min(axis=1) + np.array(min_coords)[:, None] # create the individual image - return MultiChannelImage.from_sparse( - values=sel_values, - coordinates=sel_coords, - channel_names=sel_values.coords["c"].values.tolist(), - bg_value=sel_values.bg_value, - ) + return MultiChannelImage.from_flat(values=sel_values, coordinates=sel_coords) def get_single_images(self) -> list[MultiChannelImage]: return [self.get_single_image(index=index) for index in range(self.n_individual_images)] diff --git a/src/depiction/image/multi_channel_image_persistence.py b/src/depiction/image/multi_channel_image_persistence.py new file mode 100644 index 0000000..8fab0d0 --- /dev/null +++ b/src/depiction/image/multi_channel_image_persistence.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import xarray + +from depiction.image.container.alpha_channel import AlphaChannel + +if TYPE_CHECKING: + from depiction.image.multi_channel_image import MultiChannelImage + + +# TODO currently the files are not closed, which you can observe e.g. in a notebook when the original files has been +# replaced + + +class MultiChannelImagePersistence: + """Implements the persistence layer logic for MultiChannelImage.""" + + def __init__(self, image: MultiChannelImage) -> None: + self._image = image + self._alpha_channel = AlphaChannel(label=image.is_foreground_label) + + def write_hdf5(self, path: Path, mode: Literal["a", "w"] = "w", group: str | None = None) -> None: + data_array = self._image.data_spatial + is_fg_array = self._image.fg_mask + + if not isinstance(data_array.coords["c"][0].item(), str): + # TODO this really should be validated against in the constructor, and the static methods need to set it + # TODO FIXME later + data_array = data_array.assign_coords(c=self._image.channel_names) + + combined_array = self._alpha_channel.stack(data_array=data_array, is_fg_array=is_fg_array) + # TODO engine should not be necessary, but using it for debugging + combined_array.to_netcdf(path, mode=mode, group=group, format="NETCDF4", engine="netcdf4") + + @classmethod + def read_hdf5( + cls, path: Path, group: str | None = None, is_foreground_label: str = "is_foreground" + ) -> MultiChannelImage: + from depiction.image.multi_channel_image import MultiChannelImage + + combined_array = xarray.open_dataarray(path, group=group) + data_array, is_fg_array = AlphaChannel(label=is_foreground_label).split(combined=combined_array) + return MultiChannelImage(data=data_array, is_foreground=is_fg_array, is_foreground_label=is_foreground_label) + + # TODO is_valid_hdf5 diff --git a/src/depiction/image/smoothing/__init__.py b/src/depiction/image/smoothing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/depiction/image/smoothing/bilateral_filter.py b/src/depiction/image/smoothing/bilateral_filter.py new file mode 100644 index 0000000..1a46f23 --- /dev/null +++ b/src/depiction/image/smoothing/bilateral_filter.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import cv2 +import numpy as np +import xarray as xr +from loguru import logger + +from depiction.image import MultiChannelImage +from depiction.image.xarray_helper import XarrayHelper + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +@dataclass(frozen=True) +class SmoothBilateralFilter: + diameter: int = 5 + sigma_intensity: float = 5.0 + sigma_spatial: float = 20.0 + + def smooth_image(self, image: MultiChannelImage) -> MultiChannelImage: + data = XarrayHelper.ensure_dense(image.data_spatial) + is_foreground = XarrayHelper.ensure_dense(image.fg_mask) + dat = xr.apply_ufunc( + self._smooth_dense_image, + data, + is_foreground, + input_core_dims=[["y", "x"], ["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + ) + return MultiChannelImage(dat, is_foreground=is_foreground, is_foreground_label=image.is_foreground_label) + + def _smooth_dense_image(self, image_2d: NDArray[float], is_foreground: NDArray[bool]) -> NDArray[float]: + if not np.issubdtype(image_2d.dtype, np.floating): + raise ValueError("The input image must be a floating point array.") + + # apply the bilateral filter + logger.info("Applying bilateral filter") + smoothed_image = cv2.bilateralFilter( + np.nan_to_num(image_2d.astype(np.float32)), + d=self.diameter, + sigmaColor=self.sigma_intensity, + sigmaSpace=self.sigma_spatial, + ) + smoothed_image[~is_foreground] = 0 + return smoothed_image diff --git a/src/depiction/image/smoothing/median_filter.py b/src/depiction/image/smoothing/median_filter.py new file mode 100644 index 0000000..72891db --- /dev/null +++ b/src/depiction/image/smoothing/median_filter.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import scipy +import scipy.ndimage +import xarray as xr +from loguru import logger + +from depiction.image import MultiChannelImage +from depiction.image.xarray_helper import XarrayHelper + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +@dataclass(frozen=True) +class SmoothMedianFilter: + kernel_size: int = 9 + + def smooth_image(self, image: MultiChannelImage) -> MultiChannelImage: + data = XarrayHelper.ensure_dense(image.data_spatial) + is_foreground = XarrayHelper.ensure_dense(image.fg_mask) + dat = xr.apply_ufunc( + self._smooth_dense_image, + data, + is_foreground, + input_core_dims=[["y", "x"], ["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + ) + return MultiChannelImage(dat, is_foreground=is_foreground, is_foreground_label=image.is_foreground_label) + + def _smooth_dense_image(self, image_2d: NDArray[float], is_foreground: NDArray[bool]) -> NDArray[float]: + if not np.issubdtype(image_2d.dtype, np.floating): + raise ValueError("The input image must be a floating point array.") + + # apply the bilateral filter + logger.info("Applying median filter") + # smoothed_image = cv2.medianBlur(np.nan_to_num(image_2d.astype(np.float32)), ksize=self.kernel_size) + smoothed_image = scipy.ndimage.median_filter(np.nan_to_num(image_2d.astype(np.float32)), size=self.kernel_size) + smoothed_image[~is_foreground] = 0 + return smoothed_image diff --git a/src/depiction/image/spatial_smoothing.py b/src/depiction/image/smoothing/spatial_smoothing.py similarity index 100% rename from src/depiction/image/spatial_smoothing.py rename to src/depiction/image/smoothing/spatial_smoothing.py diff --git a/src/depiction/image/spatial_smoothing_sparse_aware.py b/src/depiction/image/smoothing/spatial_smoothing_sparse_aware.py similarity index 64% rename from src/depiction/image/spatial_smoothing_sparse_aware.py rename to src/depiction/image/smoothing/spatial_smoothing_sparse_aware.py index a2c3c40..2dbc21a 100644 --- a/src/depiction/image/spatial_smoothing_sparse_aware.py +++ b/src/depiction/image/smoothing/spatial_smoothing_sparse_aware.py @@ -6,8 +6,8 @@ import scipy import xarray as xr from numpy.typing import NDArray -from xarray import DataArray +from depiction.image import MultiChannelImage from depiction.image.xarray_helper import XarrayHelper @@ -26,44 +26,42 @@ class SpatialSmoothingSparseAware: kernel_std: float use_interpolation: bool = False - def smooth(self, image: DataArray, bg_value: float = 0.0) -> DataArray: - image = image.transpose("y", "x", "c") - image = image.astype(np.promote_types(image.dtype, np.dtype(type(bg_value)).type)) - image = XarrayHelper.ensure_dense(image) - image = xr.apply_ufunc( + def smooth_image(self, image: MultiChannelImage) -> MultiChannelImage: + data_input = XarrayHelper.ensure_dense(image.data_spatial) + is_foreground = XarrayHelper.ensure_dense(image.fg_mask) + data_result = xr.apply_ufunc( self._smooth_dense, - image, - input_core_dims=[["y", "x"]], + data_input, + is_foreground, + input_core_dims=[["y", "x"], ["y", "x"]], output_core_dims=[["y", "x"]], vectorize=True, - kwargs={"bg_value": bg_value}, ) - return image.transpose("y", "x", "c") + if self.use_interpolation: + is_foreground[:] = True + return MultiChannelImage( + data_result, is_foreground=is_foreground, is_foreground_label=image.is_foreground_label + ) - def _smooth_dense(self, image_2d: NDArray[float], bg_value: float) -> NDArray[float]: - """Applies the spatial smoothing to the provided 2D image.""" - if not np.issubdtype(image_2d.dtype, np.floating): - raise ValueError("The input image must be a floating point array.") + def _smooth_dense(self, image_2d: NDArray[float], is_foreground: NDArray[float]) -> NDArray[float]: + image_2d = image_2d.astype(float) - # Get an initial kernel, and mask of the missing values. + # Get an initial kernel kernel = self.gaussian_kernel - is_missing = (image_2d == bg_value) | (np.isnan(bg_value) & np.isnan(image_2d)) # Apply the kernel to the image. smoothed_image = scipy.signal.convolve(np.nan_to_num(image_2d), kernel, mode="same") # Apply the kernel counting the sum of the weights, so we can normalize the data. - kernel_sum_image = scipy.signal.convolve((~is_missing).astype(float), kernel, mode="same") - # Values are zero, when a pixel and all its neighbors are missing. + kernel_sum_image = scipy.signal.convolve(is_foreground.astype(float), kernel, mode="same") + # Values are zero, when a pixel and all its neighbors are missing (but they are masked anyways). kernel_sum_image[np.abs(kernel_sum_image) < 1e-10] = 1 - # TODO double check this does not mess up the scaling of the values - # Normalize the image, and set the missing values to NaN. result_image = smoothed_image / kernel_sum_image if not self.use_interpolation: - result_image[is_missing] = bg_value + result_image[~is_foreground] = 0.0 # Return the result. return result_image diff --git a/src/depiction/image/sparse_representation.py b/src/depiction/image/sparse_representation.py index 988fdc1..e3ee2f2 100644 --- a/src/depiction/image/sparse_representation.py +++ b/src/depiction/image/sparse_representation.py @@ -19,29 +19,40 @@ class SparseRepresentation: """ @classmethod - def sparse_to_dense(cls, sparse_values: DataArray, coordinates: DataArray, bg_value: float) -> DataArray: - """Converts the sparse image representation into a dense image representation. - :param sparse_values: DataArray with "i" (index of value) and "c" (channel) dimensions - :param coordinates: DataArray with "i" (index of value) and "d" (dimension) dimensions - :param bg_value: the value to use for the background - :return: DataArray with "y", "x", and "c" dimensions - """ + def flat_to_spatial( + cls, sparse_values: DataArray, coordinates: DataArray, bg_value: float + ) -> tuple[DataArray, DataArray]: + # TODO fully test and simplify this method + original_coords = sparse_values.coords n_channels = sparse_values.sizes["c"] sparse_values = sparse_values.transpose("i", "c").values - coordinates = coordinates.transpose("i", "d").values + coordinates = coordinates.transpose("i", "d").astype(int) + if coordinates.coords["d"].values.tolist() != ["x", "y"]: + raise ValueError(f"Unexpected coordinates={coordinates.coords['d'].values}") + coordinates = coordinates.values - coordinates_extent = coordinates.max(axis=0) - coordinates.min(axis=0) + 1 - coordinates_shifted = coordinates - coordinates.min(axis=0) + coordinates_min = coordinates.min(axis=0) + coordinates_extent = coordinates.max(axis=0) - coordinates_min + 1 + coordinates_shifted = coordinates - coordinates_min dtype = np.promote_types(sparse_values.dtype, np.dtype(type(bg_value)).type) values_grid = np.full( (coordinates_extent[0], coordinates_extent[1], n_channels), fill_value=bg_value, dtype=dtype ) + is_foreground = np.zeros((coordinates_extent[0], coordinates_extent[1]), dtype=bool) for i_channel in range(n_channels): values_grid[tuple(coordinates_shifted.T) + (i_channel,)] = sparse_values[:, i_channel] + is_foreground[tuple(coordinates_shifted.T)] = True - # TODO coordinates might come in the wrong order FIXME - return DataArray(values_grid, dims=("y", "x", "c")) + coords = { + "x": np.arange(coordinates_min[0], coordinates_min[0] + coordinates_extent[0]), + "y": np.arange(coordinates_min[1], coordinates_min[1] + coordinates_extent[1]), + } + coords_c = {"c": original_coords["c"]} if "c" in original_coords else {} + return ( + DataArray(values_grid, dims=("x", "y", "c"), coords=coords | coords_c).transpose("y", "x", "c"), + DataArray(is_foreground, dims=("x", "y"), coords=coords).transpose("y", "x"), + ) @classmethod def dense_to_sparse(cls, grid_values: DataArray, bg_value: float | None) -> tuple[DataArray, DataArray]: @@ -69,11 +80,3 @@ def dense_to_sparse(cls, grid_values: DataArray, bg_value: float | None) -> tupl return DataArray(sparse_values, dims=("i", "c")), DataArray( coordinates, dims=("i", "d"), coords={"d": ["x", "y"]} ) - - @classmethod - def dense_to_sparse_coords(cls, grid_values: DataArray, coords: DataArray, is_shift_subtracted: bool) -> DataArray: - if not is_shift_subtracted: - coords = coords - coords.min(dim="i") - grid_values = grid_values.transpose("y", "x", "c").values - coords = coords.transpose("i", "d").values - return DataArray(grid_values[tuple(coords.T)], dims=["i", "c"]) diff --git a/src/depiction/persistence/types.py b/src/depiction/persistence/types.py index c60dbee..2bb8cb8 100644 --- a/src/depiction/persistence/types.py +++ b/src/depiction/persistence/types.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Self, Protocol import numpy as np +from xarray import DataArray if TYPE_CHECKING: from pathlib import Path @@ -134,6 +135,8 @@ def imzml_mode(self) -> ImzmlModeEnum: """Mode of the .imzML file (continuous or processed).""" raise NotImplementedError + # TODO: coordinates = DataArray(read_peaks.coordinates_2d, dims=["i", "d"], coords={"d": ["x", "y"]}) + @property def coordinates(self) -> NDArray[int]: """Spatial coordinates of the spectra in the .imzML file. @@ -147,6 +150,11 @@ def coordinates_2d(self) -> NDArray[int]: # TODO double check convention and update docstring accordingly return self.coordinates[:, :2] + @property + def coordinates_array_2d(self) -> DataArray: + # TODO this should replace the old coordinates_2d later + return DataArray(self.coordinates_2d.astype(int), dims=("i", "d"), coords={"d": ["x", "y"]}) + @property def compact_metadata(self) -> dict[str, int | str | list[float]]: """Compact representation of general metadata about the .imzML file, useful when comparing a large diff --git a/src/depiction/tools/generate_ion_image.py b/src/depiction/tools/generate_ion_image.py index f02124c..15d0d41 100644 --- a/src/depiction/tools/generate_ion_image.py +++ b/src/depiction/tools/generate_ion_image.py @@ -49,7 +49,7 @@ def generate_ion_images_for_file( .set_xindex(["y", "x"]) .unstack("i") ) - return MultiChannelImage(data) + return MultiChannelImage.from_spatial(data, bg_value=np.nan) def _generate_channel_values( self, input_file: ImzmlReadFile, mz_values: Sequence[float], tol: float | Sequence[float] @@ -63,7 +63,7 @@ def _generate_channel_values( bind_args=dict(mz_values=mz_values, tol_values=tol), reduce_fn=lambda chunks: np.concatenate(chunks, axis=0), ) - return DataArray(array, dims=("i", "c"), attrs={"bg_value": np.nan}) + return DataArray(array, dims=("i", "c")) def generate_range_images_for_file( self, @@ -84,12 +84,10 @@ def generate_range_images_for_file( bind_args=dict(mz_ranges=mz_ranges), reduce_fn=lambda chunks: np.concatenate(chunks, axis=0), ) - return MultiChannelImage.from_sparse( + return MultiChannelImage.from_flat( values=channel_values, - coordinates=input_file.coordinates_2d, + coordinates=input_file.coordinates_array_2d, channel_names=channel_names, - # TODO clarfiy (see above) - bg_value=np.nan, ) @classmethod diff --git a/src/depiction/tools/simulate/generate_label_image.py b/src/depiction/tools/simulate/generate_label_image.py index 22d3a7f..64b01df 100644 --- a/src/depiction/tools/simulate/generate_label_image.py +++ b/src/depiction/tools/simulate/generate_label_image.py @@ -1,8 +1,10 @@ from collections.abc import Sequence from pathlib import Path -import polars as pl + import cyclopts import numpy as np +import polars as pl +import xarray from xarray import DataArray from depiction.image.multi_channel_image import MultiChannelImage @@ -73,8 +75,8 @@ def f(x, y): def render(self) -> MultiChannelImage: blended = np.sum(self._layers, axis=0) data = DataArray(blended, dims=("y", "x", "c"), coords={"c": [f"synthetic_{i}" for i in range(self._n_labels)]}) - data.attrs["bg_value"] = 0.0 - return MultiChannelImage(data) + is_foreground = xarray.DataArray(np.ones((self._image_height, self._image_width), dtype=bool), dims=("y", "x")) + return MultiChannelImage(data, is_foreground=is_foreground) app = cyclopts.App() diff --git a/src/depiction_targeted_preproc/workflow/vis/images_norm.py b/src/depiction_targeted_preproc/workflow/vis/images_norm.py index bf50ca2..b31550c 100644 --- a/src/depiction_targeted_preproc/workflow/vis/images_norm.py +++ b/src/depiction_targeted_preproc/workflow/vis/images_norm.py @@ -1,23 +1,19 @@ -from typing import Annotated +from pathlib import Path -import typer -import xarray +import cyclopts +from depiction.image import MultiChannelImage from depiction.image.image_normalization import ImageNormalization, ImageNormalizationVariant +app = cyclopts.App() -def vis_images_norm( - input_hdf5_path: Annotated[str, typer.Option()], - output_hdf5_path: Annotated[str, typer.Option()], -) -> None: - image_orig = xarray.open_dataset(input_hdf5_path).to_array("var").squeeze("var") - image_norm = ImageNormalization().normalize_xarray(image_orig, variant=ImageNormalizationVariant.VEC_NORM) - image_norm.to_netcdf(output_hdf5_path) - -def main(): - typer.run(vis_images_norm) +@app.default +def vis_images_norm(input_hdf5_path: Path, output_hdf5_path: Path) -> None: + image_orig = MultiChannelImage.read_hdf5(input_hdf5_path) + image_norm = ImageNormalization().normalize_image(image_orig, variant=ImageNormalizationVariant.VEC_NORM) + image_norm.write_hdf5(output_hdf5_path) if __name__ == "__main__": - main() + app() diff --git a/src/depiction_targeted_preproc/workflow/vis/test_mass_shifts.py b/src/depiction_targeted_preproc/workflow/vis/test_mass_shifts.py index 1f3879b..5e8502f 100644 --- a/src/depiction_targeted_preproc/workflow/vis/test_mass_shifts.py +++ b/src/depiction_targeted_preproc/workflow/vis/test_mass_shifts.py @@ -6,9 +6,11 @@ import typer import xarray as xr import yaml -from depiction.tools.calibrate import get_calibration_instance, CalibrationConfig from typer import Option +from depiction.image import MultiChannelImage +from depiction.tools.calibrate import get_calibration_instance, CalibrationConfig + def vis_test_mass_shifts( calib_hdf5_path: Annotated[Path, Option()], @@ -17,7 +19,7 @@ def vis_test_mass_shifts( output_hdf5_path: Annotated[Path, Option()], ) -> None: # load inputs - model_coefs = xr.open_dataarray(calib_hdf5_path, group="model_coefs") + model_coefs = MultiChannelImage.read_hdf5(calib_hdf5_path, group="model_coefs") config = CalibrationConfig.model_validate(yaml.safe_load(config_path.read_text())) mass_list = pl.read_csv(mass_list_path) @@ -38,16 +40,12 @@ def compute_shifts(coef): shifts = xr.apply_ufunc( compute_shifts, - model_coefs, + model_coefs.data_spatial, input_core_dims=[["c"]], output_core_dims=[["m"]], vectorize=True, ).rename({"m": "c"}) - shifts = shifts.assign_coords(c=test_masses) - - shifts_2d = shifts.set_xindex(["x", "y"]).unstack("i") - shifts_2d.attrs["bg_value"] = np.nan - + shifts_2d = shifts.assign_coords(c=test_masses) # save the result shifts_2d.to_netcdf(output_hdf5_path) diff --git a/tests/unit/clustering/test_remap_clusters.py b/tests/unit/clustering/test_remap_clusters.py index 17d11cc..dca368c 100644 --- a/tests/unit/clustering/test_remap_clusters.py +++ b/tests/unit/clustering/test_remap_clusters.py @@ -9,8 +9,8 @@ @pytest.fixture() def label_image() -> MultiChannelImage: values = np.array([[0, 1, 0], [0, 0, 0], [0, 2, 2]]).reshape(3, 3, 1) - data = xarray.DataArray(values, dims=("y", "x", "c"), coords={"c": ["cluster"]}, attrs={"bg_value": np.nan}) - return MultiChannelImage(data) + data = xarray.DataArray(values, dims=("y", "x", "c"), coords={"c": ["cluster"]}) + return MultiChannelImage(data, is_foreground=xarray.ones_like(data.isel(c=0), dtype=bool)) def test_get_centroids(): @@ -30,5 +30,4 @@ def test_remap_cluster_labels(label_image): remapped = remap_cluster_labels(image=label_image, mapping={0: 1, 1: 0, 2: 2}) expected = np.array([[1, 0, 1], [1, 1, 1], [1, 2, 2]]).reshape(3, 3, 1) assert np.allclose(remapped.data_spatial.values, expected) - assert np.isnan(remapped.data_spatial.attrs["bg_value"]) assert remapped.data_spatial.coords["c"].values == ["cluster"] diff --git a/tests/unit/image/test_feature_selection.py b/tests/unit/image/test_feature_selection.py index ab463ac..27c7859 100644 --- a/tests/unit/image/test_feature_selection.py +++ b/tests/unit/image/test_feature_selection.py @@ -1,5 +1,5 @@ -import numpy as np import pytest +import xarray from xarray import DataArray from depiction.image.feature_selection import FeatureSelectionIQR, select_features, FeatureSelectionCV @@ -8,13 +8,14 @@ @pytest.fixture() def image() -> MultiChannelImage: + data = DataArray( + [[[1.0, 2, 0.0], [1.0, 5, 0.5], [1.0, 10, 0.0], [1.0, 20, 0.5], [1.0, 30, 0.0], [1.0, 40, 0.5]]], + dims=("y", "x", "c"), + coords={"c": ["channel1", "channel2", "channel3"]}, + ) return MultiChannelImage( - DataArray( - [[[1.0, 2, 0.0], [1.0, 5, 0.5], [1.0, 10, 0.0], [1.0, 20, 0.5], [1.0, 30, 0.0], [1.0, 40, 0.5]]], - dims=("y", "x", "c"), - coords={"c": ["channel1", "channel2", "channel3"]}, - attrs={"bg_value": np.nan}, - ) + data=data, + is_foreground=xarray.ones_like(data.isel(c=0), dtype=bool).drop_vars("c"), ) diff --git a/tests/unit/image/test_horizontal_concat.py b/tests/unit/image/test_horizontal_concat.py index 7c8558a..5c1870c 100644 --- a/tests/unit/image/test_horizontal_concat.py +++ b/tests/unit/image/test_horizontal_concat.py @@ -3,6 +3,7 @@ import numpy as np import pytest import xarray as xr + from depiction.image.horizontal_concat import horizontal_concat from depiction.image.multi_channel_image import MultiChannelImage @@ -13,9 +14,8 @@ def sample_image(): np.random.rand(2, 3, 4), dims=["c", "y", "x"], coords={"c": ["red", "green"], "y": [0, 1, 2], "x": [0, 1, 2, 3]}, - attrs={"bg_value": 0.0}, ) - return MultiChannelImage(data) + return MultiChannelImage(data, is_foreground=xr.ones_like(data.isel(c=0), dtype=bool)) def test_horizontal_concat_success(sample_image): diff --git a/tests/unit/image/test_image_channel_stats.py b/tests/unit/image/test_image_channel_stats.py index 112b987..ca4cb8f 100644 --- a/tests/unit/image/test_image_channel_stats.py +++ b/tests/unit/image/test_image_channel_stats.py @@ -7,7 +7,7 @@ @pytest.fixture def mock_multi_channel_image(mocker): - return mocker.Mock(name="mock_image", n_channels=3, channel_names=["channel1", "channel2", "channel3"], bg_value=0) + return mocker.Mock(name="mock_image", n_channels=3, channel_names=["channel1", "channel2", "channel3"]) @pytest.fixture @@ -80,9 +80,11 @@ def test_std(mocker, image_channel_stats, mock_multi_channel_image): ) -def test_get_channel_values(image_channel_stats, mock_multi_channel_image): - mock_data = np.array([1, 2, 3, 0, 5]) +@pytest.mark.parametrize("bg_value", [0, np.nan]) +def test_get_channel_values(image_channel_stats, mock_multi_channel_image, bg_value: float): + mock_data = np.array([1, 2, 3, bg_value, 5]) mock_multi_channel_image.data_flat.isel.return_value.values = mock_data + mock_multi_channel_image.fg_mask_flat = np.array([True, True, True, False, True]) # Test without dropping missing values result = image_channel_stats._get_channel_values(0, drop_missing=False) @@ -93,20 +95,6 @@ def test_get_channel_values(image_channel_stats, mock_multi_channel_image): np.testing.assert_array_equal(result, np.array([1, 2, 3, 5])) -def test_get_channel_values_with_nan(image_channel_stats, mock_multi_channel_image): - mock_data = np.array([1, 2, 3, np.nan, 5]) - mock_multi_channel_image.data_flat.isel.return_value.values = mock_data - mock_multi_channel_image.bg_value = np.nan - - # Test without dropping missing values - result = image_channel_stats._get_channel_values(0, drop_missing=False) - np.testing.assert_array_equal(result, mock_data) - - # Test with dropping missing values (bg_value = nan) - result = image_channel_stats._get_channel_values(0, drop_missing=True) - np.testing.assert_array_equal(result, np.array([1, 2, 3, 5])) - - def test_five_number_summary_when_empty_channel(mocker, image_channel_stats, mock_multi_channel_image): mock_data = np.array([]) mocker.patch.object(image_channel_stats, "_get_channel_values", return_value=mock_data) diff --git a/tests/unit/image/test_image_normalization.py b/tests/unit/image/test_image_normalization.py index ea2e56a..be0d69f 100644 --- a/tests/unit/image/test_image_normalization.py +++ b/tests/unit/image/test_image_normalization.py @@ -14,103 +14,118 @@ def image_normalizer(): @pytest.fixture def single_image(): return xr.DataArray( - data=[[[2, 0], [0, 2]], [[1, 1], [4, 1]], [[0, 0], [0, 0]]], dims=["y", "x", "c"], attrs={"bg_value": 0} + data=[[[2, 0], [0, 2]], [[1, 1], [4, 1]], [[0, 0], [0, 0]]], + dims=["y", "x", "c"], + coords={"c": ["A", "B"]}, ) @pytest.fixture def multiple_images(): - return xr.DataArray(data=[[[[2, 0]]], [[[0, 3]]]], dims=["whatever", "y", "x", "c"], attrs={"bg_value": 0}) + return xr.DataArray(data=[[[[2, 0]]], [[[0, 3]]]], dims=["whatever", "y", "x", "c"], coords={"c": ["A", "B"]}) + + +def test_normalize_image(image_normalizer, single_image): + multi_channel_image = MultiChannelImage( + single_image, is_foreground=xr.ones_like(single_image.isel(c=0), dtype=bool) + ) + normalized_image = image_normalizer.normalize_image(multi_channel_image, variant=ImageNormalizationVariant.VEC_NORM) + assert isinstance(normalized_image, MultiChannelImage) + xr.testing.assert_allclose( + normalized_image.data_spatial, + image_normalizer._normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM), + ) + + +def test_normalize_image_with_background(image_normalizer, single_image): + is_foreground = xr.ones_like(single_image.isel(c=0), dtype=bool).drop_vars("c") + is_foreground[0, 0] = False + is_foreground[1, 0] = False + multi_channel_image = MultiChannelImage(single_image, is_foreground=is_foreground) + normalized_image = image_normalizer.normalize_image(multi_channel_image, variant=ImageNormalizationVariant.VEC_NORM) + xr.testing.assert_equal(normalized_image.fg_mask, is_foreground) + xr.testing.assert_allclose( + normalized_image.data_spatial, + image_normalizer._normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM), + ) def test_normalize_xarray_single_vec_norm(image_normalizer, single_image): - norm_vec = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM) + norm_vec = image_normalizer._normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM) expected = xr.DataArray( data=[[[1, 0], [0, 1]], [[0.707107, 0.707107], [0.970143, 0.242536]], [[0, 0], [0, 0]]], dims=["y", "x", "c"], - attrs={"bg_value": 0}, + coords={"c": ["A", "B"]}, ) xr.testing.assert_allclose(expected, norm_vec) - assert norm_vec.attrs["bg_value"] == single_image.attrs["bg_value"] +# TODO revisit this test +@pytest.mark.skip(reason="Reconsider") def test_normalize_xarray_single_vec_norm_with_nans(image_normalizer): image_with_nans = xr.DataArray( data=[[[2, np.nan], [0, 2]], [[1, 1], [4, np.nan]], [[np.nan, 0], [0, 0]]], dims=["y", "x", "c"], - attrs={"bg_value": np.nan}, + coords={"c": ["A", "B"]}, ) - norm_vec = image_normalizer.normalize_xarray(image_with_nans, variant=ImageNormalizationVariant.VEC_NORM) + norm_vec = image_normalizer._normalize_xarray(image_with_nans, variant=ImageNormalizationVariant.VEC_NORM) expected = xr.DataArray( data=[[[1, np.nan], [0, 1]], [[0.707107, 0.707107], [1, np.nan]], [[np.nan, np.nan], [np.nan, np.nan]]], dims=["y", "x", "c"], - attrs={"bg_value": np.nan}, + coords={"c": ["A", "B"]}, ) xr.testing.assert_allclose(expected, norm_vec) - assert np.isnan(norm_vec.attrs["bg_value"]) def test_normalize_xarray_single_std(image_normalizer, single_image): - norm_std = image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.STD) + norm_std = image_normalizer._normalize_xarray(single_image, variant=ImageNormalizationVariant.STD) expected = xr.DataArray( data=[[[1.0, -1.0], [-1.0, 1.0]], [[0.0, 0.0], [1.0, -1.0]], [[0.0, 0.0], [0.0, 0.0]]], dims=["y", "x", "c"], - attrs={"bg_value": 0}, + coords={"c": ["A", "B"]}, ) xr.testing.assert_allclose(norm_std, expected, rtol=1e-5) - assert norm_std.attrs["bg_value"] == single_image.attrs["bg_value"] -def test_normalize_xarray_multiple_vec_norm(image_normalizer, multiple_images): - norm_vec = image_normalizer.normalize_xarray(multiple_images, variant=ImageNormalizationVariant.VEC_NORM) - expected = xr.DataArray( - data=[[[[1, 0]]], [[[0, 1]]]], - dims=["whatever", "y", "x", "c"], - coords={"whatever": [0, 1]}, - attrs={"bg_value": 0}, - ) - xr.testing.assert_allclose(expected, norm_vec) - assert norm_vec.attrs["bg_value"] == multiple_images.attrs["bg_value"] - - -def test_normalize_xarray_multiple_std(image_normalizer, multiple_images): - norm_std = image_normalizer.normalize_xarray(multiple_images, variant=ImageNormalizationVariant.STD) - expected = xr.DataArray( - data=[[[[1, -1]]], [[[-1, 1]]]], - dims=["whatever", "y", "x", "c"], - coords={"whatever": [0, 1]}, - attrs={"bg_value": 0}, - ) - xr.testing.assert_allclose(expected, norm_std) - assert norm_std.attrs["bg_value"] == multiple_images.attrs["bg_value"] - - -def test_normalize_image(image_normalizer, single_image): - multi_channel_image = MultiChannelImage(single_image) - normalized_image = image_normalizer.normalize_image(multi_channel_image, variant=ImageNormalizationVariant.VEC_NORM) - assert isinstance(normalized_image, MultiChannelImage) - xr.testing.assert_allclose( - normalized_image.data_spatial, - image_normalizer.normalize_xarray(single_image, variant=ImageNormalizationVariant.VEC_NORM), - ) - assert normalized_image.data_spatial.attrs["bg_value"] == single_image.attrs["bg_value"] +# def test_normalize_xarray_multiple_vec_norm(image_normalizer, multiple_images): +# norm_vec = image_normalizer._normalize_xarray(multiple_images, variant=ImageNormalizationVariant.VEC_NORM) +# expected = xr.DataArray( +# data=[[[[1, 0]]], [[[0, 1]]]], +# dims=["whatever", "y", "x", "c"], +# coords={"whatever": [0, 1]}, +# attrs={"bg_value": 0}, +# ) +# xr.testing.assert_allclose(expected, norm_vec) +# assert norm_vec.attrs["bg_value"] == multiple_images.attrs["bg_value"] +# +# +# def test_normalize_xarray_multiple_std(image_normalizer, multiple_images): +# norm_std = image_normalizer._normalize_xarray(multiple_images, variant=ImageNormalizationVariant.STD) +# expected = xr.DataArray( +# data=[[[[1, -1]]], [[[-1, 1]]]], +# dims=["whatever", "y", "x", "c"], +# coords={"whatever": [0, 1]}, +# attrs={"bg_value": 0}, +# ) +# xr.testing.assert_allclose(expected, norm_std) +# assert norm_std.attrs["bg_value"] == multiple_images.attrs["bg_value"] def test_missing_dimensions(image_normalizer): invalid_image = xr.DataArray(data=[[2, 0], [0, 2]], dims=["y", "x"]) with pytest.raises(ValueError, match="Missing required dimensions: {'c'}"): - image_normalizer.normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM) + image_normalizer._normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM) def test_multiple_index_dimensions(image_normalizer): invalid_image = xr.DataArray(data=[[[[[2, 0]]], [[[0, 3]]]]], dims=["dim1", "dim2", "y", "x", "c"]) with pytest.raises(NotImplementedError, match="Multiple index columns are not supported yet."): - image_normalizer.normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM) + image_normalizer._normalize_xarray(invalid_image, variant=ImageNormalizationVariant.VEC_NORM) def test_unknown_variant(image_normalizer, single_image): with pytest.raises(NotImplementedError, match="Unknown variant: unknown"): - image_normalizer.normalize_xarray(single_image, variant="unknown") + image_normalizer._normalize_xarray(single_image, variant="unknown") if __name__ == "__main__": diff --git a/tests/unit/image/test_multi_channel_image.py b/tests/unit/image/test_multi_channel_image.py index 6ed8dcb..8f8cfce 100644 --- a/tests/unit/image/test_multi_channel_image.py +++ b/tests/unit/image/test_multi_channel_image.py @@ -23,23 +23,52 @@ def mock_data(mock_coords) -> DataArray: [[[2.0, 5], [4, 5]], [[6, 5], [8, 5]], [[10, 5], [12, 5]]], dims=("y", "x", "c"), coords=mock_coords, - attrs={"bg_value": 0}, + ) + + +@pytest.fixture +def mock_data_sparse(mock_coords) -> DataArray: + return DataArray( + [[[0, 0], [0, 0]], [[6, 5], [8, 5]], [[0, 0], [0, 0]]], + dims=("y", "x", "c"), + coords=mock_coords, ) @pytest.fixture def mock_image(mock_data) -> MultiChannelImage: """Dense mock image without any missing values.""" - return MultiChannelImage(data=mock_data) + return MultiChannelImage( + data=mock_data, is_foreground=DataArray([[True, True], [True, True], [True, True]], dims=("y", "x")) + ) -def test_from_numpy_sparse() -> None: - values = np.array([[1, 2, 3], [4, 5, 6]]) - coordinates = np.array([[0, 0], [1, 1]]) - image = MultiChannelImage.from_sparse(values=values, coordinates=coordinates, channel_names=["A", "B", "C"]) - assert image.channel_names == ["A", "B", "C"] - values = image.data_spatial.sel(c="B") - xarray.testing.assert_equal(DataArray([[2, 0], [0, 5]], dims=("y", "x"), coords={"c": "B"}, name="values"), values) +@pytest.fixture +def mock_image_sparse(mock_data_sparse) -> MultiChannelImage: + """Sparse mock image.""" + return MultiChannelImage( + data=mock_data_sparse, is_foreground=DataArray([[False, False], [True, True], [False, False]], dims=("y", "x")) + ) + + +@pytest.mark.parametrize( + ["values", "channel_names", "expected_channel_names"], + [ + (DataArray([[1, 2, 3], [4, 5, 6]], dims=("i", "c"), coords={"c": ["A", "B", "C"]}), False, ["A", "B", "C"]), + (DataArray([[1, 2, 3], [4, 5, 6]], dims=("i", "c")), ["A", "B", "C"], ["A", "B", "C"]), + (DataArray([[1, 2, 3], [4, 5, 6]], dims=("i", "c")), True, ["0", "1", "2"]), + ], +) +def test_from_flat(values, channel_names, expected_channel_names) -> None: + coords = DataArray([[0, 0], [1, 2]], dims=("i", "d"), coords={"d": ["x", "y"]}) + image = MultiChannelImage.from_flat(values, coords, channel_names) + assert image.channel_names == expected_channel_names + np.testing.assert_array_equal(image.data_spatial[0, 0, :], [1, 2, 3]) + np.testing.assert_array_equal(image.data_spatial[2, 1, :], [4, 5, 6]) + expected_fg_mask = DataArray( + [[True, False], [False, False], [False, True]], dims=("y", "x"), coords={"y": [0, 1, 2], "x": [0, 1]} + ) + xarray.testing.assert_equal(image.fg_mask, expected_fg_mask) def test_n_channels(mock_image: MultiChannelImage) -> None: @@ -51,7 +80,7 @@ def test_n_nonzero(mock_image: MultiChannelImage) -> None: def test_n_nonzero_when_sparse(mock_image: MultiChannelImage) -> None: - mock_image.data_spatial[1, 0, :] = 0 + mock_image._is_foreground[1, 0] = False assert mock_image.n_nonzero == 5 @@ -59,23 +88,38 @@ def test_dtype(mock_image: MultiChannelImage) -> None: assert mock_image.dtype == float -def test_bg_value(mock_image: MultiChannelImage) -> None: - assert mock_image.bg_value == 0.0 +def test_is_foreground_label(mock_image: MultiChannelImage) -> None: + assert mock_image.is_foreground_label == "is_foreground" + + +def test_bg_mask(mock_image: MultiChannelImage) -> None: + # TODO more interesting example + expected_bg_mask = DataArray([[False, False], [False, False], [False, False]], dims=("y", "x")) + xarray.testing.assert_equal(expected_bg_mask, mock_image.bg_mask) + + +def test_bg_mask_flat(mock_image: MultiChannelImage) -> None: + # TODO more interesting example + expected_bg_mask_flat = DataArray( + [False, False, False, False, False, False], + dims="i", + coords={"i": pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], names=("y", "x"))}, + ) + xarray.testing.assert_equal(expected_bg_mask_flat, mock_image.bg_mask_flat) -def test_bg_mask_when_0(mock_data: DataArray, mock_image: MultiChannelImage) -> None: - mock_data[1, :, :] = 0 - bg_mask = mock_image.bg_mask - expected_bg_mask = DataArray([[False, False], [True, True], [False, False]], dims=("y", "x")) - xarray.testing.assert_equal(expected_bg_mask, bg_mask) +def test_fg_mask(mock_image_sparse) -> None: + expected_fg_mask = DataArray([[False, False], [True, True], [False, False]], dims=("y", "x")) + xarray.testing.assert_equal(expected_fg_mask, mock_image_sparse.fg_mask) -def test_bg_mask_when_nan(mock_image: MultiChannelImage) -> None: - mock_image.data_spatial[1, :, :] = np.nan - mock_image.data_spatial.attrs["bg_value"] = np.nan - bg_mask = mock_image.bg_mask - expected_bg_mask = DataArray([[False, False], [True, True], [False, False]], dims=("y", "x")) - xarray.testing.assert_equal(expected_bg_mask, bg_mask) +def test_fg_mask_flat(mock_image_sparse) -> None: + expected_fg_mask_flat = DataArray( + [False, False, True, True, False, False], + dims="i", + coords={"i": pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], names=("y", "x"))}, + ) + xarray.testing.assert_equal(expected_fg_mask_flat, mock_image_sparse.fg_mask_flat) def test_dimensions(mock_image: MultiChannelImage) -> None: @@ -83,21 +127,17 @@ def test_dimensions(mock_image: MultiChannelImage) -> None: def test_channel_names_when_set(mock_image: MultiChannelImage) -> None: + # TODO there should be some functionality to make it work for on-the-fly generated channel names assert mock_image.channel_names == ["Channel A", "Channel B"] -@pytest.mark.parametrize("mock_coords", [{}]) -def test_channel_names_when_not_set(mock_image: MultiChannelImage) -> None: - assert mock_image.channel_names == ["0", "1"] - - def test_data_spatial(mock_data: DataArray, mock_image: MultiChannelImage) -> None: xarray.testing.assert_identical(mock_data, mock_image.data_spatial) -def test_data_flat(mock_data: DataArray, mock_image: MultiChannelImage) -> None: - mock_data[0, 0, :] = 0 - mock_data[1, 0, 0] = np.nan +def test_data_flat(mock_image: MultiChannelImage) -> None: + mock_image._is_foreground[0, 0] = False + mock_image._is_foreground[1, 0] = False expected = DataArray( [[4.0, 8, 10, 12], [5, 5, 5, 5]], dims=("c", "i"), @@ -105,14 +145,27 @@ def test_data_flat(mock_data: DataArray, mock_image: MultiChannelImage) -> None: "c": ["Channel A", "Channel B"], "i": pd.MultiIndex.from_tuples([(0, 1), (1, 1), (2, 0), (2, 1)], names=("y", "x")), }, - attrs={"bg_value": 0}, ) - xarray.testing.assert_identical(expected, mock_image.data_flat) + xarray.testing.assert_identical(mock_image.data_flat, expected) + + +def test_data_flat_preserves_fg_nan(mock_image: MultiChannelImage) -> None: + mock_image._is_foreground[0, 0] = False + mock_image.data_spatial[1, 0, 0] = np.nan + expected = DataArray( + [[4.0, np.nan, 8, 10, 12], [5, 5, 5, 5, 5]], + dims=("c", "i"), + coords={ + "c": ["Channel A", "Channel B"], + "i": pd.MultiIndex.from_tuples([(0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], names=("y", "x")), + }, + ) + xarray.testing.assert_identical(mock_image.data_flat, expected) def test_coordinates_flat(mock_data: DataArray, mock_image: MultiChannelImage) -> None: - mock_data[0, 0, :] = 0 - mock_data[1, 0, 0] = np.nan + mock_image._is_foreground[0, 0] = False + mock_image._is_foreground[1, 0] = False expected = DataArray( [[0, 1, 2, 2], [1, 1, 0, 1]], dims=("d", "i"), @@ -124,6 +177,15 @@ def test_coordinates_flat(mock_data: DataArray, mock_image: MultiChannelImage) - xarray.testing.assert_identical(mock_image.coordinates_flat, expected) +def test_recompute_is_foreground(mocker: MockerFixture, mock_image: MultiChannelImage) -> None: + mock_compute = mocker.patch.object( + MultiChannelImage, "_compute_is_foreground", return_value=xarray.ones_like(mock_image.fg_mask) + ) + new_image = mock_image.recompute_is_foreground() + xarray.testing.assert_equal(new_image.fg_mask, mock_compute.return_value) + xarray.testing.assert_equal(new_image.data_spatial, mock_image.data_spatial) + + def test_retain_channels_when_both_none(mock_image: MultiChannelImage) -> None: with pytest.raises(ValueError): mock_image.retain_channels(None, None) @@ -163,21 +225,29 @@ def test_drop_channels_when_coords_and_not_allow_missing(mock_image: MultiChanne def test_write_hdf5(mocker: MockerFixture, mock_image: MultiChannelImage) -> None: mocker.patch("xarray.DataArray.to_netcdf") mock_image.write_hdf5(Path("test.h5")) - mock_image.data_spatial.to_netcdf.assert_called_once_with(Path("test.h5"), format="NETCDF4") + mock_image.data_spatial.to_netcdf.assert_called_once_with( + Path("test.h5"), engine="netcdf4", format="NETCDF4", group=None, mode="w" + ) def test_read_hdf5(mocker: MockerFixture, mock_data: DataArray) -> None: - mocker.patch("xarray.open_dataarray").return_value = mock_data + mock_is_foreground = xarray.ones_like(mock_data.isel(c=[0])).assign_coords(c=["is_foreground"]) + persisted_data = xarray.concat([mock_data, mock_is_foreground], dim="c") + mocker.patch("xarray.open_dataarray").return_value = persisted_data image = MultiChannelImage.read_hdf5(Path("test.h5")) xarray.open_dataarray.assert_called_once_with(Path("test.h5"), group=None) xarray.testing.assert_equal(image.data_spatial, mock_data) + xarray.testing.assert_equal(image.fg_mask, mock_is_foreground.isel(c=0).drop_vars("c")) def test_read_ome_tiff(mocker: MockerFixture, mock_data: DataArray) -> None: mock_read = mocker.patch.object(OmeTiff, "read", return_value=mock_data) + mock_foreground = xarray.ones_like(mock_data.isel(c=0), dtype=bool).drop_vars("c") + mocker.patch.object(MultiChannelImage, "_compute_is_foreground", return_value=mock_foreground) image = MultiChannelImage.read_ome_tiff(Path("test.ome.tiff")) xarray.testing.assert_equal(image.data_spatial, mock_data) mock_read.assert_called_once_with(Path("test.ome.tiff")) + xarray.testing.assert_equal(image.fg_mask, mock_foreground) def test_with_channel_names(mock_image: MultiChannelImage) -> None: @@ -201,9 +271,8 @@ def test_append_channels(mock_image: MultiChannelImage) -> None: data=np.arange(12).reshape(3, 2, 2), dims=("y", "x", "c"), coords={"c": ["Channel X", "Channel Y"]}, - attrs={"bg_value": 0}, ) - extra_image = MultiChannelImage(data=extra_image_data) + extra_image = MultiChannelImage(data=extra_image_data, is_foreground=mock_image.fg_mask) result = mock_image.append_channels(extra_image) assert result.channel_names == ["Channel A", "Channel B", "Channel X", "Channel Y"] assert result.retain_channels(coords=["Channel A", "Channel B"]).data_spatial.identical(mock_image.data_spatial) @@ -240,5 +309,56 @@ def test_repr(mocker: MockerFixture, mock_image: MultiChannelImage) -> None: assert repr(mock_image) == "MultiChannelImage(data=DataArray)" +@pytest.mark.parametrize("bg_value", [0, 1, np.nan]) +def test_compute_is_foreground(bg_value: float): + a = bg_value - 1 if np.isfinite(bg_value) else 1 + array = DataArray( + [ + [[bg_value, bg_value], [a, a]], + [[3, bg_value], [bg_value, bg_value]], + [[bg_value, bg_value], [bg_value, bg_value]], + ], + dims=("y", "x", "c"), + ) + mask = MultiChannelImage._compute_is_foreground(array, bg_value=bg_value) + xarray.testing.assert_equal(DataArray([[False, True], [True, False], [False, False]], dims=("y", "x")), mask) + + +@pytest.mark.parametrize( + "input_coordinates", + [ + np.array([[1, 2], [3, 4]]), + xarray.DataArray([[1, 2], [3, 4]], dims=("i", "d"), coords={"d": ["x", "y"]}), + xarray.DataArray([[2, 1], [4, 3]], dims=("i", "d"), coords={"d": ["y", "x"]}), + xarray.DataArray([[1, 3], [2, 4]], dims=("d", "i"), coords={"d": ["x", "y"]}), + ], +) +def test_validate_coordinates(input_coordinates): + result = MultiChannelImage._validate_coordinates(input_coordinates) + xarray.testing.assert_equal(result, xarray.DataArray([[1, 2], [3, 4]], dims=("i", "d"), coords={"d": ["x", "y"]})) + + +@pytest.mark.parametrize( + "input_coordinates", + [ + xarray.DataArray([[1, 2], [3, 4]], dims=("i", "d"), coords={"d": ["x", "z"]}), + ], +) +def test_validate_coordinates_when_invalid(input_coordinates): + with pytest.raises(ValueError): + MultiChannelImage._validate_coordinates(input_coordinates) + + +def test_extract_flat_coordinates(mock_image_sparse): + data_flat = xarray.DataArray( + [[6.0, 8], [5, 5]], dims=("c", "i"), coords={"i": pd.MultiIndex.from_arrays(([0, 1], [1, 1]), names=("x", "y"))} + ) + coords = MultiChannelImage._extract_flat_coordinates(data_flat) + xarray.testing.assert_equal( + coords, + xarray.DataArray([[0, 1], [1, 1]], dims=("i", "d"), coords={"d": ["x", "y"]}), + ) + + if __name__ == "__main__": pytest.main() diff --git a/tests/unit/image/test_multi_channel_image_concatenation.py b/tests/unit/image/test_multi_channel_image_concatenation.py index 65c68fa..7c39f95 100644 --- a/tests/unit/image/test_multi_channel_image_concatenation.py +++ b/tests/unit/image/test_multi_channel_image_concatenation.py @@ -16,14 +16,12 @@ def channel_names() -> list[str]: def _construct_single_image(data: np.ndarray, channel_names: list[str]) -> MultiChannelImage: - return MultiChannelImage( - data=DataArray( - data=data, - dims=("y", "x", "c"), - coords={"c": channel_names}, - attrs={"bg_value": np.nan}, - ) + data = DataArray( + data=data, + dims=("y", "x", "c"), + coords={"c": channel_names}, ) + return MultiChannelImage(data=data, is_foreground=xarray.ones_like(data.isel(c=0), dtype=bool).drop_vars("c")) @pytest.fixture() @@ -61,13 +59,12 @@ def test_get_combined_image( combined_image = concat_image.get_combined_image() assert combined_image.data_spatial.shape == (3, 7, 2) assert combined_image.channel_names == image_0.channel_names == image_1.channel_names - assert np.isnan(combined_image.bg_value) # image 0 assert combined_image.data_spatial[0, 0, 1] == 1.0 # image 1 assert combined_image.data_spatial[0, 3, 1] == 3.0 - # check nan (because of shape differences) - assert np.isnan(combined_image.data_spatial[2, 0, 0]) + # check is background (because of shape differences) + assert combined_image.bg_mask[2, 0] def test_get_combined_image_index(concat_image: MultiChannelImageConcatenation) -> None: @@ -101,13 +98,13 @@ def test_relabel_combined_image( concat_image: MultiChannelImageConcatenation, ) -> None: new_data = np.ones((3, 7, 4)) + data = xarray.DataArray( + data=new_data, + dims=("y", "x", "c"), + coords={"c": ["A", "B", "C", "D"]}, + ) new_combined_image = MultiChannelImage( - data=xarray.DataArray( - data=new_data, - dims=("y", "x", "c"), - coords={"c": ["A", "B", "C", "D"]}, - attrs={"bg_value": np.nan}, - ) + data=data, is_foreground=xarray.ones_like(data.isel(c=0), dtype=bool).drop_vars("c") ) relabeled_image = concat_image.relabel_combined_image(new_combined_image) assert relabeled_image.get_combined_image().channel_names == ["A", "B", "C", "D"] @@ -118,13 +115,13 @@ def test_relabel_combined_image_different_shape( concat_image: MultiChannelImageConcatenation, channel_names: list[str] ) -> None: new_data = np.ones((4, 8, len(channel_names))) * 5.0 + data = xarray.DataArray( + data=new_data, + dims=("y", "x", "c"), + coords={"c": channel_names}, + ) new_combined_image = MultiChannelImage( - data=xarray.DataArray( - data=new_data, - dims=("y", "x", "c"), - coords={"c": channel_names}, - attrs={"bg_value": np.nan}, - ) + data=data, is_foreground=xarray.ones_like(data.isel(c=0), dtype=bool).drop_vars("c") ) with pytest.raises(ValueError, match="The new image must have the same shape as the original combined image"): diff --git a/tests/unit/image/test_sparse_representation.py b/tests/unit/image/test_sparse_representation.py index 4286e71..d722fb9 100644 --- a/tests/unit/image/test_sparse_representation.py +++ b/tests/unit/image/test_sparse_representation.py @@ -1,46 +1,53 @@ -import unittest +from enum import Enum import numpy as np +import pytest import xarray.testing from xarray import DataArray from depiction.image.sparse_representation import SparseRepresentation -class TestSparseRepresentation(unittest.TestCase): - def setUp(self) -> None: - self.define_samples() +class Variant(Enum): + """The samples are all images, and commented for clarity in the following coordinate system + ^ +y + + -> +x + """ - def define_samples(self) -> None: - """The samples are all images, and commented for clarity in the following coordinate system - ^ +y - + -> +x - """ - # Simple example: + Simple = "1_simple" + MultiChannel = "2_multi" + Offset = "3_offset" + + +@pytest.fixture() +def sample(request): + if request.param == Variant.Simple: # +----+----+ # | 5 | NA | # +----+----+ # | 4 | 6 | # +----+----+ - self.dense_1_simple = DataArray([[[4.0], [6]], [[5], [np.nan]]], dims=["y", "x", "c"]) - self.sparse_1_simple = ( + return ( + DataArray([[[4.0], [6]], [[5], [np.nan]]], dims=["y", "x", "c"], coords={"x": [0, 1], "y": [0, 1]}), DataArray([[4.0], [5], [6]], dims=["i", "c"]), DataArray([[0, 0], [0, 1], [1, 0]], dims=["i", "d"], coords={"d": ["x", "y"]}), ) - - # Multi-channel example + elif request.param == Variant.MultiChannel: # +-------+--------+ # | 5, 15 | NA, NA | # +-------+--------+ # | 4, 14 | 6, 16 | # +-------+--------+ - self.dense_2_multi = DataArray([[[4, 14], [6, 16]], [[5.0, 15], [np.nan, np.nan]]], dims=["y", "x", "c"]) - self.sparse_2_multi = ( + return ( + DataArray( + [[[4, 14], [6, 16]], [[5.0, 15], [np.nan, np.nan]]], + dims=["y", "x", "c"], + coords={"x": [0, 1], "y": [0, 1]}, + ), DataArray([[4.0, 14], [5, 15], [6, 16]], dims=["i", "c"]), DataArray([[0, 0], [0, 1], [1, 0]], dims=["i", "d"], coords={"d": ["x", "y"]}), ) - - # Offset example + elif request.param == Variant.Offset: # +----+----+ # | NA | 5 | # +----+----+ @@ -48,90 +55,110 @@ def define_samples(self) -> None: # +----+----+ # | NA | NA | # +----+----+ - self.dense_3_offset = DataArray( - [[[np.nan], [5]], [[np.nan], [np.nan]], [[np.nan], [np.nan]]], dims=["y", "x", "c"] - ) - self.sparse_3_offset = ( + return ( + DataArray( + [[[np.nan], [5]], [[np.nan], [np.nan]], [[np.nan], [np.nan]]], + dims=["y", "x", "c"], + coords={"x": [0, 1], "y": [0, 1, 2]}, + ), DataArray([[5]], dims=["i", "c"]), DataArray([[1, 2]], dims=["i", "d"], coords={"d": ["x", "y"]}), ) - def test_dense_to_sparse_when_simple(self) -> None: - values, coords = SparseRepresentation.dense_to_sparse(grid_values=self.dense_1_simple, bg_value=np.nan) - xarray.testing.assert_equal(self.sparse_1_simple[1], coords) - xarray.testing.assert_equal(self.sparse_1_simple[0], values) - - def test_dense_to_sparse_when_multi_channel(self) -> None: - values, coords = SparseRepresentation.dense_to_sparse(grid_values=self.dense_2_multi, bg_value=np.nan) - xarray.testing.assert_equal(self.sparse_2_multi[1], coords) - xarray.testing.assert_equal(self.sparse_2_multi[0], values) - - # def test_dense_to_sparse_when_offset(self) -> None: - # values, coords = SparseRepresentation.dense_to_sparse(grid_values=self.dense_3_offset, bg_value=np.nan) - # xarray.testing.assert_equal(self.sparse_3_offset[1], coords) - # xarray.testing.assert_equal(self.sparse_3_offset[0], values) - - # def test_sparse_to_dense_when_real_bg(self): - # coordinates = DataArray([[0, 0], [1, 1], [1, 0]], dims=["i", "d"]) - # sparse_values = DataArray([[2, 3, 4]], dims=["c", "i"]) - # dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=0) - # expected_array = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) - # xarray.testing.assert_equal(expected_array, dense_values) - - # def test_sparse_to_dense_when_nan_bg(self): - # coordinates = DataArray([[0, 0], [1, 1], [1, 0]], dims=["i", "d"]) - # sparse_values = DataArray([[2, 3, 4]], dims=["c", "i"]) - # dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=np.nan) - # expected_array = DataArray([[[2], [np.nan]], [[4], [3]]], dims=["y", "x", "c"]) - # xarray.testing.assert_equal(expected_array, dense_values) - - # def test_sparse_to_dense_when_multi_channel(self): - # coordinates = DataArray([[0, 0], [1, 1], [1, 0]], dims=["i", "d"]) - # sparse_values = DataArray([[1, 2], [3, 4], [5, 6]], dims=["i", "c"]) - # dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=0) - # expected_array = DataArray([[[1, 2], [0, 0]], [[5, 6], [3, 4]]], dims=["y", "x", "c"]) - # xarray.testing.assert_equal(expected_array, dense_values) - - # def test_dense_to_sparse_when_real_bg(self): - # dense_values = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) - # sparse_values, coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=0) - # expected_sparse_values = DataArray([[2], [4], [3]], dims=["i", "c"]) - # expected_coordinates = DataArray([[0, 0], [1, 0], [1, 1]], dims=["i", "d"]) - # xarray.testing.assert_equal(expected_sparse_values, sparse_values) - # xarray.testing.assert_equal(expected_coordinates, coordinates) - - # def test_dense_to_sparse_when_nan_bg(self): - # dense_values = DataArray([[[2], [np.nan]], [[4], [3]]], dims=["y", "x", "c"]) - # sparse_values, coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=np.nan) - # expected_sparse_values = DataArray([[2], [4], [3]], dims=["i", "c"]) - # expected_coordinates = DataArray([[0, 0], [1, 0], [1, 1]], dims=["i", "d"]) - # xarray.testing.assert_equal(expected_sparse_values, sparse_values) - # xarray.testing.assert_equal(expected_coordinates, coordinates) - - # def test_dense_to_sparse_when_none_bg(self): - # dense_values = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) - # sparse_values, coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=None) - # expected_sparse_values = DataArray([[2], [0], [4], [3]], dims=["i", "c"]) - # expected_coordinates = DataArray([[0, 0], [0, 1], [1, 0], [1, 1]], dims=["i", "d"]) - # xarray.testing.assert_equal(expected_sparse_values, sparse_values) - # xarray.testing.assert_equal(expected_coordinates, coordinates) - - # def test_dense_to_sparse_coords(self): - # dense_values = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) - # coordinates = DataArray([[0, 0], [0, 1]]) - # # sparse_values = - - # def test_round_trip(self): - # # NOTE: This is not really a unit test (TODO keep it?) - # sparse_values = DataArray([[2, 3], [4, 5], [6, 7]], dims=["i", "c"]) - # coordinates = DataArray([[0, 0], [1, 0], [2, 1]], dims=["i", "d"]) - - # dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=0) - # new_sparse_values, new_coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=0) - - # xarray.testing.assert_equal(sparse_values, new_sparse_values) - # xarray.testing.assert_equal(coordinates, new_coordinates) - - -if __name__ == "__main__": - unittest.main() + +@pytest.fixture() +def spatial_array(sample) -> DataArray: + return sample[0] + + +@pytest.fixture() +def flat_data(sample) -> DataArray: + return sample[1] + + +@pytest.fixture() +def flat_coords(sample) -> DataArray: + return sample[2] + + +@pytest.mark.parametrize("sample", [Variant.Simple, Variant.MultiChannel], indirect=True) +def test_flat_to_spatial(flat_data, flat_coords, spatial_array): + result_values, result_is_fg = SparseRepresentation.flat_to_spatial(flat_data, flat_coords, bg_value=np.nan) + xarray.testing.assert_equal(result_values, spatial_array) + xarray.testing.assert_equal(result_is_fg, spatial_array.isel(c=0).notnull()) + + +@pytest.mark.parametrize("sample", [Variant.Offset], indirect=True) +def test_flat_to_spatial_when_offset(flat_data, flat_coords, spatial_array): + result_values, result_is_fg = SparseRepresentation.flat_to_spatial(flat_data, flat_coords, bg_value=np.nan) + xarray.testing.assert_equal( + result_values, xarray.DataArray([[[5]]], dims=["y", "x", "c"], coords={"x": [1], "y": [2]}) + ) + xarray.testing.assert_equal(result_is_fg, xarray.DataArray([[True]], dims=["y", "x"], coords={"x": [1], "y": [2]})) + + +# def test_dense_to_sparse_when_offset(self) -> None: +# values, coords = SparseRepresentation.dense_to_sparse(grid_values=self.dense_3_offset, bg_value=np.nan) +# xarray.testing.assert_equal(self.sparse_3_offset[1], coords) +# xarray.testing.assert_equal(self.sparse_3_offset[0], values) + +# def test_sparse_to_dense_when_real_bg(self): +# coordinates = DataArray([[0, 0], [1, 1], [1, 0]], dims=["i", "d"]) +# sparse_values = DataArray([[2, 3, 4]], dims=["c", "i"]) +# dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=0) +# expected_array = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) +# xarray.testing.assert_equal(expected_array, dense_values) + +# def test_sparse_to_dense_when_nan_bg(self): +# coordinates = DataArray([[0, 0], [1, 1], [1, 0]], dims=["i", "d"]) +# sparse_values = DataArray([[2, 3, 4]], dims=["c", "i"]) +# dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=np.nan) +# expected_array = DataArray([[[2], [np.nan]], [[4], [3]]], dims=["y", "x", "c"]) +# xarray.testing.assert_equal(expected_array, dense_values) + +# def test_sparse_to_dense_when_multi_channel(self): +# coordinates = DataArray([[0, 0], [1, 1], [1, 0]], dims=["i", "d"]) +# sparse_values = DataArray([[1, 2], [3, 4], [5, 6]], dims=["i", "c"]) +# dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=0) +# expected_array = DataArray([[[1, 2], [0, 0]], [[5, 6], [3, 4]]], dims=["y", "x", "c"]) +# xarray.testing.assert_equal(expected_array, dense_values) + +# def test_dense_to_sparse_when_real_bg(self): +# dense_values = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) +# sparse_values, coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=0) +# expected_sparse_values = DataArray([[2], [4], [3]], dims=["i", "c"]) +# expected_coordinates = DataArray([[0, 0], [1, 0], [1, 1]], dims=["i", "d"]) +# xarray.testing.assert_equal(expected_sparse_values, sparse_values) +# xarray.testing.assert_equal(expected_coordinates, coordinates) + +# def test_dense_to_sparse_when_nan_bg(self): +# dense_values = DataArray([[[2], [np.nan]], [[4], [3]]], dims=["y", "x", "c"]) +# sparse_values, coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=np.nan) +# expected_sparse_values = DataArray([[2], [4], [3]], dims=["i", "c"]) +# expected_coordinates = DataArray([[0, 0], [1, 0], [1, 1]], dims=["i", "d"]) +# xarray.testing.assert_equal(expected_sparse_values, sparse_values) +# xarray.testing.assert_equal(expected_coordinates, coordinates) + +# def test_dense_to_sparse_when_none_bg(self): +# dense_values = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) +# sparse_values, coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=None) +# expected_sparse_values = DataArray([[2], [0], [4], [3]], dims=["i", "c"]) +# expected_coordinates = DataArray([[0, 0], [0, 1], [1, 0], [1, 1]], dims=["i", "d"]) +# xarray.testing.assert_equal(expected_sparse_values, sparse_values) +# xarray.testing.assert_equal(expected_coordinates, coordinates) + +# def test_dense_to_sparse_coords(self): +# dense_values = DataArray([[[2], [0]], [[4], [3]]], dims=["y", "x", "c"]) +# coordinates = DataArray([[0, 0], [0, 1]]) +# # sparse_values = + +# def test_round_trip(self): +# # NOTE: This is not really a unit test (TODO keep it?) +# sparse_values = DataArray([[2, 3], [4, 5], [6, 7]], dims=["i", "c"]) +# coordinates = DataArray([[0, 0], [1, 0], [2, 1]], dims=["i", "d"]) + +# dense_values = SparseRepresentation.sparse_to_dense(sparse_values, coordinates, background_value=0) +# new_sparse_values, new_coordinates = SparseRepresentation.dense_to_sparse(dense_values, bg_value=0) + +# xarray.testing.assert_equal(sparse_values, new_sparse_values) +# xarray.testing.assert_equal(coordinates, new_coordinates) diff --git a/tests/unit/image/test_spatial_smoothing.py b/tests/unit/image/test_spatial_smoothing.py index 25071c3..58e4e20 100644 --- a/tests/unit/image/test_spatial_smoothing.py +++ b/tests/unit/image/test_spatial_smoothing.py @@ -5,7 +5,7 @@ from sparse import GCXS from xarray import DataArray -from depiction.image.spatial_smoothing import SpatialSmoothing +from depiction.image.smoothing.spatial_smoothing import SpatialSmoothing class TestSpatialSmoothing(unittest.TestCase): diff --git a/tests/unit/image/test_spatial_smoothing_sparse_aware.py b/tests/unit/image/test_spatial_smoothing_sparse_aware.py index fcbbece..b1ca1af 100644 --- a/tests/unit/image/test_spatial_smoothing_sparse_aware.py +++ b/tests/unit/image/test_spatial_smoothing_sparse_aware.py @@ -3,11 +3,14 @@ import hypothesis import numpy as np import pytest +import xarray.testing from hypothesis import given, strategies from sparse import GCXS from xarray import DataArray -from depiction.image.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware +from depiction.image import MultiChannelImage +from depiction.image.smoothing.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware +from depiction.image.xarray_helper import XarrayHelper @pytest.fixture(autouse=True) @@ -47,14 +50,28 @@ def _convert_array(arr: DataArray, variant: str) -> DataArray: return DataArray(values, dims=arr.dims, coords=arr.coords, attrs=arr.attrs, name=arr.name) -@pytest.mark.parametrize("variant", ["dense", "sparse"]) -def test_smooth_when_unchanged(mock_smooth, variant): - image = DataArray(np.concatenate([np.ones((5, 5, 1)), np.zeros((5, 5, 1))], axis=0), dims=("y", "x", "c")) - image = _convert_array(image, variant) - smoothed = mock_smooth.smooth(image, bg_value=0) - np.testing.assert_array_almost_equal(1.0, smoothed.values[:5, :, 0], decimal=8) - np.testing.assert_array_almost_equal(0.0, smoothed.values[-5:, :, 0], decimal=8) - assert smoothed.dims == ("y", "x", "c") +@pytest.fixture() +def dense_data(): + return DataArray( + np.concatenate([np.ones((5, 5, 1)), np.zeros((5, 5, 1))], axis=0), + dims=("y", "x", "c"), + coords={"c": ["channel"]}, + ) + + +@pytest.fixture(params=["dense", "sparse"]) +def image(request, dense_data): + # TODO the whole idea of the old test of using nan and zero is not so relevant anymore, maybe it can be simplified + data = _convert_array(dense_data, request.param) + is_fg = data.isel(c=0) != 0 + return MultiChannelImage(data=data, is_foreground=is_fg) + + +def test_smooth_image_when_unchanged(mock_smooth, image): + smoothed = mock_smooth.smooth_image(image) + xarray.testing.assert_equal(smoothed.fg_mask, image.fg_mask) + np.testing.assert_array_almost_equal(1.0, smoothed.data_spatial[:5, :, 0], decimal=8) + np.testing.assert_array_almost_equal(0.0, smoothed.data_spatial[-5:, :, 0], decimal=8) @pytest.mark.parametrize("variant", ["dense", "sparse"]) @@ -63,44 +80,55 @@ def test_smooth_when_unchanged(mock_smooth, variant): deadline=timedelta(seconds=1), suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture] ) def test_smooth_preserves_values(mock_smooth, fill_value, variant): - dense_image = DataArray(np.full((2, 5, 1), fill_value=fill_value), dims=("y", "x", "c")) - image = _convert_array(dense_image, variant) - smoothed = mock_smooth.smooth(image, bg_value=0) - np.testing.assert_allclose(dense_image.values, smoothed.values, rtol=1e-8) - - -@pytest.mark.parametrize("variant", ["dense", "sparse"]) -def test_smooth_when_bg_nan(mock_smooth, variant): - dense_image = DataArray( - np.concatenate([np.full((5, 5, 1), np.nan), np.ones((5, 5, 1)), np.zeros((5, 5, 1))]), - dims=("y", "x", "c"), - ) - image = _convert_array(dense_image, variant) - smoothed = mock_smooth.smooth(image, bg_value=np.nan) - np.testing.assert_array_equal(np.nan, smoothed.values[:5, 0]) - np.testing.assert_array_almost_equal(1.0, smoothed.values[5:8, :, 0], decimal=8) - np.testing.assert_array_almost_equal(0.0, smoothed.values[-3:, :, 0], decimal=8) - smoothed_part = smoothed.values[8:12, :, 0] - for i_value, value in enumerate([0.94551132, 0.70130997, 0.29869003, 0.05448868]): - np.testing.assert_array_almost_equal(value, smoothed_part[i_value, :], decimal=8) + dense_data = DataArray(np.full((2, 5, 1), fill_value=fill_value), dims=("y", "x", "c"), coords={"c": ["channel"]}) + is_foreground = DataArray(np.full((2, 5), fill_value=True), dims=("y", "x")) + image = MultiChannelImage(data=_convert_array(dense_data, variant), is_foreground=is_foreground) + smoothed = mock_smooth.smooth_image(image) + xarray.testing.assert_equal(smoothed.fg_mask, image.fg_mask) + xarray.testing.assert_allclose(smoothed.data_spatial, dense_data, rtol=1e-8) + + +# @pytest.mark.parametrize("variant", ["dense", "sparse"]) +# def test_smooth_when_bg_nan(mock_smooth, variant): +# dense_image = DataArray( +# np.concatenate([np.full((5, 5, 1), np.nan), np.ones((5, 5, 1)), np.zeros((5, 5, 1))]), +# dims=("y", "x", "c"), +# ) +# image = _convert_array(dense_image, variant) +# smoothed = mock_smooth.smooth(image, bg_value=np.nan) +# np.testing.assert_array_equal(np.nan, smoothed.values[:5, 0]) +# np.testing.assert_array_almost_equal(1.0, smoothed.values[5:8, :, 0], decimal=8) +# np.testing.assert_array_almost_equal(0.0, smoothed.values[-3:, :, 0], decimal=8) +# smoothed_part = smoothed.values[8:12, :, 0] +# for i_value, value in enumerate([0.94551132, 0.70130997, 0.29869003, 0.05448868]): +# np.testing.assert_array_almost_equal(value, smoothed_part[i_value, :], decimal=8) @pytest.mark.parametrize("variant", ["dense", "sparse"]) def test_smooth_casts_when_integer(mock_smooth, variant): - image_dense = DataArray(np.full((2, 5, 1), fill_value=10, dtype=int), dims=("y", "x", "c")) - image = _convert_array(image_dense, variant) - res_values = mock_smooth.smooth(image=image) + data_full = DataArray(np.full((2, 5, 1), fill_value=10, dtype=int), dims=("y", "x", "c"), coords={"c": ["channel"]}) + is_foreground = DataArray(np.full((2, 5), fill_value=True), dims=("y", "x")) + image = MultiChannelImage(data=_convert_array(data_full, variant), is_foreground=is_foreground) + res_values = mock_smooth.smooth_image(image=image) assert res_values.dtype == np.float64 - np.testing.assert_allclose(image_dense.values, res_values.values, rtol=1e-8) + if variant == "dense": + xarray.testing.assert_allclose(res_values.data_spatial, image.data_spatial) + else: + np.testing.assert_allclose( + XarrayHelper.ensure_dense(res_values.data_spatial).values, + XarrayHelper.ensure_dense(image.data_spatial).values, + ) @pytest.mark.parametrize("mock_use_interpolation", [True]) def test_smooth_dense_when_use_interpolation(mock_smooth): - mock_image = np.full((9, 5), fill_value=3.0) - mock_image[4, 2] = np.nan - smoothed = mock_smooth._smooth_dense(image_2d=mock_image, bg_value=np.nan) - assert np.sum(np.isnan(smoothed)) == 0 - np.testing.assert_almost_equal(smoothed[4, 2], 3, decimal=6) + mock_data = DataArray(np.full((9, 5, 1), fill_value=3.0), dims=("y", "x", "c"), coords={"c": ["channel"]}) + mock_data[4, 2, 0] = np.nan + is_foreground = ~mock_data.isel(c=0).isnull() + mock_image = MultiChannelImage(data=mock_data, is_foreground=is_foreground) + smoothed = mock_smooth.smooth_image(mock_image) + assert np.sum(smoothed.bg_mask) == 0 + np.testing.assert_almost_equal(smoothed.data_spatial[4, 2], 3, decimal=6) @pytest.mark.parametrize("mock_kernel_size, mock_kernel_std", [(3, 1.0)]) diff --git a/tests/unit/tools/simulate/test_generate_label_image.py b/tests/unit/tools/simulate/test_generate_label_image.py index d9ee696..52c0945 100644 --- a/tests/unit/tools/simulate/test_generate_label_image.py +++ b/tests/unit/tools/simulate/test_generate_label_image.py @@ -57,6 +57,8 @@ def test_render(generate) -> None: ] ) ] + generate._image_height = 2 + generate._image_width = 2 image = generate.render() assert image.n_channels == 3 assert image.dimensions == (2, 2) diff --git a/tests/unit/tools/test_generate_ion_image.py b/tests/unit/tools/test_generate_ion_image.py index 2b2d179..22ac3c6 100644 --- a/tests/unit/tools/test_generate_ion_image.py +++ b/tests/unit/tools/test_generate_ion_image.py @@ -21,9 +21,7 @@ def mock_generate(mock_parallel_config: MagicMock) -> GenerateIonImage: def test_generate_ion_images_for_file(mocker, mock_generate: GenerateIonImage) -> None: mock_generate_channel_values = mocker.patch.object(GenerateIonImage, "_generate_channel_values") - mock_generate_channel_values.return_value = DataArray( - [[1, 2], [3, 4], [5, 6]], dims=("i", "c"), attrs={"bg_value": np.nan} - ) + mock_generate_channel_values.return_value = DataArray([[1, 2], [3, 4], [5, 6]], dims=("i", "c")) mock_input_file = MagicMock(name="mock_input_file", coordinates_2d=np.array([[0, 0], [0, 1], [1, 0]])) mock_mz_values = MagicMock(name="mock_mz_values", spec=[]) @@ -56,9 +54,7 @@ def test_generate_channel_values(mocker, mock_generate: GenerateIonImage, mock_p tol = [0.25, 0.5, 0.25] values = mock_generate._generate_channel_values(input_file=mock_input_file, mz_values=mock_mz_values, tol=tol) - xarray.testing.assert_identical( - values, DataArray(np.array([[1.0, 2], [3, 4]]), dims=("i", "c"), attrs={"bg_value": np.nan}) - ) + xarray.testing.assert_identical(values, DataArray(np.array([[1.0, 2], [3, 4]]), dims=("i", "c"))) mock_read_parallel_from.assert_called_once_with(mock_parallel_config) mock_read_parallel_from.return_value.map_chunked.assert_called_once_with( read_file=mock_input_file, @@ -107,7 +103,7 @@ def test_generate_range_images_for_file(mocker, mock_generate: GenerateIonImage, method_compute_for_mz_ranges = mocker.patch.object(GenerateIonImage, "_compute_for_mz_ranges") mock_multi_channel_image = mocker.patch("depiction.tools.generate_ion_image.MultiChannelImage") - mock_input_file = MagicMock(name="input_file", spec=["coordinates_2d"]) + mock_input_file = MagicMock(name="input_file", spec=["coordinates_array_2d"]) mock_mz_ranges = MagicMock(name="mz_ranges", spec=[]) result = mock_generate.generate_range_images_for_file( @@ -128,13 +124,12 @@ def test_generate_range_images_for_file(mocker, mock_generate: GenerateIonImage, reduced = reduce_fn([np.array([[1], [2]]), np.array([[3], [4]])]) np.testing.assert_array_equal(np.array([[1], [2], [3], [4]]), reduced) - mock_multi_channel_image.from_sparse.assert_called_once_with( + mock_multi_channel_image.from_flat.assert_called_once_with( values=mock_parallelize.map_chunked.return_value, - coordinates=mock_input_file.coordinates_2d, + coordinates=mock_input_file.coordinates_array_2d, channel_names=None, - bg_value=np.nan, ) - assert result == mock_multi_channel_image.from_sparse.return_value + assert result == mock_multi_channel_image.from_flat.return_value if __name__ == "__main__":