diff --git a/src/depiction/tools/cli/correct_baseline.py b/src/depiction/tools/cli/correct_baseline.py index 6f0033a..691e008 100644 --- a/src/depiction/tools/cli/correct_baseline.py +++ b/src/depiction/tools/cli/correct_baseline.py @@ -18,14 +18,14 @@ class BaselineCorrectionConfig(BaseModel): - n_jobs: int + n_jobs: int | None baseline_variant: BaselineVariants = BaselineVariants.TopHat window_size: int | float = 5000.0 window_unit: Literal["ppm", "index"] = "ppm" @app.command -def config( +def run_config( input_imzml: Annotated[Path, Argument()], output_imzml: Annotated[Path, Argument()], config: Annotated[Path, Argument()], @@ -35,7 +35,7 @@ def config( @app.default -def main_args( +def run( input_imzml: Annotated[Path, Argument()], output_imzml: Annotated[Path, Argument()], n_jobs: Annotated[int, Option()] = None, @@ -43,11 +43,10 @@ def main_args( window_size: Annotated[int | float, Option()] = 5000, window_unit: Annotated[Literal["ppm", "index"], Option()] = "ppm", ): - parsed = BaselineCorrectionConfig( - n_jobs=n_jobs, baseline_type=baseline_variant, window_size=window_size, window_unit=window_unit + config = BaselineCorrectionConfig.validate( + dict(n_jobs=n_jobs, baseline_variant=baseline_variant, window_size=window_size, window_unit=window_unit) ) - parsed.validate() - correct_baseline(config=parsed, input_imzml=input_imzml, output_imzml=output_imzml) + correct_baseline(config=config, input_imzml=input_imzml, output_imzml=output_imzml) def correct_baseline(config: BaselineCorrectionConfig, input_imzml: Path, output_imzml: Path) -> None: diff --git a/tests/unit/tools/cli/test_correct_baseline.py b/tests/unit/tools/cli/test_correct_baseline.py index 39342b6..2faefd5 100644 --- a/tests/unit/tools/cli/test_correct_baseline.py +++ b/tests/unit/tools/cli/test_correct_baseline.py @@ -2,19 +2,17 @@ import pytest -from depiction.tools.cli.correct_baseline import correct_baseline +from depiction.tools.cli.correct_baseline import run from depiction.tools.correct_baseline import BaselineVariants -def test_correct_baseline_when_variant_zero(mocker) -> None: +def test_run_when_variant_zero(mocker) -> None: mock_copyfile = mocker.patch("shutil.copyfile") mock_logger = mocker.patch("depiction.tools.cli.correct_baseline.logger") mock_input_imzml = Path("/dev/null/hello.imzML") mock_output_imzml = mocker.MagicMock(name="mock_output_imzml") - correct_baseline( - input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.Zero - ) + run(input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.Zero) assert mock_copyfile.mock_calls == [ mocker.call(mock_input_imzml, mock_output_imzml), @@ -24,7 +22,7 @@ def test_correct_baseline_when_variant_zero(mocker) -> None: mock_logger.info.assert_called_once() -def test_correct_baseline_when_other_variant(mocker) -> None: +def test_run_when_other_variant(mocker) -> None: mock_logger = mocker.patch("loguru.logger") mock_imzml_mode = mocker.MagicMock(name="mock_imzml_mode", spec=[]) construct_imzml_read_file = mocker.patch("depiction.tools.cli.correct_baseline.ImzmlReadFile") @@ -34,9 +32,7 @@ def test_correct_baseline_when_other_variant(mocker) -> None: mock_input_imzml = Path("/dev/null/hello.imzML") mock_output_imzml = mocker.MagicMock(name="mock_output_imzml") - correct_baseline( - input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.TopHat - ) + run(input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.TopHat) construct_correct_baseline.assert_called_once_with( parallel_config=mocker.ANY, variant=BaselineVariants.TopHat, window_size=5000, window_unit="ppm"