Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user messages #77

Merged
merged 8 commits into from
Oct 17, 2024
13 changes: 11 additions & 2 deletions src/sim_recon/images/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from dataclasses import dataclass, field
from enum import StrEnum, auto
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, Generic

Expand All @@ -15,13 +16,20 @@
logger = logging.getLogger(__name__)


class ProcessingStatus(StrEnum):
UNSET = auto()
PENDING = auto()
FAILED = auto()
COMPLETE = auto()


@dataclass(frozen=True, slots=True)
class BoundMrc:
array: NDArray[Any]
mrc: Mrc


@dataclass(frozen=True)
@dataclass(slots=True)
class ProcessingInfo:
image_path: Path
otf_path: Path
Expand All @@ -30,6 +38,7 @@ class ProcessingInfo:
log_path: Path
wavelengths: Wavelengths
kwargs: dict[str, Any]
status: ProcessingStatus = ProcessingStatus.UNSET


@dataclass(slots=True)
Expand Down Expand Up @@ -65,4 +74,4 @@ def __post_init__(self):
object.__setattr__(self, "emission_nm_int", emission_nm_int)

def __str__(self):
return f"excitation: {self.excitation_nm}nm; emission: {self.emission_nm}nm"
return f"{self.emission_nm_int} (excitation: {self.excitation_nm}nm; emission: {self.emission_nm}nm)"
37 changes: 32 additions & 5 deletions src/sim_recon/images/tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
from typing import TYPE_CHECKING, cast

from ..info import __version__
from ..exceptions import PySimReconFileExistsError, PySimReconValueError
from ..exceptions import (
PySimReconFileExistsError,
PySimReconValueError,
UndefinedValueError,
PySimReconIOError,
)

if TYPE_CHECKING:
from typing import Any
from typing import Any, Generator
from os import PathLike
from numpy.typing import NDArray
from .dataclasses import ImageChannel
Expand All @@ -33,6 +38,17 @@ def read_tiff(filepath: str | PathLike[str]) -> NDArray[Any]:
return tiff.asarray()


def generate_memmaps_from_tiffs(
*file_paths: str | PathLike[str],
) -> Generator[NDArray[Any], None, None]:
for fp in file_paths:
try:
yield tf.memmap(fp).squeeze()
except Exception as e:
logger.error("Unable to read image from %s: %s", fp, e)
raise


def get_combined_array_from_tiffs(
*file_paths: str | PathLike[str],
) -> NDArray[Any]:
Expand All @@ -42,7 +58,10 @@ def get_combined_array_from_tiffs(
)
if not file_paths:
raise PySimReconValueError("Cannot create a combined array without files")
return np.stack(tuple(tf.memmap(fp).squeeze() for fp in file_paths), -3)
try:
return np.stack(tuple(generate_memmaps_from_tiffs(*file_paths)), -3)
except Exception:
raise PySimReconIOError("Failed to combine TIFF files")


def write_tiff(
Expand All @@ -51,6 +70,7 @@ def write_tiff(
xy_pixel_size_microns: tuple[float | None, float | None] | None = None,
ome: bool = True,
overwrite: bool = False,
allow_empty_channels: bool = False,
) -> None:
def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
channel_dict: dict[str, Any] = {}
Expand Down Expand Up @@ -111,8 +131,15 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
) as tiff:
for channel in channels:
if channel.array is None:
logger.warning("Channel %s has no array to write", channel.wavelengths)
continue
if allow_empty_channels:
logger.warning(
"Channel %s has no array to write",
channel.wavelengths,
)
continue
raise UndefinedValueError(
f"{output_path} will not be created as channel {channel.wavelengths} has no array to write",
)
channel_kwargs = tiff_kwargs.copy()
channel_kwargs["metadata"]["axes"] = (
"YX" if channel.array.ndim == 2 else "ZYX"
Expand Down
143 changes: 116 additions & 27 deletions src/sim_recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os
import traceback
from functools import partial
from os.path import abspath
from shutil import copyfile
from pathlib import Path
Expand All @@ -23,7 +24,12 @@
write_tiff,
get_combined_array_from_tiffs,
)
from .images.dataclasses import ImageChannel, Wavelengths, ProcessingInfo
from .images.dataclasses import (
ImageChannel,
Wavelengths,
ProcessingInfo,
ProcessingStatus,
)
from .settings import ConfigManager
from .settings.formatting import (
formatters_to_default_value_kwargs,
Expand Down Expand Up @@ -113,19 +119,25 @@ def reconstruct(

return _recon_get_result(reconstructor, output_shape=(z, y, x))
# return reconstructor.get_result()
except Exception:
# Unlikely to ever hit this as errors from the C++ just kill the process

# Rare to ever hit this as errors from the C++ often just kill the process
except OSError as e:
logger.error("Reconstruction failed: %s", e)
raise ReconstructionError(
f"Error during reconstruction with config {config_path}: {e}"
)
except Exception as e:
logger.error(
"Exception raised during reconstruction with config %s",
"Unexpected error during reconstruction",
config_path,
exc_info=True,
)
raise ReconstructionError(
f"Exception raised during reconstruction with config {config_path}"
f"Unexpected error during reconstruction with config {config_path}: {e}"
)


def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path:
def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> ProcessingInfo:
PeterC-DLS marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
"Starting reconstruction of %s with %s to be saved as %s",
processing_info.image_path,
Expand All @@ -140,27 +152,38 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path:

data = read_tiff(processing_info.image_path) # Cannot use a memmap here!

rec_array = None
with redirect_output_to(processing_info.log_path):
sep = "-" * 80
print(
"\n".join(
(
f"Channel {processing_info.wavelengths.emission_nm_int} ({processing_info.wavelengths})",
f"Channel {processing_info.wavelengths}",
"Config used:",
processing_info.config_path.read_text().strip(),
"-" * 80,
sep,
"The text below is the output from cudasirecon",
)
)
)
rec_array = reconstruct(
data, processing_info.config_path, zoomfact, zzoom, ndirs, nphases
)
# rec_array = subprocess_recon(
# processing_info.image_path,
# processing_info.otf_path,
# processing_info.config_path,
# )

if rec_array is None:
raise ReconstructionError(
f"No image was returned from reconstruction with {processing_info.config_path}"
)
elif np.isnan(rec_array).all():
raise ReconstructionError(
f"Empty (NaN) image was returned from reconstruction with {processing_info.config_path}"
)

# rec_array = subprocess_recon(
# processing_info.image_path,
# processing_info.otf_path,
# processing_info.config_path,
# )
logger.info("Reconstructed %s", processing_info.image_path)
recon_pixel_size = float(processing_info.kwargs["xyres"]) / zoomfact
write_tiff(
Expand All @@ -174,7 +197,8 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Path:
processing_info.image_path,
processing_info.output_path,
)
return Path(processing_info.output_path)
processing_info.status = ProcessingStatus.COMPLETE
return processing_info


def _processing_files_to_output(
Expand Down Expand Up @@ -215,7 +239,11 @@ def _processing_files_to_output(
header="Reconstruction log",
)

output_files = tuple(pi.output_path for pi in processing_info_dict.values())
output_files = tuple(
pi.output_path
for pi in processing_info_dict.values()
if pi.status == ProcessingStatus.COMPLETE
)
if not output_files:
logger.warning(
"No reconstructions were created from %s",
Expand All @@ -239,6 +267,8 @@ def _processing_files_to_output(
wavelength,
processing_info,
) in processing_info_dict.items():
if processing_info.status != ProcessingStatus.COMPLETE:
continue
dv_path = create_output_path(
sim_data_path,
output_type="recon",
Expand All @@ -264,6 +294,52 @@ def _processing_files_to_output(
)


def _reconstruction_process_callback(
processing_info: ProcessingInfo,
wavelength: int,
processing_info_dict: dict[int, ProcessingInfo],
) -> None:
logger.debug("Channel %i process complete", wavelength)
processing_info_dict[wavelength].status = processing_info.status


def _reconstruction_process_error_callback(
exception: BaseException,
sim_data_path: str | PathLike[str],
wavelength: int,
processing_info: ProcessingInfo,
) -> None:
processing_info.status = ProcessingStatus.FAILED
if isinstance(exception, PySimReconException):
exception_str = str(exception)
else:
exception_str = "".join(traceback.format_exception(exception))
logger.error(
# exc_info doesn't work with the callback
"Error occurred during reconstruction of %s channel %i: %s",
sim_data_path,
wavelength,
exception_str,
)


def _get_incomplete_channels(
processing_info_dict: dict[int, ProcessingInfo]
) -> list[int]:
incomplete_wavelengths: list[int] = []
for wavelength, processing_info in processing_info_dict.items():
if processing_info.status == ProcessingStatus.COMPLETE:
logger.debug("%i is complete", wavelength)
else:
incomplete_wavelengths.append(wavelength)
logger.warning(
"Channel %i reconstruction ended with status '%s'",
wavelength,
processing_info.status,
)
return incomplete_wavelengths


def run_reconstructions(
conf: ConfigManager,
*sim_data_paths: str | PathLike[str],
Expand Down Expand Up @@ -326,16 +402,21 @@ def run_reconstructions(

async_results: list[AsyncResult] = []
for wavelength, processing_info in processing_info_dict.items():
processing_info.status = ProcessingStatus.PENDING
PeterC-DLS marked this conversation as resolved.
Show resolved Hide resolved
async_results.append(
pool.apply_async(
reconstruct_from_processing_info,
args=(processing_info,),
error_callback=lambda e: logger.error(
# exc_info doesn't work with this
"Error occurred during reconstruction of %s channel %i: %s",
sim_data_path,
wavelength,
"".join(traceback.format_exception(e)),
callback=partial(
_reconstruction_process_callback,
wavelength=wavelength,
processing_info_dict=processing_info_dict,
),
error_callback=partial(
_reconstruction_process_error_callback,
sim_data_path=sim_data_path,
wavelength=wavelength,
processing_info=processing_info,
),
)
)
Expand All @@ -348,6 +429,12 @@ def run_reconstructions(
):
r.wait()

incomplete_channels = _get_incomplete_channels(processing_info_dict)
if incomplete_channels and not allow_missing_channels:
raise ReconstructionError(
f"Failed to reconstruct channels: {', '.join(str(i) for i in incomplete_channels)}"
)

_processing_files_to_output(
sim_data_path,
file_output_directory=file_output_directory,
Expand All @@ -371,9 +458,13 @@ def run_reconstructions(
exc_info=True,
)
except ConfigException as e:
logger.error("Unable to process %s: %s)", sim_data_path, e)
logger.error("Unable to process %s: %s", sim_data_path, e)
except PySimReconException as e:
logger.error("Reconstruction failed for %s: %s", sim_data_path, e)
except Exception:
logger.error("Error occurred for %s", sim_data_path, exc_info=True)
logger.error(
"Unexpected error occurred for %s", sim_data_path, exc_info=True
)


def _prepare_config_kwargs(
Expand Down Expand Up @@ -535,8 +626,7 @@ def _prepare_files(
processing_info_dict[channel.wavelengths.emission_nm_int] = processing_info
except PySimReconException as e:
logger.error(
"Failed to prepare files for channel %i (%s) of %s: %s",
channel.wavelengths.emission_nm_int,
"Failed to prepare files for channel %s of %s: %s",
channel.wavelengths,
file_path,
e,
Expand All @@ -545,8 +635,7 @@ def _prepare_files(
raise
except Exception:
logger.error(
"Unexpected error preparing files for channel %i (%s) of %s",
channel.wavelengths.emission_nm_int,
"Unexpected error preparing files for channel %s of %s",
channel.wavelengths,
file_path,
exc_info=True,
Expand Down