diff --git a/src/sim_recon/images/__init__.py b/src/sim_recon/images/__init__.py index 5706d59..6212b5d 100644 --- a/src/sim_recon/images/__init__.py +++ b/src/sim_recon/images/__init__.py @@ -58,8 +58,9 @@ def dv_to_tiff( channel.array = complex_to_interleaved_float(channel.array) write_tiff( tiff_path, - *image_data.channels, + image_data.channels, resolution=image_data.resolution, overwrite=overwrite, + allow_missing_channel_info=True, ) return Path(tiff_path) diff --git a/src/sim_recon/images/dv.py b/src/sim_recon/images/dv.py index bdade54..11cdb7e 100644 --- a/src/sim_recon/images/dv.py +++ b/src/sim_recon/images/dv.py @@ -149,7 +149,7 @@ def handle_float_array( def write_dv( output_file: str | PathLike[str], - *channels: ImageChannel[Wavelengths], + channels: Collection[ImageChannel[Wavelengths]], input_dv: mrc.Mrc, resolution: ImageResolution, overwrite: bool = False, diff --git a/src/sim_recon/images/tiff.py b/src/sim_recon/images/tiff.py index 5ed6b7d..7d16158 100644 --- a/src/sim_recon/images/tiff.py +++ b/src/sim_recon/images/tiff.py @@ -2,6 +2,7 @@ import logging from pathlib import Path +import numpy as np import tifffile as tf from typing import TYPE_CHECKING @@ -14,7 +15,8 @@ ) if TYPE_CHECKING: - from typing import Any, Generator + from typing import Any + from collections.abc import Generator, Collection from os import PathLike from numpy.typing import NDArray from .dataclasses import Wavelengths @@ -57,23 +59,46 @@ def generate_channels_from_tiffs( def write_tiff( output_path: str | PathLike[str], - *channels: ImageChannel, + *images: Collection[ImageChannel[Wavelengths] | ImageChannel[None]], resolution: ImageResolution | None = None, ome: bool = True, allow_empty_channels: bool = False, + allow_missing_channel_info: bool = False, overwrite: bool = False, ) -> Path: - def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None: - channel_dict: dict[str, Any] = {} - if channel.wavelengths is None: - return None - if channel.wavelengths.excitation_nm is not None: - channel_dict["ExcitationWavelength"] = channel.wavelengths.excitation_nm - channel_dict["ExcitationWavelengthUnits"] = "nm" - if channel.wavelengths.emission_nm is not None: - channel_dict["EmissionWavelength"] = channel.wavelengths.emission_nm - channel_dict["EmissionWavelengthUnits"] = "nm" - return channel_dict + def get_ome_channel_dict(*channels: ImageChannel) -> dict[str, Any] | None: + names: list[str] = [] + excitation_wavelengths: list[float | None] = [] + excitation_wavelength_units: list[str | None] = [] + emission_wavelengths: list[float | None] = [] + emission_wavelength_units: list[str | None] = [] + for i, channel in enumerate(channels, 1): + if channel.wavelengths is None: + if not allow_missing_channel_info: + raise UndefinedValueError(f"Missing wavelengths for channel {i}") + names.append(f"Channel {i}") + excitation_wavelengths.append(None) + excitation_wavelength_units.append(None) + emission_wavelengths.append(None) + emission_wavelength_units.append(None) + + else: + names.append( + f"Channel {i}" + if channel.wavelengths.emission_nm_int is None + else str(channel.wavelengths.emission_nm_int) + ) + excitation_wavelengths.append(channel.wavelengths.excitation_nm) + excitation_wavelength_units.append("nm") + emission_wavelengths.append(channel.wavelengths.emission_nm) + emission_wavelength_units.append("nm") + return { + "Name": names, + "ExcitationWavelength": excitation_wavelengths, + "ExcitationWavelengthUnits": excitation_wavelength_units, + "EmissionWavelength": emission_wavelengths, + "EmissionWavelengthUnits": emission_wavelength_units, + } output_path = Path(output_path) @@ -103,6 +128,7 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None: ) # Use CENTIMETER for maximum compatibility if ome: + tiff_kwargs["metadata"]["Name"] = output_path.name if resolution is not None: # OME PhysicalSize: tiff_kwargs["metadata"]["PhysicalSizeX"] = resolution.x @@ -120,24 +146,43 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None: ome=ome, shaped=not ome, ) as tiff: - for channel in channels: - if channel.array is None: - if allow_empty_channels: - logger.warning( - "Channel %s has no array to write", - channel.wavelengths, + for channels in images: + channels_to_write = [] + for channel in channels: + if channel.array is None: + if allow_empty_channels: + logger.warning( + "Channel %s has no array to write", + channel.wavelengths, + ) + continue + raise UndefinedValueError( + f"{output_path} will not be created as channel {channel.wavelengths} has no array to write", ) - continue - raise UndefinedValueError( - f"{output_path} will not be created as channel {channel.wavelengths} has no array to write", - ) - channel_kwargs = tiff_kwargs.copy() - channel_kwargs["metadata"]["axes"] = ( - "YX" if channel.array.ndim == 2 else "ZYX" - ) + channels_to_write.append(channel) + + array = np.stack( + [_.array for _ in channels], axis=0 + ) # adds another axis, even if only 1 channel + + image_kwargs = tiff_kwargs.copy() + if array.ndim == 3: + # In case the images are 2D + image_kwargs["metadata"]["axes"] = "CYX" + else: + image_kwargs["metadata"]["axes"] = "CZYX" + if ome: - channel_dict = get_channel_dict(channel) - if channel_dict is not None: - channel_kwargs["metadata"]["Channel"] = channel_dict - tiff.write(channel.array, **channel_kwargs) + try: + channel_dict = get_ome_channel_dict(*channels) + if channel_dict is not None: + image_kwargs["metadata"]["Channel"] = channel_dict + except UndefinedValueError as e: + if not allow_missing_channel_info: + logger.error("Failed to write image '%s': %s", output_path, e) + raise + logger.warning( + "Writing without OME Channel metadata due to error: %s", e + ) + tiff.write(array, **image_kwargs) return output_path diff --git a/src/sim_recon/recon.py b/src/sim_recon/recon.py index 281b382..c5a3ac6 100644 --- a/src/sim_recon/recon.py +++ b/src/sim_recon/recon.py @@ -191,7 +191,7 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Process recon_pixel_size = float(processing_info.kwargs["xyres"]) / zoomfact write_tiff( processing_info.output_path, - ImageChannel(processing_info.wavelengths, rec_array), + (ImageChannel(processing_info.wavelengths, rec_array),), resolution=ImageResolution(recon_pixel_size, recon_pixel_size), overwrite=True, ) @@ -266,12 +266,15 @@ def _reconstructions_to_output( zzoom = zoom_factors[0][1] write_output( output_image_path, - *generate_channels_from_tiffs(*output_wavelengths_path_tuples), + tuple( + generate_channels_from_tiffs(*output_wavelengths_path_tuples) + ), resolution=ImageResolution( input_resolution.x / zoom_fact, input_resolution.y / zoom_fact, input_resolution.z / zzoom, - )) + ), + ) return except InvalidValueError as e: logger.warning("Unable to stitch files due to error: %s", e) @@ -294,7 +297,11 @@ def _reconstructions_to_output( zzoom = processing_info.kwargs["zzoom"] write_output( output_image_path, - *generate_channels_from_tiffs((processing_info.wavelengths, processing_info.output_path)), + tuple( + generate_channels_from_tiffs( + (processing_info.wavelengths, processing_info.output_path) + ) + ), resolution=ImageResolution( input_resolution.x / zoom_fact, input_resolution.y / zoom_fact, @@ -454,7 +461,7 @@ def run_reconstructions( processing_info_dict=processing_info_dict, stitch_channels=stitch_channels, overwrite=overwrite, - file_type=output_file_type + file_type=output_file_type, ) finally: proc_log_files: list[Path] = [] @@ -649,7 +656,7 @@ def _prepare_files( ) write_tiff( split_file_path, - channel, + (channel,), resolution=image_data.resolution, )