-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add initial code to get per spectrum evaluator
- Loading branch information
1 parent
8de506d
commit 23b0a7d
Showing
2 changed files
with
82 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from functools import cached_property | ||
from numpy.typing import NDArray | ||
from typing import Protocol | ||
|
||
from depiction.parallel_ops import ParallelConfig | ||
from depiction.spectrum.peak_filtering import PeakFilteringType | ||
from depiction.tools.correct_baseline.config import BaselineCorrectionConfig | ||
from depiction.tools.correct_baseline.correct_baseline import CorrectBaseline | ||
from depiction.tools.filter_peaks.config import FilterPeaksConfig | ||
from depiction.tools.filter_peaks.filter_peaks import get_peak_filter | ||
from depiction.tools.pick_peaks.config import PickPeaksConfig | ||
from depiction.tools.pick_peaks.pick_peaks import get_peak_picker_from_config | ||
from depiction.tools.process_spectra.config import ( | ||
ProcessSpectraStepPickPeaks, | ||
ProcessSpectraStepRemoveBaseline, | ||
ProcessSpectraStepFilterPeaks, | ||
) | ||
|
||
|
||
class Evaluator(Protocol): | ||
def evaluate(self, mz_arr: NDArray[float], int_arr: NDArray[float]) -> tuple[NDArray[float], NDArray[float]]: | ||
raise NotImplementedError | ||
|
||
|
||
def get_evaluator(step_config) -> Evaluator: | ||
match step_config: | ||
case ProcessSpectraStepPickPeaks(pick=pick_peaks_config): | ||
return EvaluatePickPeaks(config=pick_peaks_config) | ||
case ProcessSpectraStepRemoveBaseline(baseline=baseline_config): | ||
return EvaluateRemoveBaseline(config=baseline_config) | ||
case ProcessSpectraStepFilterPeaks(filter=filter_peaks_config): | ||
return EvaluateFilterPeaks(config=filter_peaks_config) | ||
case _: | ||
raise ValueError(f"Unsupported step config: {step_config}") | ||
|
||
|
||
class EvaluatePickPeaks(Evaluator): | ||
def __init__(self, config: PickPeaksConfig) -> None: | ||
self._config = config | ||
|
||
@cached_property | ||
def _picker(self): | ||
return get_peak_picker_from_config(self._config) | ||
|
||
def evaluate(self, mz_arr, int_arr): | ||
return self._picker.pick_peaks(mz_arr, int_arr) | ||
|
||
|
||
class EvaluateRemoveBaseline(Evaluator): | ||
def __init__(self, config: BaselineCorrectionConfig) -> None: | ||
self._config = config | ||
|
||
@cached_property | ||
def _correct_baseline(self) -> CorrectBaseline: | ||
return CorrectBaseline.from_variant( | ||
parallel_config=ParallelConfig.no_parallelism(), | ||
variant=self._config.baseline_variant, | ||
window_size=self._config.window_size, | ||
window_unit=self._config.window_unit, | ||
) | ||
|
||
def evaluate(self, mz_arr, int_arr): | ||
int_arr_new = self._correct_baseline.evaluate_spectrum(mz_arr, int_arr) | ||
return mz_arr, int_arr_new | ||
|
||
|
||
class EvaluateFilterPeaks(Evaluator): | ||
def __init__(self, config: FilterPeaksConfig) -> None: | ||
self._config = config | ||
|
||
@cached_property | ||
def _filter(self) -> PeakFilteringType: | ||
return get_peak_filter(self._config) | ||
|
||
def evaluate(self, mz_arr, int_arr): | ||
# TODO this is going to be important, how to handle this, i think we really | ||
# need to remove the spectrum_mz_arr, spectrum_int_arr; but then it will | ||
# not be possible to implement some things anymore (e.g. relative to total TIC) | ||
# unless filters can request these info somehow | ||
return self._filter.filter_peaks(mz_arr, int_arr, mz_arr, int_arr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from depiction.parallel_ops import ParallelConfig | ||
from depiction.persistence.types import GenericReadFile, GenericWriteFile | ||
from depiction.tools.process_spectra.config import ProcessSpectraConfig | ||
|
||
|
||
def process_spectra(read_file: GenericReadFile, write_file: GenericWriteFile, config: ProcessSpectraConfig) -> None: | ||
pass | ||
parallel_config = ParallelConfig(n_jobs=config.n_jobs) |