diff --git a/src/depiction/tools/process_spectra/evaluators.py b/src/depiction/tools/process_spectra/evaluators.py new file mode 100644 index 0000000..1129b28 --- /dev/null +++ b/src/depiction/tools/process_spectra/evaluators.py @@ -0,0 +1,80 @@ +from functools import cached_property +from numpy.typing import NDArray +from typing import Protocol + +from depiction.parallel_ops import ParallelConfig +from depiction.spectrum.peak_filtering import PeakFilteringType +from depiction.tools.correct_baseline.config import BaselineCorrectionConfig +from depiction.tools.correct_baseline.correct_baseline import CorrectBaseline +from depiction.tools.filter_peaks.config import FilterPeaksConfig +from depiction.tools.filter_peaks.filter_peaks import get_peak_filter +from depiction.tools.pick_peaks.config import PickPeaksConfig +from depiction.tools.pick_peaks.pick_peaks import get_peak_picker_from_config +from depiction.tools.process_spectra.config import ( + ProcessSpectraStepPickPeaks, + ProcessSpectraStepRemoveBaseline, + ProcessSpectraStepFilterPeaks, +) + + +class Evaluator(Protocol): + def evaluate(self, mz_arr: NDArray[float], int_arr: NDArray[float]) -> tuple[NDArray[float], NDArray[float]]: + raise NotImplementedError + + +def get_evaluator(step_config) -> Evaluator: + match step_config: + case ProcessSpectraStepPickPeaks(pick=pick_peaks_config): + return EvaluatePickPeaks(config=pick_peaks_config) + case ProcessSpectraStepRemoveBaseline(baseline=baseline_config): + return EvaluateRemoveBaseline(config=baseline_config) + case ProcessSpectraStepFilterPeaks(filter=filter_peaks_config): + return EvaluateFilterPeaks(config=filter_peaks_config) + case _: + raise ValueError(f"Unsupported step config: {step_config}") + + +class EvaluatePickPeaks(Evaluator): + def __init__(self, config: PickPeaksConfig) -> None: + self._config = config + + @cached_property + def _picker(self): + return get_peak_picker_from_config(self._config) + + def evaluate(self, mz_arr, int_arr): + return self._picker.pick_peaks(mz_arr, int_arr) + + +class EvaluateRemoveBaseline(Evaluator): + def __init__(self, config: BaselineCorrectionConfig) -> None: + self._config = config + + @cached_property + def _correct_baseline(self) -> CorrectBaseline: + return CorrectBaseline.from_variant( + parallel_config=ParallelConfig.no_parallelism(), + variant=self._config.baseline_variant, + window_size=self._config.window_size, + window_unit=self._config.window_unit, + ) + + def evaluate(self, mz_arr, int_arr): + int_arr_new = self._correct_baseline.evaluate_spectrum(mz_arr, int_arr) + return mz_arr, int_arr_new + + +class EvaluateFilterPeaks(Evaluator): + def __init__(self, config: FilterPeaksConfig) -> None: + self._config = config + + @cached_property + def _filter(self) -> PeakFilteringType: + return get_peak_filter(self._config) + + def evaluate(self, mz_arr, int_arr): + # TODO this is going to be important, how to handle this, i think we really + # need to remove the spectrum_mz_arr, spectrum_int_arr; but then it will + # not be possible to implement some things anymore (e.g. relative to total TIC) + # unless filters can request these info somehow + return self._filter.filter_peaks(mz_arr, int_arr, mz_arr, int_arr) diff --git a/src/depiction/tools/process_spectra/process.py b/src/depiction/tools/process_spectra/process.py index 313b3e9..92c8fad 100644 --- a/src/depiction/tools/process_spectra/process.py +++ b/src/depiction/tools/process_spectra/process.py @@ -1,6 +1,7 @@ +from depiction.parallel_ops import ParallelConfig from depiction.persistence.types import GenericReadFile, GenericWriteFile from depiction.tools.process_spectra.config import ProcessSpectraConfig def process_spectra(read_file: GenericReadFile, write_file: GenericWriteFile, config: ProcessSpectraConfig) -> None: - pass + parallel_config = ParallelConfig(n_jobs=config.n_jobs)