Skip to content

Commit

Permalink
Improve user messages (#77)
Browse files Browse the repository at this point in the history
Various changes that allow better messages:
- Status tracking for reconstruction processes
- Better exception handling (e.g. only include the traceback when
logging unexpected exceptions)
- Improve behaviour and logging around TIFF combining and writing
- Raise an exception if `rec_array` isn't set (shouldn't be possible but
sometimes happens) or the array is all `NaN`s.
  • Loading branch information
thomasmfish authored Oct 17, 2024
2 parents 104379a + eddd9dd commit 869e3a0
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 34 deletions.
13 changes: 11 additions & 2 deletions src/sim_recon/images/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from dataclasses import dataclass, field
from enum import StrEnum, auto
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, Generic

Expand All @@ -15,13 +16,20 @@
logger = logging.getLogger(__name__)


class ProcessingStatus(StrEnum):
UNSET = auto()
PENDING = auto()
FAILED = auto()
COMPLETE = auto()


@dataclass(frozen=True, slots=True)
class BoundMrc:
array: NDArray[Any]
mrc: Mrc


@dataclass(frozen=True)
@dataclass(slots=True)
class ProcessingInfo:
image_path: Path
otf_path: Path
Expand All @@ -30,6 +38,7 @@ class ProcessingInfo:
log_path: Path
wavelengths: Wavelengths
kwargs: dict[str, Any]
status: ProcessingStatus = ProcessingStatus.UNSET


@dataclass(slots=True)
Expand Down Expand Up @@ -65,4 +74,4 @@ def __post_init__(self):
object.__setattr__(self, "emission_nm_int", emission_nm_int)

def __str__(self):
return f"excitation: {self.excitation_nm}nm; emission: {self.emission_nm}nm"
return f"{self.emission_nm_int} (excitation: {self.excitation_nm}nm; emission: {self.emission_nm}nm)"
37 changes: 32 additions & 5 deletions src/sim_recon/images/tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
from typing import TYPE_CHECKING, cast

from ..info import __version__
from ..exceptions import PySimReconFileExistsError, PySimReconValueError
from ..exceptions import (
PySimReconFileExistsError,
PySimReconValueError,
UndefinedValueError,
PySimReconIOError,
)

if TYPE_CHECKING:
from typing import Any
from typing import Any, Generator
from os import PathLike
from numpy.typing import NDArray
from .dataclasses import ImageChannel
Expand All @@ -33,6 +38,17 @@ def read_tiff(filepath: str | PathLike[str]) -> NDArray[Any]:
return tiff.asarray()


def generate_memmaps_from_tiffs(
*file_paths: str | PathLike[str],
) -> Generator[NDArray[Any], None, None]:
for fp in file_paths:
try:
yield tf.memmap(fp).squeeze()
except Exception as e:
logger.error("Unable to read image from %s: %s", fp, e)
raise


def get_combined_array_from_tiffs(
*file_paths: str | PathLike[str],
) -> NDArray[Any]:
Expand All @@ -42,7 +58,10 @@ def get_combined_array_from_tiffs(
)
if not file_paths:
raise PySimReconValueError("Cannot create a combined array without files")
return np.stack(tuple(tf.memmap(fp).squeeze() for fp in file_paths), -3)
try:
return np.stack(tuple(generate_memmaps_from_tiffs(*file_paths)), -3)
except Exception:
raise PySimReconIOError("Failed to combine TIFF files")


def write_tiff(
Expand All @@ -51,6 +70,7 @@ def write_tiff(
xy_pixel_size_microns: tuple[float | None, float | None] | None = None,
ome: bool = True,
overwrite: bool = False,
allow_empty_channels: bool = False,
) -> None:
def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
channel_dict: dict[str, Any] = {}
Expand Down Expand Up @@ -111,8 +131,15 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
) as tiff:
for channel in channels:
if channel.array is None:
logger.warning("Channel %s has no array to write", channel.wavelengths)
continue
if allow_empty_channels:
logger.warning(
"Channel %s has no array to write",
channel.wavelengths,
)
continue
raise UndefinedValueError(
f"{output_path} will not be created as channel {channel.wavelengths} has no array to write",
)
channel_kwargs = tiff_kwargs.copy()
channel_kwargs["metadata"]["axes"] = (
"YX" if channel.array.ndim == 2 else "ZYX"
Expand Down
143 changes: 116 additions & 27 deletions src/sim_recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os
import traceback
from functools import partial
from os.path import abspath
from shutil import copyfile
from pathlib import Path
Expand All @@ -23,7 +24,12 @@
write_tiff,
get_combined_array_from_tiffs,
)
from .images.dataclasses import ImageChannel, Wavelengths, ProcessingInfo
from .images.dataclasses import (
ImageChannel,
Wavelengths,
ProcessingInfo,
ProcessingStatus,
)
from .settings import ConfigManager
from .settings.formatting import (
formatters_to_default_value_kwargs,
Expand Down Expand Up @@ -113,19 +119,25 @@ def reconstruct(

return _recon_get_result(reconstructor, output_shape=(z, y, x))
# return reconstructor.get_result()
except Exception:
# Unlikely to ever hit this as errors from the C++ just kill the process

# Rare to ever hit this as errors from the C++ often just kill the process
except OSError as e:
logger.error("Reconstruction failed: %s", e)
raise ReconstructionError(
f"Error during reconstruction with config {config_path}: {e}"
)
except Exception as e:
logger.error(
"Exception raised during reconstruction with config %s",
"Unexpected error during reconstruction",
config_path,
exc_info=True,
)
raise ReconstructionError(
f"Exception raised during reconstruction with config {config_path}"
f"Unexpected error during reconstruction with config {config_path}: {e}"
)


def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path:
def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> ProcessingInfo:
logger.info(
"Starting reconstruction of %s with %s to be saved as %s",
processing_info.image_path,
Expand All @@ -140,27 +152,38 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path:

data = read_tiff(processing_info.image_path) # Cannot use a memmap here!

rec_array = None
with redirect_output_to(processing_info.log_path):
sep = "-" * 80
print(
"\n".join(
(
f"Channel {processing_info.wavelengths.emission_nm_int} ({processing_info.wavelengths})",
f"Channel {processing_info.wavelengths}",
"Config used:",
processing_info.config_path.read_text().strip(),
"-" * 80,
sep,
"The text below is the output from cudasirecon",
)
)
)
rec_array = reconstruct(
data, processing_info.config_path, zoomfact, zzoom, ndirs, nphases
)
# rec_array = subprocess_recon(
# processing_info.image_path,
# processing_info.otf_path,
# processing_info.config_path,
# )

if rec_array is None:
raise ReconstructionError(
f"No image was returned from reconstruction with {processing_info.config_path}"
)
elif np.isnan(rec_array).all():
raise ReconstructionError(
f"Empty (NaN) image was returned from reconstruction with {processing_info.config_path}"
)

# rec_array = subprocess_recon(
# processing_info.image_path,
# processing_info.otf_path,
# processing_info.config_path,
# )
logger.info("Reconstructed %s", processing_info.image_path)
recon_pixel_size = float(processing_info.kwargs["xyres"]) / zoomfact
write_tiff(
Expand All @@ -174,7 +197,8 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path:
processing_info.image_path,
processing_info.output_path,
)
return Path(processing_info.output_path)
processing_info.status = ProcessingStatus.COMPLETE
return processing_info


def _processing_files_to_output(
Expand Down Expand Up @@ -215,7 +239,11 @@ def _processing_files_to_output(
header="Reconstruction log",
)

output_files = tuple(pi.output_path for pi in processing_info_dict.values())
output_files = tuple(
pi.output_path
for pi in processing_info_dict.values()
if pi.status == ProcessingStatus.COMPLETE
)
if not output_files:
logger.warning(
"No reconstructions were created from %s",
Expand All @@ -239,6 +267,8 @@ def _processing_files_to_output(
wavelength,
processing_info,
) in processing_info_dict.items():
if processing_info.status != ProcessingStatus.COMPLETE:
continue
dv_path = create_output_path(
sim_data_path,
output_type="recon",
Expand All @@ -264,6 +294,52 @@ def _processing_files_to_output(
)


def _reconstruction_process_callback(
processing_info: ProcessingInfo,
wavelength: int,
processing_info_dict: dict[int, ProcessingInfo],
) -> None:
logger.debug("Channel %i process complete", wavelength)
processing_info_dict[wavelength].status = processing_info.status


def _reconstruction_process_error_callback(
exception: BaseException,
sim_data_path: str | PathLike[str],
wavelength: int,
processing_info: ProcessingInfo,
) -> None:
processing_info.status = ProcessingStatus.FAILED
if isinstance(exception, PySimReconException):
exception_str = str(exception)
else:
exception_str = "".join(traceback.format_exception(exception))
logger.error(
# exc_info doesn't work with the callback
"Error occurred during reconstruction of %s channel %i: %s",
sim_data_path,
wavelength,
exception_str,
)


def _get_incomplete_channels(
processing_info_dict: dict[int, ProcessingInfo]
) -> list[int]:
incomplete_wavelengths: list[int] = []
for wavelength, processing_info in processing_info_dict.items():
if processing_info.status == ProcessingStatus.COMPLETE:
logger.debug("%i is complete", wavelength)
else:
incomplete_wavelengths.append(wavelength)
logger.warning(
"Channel %i reconstruction ended with status '%s'",
wavelength,
processing_info.status,
)
return incomplete_wavelengths


def run_reconstructions(
conf: ConfigManager,
*sim_data_paths: str | PathLike[str],
Expand Down Expand Up @@ -326,16 +402,21 @@ def run_reconstructions(

async_results: list[AsyncResult] = []
for wavelength, processing_info in processing_info_dict.items():
processing_info.status = ProcessingStatus.PENDING
async_results.append(
pool.apply_async(
reconstruct_from_processing_info,
args=(processing_info,),
error_callback=lambda e: logger.error(
# exc_info doesn't work with this
"Error occurred during reconstruction of %s channel %i: %s",
sim_data_path,
wavelength,
"".join(traceback.format_exception(e)),
callback=partial(
_reconstruction_process_callback,
wavelength=wavelength,
processing_info_dict=processing_info_dict,
),
error_callback=partial(
_reconstruction_process_error_callback,
sim_data_path=sim_data_path,
wavelength=wavelength,
processing_info=processing_info,
),
)
)
Expand All @@ -348,6 +429,12 @@ def run_reconstructions(
):
r.wait()

incomplete_channels = _get_incomplete_channels(processing_info_dict)
if incomplete_channels and not allow_missing_channels:
raise ReconstructionError(
f"Failed to reconstruct channels: {', '.join(str(i) for i in incomplete_channels)}"
)

_processing_files_to_output(
sim_data_path,
file_output_directory=file_output_directory,
Expand All @@ -371,9 +458,13 @@ def run_reconstructions(
exc_info=True,
)
except ConfigException as e:
logger.error("Unable to process %s: %s)", sim_data_path, e)
logger.error("Unable to process %s: %s", sim_data_path, e)
except PySimReconException as e:
logger.error("Reconstruction failed for %s: %s", sim_data_path, e)
except Exception:
logger.error("Error occurred for %s", sim_data_path, exc_info=True)
logger.error(
"Unexpected error occurred for %s", sim_data_path, exc_info=True
)


def _prepare_config_kwargs(
Expand Down Expand Up @@ -535,8 +626,7 @@ def _prepare_files(
processing_info_dict[channel.wavelengths.emission_nm_int] = processing_info
except PySimReconException as e:
logger.error(
"Failed to prepare files for channel %i (%s) of %s: %s",
channel.wavelengths.emission_nm_int,
"Failed to prepare files for channel %s of %s: %s",
channel.wavelengths,
file_path,
e,
Expand All @@ -545,8 +635,7 @@ def _prepare_files(
raise
except Exception:
logger.error(
"Unexpected error preparing files for channel %i (%s) of %s",
channel.wavelengths.emission_nm_int,
"Unexpected error preparing files for channel %s of %s",
channel.wavelengths,
file_path,
exc_info=True,
Expand Down

0 comments on commit 869e3a0

Please sign in to comment.