Skip to content

Commit

Permalink
ensure test passes
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 3, 2024
1 parent 7f75482 commit b520483
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
13 changes: 6 additions & 7 deletions src/depiction/tools/cli/correct_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


class BaselineCorrectionConfig(BaseModel):
n_jobs: int
n_jobs: int | None
baseline_variant: BaselineVariants = BaselineVariants.TopHat
window_size: int | float = 5000.0
window_unit: Literal["ppm", "index"] = "ppm"


@app.command
def config(
def run_config(
input_imzml: Annotated[Path, Argument()],
output_imzml: Annotated[Path, Argument()],
config: Annotated[Path, Argument()],
Expand All @@ -35,19 +35,18 @@ def config(


@app.default
def main_args(
def run(
input_imzml: Annotated[Path, Argument()],
output_imzml: Annotated[Path, Argument()],
n_jobs: Annotated[int, Option()] = None,
baseline_variant: Annotated[BaselineVariants, Option()] = BaselineVariants.TopHat,
window_size: Annotated[int | float, Option()] = 5000,
window_unit: Annotated[Literal["ppm", "index"], Option()] = "ppm",
):
parsed = BaselineCorrectionConfig(
n_jobs=n_jobs, baseline_type=baseline_variant, window_size=window_size, window_unit=window_unit
config = BaselineCorrectionConfig.validate(
dict(n_jobs=n_jobs, baseline_variant=baseline_variant, window_size=window_size, window_unit=window_unit)
)
parsed.validate()
correct_baseline(config=parsed, input_imzml=input_imzml, output_imzml=output_imzml)
correct_baseline(config=config, input_imzml=input_imzml, output_imzml=output_imzml)


def correct_baseline(config: BaselineCorrectionConfig, input_imzml: Path, output_imzml: Path) -> None:
Expand Down
14 changes: 5 additions & 9 deletions tests/unit/tools/cli/test_correct_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@

import pytest

from depiction.tools.cli.correct_baseline import correct_baseline
from depiction.tools.cli.correct_baseline import run
from depiction.tools.correct_baseline import BaselineVariants


def test_correct_baseline_when_variant_zero(mocker) -> None:
def test_run_when_variant_zero(mocker) -> None:
mock_copyfile = mocker.patch("shutil.copyfile")
mock_logger = mocker.patch("depiction.tools.cli.correct_baseline.logger")
mock_input_imzml = Path("/dev/null/hello.imzML")
mock_output_imzml = mocker.MagicMock(name="mock_output_imzml")

correct_baseline(
input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.Zero
)
run(input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.Zero)

assert mock_copyfile.mock_calls == [
mocker.call(mock_input_imzml, mock_output_imzml),
Expand All @@ -24,7 +22,7 @@ def test_correct_baseline_when_variant_zero(mocker) -> None:
mock_logger.info.assert_called_once()


def test_correct_baseline_when_other_variant(mocker) -> None:
def test_run_when_other_variant(mocker) -> None:
mock_logger = mocker.patch("loguru.logger")
mock_imzml_mode = mocker.MagicMock(name="mock_imzml_mode", spec=[])
construct_imzml_read_file = mocker.patch("depiction.tools.cli.correct_baseline.ImzmlReadFile")
Expand All @@ -34,9 +32,7 @@ def test_correct_baseline_when_other_variant(mocker) -> None:
mock_input_imzml = Path("/dev/null/hello.imzML")
mock_output_imzml = mocker.MagicMock(name="mock_output_imzml")

correct_baseline(
input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.TopHat
)
run(input_imzml=mock_input_imzml, output_imzml=mock_output_imzml, baseline_variant=BaselineVariants.TopHat)

construct_correct_baseline.assert_called_once_with(
parallel_config=mocker.ANY, variant=BaselineVariants.TopHat, window_size=5000, window_unit="ppm"
Expand Down

0 comments on commit b520483

Please sign in to comment.