Skip to content

Commit

Permalink
Fix tiff channel handling (it was saving each channel as a separate i…
Browse files Browse the repository at this point in the history
…mage in the TIFF)
  • Loading branch information
Thomas Fish committed Oct 21, 2024
1 parent 70e4c1a commit 0841640
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 39 deletions.
3 changes: 2 additions & 1 deletion src/sim_recon/images/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def dv_to_tiff(
channel.array = complex_to_interleaved_float(channel.array)
write_tiff(
tiff_path,
*image_data.channels,
image_data.channels,
resolution=image_data.resolution,
overwrite=overwrite,
allow_missing_channel_info=True,
)
return Path(tiff_path)
2 changes: 1 addition & 1 deletion src/sim_recon/images/dv.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def handle_float_array(

def write_dv(
output_file: str | PathLike[str],
*channels: ImageChannel[Wavelengths],
channels: Collection[ImageChannel[Wavelengths]],
input_dv: mrc.Mrc,
resolution: ImageResolution,
overwrite: bool = False,
Expand Down
107 changes: 76 additions & 31 deletions src/sim_recon/images/tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from pathlib import Path
import numpy as np
import tifffile as tf
from typing import TYPE_CHECKING

Expand All @@ -14,7 +15,8 @@
)

if TYPE_CHECKING:
from typing import Any, Generator
from typing import Any
from collections.abc import Generator, Collection
from os import PathLike
from numpy.typing import NDArray
from .dataclasses import Wavelengths
Expand Down Expand Up @@ -57,23 +59,46 @@ def generate_channels_from_tiffs(

def write_tiff(
output_path: str | PathLike[str],
*channels: ImageChannel,
*images: Collection[ImageChannel[Wavelengths] | ImageChannel[None]],
resolution: ImageResolution | None = None,
ome: bool = True,
allow_empty_channels: bool = False,
allow_missing_channel_info: bool = False,
overwrite: bool = False,
) -> Path:
def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
channel_dict: dict[str, Any] = {}
if channel.wavelengths is None:
return None
if channel.wavelengths.excitation_nm is not None:
channel_dict["ExcitationWavelength"] = channel.wavelengths.excitation_nm
channel_dict["ExcitationWavelengthUnits"] = "nm"
if channel.wavelengths.emission_nm is not None:
channel_dict["EmissionWavelength"] = channel.wavelengths.emission_nm
channel_dict["EmissionWavelengthUnits"] = "nm"
return channel_dict
def get_ome_channel_dict(*channels: ImageChannel) -> dict[str, Any] | None:
names: list[str] = []
excitation_wavelengths: list[float | None] = []
excitation_wavelength_units: list[str | None] = []
emission_wavelengths: list[float | None] = []
emission_wavelength_units: list[str | None] = []
for i, channel in enumerate(channels, 1):
if channel.wavelengths is None:
if not allow_missing_channel_info:
raise UndefinedValueError(f"Missing wavelengths for channel {i}")
names.append(f"Channel {i}")
excitation_wavelengths.append(None)
excitation_wavelength_units.append(None)
emission_wavelengths.append(None)
emission_wavelength_units.append(None)

else:
names.append(
f"Channel {i}"
if channel.wavelengths.emission_nm_int is None
else str(channel.wavelengths.emission_nm_int)
)
excitation_wavelengths.append(channel.wavelengths.excitation_nm)
excitation_wavelength_units.append("nm")
emission_wavelengths.append(channel.wavelengths.emission_nm)
emission_wavelength_units.append("nm")
return {
"Name": names,
"ExcitationWavelength": excitation_wavelengths,
"ExcitationWavelengthUnits": excitation_wavelength_units,
"EmissionWavelength": emission_wavelengths,
"EmissionWavelengthUnits": emission_wavelength_units,
}

output_path = Path(output_path)

Expand Down Expand Up @@ -103,6 +128,7 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
) # Use CENTIMETER for maximum compatibility

if ome:
tiff_kwargs["metadata"]["Name"] = output_path.name
if resolution is not None:
# OME PhysicalSize:
tiff_kwargs["metadata"]["PhysicalSizeX"] = resolution.x
Expand All @@ -120,24 +146,43 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
ome=ome,
shaped=not ome,
) as tiff:
for channel in channels:
if channel.array is None:
if allow_empty_channels:
logger.warning(
"Channel %s has no array to write",
channel.wavelengths,
for channels in images:
channels_to_write = []
for channel in channels:
if channel.array is None:
if allow_empty_channels:
logger.warning(
"Channel %s has no array to write",
channel.wavelengths,
)
continue
raise UndefinedValueError(
f"{output_path} will not be created as channel {channel.wavelengths} has no array to write",
)
continue
raise UndefinedValueError(
f"{output_path} will not be created as channel {channel.wavelengths} has no array to write",
)
channel_kwargs = tiff_kwargs.copy()
channel_kwargs["metadata"]["axes"] = (
"YX" if channel.array.ndim == 2 else "ZYX"
)
channels_to_write.append(channel)

array = np.stack(
[_.array for _ in channels], axis=0
) # adds another axis, even if only 1 channel

image_kwargs = tiff_kwargs.copy()
if array.ndim == 3:
# In case the images are 2D
image_kwargs["metadata"]["axes"] = "CYX"
else:
image_kwargs["metadata"]["axes"] = "CZYX"

if ome:
channel_dict = get_channel_dict(channel)
if channel_dict is not None:
channel_kwargs["metadata"]["Channel"] = channel_dict
tiff.write(channel.array, **channel_kwargs)
try:
channel_dict = get_ome_channel_dict(*channels)
if channel_dict is not None:
image_kwargs["metadata"]["Channel"] = channel_dict
except UndefinedValueError as e:
if not allow_missing_channel_info:
logger.error("Failed to write image '%s': %s", output_path, e)
raise
logger.warning(
"Writing without OME Channel metadata due to error: %s", e
)
tiff.write(array, **image_kwargs)
return output_path
19 changes: 13 additions & 6 deletions src/sim_recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def reconstruct_from_processing_info(processing_info: ProcessingInfo) -> Process
recon_pixel_size = float(processing_info.kwargs["xyres"]) / zoomfact
write_tiff(
processing_info.output_path,
ImageChannel(processing_info.wavelengths, rec_array),
(ImageChannel(processing_info.wavelengths, rec_array),),
resolution=ImageResolution(recon_pixel_size, recon_pixel_size),
overwrite=True,
)
Expand Down Expand Up @@ -266,12 +266,15 @@ def _reconstructions_to_output(
zzoom = zoom_factors[0][1]
write_output(
output_image_path,
*generate_channels_from_tiffs(*output_wavelengths_path_tuples),
tuple(
generate_channels_from_tiffs(*output_wavelengths_path_tuples)
),
resolution=ImageResolution(
input_resolution.x / zoom_fact,
input_resolution.y / zoom_fact,
input_resolution.z / zzoom,
))
),
)
return
except InvalidValueError as e:
logger.warning("Unable to stitch files due to error: %s", e)
Expand All @@ -294,7 +297,11 @@ def _reconstructions_to_output(
zzoom = processing_info.kwargs["zzoom"]
write_output(
output_image_path,
*generate_channels_from_tiffs((processing_info.wavelengths, processing_info.output_path)),
tuple(
generate_channels_from_tiffs(
(processing_info.wavelengths, processing_info.output_path)
)
),
resolution=ImageResolution(
input_resolution.x / zoom_fact,
input_resolution.y / zoom_fact,
Expand Down Expand Up @@ -454,7 +461,7 @@ def run_reconstructions(
processing_info_dict=processing_info_dict,
stitch_channels=stitch_channels,
overwrite=overwrite,
file_type=output_file_type
file_type=output_file_type,
)
finally:
proc_log_files: list[Path] = []
Expand Down Expand Up @@ -649,7 +656,7 @@ def _prepare_files(
)
write_tiff(
split_file_path,
channel,
(channel,),
resolution=image_data.resolution,
)

Expand Down

0 comments on commit 0841640

Please sign in to comment.