Skip to content

Commit

Permalink
refactor pick_peaks into depiction
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 4, 2024
1 parent 4d2c5da commit b207a66
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 224 deletions.
40 changes: 40 additions & 0 deletions src/depiction/tools/cli/cli_pick_peaks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from pathlib import Path

import cyclopts
import yaml

from depiction.persistence import ImzmlReadFile, ImzmlWriteFile, ImzmlModeEnum
from depiction.tools.filter_peaks import FilterPeaksConfig, filter_peaks, FilterNHighestIntensityPartitionedConfig
from depiction.tools.pick_peaks import PickPeaksConfig, pick_peaks

app = cyclopts.App()


@app.command
def run_config(
config: Path,
input_imzml: Path,
output_imzml: Path,
) -> None:
config = PickPeaksConfig.validate(yaml.safe_load(config.read_text()))
pick_peaks(
config=config,
input_file=ImzmlReadFile(input_imzml),
output_file=ImzmlWriteFile(output_imzml, imzml_mode=ImzmlModeEnum.PROCESSED),
)


# @app.default
# def run(
# input_imzml: Path,
# output_imzml: Path,
# *,
# n_jobs: int | None = None,
# ) -> None:
# pass


if __name__ == "__main__":
app()
193 changes: 115 additions & 78 deletions src/depiction/tools/pick_peaks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,54 @@
import argparse
from __future__ import annotations

import shutil
from pathlib import Path
from typing import Literal, Any

from loguru import logger
from pydantic import BaseModel, Field

from depiction.parallel_ops import ParallelConfig, WriteSpectraParallel
from depiction.spectrum.peak_filtering import FilterByIntensity
from depiction.persistence import ImzmlModeEnum
from depiction.persistence import ImzmlWriteFile, ImzmlReadFile, ImzmlWriter, ImzmlReader
from depiction.spectrum.peak_filtering import PeakFilteringType
from depiction.spectrum.peak_picking import BasicInterpolatedPeakPicker
from depiction.persistence import ImzmlWriteFile, ImzmlReadFile, ImzmlWriter, ImzmlReader, ImzmlModeEnum
from depiction.spectrum.peak_picking.ms_peak_picker_wrapper import MSPeakPicker
from depiction.tools.filter_peaks import FilterPeaksConfig, get_peak_filter


class PeakPickerBasicInterpolatedConfig(BaseModel):
peak_picker_type: Literal["BasicInterpolated"]
min_prominence: float
min_distance: int | float | None = None
min_distance_unit: Literal["index", "mz"] | None = None

# TODO ensure min_distance are both either present or missing
# (ideally we would just have a better typing support here and provide as tuple,
# but postpone for later)


class PeakPickerMSPeakPickerConfig(BaseModel):
peak_picker_type: Literal["MSPeakPicker"]
fit_type: Literal["quadratic"] = "quadratic"


class PeakPickerFindMFPyConfig(BaseModel):
peak_picker_type: Literal["FindMFPy"]
resolution: float = 10000.0
width: float = 2.0
int_width: float = 2.0
int_threshold: float = 10.0
area: bool = True
max_peaks: int = 0


class PickPeaksConfig(BaseModel, use_enum_values=True, validate_default=True):
peak_picker: PeakPickerBasicInterpolatedConfig | PeakPickerMSPeakPickerConfig | PeakPickerFindMFPyConfig = Field(
..., discriminator="peak_picker_type"
)
peak_filtering: FilterPeaksConfig | None
force_peak_picker: bool = False
n_jobs: int


class PickPeaks:
Expand Down Expand Up @@ -39,83 +82,77 @@ def _operation(
logger.warning(f"Dropped spectrum {spectrum_index} as no peaks were found")


def debug_diagnose_threshold_correspondence(
peak_filtering: FilterByIntensity,
peak_picker: BasicInterpolatedPeakPicker,
input_imzml: ImzmlReadFile,
n_points: int,
) -> None:
unfiltered_peak_picker = BasicInterpolatedPeakPicker(
min_prominence=peak_picker.min_prominence,
min_distance=peak_picker.min_distance,
min_distance_unit=peak_picker.min_distance_unit,
peak_filtering=None,
)

# TODO remove/consolidate this debugging functionality
with input_imzml.reader() as reader:
for i_spectrum in range(0, input_imzml.n_spectra, input_imzml.n_spectra // n_points):
spec_mz_arr, spec_int_arr = reader.get_spectrum(i_spectrum)
_, peak_int_arr = unfiltered_peak_picker.pick_peaks(spec_mz_arr, spec_int_arr)
peak_filtering.debug_diagnose_threshold_correspondence(
spectrum_int_arr=spec_int_arr, peak_int_arr=peak_int_arr
# def debug_diagnose_threshold_correspondence(
# peak_filtering: FilterByIntensity,
# peak_picker: BasicInterpolatedPeakPicker,
# input_imzml: ImzmlReadFile,
# n_points: int,
# ) -> None:
# unfiltered_peak_picker = BasicInterpolatedPeakPicker(
# min_prominence=peak_picker.min_prominence,
# min_distance=peak_picker.min_distance,
# min_distance_unit=peak_picker.min_distance_unit,
# peak_filtering=None,
# )
#
# # TODO remove/consolidate this debugging functionality
# with input_imzml.reader() as reader:
# for i_spectrum in range(0, input_imzml.n_spectra, input_imzml.n_spectra // n_points):
# spec_mz_arr, spec_int_arr = reader.get_spectrum(i_spectrum)
# _, peak_int_arr = unfiltered_peak_picker.pick_peaks(spec_mz_arr, spec_int_arr)
# peak_filtering.debug_diagnose_threshold_correspondence(
# spectrum_int_arr=spec_int_arr, peak_int_arr=peak_int_arr
# )


def get_peak_picker(config: PickPeaksConfig, peak_filtering: PeakFilteringType | None) -> Any:
match config.peak_picker:
case PeakPickerBasicInterpolatedConfig() as peak_picker_config:
return BasicInterpolatedPeakPicker(
min_prominence=peak_picker_config.min_prominence,
min_distance=peak_picker_config.min_distance,
min_distance_unit=peak_picker_config.min_distance_unit,
peak_filtering=peak_filtering,
)
case PeakPickerMSPeakPickerConfig() as peak_picker_config:
return MSPeakPicker(fit_type=peak_picker_config.fit_type, peak_filtering=peak_filtering)
case PeakPickerFindMFPyConfig() as peak_picker_config:
# TODO refactor this later?
# NOTE: importing this here since it has non-standard dependencies
from depiction.spectrum.peak_picking.findmf_peak_picker import FindMFPeakpicker

return FindMFPeakpicker(
resolution=peak_picker_config.resolution,
width=peak_picker_config.width,
int_width=peak_picker_config.int_width,
int_threshold=peak_picker_config.int_threshold,
area=peak_picker_config.area,
max_peaks=peak_picker_config.max_peaks,
)
case _:
raise ValueError(f"Unsupported peak picker type: {config.peak_picker.peak_picker_type}")


def pick_peaks(
input_imzml_path: str,
output_imzml_path: str,
n_jobs: int,
peak_picker: str,
min_prominence: float,
min_distance: float,
min_distance_unit: str,
min_peak_intensity: float,
min_peak_intensity_unit: str,
config: PickPeaksConfig,
input_file: ImzmlReadFile,
output_file: ImzmlWriteFile,
) -> None:
parallel_config = ParallelConfig(n_jobs=n_jobs, task_size=None)
if peak_picker != "basic_interpolated":
raise ValueError(f"Unknown peak picker: {peak_picker}")
peak_filtering = FilterByIntensity(min_intensity=min_peak_intensity, normalization=min_peak_intensity_unit)
peak_picker = BasicInterpolatedPeakPicker(
min_prominence=min_prominence,
min_distance=min_distance,
min_distance_unit=min_distance_unit,
peak_filtering=peak_filtering,
# peak_filtering=FilterNHighestIntensityPartitioned(max_count=120*3, n_partitions=8),
# peak_filtering=FilterByIntensity(min_intensity=min_peak_intensity, normalization="vec_norm"),
)
input_imzml = ImzmlReadFile(input_imzml_path)
debug_diagnose_threshold_correspondence(
peak_filtering=peak_filtering, peak_picker=peak_picker, input_imzml=input_imzml, n_points=10
)

pick = PickPeaks(
peak_picker=peak_picker,
parallel_config=parallel_config,
)
pick.evaluate_file(
read_file=input_imzml,
write_file=ImzmlWriteFile(output_imzml_path, imzml_mode=ImzmlModeEnum.PROCESSED),
)


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True, dest="input_imzml_path")
parser.add_argument("--output", type=str, required=True, dest="output_imzml_path")
parser.add_argument("--n_jobs", type=int, default=20, dest="n_jobs")
subparsers = parser.add_subparsers(dest="peak_picker")
parser_bi = subparsers.add_parser("basic_interpolated")
parser_bi.add_argument("--min_prominence", type=float, default=0.01)
parser_bi.add_argument("--min_distance", type=float, default=0.5)
parser_bi.add_argument("--min_distance_unit", type=str, default="mz")
parser_bi.add_argument("--min_peak_intensity", type=float, default=0.0005)
parser_bi.add_argument("--min_peak_intensity_unit", type=str, default="tic", choices=["tic", "median", "vec_norm"])

args = vars(parser.parse_args())
pick_peaks(**args)


if __name__ == "__main__":
main()
peak_filtering = get_peak_filter(config.peak_filtering)
peak_picker = get_peak_picker(config, peak_filtering)
parallel_config = ParallelConfig(n_jobs=config.n_jobs)

if config.peak_picker is None or (
not config.force_peak_picker and input_file.imzml_mode == ImzmlModeEnum.PROCESSED
):
copy_without_picking(input_imzml_path=input_file.imzml_file, output_imzml_path=output_file.imzml_file)
else:
pick_peaks = PickPeaks(peak_picker=peak_picker, parallel_config=parallel_config)
pick_peaks.evaluate_file(read_file=input_file, write_file=output_file)


def copy_without_picking(input_imzml_path: Path, output_imzml_path: Path) -> None:
# TODO this is duplicated in several places and should be unified, in fact it could be a method of ImzmlReadFile
logger.info("Peak picking is deactivated")
shutil.copy(input_imzml_path, output_imzml_path)
shutil.copy(input_imzml_path.with_suffix(".ibd"), output_imzml_path.with_suffix(".ibd"))
22 changes: 15 additions & 7 deletions src/depiction_targeted_preproc/pipeline_config/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ baseline_correction:
baseline_type: TopHat
window_size: 3000
window_unit: ppm
peak_picker:
peak_picker_type: FindMFPy
force_peak_picker: no
peak_filtering:
method: FilterNHighestIntensityPartitioned
max_count: 200
n_partitions: 8
pick_peaks:
peak_picker:
peak_picker_type: FindMFPy
force_peak_picker: no
peak_filtering:
filters:
- method: FilterNHighestIntensityPartitioned
max_count: 200
n_partitions: 8
n_jobs: 10
filter_peaks:
filters:
- method: FilterNHighestIntensityPartitioned
max_count: 200
n_partitions: 8
#peak_picker:
# peak_picker_type: BasicInterpolated
# min_prominence: 0.5
Expand Down
54 changes: 5 additions & 49 deletions src/depiction_targeted_preproc/pipeline_config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from enum import Enum
from pathlib import Path
from typing import Literal, Union, Annotated, Self
from typing import Literal, Annotated, Self

import yaml
from pydantic import BaseModel, Field, ConfigDict

from depiction.tools.correct_baseline import BaselineCorrectionConfig
from depiction.tools.filter_peaks import FilterPeaksConfig
from depiction.tools.pick_peaks import PickPeaksConfig


class Model(BaseModel):
Expand All @@ -20,38 +22,6 @@ def parse_yaml(cls, path: Path) -> Self:
return cls.model_validate(yaml.unsafe_load(path.read_text()))


class PeakPickerBasicInterpolated(BaseModel):
peak_picker_type: Literal["BasicInterpolated"]
min_prominence: float
min_distance: Union[int, float, None] = None
min_distance_unit: Literal["index", "mz"] | None = None

# TODO ensure min_distance are both either present or missing
# (ideally we would just have a better typing support here and provide as tuple,
# but postpone for later)


class PeakPickerMSPeakPicker(BaseModel):
peak_picker_type: Literal["MSPeakPicker"]
fit_type: Literal["quadratic"] = "quadratic"


class PeakPickerFindMFPy(BaseModel):
peak_picker_type: Literal["FindMFPy"]
resolution: float = 10000.0
width: float = 2.0
int_width: float = 2.0
int_threshold: float = 10.0
area: bool = True
max_peaks: int = 0


PeakPicker = Annotated[
None | PeakPickerBasicInterpolated | PeakPickerMSPeakPicker | PeakPickerFindMFPy,
Field(discriminator="peak_picker_type"),
]


class CalibrationRegressShift(BaseModel):
calibration_method: Literal["RegressShift"]

Expand Down Expand Up @@ -111,25 +81,11 @@ class PipelineArtifact(str, Enum):
DEBUG = "DEBUG"


class FilterNHighestIntensityPartitioned(BaseModel):
method: Literal["FilterNHighestIntensityPartitioned"]
max_count: int
n_partitions: int


# PeakFiltering = Annotated[FilterNHighestIntensityPartitioned, Field(discriminator="method")] |None
PeakFiltering = FilterNHighestIntensityPartitioned | None


class PipelineParametersPreset(Model, use_enum_values=True, validate_default=True):
baseline_correction: BaselineCorrectionConfig
filter_peaks: FilterPeaksConfig
calibration: Calibration
peak_picker: PeakPicker
peak_filtering: PeakFiltering
force_peak_picker: bool


# class PipelineParameters(PipelineParametersPreset, use_enum_values=True):
pick_peaks: PickPeaksConfig


class PipelineParameters(PipelineParametersPreset, use_enum_values=True, validate_default=True):
Expand Down
Loading

0 comments on commit b207a66

Please sign in to comment.