Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use single reconstruction function from CLI as appropriate #90

Merged
merged 8 commits into from
Nov 15, 2024
Merged
8 changes: 7 additions & 1 deletion src/sim_recon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@


if __name__ == "__main__":
from .main import sim_reconstruct, sim_psf_to_otf, sim_reconstruct_single
from .main import (
sim_reconstruct,
sim_reconstruct_multiple,
sim_reconstruct_single,
sim_psf_to_otf,
)

__all__ = [
"__version__",
"sim_reconstruct",
"sim_reconstruct_multiple",
"sim_reconstruct_single",
"sim_psf_to_otf",
]
Expand Down
93 changes: 87 additions & 6 deletions src/sim_recon/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from .settings import ConfigManager
from .otfs import convert_psfs_to_otfs
from .recon import run_reconstructions, run_single_reconstruction
from .recon import reconstruct_multiple, reconstruct_single


if TYPE_CHECKING:
Expand Down Expand Up @@ -104,11 +104,18 @@ def sim_reconstruct(
stitch_channels: bool = True,
allow_missing_channels: bool = False,
output_file_type: OutputFileTypes = "dv",
multiprocessing_pool: Pool | None = None,
parallel_process: bool = False,
**recon_kwargs: Any,
) -> None:
"""
Top level function for reconstructing SIM data
Top level function for reconstructing SIM data.

The handling of `processing_directory` depends on the number of `sim_data_paths`:
- For `len(sim_data_paths) == 1`, `sim_reconstruct_single` is used.
- For `len(sim_data_paths) > 1`, `sim_reconstruct_multiple` is used.

For consistent behaviour, one of those functions can be used instead.

Parameters
----------
Expand All @@ -128,16 +135,88 @@ def sim_reconstruct(
Clean up temporary directory and files after reconstruction, by default True
stitch_channels : bool, optional
Stitch channels back together after processing (otherwise output will be a separate DV per channel), by default True
allow_missing_channels: bool, optional
Attempt reconstruction of other channels in a multi-channel file if one or more are not configured, by default False
output_file_type: Literal["dv", "tiff"], optional
File type that output images will be saved as, by default "dv"
parallel_process : bool, optional
Run reconstructions in 2 processes concurrently, by default False
"""
kwargs: dict[str, Any] = {
"config_path": config_path,
"output_directory": output_directory,
"processing_directory": processing_directory,
"otf_overrides": otf_overrides,
"overwrite": overwrite,
"cleanup": cleanup,
"stitch_channels": stitch_channels,
"allow_missing_channels": allow_missing_channels,
"output_file_type": output_file_type,
"multiprocessing_pool": multiprocessing_pool,
"parallel_process": parallel_process,
}
if len(sim_data_paths) == 1:
sim_reconstruct_single(
sim_data_paths[0],
**kwargs,
**recon_kwargs,
)
else:
sim_reconstruct_multiple(
*sim_data_paths,
**kwargs,
**recon_kwargs,
)


def sim_reconstruct_multiple(
*sim_data_paths: str | PathLike[str],
config_path: str | PathLike[str] | None = None,
output_directory: str | PathLike[str] | None = None,
processing_directory: str | PathLike[str] | None = None,
otf_overrides: dict[int, Path] | None = None,
overwrite: bool = False,
cleanup: bool = True,
stitch_channels: bool = True,
allow_missing_channels: bool = False,
output_file_type: OutputFileTypes = "dv",
multiprocessing_pool: Pool | None = None,
parallel_process: bool = False,
**recon_kwargs: Any,
) -> None:
"""
Top level function for reconstructing multiple SIM data files.

Parameters
----------
*sim_data_paths : str | PathLike[str]
Paths to SIM data files (DV expected)
config_path : str | PathLike[str] | None, optional
Path of the top level config file, by default None
output_directory : str | PathLike[str] | None, optional
Directory to save reconstructions in (reconstructions will be saved with the data files if not specified), by default None
processing_directory : str | PathLike[str] | None, optional
The directory in which a subdirectory containing temporary files will be stored for each of `sim_data_paths` for processing (otherwise the output directory will be used), by default None
otf_overrides : dict[int, Path] | None, optional
A dictionary with emission wavelengths in nm as keys and paths to OTF files as values (these override configured OTFs), by default None
overwrite : bool, optional
Overwrite files if they already exist, by default False
cleanup : bool, optional
Clean up temporary directory and files after reconstruction, by default True
allow_missing_channels: bool, optional
Attempt reconstruction of other channels in a multi-channel file if one or more are not configured, by default False
output_file_type: Literal["dv", "tiff"], optional
File type that output images will be saved as, by default "dv"
stitch_channels : bool, optional
Stitch channels back together after processing (otherwise output will be a separate DV per channel), by default True
multiprocessing_pool : Pool | None, optional
Multiprocessing pool to run cudasirecon in (`maxtasksperchild=1` is recommended to avoid crashes), by default None
parallel_process : bool, optional
Run reconstructions in 2 processes concurrently, by default False
"""
conf = load_configs(config_path, otf_overrides=otf_overrides)
logger.info("Starting reconstructions...")
run_reconstructions(
reconstruct_multiple(
conf,
*sim_data_paths,
output_directory=output_directory,
Expand All @@ -147,13 +226,15 @@ def sim_reconstruct(
stitch_channels=stitch_channels,
allow_partial=allow_missing_channels,
output_file_type=output_file_type,
multiprocessing_pool=multiprocessing_pool,
parallel_process=parallel_process,
**recon_kwargs,
)


def sim_reconstruct_single(
sim_data_path: str | PathLike[str],
*,
config_path: str | PathLike[str] | None = None,
output_directory: str | PathLike[str] | None = None,
processing_directory: str | PathLike[str] | None = None,
Expand All @@ -168,7 +249,7 @@ def sim_reconstruct_single(
**recon_kwargs: Any,
) -> None:
"""
Top level function for reconstructing SIM data
Top level function for reconstructing a single SIM data file.

Parameters
----------
Expand All @@ -179,7 +260,7 @@ def sim_reconstruct_single(
output_directory : str | PathLike[str] | None, optional
Directory to save reconstructions in (reconstructions will be saved with the data files if not specified), by default None
processing_directory : str | PathLike[str] | None, optional
The directory in which the temporary files will be stored for processing (otherwise the output directory will be used), by default None
The directory in which the temporary files will be stored for processing (otherwise a subdirectory of output directory will be used), by default None
otf_overrides : dict[int, Path] | None, optional
A dictionary with emission wavelengths in nm as keys and paths to OTF files as values (these override configured OTFs), by default None
overwrite : bool, optional
Expand All @@ -199,7 +280,7 @@ def sim_reconstruct_single(
"""
conf = load_configs(config_path, otf_overrides=otf_overrides)
logger.info("Starting reconstruction of %s", sim_data_path)
run_single_reconstruction(
reconstruct_single(
conf,
sim_data_path,
output_directory=output_directory,
Expand Down
112 changes: 57 additions & 55 deletions src/sim_recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import subprocess
import multiprocessing
import os
import traceback
from functools import partial
from os.path import abspath
Expand Down Expand Up @@ -79,7 +78,7 @@ def _recon_get_result(
return _result


def subprocess_recon(
def subprocess_cudasirecon_reconstruct(
sim_path: Path, otf_path: Path, config_path: Path
) -> NDArray[np.float32]:
"""Useful to bypass the pycudasirecon library, if necessary"""
Expand All @@ -104,7 +103,7 @@ def subprocess_recon(
raise ReconstructionError(f"No reconstruction file found at {expected_path}")


def reconstruct(
def cudasirecon_reconstruct(
array: NDArray[Any],
config_path: str | PathLike[str],
zoomfact: float,
Expand Down Expand Up @@ -174,14 +173,9 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Process
)
)
)
rec_array = reconstruct(
rec_array = cudasirecon_reconstruct(
data, processing_info.config_path, zoomfact, zzoom, ndirs, nphases
)
# rec_array = subprocess_recon(
# processing_info.image_path,
# processing_info.otf_path,
# processing_info.config_path,
# )

if rec_array is None:
raise ReconstructionError(
Expand Down Expand Up @@ -364,7 +358,17 @@ def _get_incomplete_channels(
return incomplete_wavelengths


def run_single_reconstruction(
def create_process_pool(parallel_process: bool = False) -> Pool:
# `maxtasksperchild=1` is necessary to ensure the child process is cleaned
# up between tasks, as the cudasirecon process doesn't fully release memory
# afterwards
return multiprocessing.Pool(
processes=2 if parallel_process else 1, # 2 processes max
maxtasksperchild=1,
)


def reconstruct_single(
conf: ConfigManager,
sim_data_path: str | PathLike[str],
*,
Expand All @@ -381,13 +385,7 @@ def run_single_reconstruction(
) -> None:

if multiprocessing_pool is None:
# `maxtasksperchild=1` is necessary to ensure the child process is cleaned
# up between tasks, as the cudasirecon process doesn't fully release memory
# afterwards
pool = multiprocessing.Pool(
processes=2 if parallel_process else 1, # 2 processes max
maxtasksperchild=1,
)
pool = create_process_pool(parallel_process)
else:
pool = multiprocessing_pool

Expand Down Expand Up @@ -542,7 +540,7 @@ def run_single_reconstruction(
pool.close()


def run_reconstructions(
def reconstruct_multiple(
conf: ConfigManager,
*sim_data_paths: str | PathLike[str],
output_directory: str | PathLike[str] | None,
Expand All @@ -552,50 +550,54 @@ def run_reconstructions(
stitch_channels: bool = True,
allow_missing_channels: bool = False,
output_file_type: OutputFileTypes = "dv",
multiprocessing_pool: Pool | None = None,
parallel_process: bool = False,
**config_kwargs: Any,
) -> None:

logging_redirect = get_logging_redirect()
progress_wrapper = get_progress_wrapper()
if multiprocessing_pool is None:
pool = create_process_pool(parallel_process)
else:
pool = multiprocessing_pool

# `maxtasksperchild=1` is necessary to ensure the child process is cleaned
# up between tasks, as the cudasirecon process doesn't fully release memory
# afterwards
with (
multiprocessing.Pool(
processes=2 if parallel_process else 1, # 2 processes max
maxtasksperchild=1,
) as pool,
logging_redirect(),
delete_directory_if_empty(processing_directory),
):
for sim_data_path in progress_wrapper(
sim_data_paths, desc="SIM data files", unit="file"
):
sim_data_path = Path(sim_data_path)
try:
logging_redirect = get_logging_redirect()
progress_wrapper = get_progress_wrapper()

if processing_directory is None:
proc_dir = None
else:
# For multiple reconstructions sharing the same processing directory
# Use subdirectories from the data path stem
proc_dir = Path(processing_directory) / sim_data_path.stem
with (
logging_redirect(),
delete_directory_if_empty(processing_directory),
):
for sim_data_path in progress_wrapper(
sim_data_paths, desc="SIM data files", unit="file"
):
sim_data_path = Path(sim_data_path)

run_single_reconstruction(
conf,
sim_data_path,
output_directory=output_directory,
processing_directory=proc_dir,
overwrite=overwrite,
cleanup=cleanup,
stitch_channels=stitch_channels,
allow_missing_channels=allow_missing_channels,
output_file_type=output_file_type,
multiprocessing_pool=pool,
parallel_process=parallel_process,
**config_kwargs,
)
if processing_directory is None:
proc_dir = None
else:
# For multiple reconstructions sharing the same processing directory
# Use subdirectories from the data path stem
proc_dir = Path(processing_directory) / sim_data_path.stem

reconstruct_single(
conf,
sim_data_path,
output_directory=output_directory,
processing_directory=proc_dir,
overwrite=overwrite,
cleanup=cleanup,
stitch_channels=stitch_channels,
allow_missing_channels=allow_missing_channels,
output_file_type=output_file_type,
multiprocessing_pool=pool,
parallel_process=parallel_process,
**config_kwargs,
)
finally:
if multiprocessing_pool is None:
# Only close pools that were created by this function
pool.close()


def _prepare_config_kwargs(
Expand Down