diff --git a/src/sim_recon/__init__.py b/src/sim_recon/__init__.py index aaf509f..b0e4d51 100644 --- a/src/sim_recon/__init__.py +++ b/src/sim_recon/__init__.py @@ -2,11 +2,17 @@ if __name__ == "__main__": - from .main import sim_reconstruct, sim_psf_to_otf, sim_reconstruct_single + from .main import ( + sim_reconstruct, + sim_reconstruct_multiple, + sim_reconstruct_single, + sim_psf_to_otf, + ) __all__ = [ "__version__", "sim_reconstruct", + "sim_reconstruct_multiple", "sim_reconstruct_single", "sim_psf_to_otf", ] diff --git a/src/sim_recon/main.py b/src/sim_recon/main.py index 664c68c..756b90b 100644 --- a/src/sim_recon/main.py +++ b/src/sim_recon/main.py @@ -10,7 +10,7 @@ ) from .settings import ConfigManager from .otfs import convert_psfs_to_otfs -from .recon import run_reconstructions, run_single_reconstruction +from .recon import reconstruct_multiple, reconstruct_single if TYPE_CHECKING: @@ -104,11 +104,18 @@ def sim_reconstruct( stitch_channels: bool = True, allow_missing_channels: bool = False, output_file_type: OutputFileTypes = "dv", + multiprocessing_pool: Pool | None = None, parallel_process: bool = False, **recon_kwargs: Any, ) -> None: """ - Top level function for reconstructing SIM data + Top level function for reconstructing SIM data. + + The handling of `processing_directory` depends on the number of `sim_data_paths`: + - For `len(sim_data_paths) == 1`, `sim_reconstruct_single` is used. + - For `len(sim_data_paths) > 1`, `sim_reconstruct_multiple` is used. + + For consistent behaviour, one of those functions can be used instead. Parameters ---------- @@ -128,16 +135,88 @@ def sim_reconstruct( Clean up temporary directory and files after reconstruction, by default True stitch_channels : bool, optional Stitch channels back together after processing (otherwise output will be a separate DV per channel), by default True + allow_missing_channels: bool, optional + Attempt reconstruction of other channels in a multi-channel file if one or more are not configured, by default False + output_file_type: Literal["dv", "tiff"], optional + File type that output images will be saved as, by default "dv" parallel_process : bool, optional Run reconstructions in 2 processes concurrently, by default False + """ + kwargs: dict[str, Any] = { + "config_path": config_path, + "output_directory": output_directory, + "processing_directory": processing_directory, + "otf_overrides": otf_overrides, + "overwrite": overwrite, + "cleanup": cleanup, + "stitch_channels": stitch_channels, + "allow_missing_channels": allow_missing_channels, + "output_file_type": output_file_type, + "multiprocessing_pool": multiprocessing_pool, + "parallel_process": parallel_process, + } + if len(sim_data_paths) == 1: + sim_reconstruct_single( + sim_data_paths[0], + **kwargs, + **recon_kwargs, + ) + else: + sim_reconstruct_multiple( + *sim_data_paths, + **kwargs, + **recon_kwargs, + ) + + +def sim_reconstruct_multiple( + *sim_data_paths: str | PathLike[str], + config_path: str | PathLike[str] | None = None, + output_directory: str | PathLike[str] | None = None, + processing_directory: str | PathLike[str] | None = None, + otf_overrides: dict[int, Path] | None = None, + overwrite: bool = False, + cleanup: bool = True, + stitch_channels: bool = True, + allow_missing_channels: bool = False, + output_file_type: OutputFileTypes = "dv", + multiprocessing_pool: Pool | None = None, + parallel_process: bool = False, + **recon_kwargs: Any, +) -> None: + """ + Top level function for reconstructing multiple SIM data files. + + Parameters + ---------- + *sim_data_paths : str | PathLike[str] + Paths to SIM data files (DV expected) + config_path : str | PathLike[str] | None, optional + Path of the top level config file, by default None + output_directory : str | PathLike[str] | None, optional + Directory to save reconstructions in (reconstructions will be saved with the data files if not specified), by default None + processing_directory : str | PathLike[str] | None, optional + The directory in which a subdirectory containing temporary files will be stored for each of `sim_data_paths` for processing (otherwise the output directory will be used), by default None + otf_overrides : dict[int, Path] | None, optional + A dictionary with emission wavelengths in nm as keys and paths to OTF files as values (these override configured OTFs), by default None + overwrite : bool, optional + Overwrite files if they already exist, by default False + cleanup : bool, optional + Clean up temporary directory and files after reconstruction, by default True allow_missing_channels: bool, optional Attempt reconstruction of other channels in a multi-channel file if one or more are not configured, by default False output_file_type: Literal["dv", "tiff"], optional File type that output images will be saved as, by default "dv" + stitch_channels : bool, optional + Stitch channels back together after processing (otherwise output will be a separate DV per channel), by default True + multiprocessing_pool : Pool | None, optional + Multiprocessing pool to run cudasirecon in (`maxtasksperchild=1` is recommended to avoid crashes), by default None + parallel_process : bool, optional + Run reconstructions in 2 processes concurrently, by default False """ conf = load_configs(config_path, otf_overrides=otf_overrides) logger.info("Starting reconstructions...") - run_reconstructions( + reconstruct_multiple( conf, *sim_data_paths, output_directory=output_directory, @@ -147,6 +226,7 @@ def sim_reconstruct( stitch_channels=stitch_channels, allow_partial=allow_missing_channels, output_file_type=output_file_type, + multiprocessing_pool=multiprocessing_pool, parallel_process=parallel_process, **recon_kwargs, ) @@ -154,6 +234,7 @@ def sim_reconstruct( def sim_reconstruct_single( sim_data_path: str | PathLike[str], + *, config_path: str | PathLike[str] | None = None, output_directory: str | PathLike[str] | None = None, processing_directory: str | PathLike[str] | None = None, @@ -168,7 +249,7 @@ def sim_reconstruct_single( **recon_kwargs: Any, ) -> None: """ - Top level function for reconstructing SIM data + Top level function for reconstructing a single SIM data file. Parameters ---------- @@ -179,7 +260,7 @@ def sim_reconstruct_single( output_directory : str | PathLike[str] | None, optional Directory to save reconstructions in (reconstructions will be saved with the data files if not specified), by default None processing_directory : str | PathLike[str] | None, optional - The directory in which the temporary files will be stored for processing (otherwise the output directory will be used), by default None + The directory in which the temporary files will be stored for processing (otherwise a subdirectory of output directory will be used), by default None otf_overrides : dict[int, Path] | None, optional A dictionary with emission wavelengths in nm as keys and paths to OTF files as values (these override configured OTFs), by default None overwrite : bool, optional @@ -199,7 +280,7 @@ def sim_reconstruct_single( """ conf = load_configs(config_path, otf_overrides=otf_overrides) logger.info("Starting reconstruction of %s", sim_data_path) - run_single_reconstruction( + reconstruct_single( conf, sim_data_path, output_directory=output_directory, diff --git a/src/sim_recon/recon.py b/src/sim_recon/recon.py index aa7d7f3..9737265 100644 --- a/src/sim_recon/recon.py +++ b/src/sim_recon/recon.py @@ -2,7 +2,6 @@ import logging import subprocess import multiprocessing -import os import traceback from functools import partial from os.path import abspath @@ -79,7 +78,7 @@ def _recon_get_result( return _result -def subprocess_recon( +def subprocess_cudasirecon_reconstruct( sim_path: Path, otf_path: Path, config_path: Path ) -> NDArray[np.float32]: """Useful to bypass the pycudasirecon library, if necessary""" @@ -104,7 +103,7 @@ def subprocess_recon( raise ReconstructionError(f"No reconstruction file found at {expected_path}") -def reconstruct( +def cudasirecon_reconstruct( array: NDArray[Any], config_path: str | PathLike[str], zoomfact: float, @@ -174,14 +173,9 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Process ) ) ) - rec_array = reconstruct( + rec_array = cudasirecon_reconstruct( data, processing_info.config_path, zoomfact, zzoom, ndirs, nphases ) - # rec_array = subprocess_recon( - # processing_info.image_path, - # processing_info.otf_path, - # processing_info.config_path, - # ) if rec_array is None: raise ReconstructionError( @@ -364,7 +358,17 @@ def _get_incomplete_channels( return incomplete_wavelengths -def run_single_reconstruction( +def create_process_pool(parallel_process: bool = False) -> Pool: + # `maxtasksperchild=1` is necessary to ensure the child process is cleaned + # up between tasks, as the cudasirecon process doesn't fully release memory + # afterwards + return multiprocessing.Pool( + processes=2 if parallel_process else 1, # 2 processes max + maxtasksperchild=1, + ) + + +def reconstruct_single( conf: ConfigManager, sim_data_path: str | PathLike[str], *, @@ -381,13 +385,7 @@ def run_single_reconstruction( ) -> None: if multiprocessing_pool is None: - # `maxtasksperchild=1` is necessary to ensure the child process is cleaned - # up between tasks, as the cudasirecon process doesn't fully release memory - # afterwards - pool = multiprocessing.Pool( - processes=2 if parallel_process else 1, # 2 processes max - maxtasksperchild=1, - ) + pool = create_process_pool(parallel_process) else: pool = multiprocessing_pool @@ -542,7 +540,7 @@ def run_single_reconstruction( pool.close() -def run_reconstructions( +def reconstruct_multiple( conf: ConfigManager, *sim_data_paths: str | PathLike[str], output_directory: str | PathLike[str] | None, @@ -552,50 +550,54 @@ def run_reconstructions( stitch_channels: bool = True, allow_missing_channels: bool = False, output_file_type: OutputFileTypes = "dv", + multiprocessing_pool: Pool | None = None, parallel_process: bool = False, **config_kwargs: Any, ) -> None: - logging_redirect = get_logging_redirect() - progress_wrapper = get_progress_wrapper() + if multiprocessing_pool is None: + pool = create_process_pool(parallel_process) + else: + pool = multiprocessing_pool - # `maxtasksperchild=1` is necessary to ensure the child process is cleaned - # up between tasks, as the cudasirecon process doesn't fully release memory - # afterwards - with ( - multiprocessing.Pool( - processes=2 if parallel_process else 1, # 2 processes max - maxtasksperchild=1, - ) as pool, - logging_redirect(), - delete_directory_if_empty(processing_directory), - ): - for sim_data_path in progress_wrapper( - sim_data_paths, desc="SIM data files", unit="file" - ): - sim_data_path = Path(sim_data_path) + try: + logging_redirect = get_logging_redirect() + progress_wrapper = get_progress_wrapper() - if processing_directory is None: - proc_dir = None - else: - # For multiple reconstructions sharing the same processing directory - # Use subdirectories from the data path stem - proc_dir = Path(processing_directory) / sim_data_path.stem + with ( + logging_redirect(), + delete_directory_if_empty(processing_directory), + ): + for sim_data_path in progress_wrapper( + sim_data_paths, desc="SIM data files", unit="file" + ): + sim_data_path = Path(sim_data_path) - run_single_reconstruction( - conf, - sim_data_path, - output_directory=output_directory, - processing_directory=proc_dir, - overwrite=overwrite, - cleanup=cleanup, - stitch_channels=stitch_channels, - allow_missing_channels=allow_missing_channels, - output_file_type=output_file_type, - multiprocessing_pool=pool, - parallel_process=parallel_process, - **config_kwargs, - ) + if processing_directory is None: + proc_dir = None + else: + # For multiple reconstructions sharing the same processing directory + # Use subdirectories from the data path stem + proc_dir = Path(processing_directory) / sim_data_path.stem + + reconstruct_single( + conf, + sim_data_path, + output_directory=output_directory, + processing_directory=proc_dir, + overwrite=overwrite, + cleanup=cleanup, + stitch_channels=stitch_channels, + allow_missing_channels=allow_missing_channels, + output_file_type=output_file_type, + multiprocessing_pool=pool, + parallel_process=parallel_process, + **config_kwargs, + ) + finally: + if multiprocessing_pool is None: + # Only close pools that were created by this function + pool.close() def _prepare_config_kwargs(