Skip to content

Commit

Permalink
create new correct_baseline cli
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 1, 2024
1 parent 7558cb2 commit 49d526c
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 9 deletions.
40 changes: 40 additions & 0 deletions src/depiction/tools/cli/correct_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import shutil
from pathlib import Path
from typing import Annotated
from loguru import logger

import typer
from typer import Argument, Option

from depiction.parallel_ops import ParallelConfig
from depiction.persistence import ImzmlReadFile, ImzmlWriteFile
from depiction.tools.correct_baseline import BaselineVariants, CorrectBaseline


def correct_baseline(
input_imzml: Annotated[Path, Argument()],
output_imzml: Annotated[Path, Argument()],
n_jobs: Annotated[int, Option()] = None,
baseline_variant: Annotated[BaselineVariants, Option()] = BaselineVariants.TopHat,
) -> None:
"""Removes the baseline from the input imzML file and writes the result to the output imzML file."""
output_imzml.parent.mkdir(parents=True, exist_ok=True)
if baseline_variant == BaselineVariants.Zero:
logger.info("Baseline correction is deactivated, copying input to output")
shutil.copyfile(input_imzml, output_imzml)
shutil.copyfile(input_imzml.with_suffix(".ibd"), output_imzml.with_suffix(".ibd"))
else:
if n_jobs is None:
# TODO define some sane default for None and -1 n_jobs e.g. use all available up to a limit (None) or use all (1-r)
n_jobs = 10
parallel_config = ParallelConfig(n_jobs=n_jobs)
input_file = ImzmlReadFile(input_imzml)
output_file = ImzmlWriteFile(output_imzml, imzml_mode=input_file.imzml_mode)
correct_baseline = CorrectBaseline.from_variant(parallel_config=parallel_config, variant=baseline_variant)
correct_baseline.evaluate_file(input_file, output_file)


if __name__ == "__main__":
typer.run(correct_baseline)
23 changes: 14 additions & 9 deletions src/depiction/tools/correct_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,34 @@
if TYPE_CHECKING:
from numpy.typing import NDArray
from pathlib import Path
from depiction.spectrum.baseline.baseline import Baseline


class BaselineVariants(str, enum.Enum):
tophat = "tophat"
loc_medians = "loc_medians"
TopHat = "TopHat"
LocMedians = "LocMedians"
Zero = "Zero"


class CorrectBaseline:
"""Implements baseline correction for imzml files."""

def __init__(self, parallel_config: ParallelConfig, baseline_correction) -> None:
def __init__(self, parallel_config: ParallelConfig, baseline_correction: Baseline) -> None:
self._parallel_config = parallel_config
self._baseline_correction = baseline_correction

@classmethod
def from_variant(
cls,
parallel_config: ParallelConfig,
variant: BaselineVariants = BaselineVariants.tophat,
variant: BaselineVariants = BaselineVariants.TopHat,
window_size: int | float = 5000,
window_unit: Literal["ppm", "index"] = "ppm",
) -> CorrectBaseline:
"""Creates an instance of CorrectBaseline with the specified variant."""
if variant == BaselineVariants.tophat:
if variant == BaselineVariants.TopHat:
baseline_correction = TophatBaseline(window_size=window_size, window_unit=window_unit)
elif variant == BaselineVariants.loc_medians:
elif variant == BaselineVariants.LocMedians:
baseline_correction = LocalMediansBaseline(window_size=window_size, window_unit=window_unit)
else:
raise ValueError(f"Unknown baseline variant: {variant}")
Expand Down Expand Up @@ -89,19 +91,22 @@ def _operation(

@staticmethod
def _get_baseline_correction(variant: BaselineVariants):
if variant == BaselineVariants.tophat:
if variant == BaselineVariants.TopHat:
return TophatBaseline(window_size=5000, window_unit="ppm")
elif variant == BaselineVariants.loc_medians:
elif variant == BaselineVariants.LocMedians:
return LocalMediansBaseline(window_size=5000, window_unit="ppm")
else:
raise ValueError(f"Unknown baseline variant: {variant}")


# TODO testing
# TODO use it in the workflow
# TODO replace use by tools.cli.correct_baseline
def main(
input_imzml: Path,
output_imzml: Path,
n_jobs: int = 20,
baseline_variant: BaselineVariants = BaselineVariants.tophat,
baseline_variant: BaselineVariants = BaselineVariants.TopHat,
) -> None:
"""Corrects the baseline of `input_imzml` and writes the results to `output_imzml`."""
parallel_config = ParallelConfig(n_jobs=n_jobs, task_size=None)
Expand Down
Empty file.
51 changes: 51 additions & 0 deletions tests/unit/tools/cli/test_correct_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pathlib import Path

import pytest

from depiction.tools.cli.correct_baseline import correct_baseline
from depiction.tools.correct_baseline import BaselineVariants


def test_correct_baseline_when_variant_zero(mocker) -> None:
mock_copyfile = mocker.patch("shutil.copyfile")
mock_logger = mocker.patch("depiction.tools.cli.correct_baseline.logger")
mock_input_imzml = Path("/dev/null/hello.imzML")
mock_output_imzml = mocker.MagicMock(name="mock_output_imzml")

correct_baseline(
input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.Zero
)

assert mock_copyfile.mock_calls == [
mocker.call(mock_input_imzml, mock_output_imzml),
mocker.call(Path("/dev/null/hello.ibd"), mocker.ANY),
]
mock_output_imzml.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
mock_logger.info.assert_called_once()


def test_correct_baseline_when_other_variant(mocker) -> None:
mock_logger = mocker.patch("loguru.logger")
mock_imzml_mode = mocker.MagicMock(name="mock_imzml_mode", spec=[])
construct_imzml_read_file = mocker.patch("depiction.tools.cli.correct_baseline.ImzmlReadFile")
construct_imzml_read_file.return_value.imzml_mode = mock_imzml_mode
construct_imzml_write_file = mocker.patch("depiction.tools.cli.correct_baseline.ImzmlWriteFile")
construct_correct_baseline = mocker.patch("depiction.tools.cli.correct_baseline.CorrectBaseline.from_variant")
mock_input_imzml = Path("/dev/null/hello.imzML")
mock_output_imzml = mocker.MagicMock(name="mock_output_imzml")

correct_baseline(
input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.TopHat
)

construct_correct_baseline.assert_called_once_with(parallel_config=mocker.ANY, variant=BaselineVariants.TopHat)
construct_correct_baseline.return_value.evaluate_file.assert_called_once_with(
construct_imzml_read_file.return_value, construct_imzml_write_file.return_value
)
construct_imzml_read_file.assert_called_once_with(mock_input_imzml)
construct_imzml_write_file.assert_called_once_with(mock_output_imzml, imzml_mode=mock_imzml_mode)
mock_output_imzml.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)


if __name__ == "__main__":
pytest.main()

0 comments on commit 49d526c

Please sign in to comment.