Skip to content

Commit

Permalink
Update MultiChannelImage definition to explicitly manage background/f…
Browse files Browse the repository at this point in the history
…oreground information.

This is a far-reaching change with many side effects probably not yet identified.
  • Loading branch information
leoschwarz committed Oct 15, 2024
1 parent 0954621 commit 8d0ac05
Show file tree
Hide file tree
Showing 41 changed files with 989 additions and 552 deletions.
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
nox.options.default_venv_backend = "uv"


@nox.session(reuse_venv=True)
def lint(session: nox.Session) -> None:
"""Runs the linter."""
session.install("pre-commit")
Expand Down
19 changes: 12 additions & 7 deletions src/depiction/calibration/calibration_method.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Protocol
from __future__ import annotations
from typing import Protocol, TYPE_CHECKING

from numpy.typing import NDArray
from xarray import DataArray

if TYPE_CHECKING:
from depiction.image.multi_channel_image import MultiChannelImage


class CalibrationMethod(Protocol):
"""Defines the interface for a spectrum calibration method."""
Expand All @@ -16,14 +20,15 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N
"""
return DataArray([], dims=["c"])

def preprocess_image_features(self, all_features: DataArray) -> DataArray:
# TODO update doc
# TODO update other methods
# TODO check if it works
def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage:
"""Preprocesses the extracted features from all spectra in an image.
For example, image-wide smoothing of the features could be applied here.
If no preprocessing is necessary, the input DataArray should be returned.
:param all_features: a DataArray with the extracted features, with dimensions ["i", "c"]
and coordinates ["i", "x", "y"] for dimension "i"
:return: a DataArray with the preprocessed features, with dimensions ["i", "c"]
and coordinates ["i", "x", "y"] for dimension "i"
If no preprocessing is necessary, the input MultiChannelImage should be returned.
:param all_features: a MultiChannelImage with the extracted features
:return: a MultiChannelImage with the preprocessed features
"""
return all_features

Expand Down
94 changes: 24 additions & 70 deletions src/depiction/calibration/perform_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from xarray import DataArray

from depiction.calibration.calibration_method import CalibrationMethod
from depiction.image import MultiChannelImage
from depiction.parallel_ops import ParallelConfig, ReadSpectraParallel, WriteSpectraParallel
from depiction.parallel_ops.parallel_map import ParallelMap
from depiction.persistence.types import GenericReadFile, GenericWriteFile, GenericReader, GenericWriter
Expand All @@ -26,77 +27,27 @@ def __init__(
self._parallel_config = parallel_config
self._coefficient_output_file = coefficient_output_file

# def _reshape(self, pattern: str, data: DataArray, coordinates) -> DataArray:
# if pattern == "i,c->y,x,c":
# data = data.copy()
# # TODO fix the deprecation here!
# data["i"] = pd.MultiIndex.from_arrays((coordinates[:, 1], coordinates[:, 0]), names=("y", "x"))
# data = data.unstack("i")
# return data.transpose("y", "x", "c")
# elif pattern == "y,x,c->i,c":
# data = data.transpose("y", "x", "c").copy()
# data = data.stack(i=("y", "x")).drop_vars(["i", "x", "y"])
# # convert to integers
# data["i"] = np.arange(len(data["i"]))
# return data.transpose("i", "c")
# else:
# raise ValueError(f"Unknown pattern={repr(pattern)}")

def _validate_per_spectra_array(self, array: DataArray, coordinates_2d: NDArray[float]) -> None:
"""Checks the DataArray has the correct shapes and dimensions. Used for debugging."""
# TODO make it configurable in the future, whether this check is executed, during development it definitely
# should be here since it can safe a ton of time
expected_coords = {"i", "x", "y"}
if set(array.coords) != expected_coords:
raise ValueError(f"Expected coords={expected_coords}, got={set(array.coords)}")
expected_dims = {"i", "c"}

errors = []
if set(array.dims) != expected_dims:
logger.error(f"Expected dims={expected_dims}, got={set(array.dims)}")
errors.append("Mismatch in dimensions")
if not np.array_equal(array.x.values, coordinates_2d[:, 0]):
logger.error(f"Expected x: values={coordinates_2d[:, 0]} shape={coordinates_2d[:, 0].shape}")
logger.error(f"Actual x: values={array.x.values} shape={array.x.values.shape}")
logger.info(f"(Expected x values without offset: {coordinates_2d[:, 0] - coordinates_2d[:, 0].min()})")
errors.append("Mismatch in x values")
if not np.array_equal(array.y.values, coordinates_2d[:, 1]):
logger.error(f"Expected y: values={coordinates_2d[:, 1]} shape={coordinates_2d[:, 1].shape}")
logger.error(f"Actual y: values={array.y.values} shape={array.y.values.shape}")
logger.info(f"(Expected y values without offset: {coordinates_2d[:, 1] - coordinates_2d[:, 1].min()})")
errors.append("Mismatch in y values")
if not np.array_equal(array.i.values, np.arange(len(array.i))):
errors.append("Mismatch in i values")
logger.error(f"Expected i: values={np.arange(len(array.i))} shape={np.arange(len(array.i)).shape}")
logger.error(f"Actual i: values={array.i.values} shape={array.i.values.shape}")
if errors:
raise ValueError(errors)

def calibrate_image(
self, read_peaks: GenericReadFile, write_file: GenericWriteFile, read_full: Optional[GenericReadFile] = None
) -> None:
if read_full is None:
read_full = read_peaks
read_full = read_full or read_peaks

logger.info("Extracting all features...")
all_features = self._extract_all_features(read_peaks).transpose("i", "c")
self._validate_per_spectra_array(all_features, coordinates_2d=read_peaks.coordinates_2d)
all_features = self._extract_all_features(read_peaks)
self._write_data_array(all_features, group="features_raw")

logger.info("Preprocessing features...")
all_features = self._calibration.preprocess_image_features(all_features=all_features).transpose("i", "c")
self._validate_per_spectra_array(all_features, coordinates_2d=read_peaks.coordinates_2d)
all_features = self._calibration.preprocess_image_features(all_features=all_features)
self._write_data_array(all_features, group="features_processed")

logger.info("Fitting models...")
model_coefs = self._fit_all_models(all_features=all_features).transpose("i", "c")
self._validate_per_spectra_array(model_coefs, coordinates_2d=read_peaks.coordinates_2d)
model_coefs = self._fit_all_models(all_features=all_features)
self._write_data_array(model_coefs, group="model_coefs")

logger.info("Applying models...")
self._apply_all_models(read_file=read_full, write_file=write_file, all_model_coefs=model_coefs)

def _extract_all_features(self, read_peaks: GenericReadFile) -> DataArray:
def _extract_all_features(self, read_peaks: GenericReadFile) -> MultiChannelImage:
read_parallel = ReadSpectraParallel.from_config(self._parallel_config)
all_features = read_parallel.map_chunked(
read_file=read_peaks,
Expand All @@ -106,12 +57,14 @@ def _extract_all_features(self, read_peaks: GenericReadFile) -> DataArray:
),
reduce_fn=lambda chunks: xarray.concat(chunks, dim="i"),
)
return all_features.assign_coords(
x=("i", read_peaks.coordinates_2d[:, 0]), y=("i", read_peaks.coordinates_2d[:, 1])
return MultiChannelImage.from_flat(
values=all_features,
coordinates=read_peaks.coordinates_array_2d,
channel_names="c" not in all_features.coords,
)

def _apply_all_models(
self, read_file: GenericReadFile, write_file: GenericWriteFile, all_model_coefs: DataArray
self, read_file: GenericReadFile, write_file: GenericWriteFile, all_model_coefs: MultiChannelImage
) -> None:
write_parallel = WriteSpectraParallel.from_config(self._parallel_config)
write_parallel.map_chunked_to_file(
Expand All @@ -124,31 +77,32 @@ def _apply_all_models(
),
)

def _fit_all_models(self, all_features: DataArray) -> DataArray:
def _fit_all_models(self, all_features: MultiChannelImage) -> MultiChannelImage:
parallel_map = ParallelMap.from_config(self._parallel_config)
# TODO to be refactored
all_features_flat = all_features.data_flat
result = parallel_map(
operation=self._fit_chunk_models,
tasks=np.array_split(all_features.coords["i"], self._parallel_config.n_jobs),
tasks=np.array_split(all_features_flat.coords["i"], self._parallel_config.n_jobs),
reduce_fn=lambda chunks: xarray.concat(chunks, dim="i"),
bind_kwargs={"all_features": all_features},
bind_kwargs={"all_features": all_features_flat},
)
return MultiChannelImage.from_flat(
result, coordinates=all_features.coordinates_flat, channel_names="c" not in result.coords
)
return result

def _fit_chunk_models(self, spectra_indices: NDArray[int], all_features: DataArray) -> DataArray:
collect = []
for spectrum_id in spectra_indices:
features = all_features.sel(i=spectrum_id)
model_coef = self._calibration.fit_spectrum_model(features=features)
collect.append(model_coef)
combined = xarray.concat(collect, dim="i")
combined.coords["i"] = spectra_indices
return combined
return xarray.concat(collect, dim="i")

def _write_data_array(self, array: DataArray, group: str) -> None:
def _write_data_array(self, image: MultiChannelImage, group: str) -> None:
if not self._coefficient_output_file:
return
# TODO engine should not be necessary, but using it for debugging
array.to_netcdf(path=self._coefficient_output_file, group=group, format="NETCDF4", engine="netcdf4", mode="a")
image.write_hdf5(path=self._coefficient_output_file, mode="a", group=group)

@staticmethod
def _extract_chunk_features(
Expand All @@ -171,12 +125,12 @@ def _calibrate_spectra(
spectra_indices: list[int],
writer: GenericWriter,
calibration: CalibrationMethod,
all_model_coefs: DataArray,
all_model_coefs: MultiChannelImage,
) -> None:
for spectrum_id in spectra_indices:
# TODO sanity check the usage of i as spectrum_id (i.e. check the coords!)
mz_arr, int_arr, coords = reader.get_spectrum_with_coords(spectrum_id)
features = all_model_coefs.sel(i=spectrum_id)
features = all_model_coefs.data_flat.isel(i=spectrum_id)
calib_mz_arr, calib_int_arr = calibration.apply_spectrum_model(
spectrum_mz_arr=mz_arr, spectrum_int_arr=int_arr, model_coef=features
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from depiction.calibration.calibration_method import CalibrationMethod
from depiction.calibration.chemical_noise_bg_2019_boskamp_v2 import ChemicalNoiseCalibration
from depiction.image import MultiChannelImage


class CalibrationMethodChemicalPeptideNoise(CalibrationMethod):
Expand All @@ -30,7 +31,7 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N
# return DataArray(shifts_arr, dims=["c"])
return DataArray([], dims=["c"])

def preprocess_image_features(self, all_features: DataArray) -> DataArray:
def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage:
# TODO no smoothing applied for now, but could be added (just, avoid duplication with the RegressShift)
return all_features

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from depiction.calibration.calibration_method import CalibrationMethod
from numpy.typing import NDArray
from xarray import DataArray

from depiction.calibration.calibration_method import CalibrationMethod
from depiction.image import MultiChannelImage


class CalibrationMethodDummy(CalibrationMethod):
"""Returns the input data and creates some dummy coefficients to ensure compatibility."""

def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: NDArray[float]) -> DataArray:
return DataArray([0], dims=["c"])

def preprocess_image_features(self, all_features: DataArray) -> DataArray:
def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage:
return all_features

def fit_spectrum_model(self, features: DataArray) -> DataArray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from depiction.calibration.calibration_method import CalibrationMethod
from depiction.calibration.spectrum.reference_peak_distances import ReferencePeakDistances
from depiction.image import MultiChannelImage


class CalibrationMethodGlobalConstantShift(CalibrationMethod):
Expand All @@ -27,12 +28,17 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N
)
return DataArray(distances_mz, dims=["c"])

def preprocess_image_features(self, all_features: DataArray) -> DataArray:
def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage:
# we compute the actual global distance here
global_distance = np.nanmedian(all_features.values.ravel())
global_distance = np.nanmedian(all_features.data_flat.ravel())
# create one copy per spectrum
n_spectra = all_features.sizes["i"]
return DataArray(np.full((n_spectra, 1), global_distance), dims=["i", "c"], coords=all_features.coords)
n_spectra = all_features.n_nonzero
return MultiChannelImage(
data=DataArray(
np.full((n_spectra, 1, 1), global_distance), dims=["y", "x", "c"], coords=all_features.coords
),
is_foreground=DataArray(np.ones((n_spectra, 1), dtype=bool), dims=["y", "x"], coords=all_features.coords),
)

def fit_spectrum_model(self, features: DataArray) -> DataArray:
return features
Expand Down
12 changes: 6 additions & 6 deletions src/depiction/calibration/spectrum/calibration_method_mcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from xarray import DataArray

from depiction.calibration.calibration_method import CalibrationMethod
from depiction.calibration.spectrum.calibration_smoothing import smooth_image_features
from depiction.image import MultiChannelImage
from depiction.image.smoothing.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware


class CalibrationMethodMassClusterCenterModel(CalibrationMethod):
Expand Down Expand Up @@ -62,13 +63,12 @@ def compute_distance_from_MCC(self, delta: NDArray[float], l_none: float = 1.000
delta_lambda[i] = -1 + term1
return delta_lambda

def preprocess_image_features(self, all_features: DataArray) -> DataArray:
def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage:
if self._model_smoothing_activated:
return smooth_image_features(
all_features=all_features,
kernel_size=self._model_smoothing_kernel_size,
kernel_std=self._model_smoothing_kernel_std,
smoother = SpatialSmoothingSparseAware(
kernel_size=self._model_smoothing_kernel_size, kernel_std=self._model_smoothing_kernel_std
)
return smoother.smooth_image(all_features)
else:
return all_features

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from depiction.calibration.calibration_method import CalibrationMethod
from depiction.calibration.models import LinearModel
from depiction.calibration.models.fit_model import fit_model
from depiction.calibration.spectrum.calibration_smoothing import smooth_image_features
from depiction.calibration.spectrum.reference_peak_distances import ReferencePeakDistances
from depiction.image import MultiChannelImage
from depiction.image.smoothing.spatial_smoothing_sparse_aware import SpatialSmoothingSparseAware


class CalibrationMethodRegressShift(CalibrationMethod):
Expand Down Expand Up @@ -80,13 +81,12 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N
else:
raise ValueError(f"Unknown unit={self._model_unit}")

def preprocess_image_features(self, all_features: DataArray) -> DataArray:
def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage:
if self._input_smoothing_activated:
return smooth_image_features(
all_features=all_features,
kernel_size=self._input_smoothing_kernel_size,
kernel_std=self._input_smoothing_kernel_std,
smoother = SpatialSmoothingSparseAware(
kernel_size=self._input_smoothing_kernel_size, kernel_std=self._input_smoothing_kernel_std
)
return smoother.smooth_image(all_features)
else:
return all_features

Expand Down
30 changes: 0 additions & 30 deletions src/depiction/calibration/spectrum/calibration_smoothing.py

This file was deleted.

9 changes: 4 additions & 5 deletions src/depiction/clustering/remap_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ def remap_cluster_labels(
:param mapping: a dictionary mapping the old cluster labels to the new ones
:param cluster_channel: the name of the channel with the cluster labels
"""
with xarray.set_options(keep_attrs=True):
relabeled = xarray.apply_ufunc(
lambda v: mapping.get(v, np.nan), image.data_spatial.sel(c=[cluster_channel]), vectorize=True
)
relabeled = xarray.apply_ufunc(
lambda v: mapping.get(v, np.nan), image.data_spatial.sel(c=[cluster_channel]), vectorize=True
)
img_relabeled = image.drop_channels(coords=[cluster_channel], allow_missing=False).append_channels(
MultiChannelImage(relabeled)
MultiChannelImage(relabeled, is_foreground=image.fg_mask, is_foreground_label=image.is_foreground_label)
)
return img_relabeled
Empty file.
Loading

0 comments on commit 8d0ac05

Please sign in to comment.