diff --git a/src/sim_recon/images/dataclasses.py b/src/sim_recon/images/dataclasses.py index bf7097d..883cd98 100644 --- a/src/sim_recon/images/dataclasses.py +++ b/src/sim_recon/images/dataclasses.py @@ -2,6 +2,7 @@ import logging from dataclasses import dataclass, field +from enum import StrEnum, auto from pathlib import Path from typing import TYPE_CHECKING, TypeVar, Generic @@ -15,13 +16,20 @@ logger = logging.getLogger(__name__) +class ProcessingStatus(StrEnum): + UNSET = auto() + PENDING = auto() + FAILED = auto() + COMPLETE = auto() + + @dataclass(frozen=True, slots=True) class BoundMrc: array: NDArray[Any] mrc: Mrc -@dataclass(frozen=True) +@dataclass(slots=True) class ProcessingInfo: image_path: Path otf_path: Path @@ -30,6 +38,7 @@ class ProcessingInfo: log_path: Path wavelengths: Wavelengths kwargs: dict[str, Any] + status: ProcessingStatus = ProcessingStatus.UNSET @dataclass(slots=True) @@ -65,4 +74,4 @@ def __post_init__(self): object.__setattr__(self, "emission_nm_int", emission_nm_int) def __str__(self): - return f"excitation: {self.excitation_nm}nm; emission: {self.emission_nm}nm" + return f"{self.emission_nm_int} (excitation: {self.excitation_nm}nm; emission: {self.emission_nm}nm)" diff --git a/src/sim_recon/images/tiff.py b/src/sim_recon/images/tiff.py index 360a689..c28170f 100644 --- a/src/sim_recon/images/tiff.py +++ b/src/sim_recon/images/tiff.py @@ -7,10 +7,15 @@ from typing import TYPE_CHECKING, cast from ..info import __version__ -from ..exceptions import PySimReconFileExistsError, PySimReconValueError +from ..exceptions import ( + PySimReconFileExistsError, + PySimReconValueError, + UndefinedValueError, + PySimReconIOError, +) if TYPE_CHECKING: - from typing import Any + from typing import Any, Generator from os import PathLike from numpy.typing import NDArray from .dataclasses import ImageChannel @@ -33,6 +38,17 @@ def read_tiff(filepath: str | PathLike[str]) -> NDArray[Any]: return tiff.asarray() +def generate_memmaps_from_tiffs( + *file_paths: str | PathLike[str], +) -> Generator[NDArray[Any], None, None]: + for fp in file_paths: + try: + yield tf.memmap(fp).squeeze() + except Exception as e: + logger.error("Unable to read image from %s: %s", fp, e) + raise + + def get_combined_array_from_tiffs( *file_paths: str | PathLike[str], ) -> NDArray[Any]: @@ -42,7 +58,10 @@ def get_combined_array_from_tiffs( ) if not file_paths: raise PySimReconValueError("Cannot create a combined array without files") - return np.stack(tuple(tf.memmap(fp).squeeze() for fp in file_paths), -3) + try: + return np.stack(tuple(generate_memmaps_from_tiffs(*file_paths)), -3) + except Exception: + raise PySimReconIOError("Failed to combine TIFF files") def write_tiff( @@ -51,6 +70,7 @@ def write_tiff( xy_pixel_size_microns: tuple[float | None, float | None] | None = None, ome: bool = True, overwrite: bool = False, + allow_empty_channels: bool = False, ) -> None: def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None: channel_dict: dict[str, Any] = {} @@ -111,8 +131,15 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None: ) as tiff: for channel in channels: if channel.array is None: - logger.warning("Channel %s has no array to write", channel.wavelengths) - continue + if allow_empty_channels: + logger.warning( + "Channel %s has no array to write", + channel.wavelengths, + ) + continue + raise UndefinedValueError( + f"{output_path} will not be created as channel {channel.wavelengths} has no array to write", + ) channel_kwargs = tiff_kwargs.copy() channel_kwargs["metadata"]["axes"] = ( "YX" if channel.array.ndim == 2 else "ZYX" diff --git a/src/sim_recon/recon.py b/src/sim_recon/recon.py index 07f4530..1665137 100644 --- a/src/sim_recon/recon.py +++ b/src/sim_recon/recon.py @@ -4,6 +4,7 @@ import multiprocessing import os import traceback +from functools import partial from os.path import abspath from shutil import copyfile from pathlib import Path @@ -23,7 +24,12 @@ write_tiff, get_combined_array_from_tiffs, ) -from .images.dataclasses import ImageChannel, Wavelengths, ProcessingInfo +from .images.dataclasses import ( + ImageChannel, + Wavelengths, + ProcessingInfo, + ProcessingStatus, +) from .settings import ConfigManager from .settings.formatting import ( formatters_to_default_value_kwargs, @@ -113,19 +119,25 @@ def reconstruct( return _recon_get_result(reconstructor, output_shape=(z, y, x)) # return reconstructor.get_result() - except Exception: - # Unlikely to ever hit this as errors from the C++ just kill the process + + # Rare to ever hit this as errors from the C++ often just kill the process + except OSError as e: + logger.error("Reconstruction failed: %s", e) + raise ReconstructionError( + f"Error during reconstruction with config {config_path}: {e}" + ) + except Exception as e: logger.error( - "Exception raised during reconstruction with config %s", + "Unexpected error during reconstruction", config_path, exc_info=True, ) raise ReconstructionError( - f"Exception raised during reconstruction with config {config_path}" + f"Unexpected error during reconstruction with config {config_path}: {e}" ) -def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path: +def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> ProcessingInfo: logger.info( "Starting reconstruction of %s with %s to be saved as %s", processing_info.image_path, @@ -140,14 +152,16 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path: data = read_tiff(processing_info.image_path) # Cannot use a memmap here! + rec_array = None with redirect_output_to(processing_info.log_path): + sep = "-" * 80 print( "\n".join( ( - f"Channel {processing_info.wavelengths.emission_nm_int} ({processing_info.wavelengths})", + f"Channel {processing_info.wavelengths}", "Config used:", processing_info.config_path.read_text().strip(), - "-" * 80, + sep, "The text below is the output from cudasirecon", ) ) @@ -155,12 +169,21 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path: rec_array = reconstruct( data, processing_info.config_path, zoomfact, zzoom, ndirs, nphases ) + # rec_array = subprocess_recon( + # processing_info.image_path, + # processing_info.otf_path, + # processing_info.config_path, + # ) + + if rec_array is None: + raise ReconstructionError( + f"No image was returned from reconstruction with {processing_info.config_path}" + ) + elif np.isnan(rec_array).all(): + raise ReconstructionError( + f"Empty (NaN) image was returned from reconstruction with {processing_info.config_path}" + ) - # rec_array = subprocess_recon( - # processing_info.image_path, - # processing_info.otf_path, - # processing_info.config_path, - # ) logger.info("Reconstructed %s", processing_info.image_path) recon_pixel_size = float(processing_info.kwargs["xyres"]) / zoomfact write_tiff( @@ -174,7 +197,8 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path: processing_info.image_path, processing_info.output_path, ) - return Path(processing_info.output_path) + processing_info.status = ProcessingStatus.COMPLETE + return processing_info def _processing_files_to_output( @@ -215,7 +239,11 @@ def _processing_files_to_output( header="Reconstruction log", ) - output_files = tuple(pi.output_path for pi in processing_info_dict.values()) + output_files = tuple( + pi.output_path + for pi in processing_info_dict.values() + if pi.status == ProcessingStatus.COMPLETE + ) if not output_files: logger.warning( "No reconstructions were created from %s", @@ -239,6 +267,8 @@ def _processing_files_to_output( wavelength, processing_info, ) in processing_info_dict.items(): + if processing_info.status != ProcessingStatus.COMPLETE: + continue dv_path = create_output_path( sim_data_path, output_type="recon", @@ -264,6 +294,52 @@ def _processing_files_to_output( ) +def _reconstruction_process_callback( + processing_info: ProcessingInfo, + wavelength: int, + processing_info_dict: dict[int, ProcessingInfo], +) -> None: + logger.debug("Channel %i process complete", wavelength) + processing_info_dict[wavelength].status = processing_info.status + + +def _reconstruction_process_error_callback( + exception: BaseException, + sim_data_path: str | PathLike[str], + wavelength: int, + processing_info: ProcessingInfo, +) -> None: + processing_info.status = ProcessingStatus.FAILED + if isinstance(exception, PySimReconException): + exception_str = str(exception) + else: + exception_str = "".join(traceback.format_exception(exception)) + logger.error( + # exc_info doesn't work with the callback + "Error occurred during reconstruction of %s channel %i: %s", + sim_data_path, + wavelength, + exception_str, + ) + + +def _get_incomplete_channels( + processing_info_dict: dict[int, ProcessingInfo] +) -> list[int]: + incomplete_wavelengths: list[int] = [] + for wavelength, processing_info in processing_info_dict.items(): + if processing_info.status == ProcessingStatus.COMPLETE: + logger.debug("%i is complete", wavelength) + else: + incomplete_wavelengths.append(wavelength) + logger.warning( + "Channel %i reconstruction ended with status '%s'", + wavelength, + processing_info.status, + ) + return incomplete_wavelengths + + def run_reconstructions( conf: ConfigManager, *sim_data_paths: str | PathLike[str], @@ -326,16 +402,21 @@ def run_reconstructions( async_results: list[AsyncResult] = [] for wavelength, processing_info in processing_info_dict.items(): + processing_info.status = ProcessingStatus.PENDING async_results.append( pool.apply_async( reconstruct_from_processing_info, args=(processing_info,), - error_callback=lambda e: logger.error( - # exc_info doesn't work with this - "Error occurred during reconstruction of %s channel %i: %s", - sim_data_path, - wavelength, - "".join(traceback.format_exception(e)), + callback=partial( + _reconstruction_process_callback, + wavelength=wavelength, + processing_info_dict=processing_info_dict, + ), + error_callback=partial( + _reconstruction_process_error_callback, + sim_data_path=sim_data_path, + wavelength=wavelength, + processing_info=processing_info, ), ) ) @@ -348,6 +429,12 @@ def run_reconstructions( ): r.wait() + incomplete_channels = _get_incomplete_channels(processing_info_dict) + if incomplete_channels and not allow_missing_channels: + raise ReconstructionError( + f"Failed to reconstruct channels: {', '.join(str(i) for i in incomplete_channels)}" + ) + _processing_files_to_output( sim_data_path, file_output_directory=file_output_directory, @@ -371,9 +458,13 @@ def run_reconstructions( exc_info=True, ) except ConfigException as e: - logger.error("Unable to process %s: %s)", sim_data_path, e) + logger.error("Unable to process %s: %s", sim_data_path, e) + except PySimReconException as e: + logger.error("Reconstruction failed for %s: %s", sim_data_path, e) except Exception: - logger.error("Error occurred for %s", sim_data_path, exc_info=True) + logger.error( + "Unexpected error occurred for %s", sim_data_path, exc_info=True + ) def _prepare_config_kwargs( @@ -535,8 +626,7 @@ def _prepare_files( processing_info_dict[channel.wavelengths.emission_nm_int] = processing_info except PySimReconException as e: logger.error( - "Failed to prepare files for channel %i (%s) of %s: %s", - channel.wavelengths.emission_nm_int, + "Failed to prepare files for channel %s of %s: %s", channel.wavelengths, file_path, e, @@ -545,8 +635,7 @@ def _prepare_files( raise except Exception: logger.error( - "Unexpected error preparing files for channel %i (%s) of %s", - channel.wavelengths.emission_nm_int, + "Unexpected error preparing files for channel %s of %s", channel.wavelengths, file_path, exc_info=True,