diff --git a/src/depiction/calibration/methods/calibration_method_constant_global_shift.py b/src/depiction/calibration/methods/calibration_method_constant_global_shift.py index 27af4f0..414bb81 100644 --- a/src/depiction/calibration/methods/calibration_method_constant_global_shift.py +++ b/src/depiction/calibration/methods/calibration_method_constant_global_shift.py @@ -1,4 +1,5 @@ import numpy as np +import xarray from numpy.typing import NDArray from xarray import DataArray @@ -30,16 +31,12 @@ def extract_spectrum_features(self, peak_mz_arr: NDArray[float], peak_int_arr: N def preprocess_image_features(self, all_features: MultiChannelImage) -> MultiChannelImage: # we compute the actual global distance here - global_distance = np.nanmedian(all_features.data_flat.ravel()) - # create one copy per spectrum - n_spectra = all_features.n_nonzero + global_distance = np.nanmedian(all_features.data_flat.data.ravel()) + # return one value per spectrum return MultiChannelImage( - data=DataArray( - np.full((n_spectra, 1, 1), global_distance), dims=["y", "x", "c"], coords=all_features.coordinates_flat - ), - is_foreground=DataArray( - np.ones((n_spectra, 1), dtype=bool), dims=["y", "x"], coords=all_features.coordinates_flat - ), + data=xarray.full_like(all_features.data_spatial.isel(c=[0]), global_distance), + is_foreground=all_features.fg_mask, + is_foreground_label=all_features.is_foreground_label, ) def fit_spectrum_model(self, features: DataArray) -> DataArray: