Skip to content

Commit

Permalink
call the new cli script from the workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 1, 2024
1 parent 49d526c commit 6638330
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 46 deletions.
10 changes: 7 additions & 3 deletions src/depiction/tools/cli/correct_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import shutil
from pathlib import Path
from typing import Annotated
from loguru import logger
from typing import Annotated, Literal

import typer
from loguru import logger
from typer import Argument, Option

from depiction.parallel_ops import ParallelConfig
Expand All @@ -18,6 +18,8 @@ def correct_baseline(
output_imzml: Annotated[Path, Argument()],
n_jobs: Annotated[int, Option()] = None,
baseline_variant: Annotated[BaselineVariants, Option()] = BaselineVariants.TopHat,
window_size: Annotated[int | float, Option()] = 5000,
window_unit: Annotated[Literal["ppm", "index"], Option()] = "ppm",
) -> None:
"""Removes the baseline from the input imzML file and writes the result to the output imzML file."""
output_imzml.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -32,7 +34,9 @@ def correct_baseline(
parallel_config = ParallelConfig(n_jobs=n_jobs)
input_file = ImzmlReadFile(input_imzml)
output_file = ImzmlWriteFile(output_imzml, imzml_mode=input_file.imzml_mode)
correct_baseline = CorrectBaseline.from_variant(parallel_config=parallel_config, variant=baseline_variant)
correct_baseline = CorrectBaseline.from_variant(
parallel_config=parallel_config, variant=baseline_variant, window_size=window_size, window_unit=window_unit
)
correct_baseline.evaluate_file(input_file, output_file)


Expand Down
25 changes: 0 additions & 25 deletions src/depiction/tools/correct_baseline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import enum
from pathlib import Path
from typing import TYPE_CHECKING, Literal

import numpy as np
import typer

from depiction.parallel_ops.parallel_config import ParallelConfig
from depiction.parallel_ops.write_spectra_parallel import WriteSpectraParallel
Expand All @@ -20,7 +18,6 @@

if TYPE_CHECKING:
from numpy.typing import NDArray
from pathlib import Path
from depiction.spectrum.baseline.baseline import Baseline


Expand Down Expand Up @@ -97,25 +94,3 @@ def _get_baseline_correction(variant: BaselineVariants):
return LocalMediansBaseline(window_size=5000, window_unit="ppm")
else:
raise ValueError(f"Unknown baseline variant: {variant}")


# TODO testing
# TODO use it in the workflow
# TODO replace use by tools.cli.correct_baseline
def main(
input_imzml: Path,
output_imzml: Path,
n_jobs: int = 20,
baseline_variant: BaselineVariants = BaselineVariants.TopHat,
) -> None:
"""Corrects the baseline of `input_imzml` and writes the results to `output_imzml`."""
parallel_config = ParallelConfig(n_jobs=n_jobs, task_size=None)
input_file = ImzmlReadFile(input_imzml)
output_file = ImzmlWriteFile(output_imzml, imzml_mode=input_file.imzml_mode)
output_imzml.parent.mkdir(exist_ok=True, parents=True)
correct_baseline = CorrectBaseline.from_variant(parallel_config=parallel_config, variant=baseline_variant)
correct_baseline.evaluate_file(input_file, output_file)


if __name__ == "__main__":
typer.run(main)
36 changes: 18 additions & 18 deletions src/depiction_targeted_preproc/workflow/proc/correct_baseline.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
import shutil
from pathlib import Path
from typing import Annotated
import typer
from loguru import logger

from depiction.spectrum.baseline.tophat_baseline import TophatBaseline
from depiction.parallel_ops import ParallelConfig
from depiction.persistence import ImzmlReadFile, ImzmlWriteFile
from depiction.tools.correct_baseline import CorrectBaseline
import typer

from depiction.tools.cli.correct_baseline import correct_baseline
from depiction.tools.correct_baseline import BaselineVariants
from depiction_targeted_preproc.pipeline_config.model import PipelineParameters, BaselineAdjustmentTophat


def correct_baseline(
def proc_correct_baseline(
input_imzml_path: Annotated[Path, typer.Option()],
config_path: Annotated[Path, typer.Option()],
output_imzml_path: Annotated[Path, typer.Option()],
) -> None:
config = PipelineParameters.parse_yaml(config_path)
window = {}
match config.baseline_adjustment:
case None:
logger.info("Baseline adjustment is deactivated")
shutil.copy(input_imzml_path, output_imzml_path)
shutil.copy(input_imzml_path.with_suffix(".ibd"), output_imzml_path.with_suffix(".ibd"))
baseline_variant = BaselineVariants.Zero
case BaselineAdjustmentTophat(window_size=window_size, window_unit=window_unit):
baseline = TophatBaseline(window_size=window_size, window_unit=window_unit)
parallel_config = ParallelConfig(n_jobs=config.n_jobs, task_size=None)
read_file = ImzmlReadFile(input_imzml_path)
write_file = ImzmlWriteFile(output_imzml_path, imzml_mode=read_file.imzml_mode)
correct_baseline = CorrectBaseline(parallel_config=parallel_config, baseline_correction=baseline)
correct_baseline.evaluate_file(read_file, write_file)
baseline_variant = BaselineVariants.TopHat
window["window_size"] = window_size
window["window_unit"] = window_unit
case _:
raise ValueError(f"Unsupported baseline adjustment type: {config.baseline_adjustment.baseline_type}")

correct_baseline(
input_imzml=input_imzml_path,
output_imzml=output_imzml_path,
n_jobs=config.n_jobs,
baseline_variant=baseline_variant,
**window,
)


def main() -> None:
typer.run(correct_baseline)
typer.run(proc_correct_baseline)


if __name__ == "__main__":
Expand Down

0 comments on commit 6638330

Please sign in to comment.