diff --git a/src/depiction/tools/cli/correct_baseline.py b/src/depiction/tools/cli/correct_baseline.py index ef8696b..64cf093 100644 --- a/src/depiction/tools/cli/correct_baseline.py +++ b/src/depiction/tools/cli/correct_baseline.py @@ -2,10 +2,10 @@ import shutil from pathlib import Path -from typing import Annotated -from loguru import logger +from typing import Annotated, Literal import typer +from loguru import logger from typer import Argument, Option from depiction.parallel_ops import ParallelConfig @@ -18,6 +18,8 @@ def correct_baseline( output_imzml: Annotated[Path, Argument()], n_jobs: Annotated[int, Option()] = None, baseline_variant: Annotated[BaselineVariants, Option()] = BaselineVariants.TopHat, + window_size: Annotated[int | float, Option()] = 5000, + window_unit: Annotated[Literal["ppm", "index"], Option()] = "ppm", ) -> None: """Removes the baseline from the input imzML file and writes the result to the output imzML file.""" output_imzml.parent.mkdir(parents=True, exist_ok=True) @@ -32,7 +34,9 @@ def correct_baseline( parallel_config = ParallelConfig(n_jobs=n_jobs) input_file = ImzmlReadFile(input_imzml) output_file = ImzmlWriteFile(output_imzml, imzml_mode=input_file.imzml_mode) - correct_baseline = CorrectBaseline.from_variant(parallel_config=parallel_config, variant=baseline_variant) + correct_baseline = CorrectBaseline.from_variant( + parallel_config=parallel_config, variant=baseline_variant, window_size=window_size, window_unit=window_unit + ) correct_baseline.evaluate_file(input_file, output_file) diff --git a/src/depiction/tools/correct_baseline.py b/src/depiction/tools/correct_baseline.py index a96bc61..41f587e 100644 --- a/src/depiction/tools/correct_baseline.py +++ b/src/depiction/tools/correct_baseline.py @@ -1,11 +1,9 @@ from __future__ import annotations import enum -from pathlib import Path from typing import TYPE_CHECKING, Literal import numpy as np -import typer from depiction.parallel_ops.parallel_config import ParallelConfig from depiction.parallel_ops.write_spectra_parallel import WriteSpectraParallel @@ -20,7 +18,6 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from pathlib import Path from depiction.spectrum.baseline.baseline import Baseline @@ -97,25 +94,3 @@ def _get_baseline_correction(variant: BaselineVariants): return LocalMediansBaseline(window_size=5000, window_unit="ppm") else: raise ValueError(f"Unknown baseline variant: {variant}") - - -# TODO testing -# TODO use it in the workflow -# TODO replace use by tools.cli.correct_baseline -def main( - input_imzml: Path, - output_imzml: Path, - n_jobs: int = 20, - baseline_variant: BaselineVariants = BaselineVariants.TopHat, -) -> None: - """Corrects the baseline of `input_imzml` and writes the results to `output_imzml`.""" - parallel_config = ParallelConfig(n_jobs=n_jobs, task_size=None) - input_file = ImzmlReadFile(input_imzml) - output_file = ImzmlWriteFile(output_imzml, imzml_mode=input_file.imzml_mode) - output_imzml.parent.mkdir(exist_ok=True, parents=True) - correct_baseline = CorrectBaseline.from_variant(parallel_config=parallel_config, variant=baseline_variant) - correct_baseline.evaluate_file(input_file, output_file) - - -if __name__ == "__main__": - typer.run(main) diff --git a/src/depiction_targeted_preproc/workflow/proc/correct_baseline.py b/src/depiction_targeted_preproc/workflow/proc/correct_baseline.py index 2b7a3a1..4898ed5 100644 --- a/src/depiction_targeted_preproc/workflow/proc/correct_baseline.py +++ b/src/depiction_targeted_preproc/workflow/proc/correct_baseline.py @@ -1,41 +1,41 @@ -import shutil from pathlib import Path from typing import Annotated -import typer -from loguru import logger -from depiction.spectrum.baseline.tophat_baseline import TophatBaseline -from depiction.parallel_ops import ParallelConfig -from depiction.persistence import ImzmlReadFile, ImzmlWriteFile -from depiction.tools.correct_baseline import CorrectBaseline +import typer +from depiction.tools.cli.correct_baseline import correct_baseline +from depiction.tools.correct_baseline import BaselineVariants from depiction_targeted_preproc.pipeline_config.model import PipelineParameters, BaselineAdjustmentTophat -def correct_baseline( +def proc_correct_baseline( input_imzml_path: Annotated[Path, typer.Option()], config_path: Annotated[Path, typer.Option()], output_imzml_path: Annotated[Path, typer.Option()], ) -> None: config = PipelineParameters.parse_yaml(config_path) + window = {} match config.baseline_adjustment: case None: - logger.info("Baseline adjustment is deactivated") - shutil.copy(input_imzml_path, output_imzml_path) - shutil.copy(input_imzml_path.with_suffix(".ibd"), output_imzml_path.with_suffix(".ibd")) + baseline_variant = BaselineVariants.Zero case BaselineAdjustmentTophat(window_size=window_size, window_unit=window_unit): - baseline = TophatBaseline(window_size=window_size, window_unit=window_unit) - parallel_config = ParallelConfig(n_jobs=config.n_jobs, task_size=None) - read_file = ImzmlReadFile(input_imzml_path) - write_file = ImzmlWriteFile(output_imzml_path, imzml_mode=read_file.imzml_mode) - correct_baseline = CorrectBaseline(parallel_config=parallel_config, baseline_correction=baseline) - correct_baseline.evaluate_file(read_file, write_file) + baseline_variant = BaselineVariants.TopHat + window["window_size"] = window_size + window["window_unit"] = window_unit case _: raise ValueError(f"Unsupported baseline adjustment type: {config.baseline_adjustment.baseline_type}") + correct_baseline( + input_imzml=input_imzml_path, + output_imzml=output_imzml_path, + n_jobs=config.n_jobs, + baseline_variant=baseline_variant, + **window, + ) + def main() -> None: - typer.run(correct_baseline) + typer.run(proc_correct_baseline) if __name__ == "__main__":