diff --git a/src/depiction/tools/cli/correct_baseline.py b/src/depiction/tools/cli/correct_baseline.py new file mode 100644 index 0000000..ef8696b --- /dev/null +++ b/src/depiction/tools/cli/correct_baseline.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import Annotated +from loguru import logger + +import typer +from typer import Argument, Option + +from depiction.parallel_ops import ParallelConfig +from depiction.persistence import ImzmlReadFile, ImzmlWriteFile +from depiction.tools.correct_baseline import BaselineVariants, CorrectBaseline + + +def correct_baseline( + input_imzml: Annotated[Path, Argument()], + output_imzml: Annotated[Path, Argument()], + n_jobs: Annotated[int, Option()] = None, + baseline_variant: Annotated[BaselineVariants, Option()] = BaselineVariants.TopHat, +) -> None: + """Removes the baseline from the input imzML file and writes the result to the output imzML file.""" + output_imzml.parent.mkdir(parents=True, exist_ok=True) + if baseline_variant == BaselineVariants.Zero: + logger.info("Baseline correction is deactivated, copying input to output") + shutil.copyfile(input_imzml, output_imzml) + shutil.copyfile(input_imzml.with_suffix(".ibd"), output_imzml.with_suffix(".ibd")) + else: + if n_jobs is None: + # TODO define some sane default for None and -1 n_jobs e.g. use all available up to a limit (None) or use all (1-r) + n_jobs = 10 + parallel_config = ParallelConfig(n_jobs=n_jobs) + input_file = ImzmlReadFile(input_imzml) + output_file = ImzmlWriteFile(output_imzml, imzml_mode=input_file.imzml_mode) + correct_baseline = CorrectBaseline.from_variant(parallel_config=parallel_config, variant=baseline_variant) + correct_baseline.evaluate_file(input_file, output_file) + + +if __name__ == "__main__": + typer.run(correct_baseline) diff --git a/src/depiction/tools/correct_baseline.py b/src/depiction/tools/correct_baseline.py index bb6f4ba..a96bc61 100644 --- a/src/depiction/tools/correct_baseline.py +++ b/src/depiction/tools/correct_baseline.py @@ -21,17 +21,19 @@ if TYPE_CHECKING: from numpy.typing import NDArray from pathlib import Path + from depiction.spectrum.baseline.baseline import Baseline class BaselineVariants(str, enum.Enum): - tophat = "tophat" - loc_medians = "loc_medians" + TopHat = "TopHat" + LocMedians = "LocMedians" + Zero = "Zero" class CorrectBaseline: """Implements baseline correction for imzml files.""" - def __init__(self, parallel_config: ParallelConfig, baseline_correction) -> None: + def __init__(self, parallel_config: ParallelConfig, baseline_correction: Baseline) -> None: self._parallel_config = parallel_config self._baseline_correction = baseline_correction @@ -39,14 +41,14 @@ def __init__(self, parallel_config: ParallelConfig, baseline_correction) -> None def from_variant( cls, parallel_config: ParallelConfig, - variant: BaselineVariants = BaselineVariants.tophat, + variant: BaselineVariants = BaselineVariants.TopHat, window_size: int | float = 5000, window_unit: Literal["ppm", "index"] = "ppm", ) -> CorrectBaseline: """Creates an instance of CorrectBaseline with the specified variant.""" - if variant == BaselineVariants.tophat: + if variant == BaselineVariants.TopHat: baseline_correction = TophatBaseline(window_size=window_size, window_unit=window_unit) - elif variant == BaselineVariants.loc_medians: + elif variant == BaselineVariants.LocMedians: baseline_correction = LocalMediansBaseline(window_size=window_size, window_unit=window_unit) else: raise ValueError(f"Unknown baseline variant: {variant}") @@ -89,19 +91,22 @@ def _operation( @staticmethod def _get_baseline_correction(variant: BaselineVariants): - if variant == BaselineVariants.tophat: + if variant == BaselineVariants.TopHat: return TophatBaseline(window_size=5000, window_unit="ppm") - elif variant == BaselineVariants.loc_medians: + elif variant == BaselineVariants.LocMedians: return LocalMediansBaseline(window_size=5000, window_unit="ppm") else: raise ValueError(f"Unknown baseline variant: {variant}") +# TODO testing +# TODO use it in the workflow +# TODO replace use by tools.cli.correct_baseline def main( input_imzml: Path, output_imzml: Path, n_jobs: int = 20, - baseline_variant: BaselineVariants = BaselineVariants.tophat, + baseline_variant: BaselineVariants = BaselineVariants.TopHat, ) -> None: """Corrects the baseline of `input_imzml` and writes the results to `output_imzml`.""" parallel_config = ParallelConfig(n_jobs=n_jobs, task_size=None) diff --git a/tests/unit/tools/cli/__init__.py b/tests/unit/tools/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/tools/cli/test_correct_baseline.py b/tests/unit/tools/cli/test_correct_baseline.py new file mode 100644 index 0000000..3898770 --- /dev/null +++ b/tests/unit/tools/cli/test_correct_baseline.py @@ -0,0 +1,51 @@ +from pathlib import Path + +import pytest + +from depiction.tools.cli.correct_baseline import correct_baseline +from depiction.tools.correct_baseline import BaselineVariants + + +def test_correct_baseline_when_variant_zero(mocker) -> None: + mock_copyfile = mocker.patch("shutil.copyfile") + mock_logger = mocker.patch("depiction.tools.cli.correct_baseline.logger") + mock_input_imzml = Path("/dev/null/hello.imzML") + mock_output_imzml = mocker.MagicMock(name="mock_output_imzml") + + correct_baseline( + input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.Zero + ) + + assert mock_copyfile.mock_calls == [ + mocker.call(mock_input_imzml, mock_output_imzml), + mocker.call(Path("/dev/null/hello.ibd"), mocker.ANY), + ] + mock_output_imzml.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_logger.info.assert_called_once() + + +def test_correct_baseline_when_other_variant(mocker) -> None: + mock_logger = mocker.patch("loguru.logger") + mock_imzml_mode = mocker.MagicMock(name="mock_imzml_mode", spec=[]) + construct_imzml_read_file = mocker.patch("depiction.tools.cli.correct_baseline.ImzmlReadFile") + construct_imzml_read_file.return_value.imzml_mode = mock_imzml_mode + construct_imzml_write_file = mocker.patch("depiction.tools.cli.correct_baseline.ImzmlWriteFile") + construct_correct_baseline = mocker.patch("depiction.tools.cli.correct_baseline.CorrectBaseline.from_variant") + mock_input_imzml = Path("/dev/null/hello.imzML") + mock_output_imzml = mocker.MagicMock(name="mock_output_imzml") + + correct_baseline( + input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.TopHat + ) + + construct_correct_baseline.assert_called_once_with(parallel_config=mocker.ANY, variant=BaselineVariants.TopHat) + construct_correct_baseline.return_value.evaluate_file.assert_called_once_with( + construct_imzml_read_file.return_value, construct_imzml_write_file.return_value + ) + construct_imzml_read_file.assert_called_once_with(mock_input_imzml) + construct_imzml_write_file.assert_called_once_with(mock_output_imzml, imzml_mode=mock_imzml_mode) + mock_output_imzml.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True) + + +if __name__ == "__main__": + pytest.main()