Skip to content

Commit

Permalink
Enable saving reconstruction outputs as TIFF (#83)
Browse files Browse the repository at this point in the history
DV file is a bit specialised and it was fairly easy to make this change
using the existing code. Had some trouble with the channels being saved
as separate images, according to the OME-TIFF metadata, but that's fixed
now.
  • Loading branch information
thomasmfish authored Oct 23, 2024
2 parents b2d9f0a + 34494fe commit 2c6ba2e
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 134 deletions.
7 changes: 7 additions & 0 deletions src/sim_recon/cli/parsing/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def parse_args(
action="store_true",
help="If specified, attempt reconstruction of other channels in a multi-channel file if one or more are not configured",
)
parser.add_argument(
"--type",
dest="output_file_type",
choices=["dv", "tiff"],
default="dv",
help="File type of output images",
)
parser.add_argument(
"--overwrite",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions src/sim_recon/cli/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def main() -> None:
stitch_channels=namespace.stitch_channels,
parallel_process=namespace.parallel_process,
allow_missing_channels=namespace.allow_missing_channels,
output_file_type=namespace.output_file_type,
**recon_kwargs,
)

Expand Down
14 changes: 9 additions & 5 deletions src/sim_recon/files/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ def get_temporary_path(directory: Path, stem: str, suffix: str) -> Path:
)


def ensure_unique_filepath(path: Path, max_iter: int = 99) -> Path:
def ensure_unique_filepath(
output_directory: Path, stem: str, suffix: str, max_iter: int = 99
) -> Path:
path = output_directory / f"{stem}{suffix}"
if not path.exists():
return path
if max_iter <= 1:
raise PySimReconValueError("max_iter must be >1")
output_path = None
for i in range(1, max_iter + 1):
output_path = path.with_name(f"{path.stem}_{i}{path.suffix}")
output_path = output_directory / f"{stem}_{i}{suffix}"
if not output_path.exists():
logger.debug("'%s' was not unique, so '%s' will be used", path, output_path)
return output_path
Expand Down Expand Up @@ -115,11 +118,12 @@ def create_output_path(
output_directory = Path(output_directory)

file_stem = "_".join(output_fp_parts)
output_path = output_directory / ensure_valid_filename(f"{file_stem}{suffix}")

if ensure_unique:
output_path = ensure_unique_filepath(output_path, max_iter=max_path_iter)
return output_path
return ensure_unique_filepath(
output_directory, stem=file_stem, suffix=suffix, max_iter=max_path_iter
)
return output_directory / ensure_valid_filename(f"{file_stem}{suffix}")


@contextmanager
Expand Down
5 changes: 3 additions & 2 deletions src/sim_recon/images/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def dv_to_tiff(
channel.array = complex_to_interleaved_float(channel.array)
write_tiff(
tiff_path,
*image_data.channels,
xy_pixel_size_microns=(image_data.resolution.x, image_data.resolution.y),
image_data.channels,
resolution=image_data.resolution,
overwrite=overwrite,
allow_missing_channel_info=True,
)
return Path(tiff_path)
4 changes: 2 additions & 2 deletions src/sim_recon/images/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class ImageData:
@dataclass(slots=True)
class ImageChannel(Generic[OptionalWavelengths]):
wavelengths: OptionalWavelengths
array: NDArray[Any] | None = None
array: NDArray[Any]


@dataclass(slots=True, frozen=True)
class ImageResolution:
x: float
y: float
z: float
z: float | None = None


@dataclass(slots=True, frozen=True)
Expand Down
34 changes: 14 additions & 20 deletions src/sim_recon/images/dv.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,15 @@ def handle_float_array(


def write_dv(
input_file: str | PathLike[str],
output_file: str | PathLike[str],
array: NDArray[Any],
wavelengths: Collection[int],
zoomfact: float,
zzoom: int,
channels: Collection[ImageChannel[Wavelengths]],
input_dv: mrc.Mrc,
resolution: ImageResolution,
overwrite: bool = False,
) -> Path:
output_file = Path(output_file)

array = np.stack([c.array for c in channels], axis=-3)
wavelengths = tuple(c.wavelengths.emission_nm_int for c in channels)
if array.size == 0:
raise PySimReconValueError(
"%s will not be created as the array is empty", output_file
Expand All @@ -176,24 +175,19 @@ def write_dv(
else:
raise PySimReconFileExistsError(f"File {output_file} already exists")

if len(wavelengths) != array.shape[-3]:
raise InvalidValueError(
"Length of wavelengths list must be equal to the number of channels in the array"
)
wave = [*wavelengths, 0, 0, 0, 0, 0][:5]
# header_array = get_mrc_header_array(input_file)
bound_mrc = read_mrc_bound_array(input_file)
resolution = image_resolution_from_mrc(bound_mrc.mrc, warn_not_square=False)
metadata = {
"dx": resolution.x,
"dy": resolution.y,
"wave": wave,
}
if resolution.z is not None:
metadata["dz"] = resolution.z
mrc.save(
array,
output_file,
hdr=bound_mrc.mrc.hdr,
metadata={
"dx": resolution.x / zoomfact,
"dy": resolution.y / zoomfact,
"dz": resolution.z / zzoom,
"wave": wave,
},
hdr=input_dv.hdr,
metadata=metadata,
)
logger.info(
"%s saved",
Expand Down
182 changes: 110 additions & 72 deletions src/sim_recon/images/tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@
from pathlib import Path
import numpy as np
import tifffile as tf
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from .dataclasses import ImageResolution, ImageChannel
from ..info import __version__
from ..exceptions import (
PySimReconFileExistsError,
PySimReconValueError,
UndefinedValueError,
PySimReconIOError,
)

if TYPE_CHECKING:
from typing import Any, Generator
from typing import Any
from collections.abc import Generator, Collection
from os import PathLike
from numpy.typing import NDArray
from .dataclasses import ImageChannel
from .dataclasses import Wavelengths


logger = logging.getLogger(__name__)
Expand All @@ -38,51 +39,66 @@ def read_tiff(filepath: str | PathLike[str]) -> NDArray[Any]:
return tiff.asarray()


def generate_memmaps_from_tiffs(
*file_paths: str | PathLike[str],
) -> Generator[NDArray[Any], None, None]:
for fp in file_paths:
try:
yield tf.memmap(fp).squeeze()
except Exception as e:
logger.error("Unable to read image from %s: %s", fp, e)
raise


def get_combined_array_from_tiffs(
*file_paths: str | PathLike[str],
) -> NDArray[Any]:
logger.debug(
"Combining tiffs from:\n%s",
"\n\t".join(str(fp) for fp in file_paths),
)
if not file_paths:
raise PySimReconValueError("Cannot create a combined array without files")
def get_memmap_from_tiff(file_path: str | PathLike[str]) -> NDArray[Any]:
try:
return np.stack(tuple(generate_memmaps_from_tiffs(*file_paths)), -3)
except Exception:
raise PySimReconIOError("Failed to combine TIFF files")
return tf.memmap(file_path).squeeze()
except Exception as e:
logger.error("Unable to read image from %s: %s", file_path, e)
raise


def generate_channels_from_tiffs(
*wavelengths_path_tuple: tuple[Wavelengths, Path]
) -> Generator[ImageChannel[Wavelengths], None, None]:
for wavelengths, fp in wavelengths_path_tuple:
try:
yield ImageChannel(wavelengths=wavelengths, array=get_memmap_from_tiff(fp))
except Exception:
raise PySimReconIOError(f"Failed to read TIFF file '{fp}'")


def write_tiff(
output_path: str | PathLike[str],
*channels: ImageChannel,
xy_pixel_size_microns: tuple[float | None, float | None] | None = None,
*images: Collection[ImageChannel[Wavelengths] | ImageChannel[None]],
resolution: ImageResolution | None = None,
ome: bool = True,
overwrite: bool = False,
allow_empty_channels: bool = False,
) -> None:
def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
channel_dict: dict[str, Any] = {}
if channel.wavelengths is None:
return None
if channel.wavelengths.excitation_nm is not None:
channel_dict["ExcitationWavelength"] = channel.wavelengths.excitation_nm
channel_dict["ExcitationWavelengthUnits"] = "nm"
if channel.wavelengths.emission_nm is not None:
channel_dict["EmissionWavelength"] = channel.wavelengths.emission_nm
channel_dict["EmissionWavelengthUnits"] = "nm"
return channel_dict
allow_missing_channel_info: bool = False,
overwrite: bool = False,
) -> Path:
def get_ome_channel_dict(*channels: ImageChannel) -> dict[str, Any] | None:
names: list[str] = []
excitation_wavelengths: list[float | None] = []
excitation_wavelength_units: list[str | None] = []
emission_wavelengths: list[float | None] = []
emission_wavelength_units: list[str | None] = []
for i, channel in enumerate(channels, 1):
if channel.wavelengths is None:
if not allow_missing_channel_info:
raise UndefinedValueError(f"Missing wavelengths for channel {i}")
names.append(f"Channel {i}")
excitation_wavelengths.append(None)
excitation_wavelength_units.append(None)
emission_wavelengths.append(None)
emission_wavelength_units.append(None)

else:
names.append(
f"Channel {i}"
if channel.wavelengths.emission_nm_int is None
else str(channel.wavelengths.emission_nm_int)
)
excitation_wavelengths.append(channel.wavelengths.excitation_nm)
excitation_wavelength_units.append("nm")
emission_wavelengths.append(channel.wavelengths.emission_nm)
emission_wavelength_units.append("nm")
return {
"Name": names,
"ExcitationWavelength": excitation_wavelengths,
"ExcitationWavelengthUnits": excitation_wavelength_units,
"EmissionWavelength": emission_wavelengths,
"EmissionWavelengthUnits": emission_wavelength_units,
}

output_path = Path(output_path)

Expand All @@ -95,32 +111,34 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
else:
raise PySimReconFileExistsError(f"File {output_path} already exists")

tiff_metadata: dict[str, Any] = {}
tiff_kwargs: dict[str, Any] = {
"software": f"PySIMRecon {__version__}",
"photometric": "MINISBLACK",
"metadata": {},
"metadata": tiff_metadata,
}

if xy_pixel_size_microns is not None and None not in xy_pixel_size_microns:
xy_pixel_size_microns = cast(tuple[float, float], xy_pixel_size_microns)
if resolution is not None:
# TIFF tags:
tiff_kwargs["resolution"] = (
1e4 / xy_pixel_size_microns[0],
1e4 / xy_pixel_size_microns[1],
1e4 / resolution.x,
1e4 / resolution.y,
)
tiff_kwargs["resolutionunit"] = (
tf.RESUNIT.CENTIMETER
) # Use CENTIMETER for maximum compatibility

if ome:
if xy_pixel_size_microns is not None:
tiff_metadata["Name"] = output_path.name
if resolution is not None:
# OME PhysicalSize:
if xy_pixel_size_microns[0] is not None:
tiff_kwargs["metadata"]["PhysicalSizeX"] = xy_pixel_size_microns[0]
tiff_kwargs["metadata"]["PhysicalSizeXUnit"] = "µm"
if xy_pixel_size_microns[1] is not None:
tiff_kwargs["metadata"]["PhysicalSizeY"] = xy_pixel_size_microns[1]
tiff_kwargs["metadata"]["PhysicalSizeYUnit"] = "µm"
tiff_metadata["PhysicalSizeX"] = resolution.x
tiff_metadata["PhysicalSizeXUnit"] = "µm"
tiff_metadata["PhysicalSizeY"] = resolution.x
tiff_metadata["PhysicalSizeYUnit"] = "µm"
if resolution.z is not None:
tiff_metadata["PhysicalSizeZ"] = resolution.z
tiff_metadata["PhysicalSizeYUnit"] = "µm"

with tf.TiffWriter(
output_path,
Expand All @@ -129,23 +147,43 @@ def get_channel_dict(channel: ImageChannel) -> dict[str, Any] | None:
ome=ome,
shaped=not ome,
) as tiff:
for channel in channels:
if channel.array is None:
if allow_empty_channels:
logger.warning(
"Channel %s has no array to write",
channel.wavelengths,
for channels in images:
channels_to_write = []
for channel in channels:
if channel.array is None:
if allow_empty_channels:
logger.warning(
"Channel %s has no array to write",
channel.wavelengths,
)
continue
raise UndefinedValueError(
f"{output_path} will not be created as channel {channel.wavelengths} has no array to write",
)
continue
raise UndefinedValueError(
f"{output_path} will not be created as channel {channel.wavelengths} has no array to write",
)
channel_kwargs = tiff_kwargs.copy()
channel_kwargs["metadata"]["axes"] = (
"YX" if channel.array.ndim == 2 else "ZYX"
)
channels_to_write.append(channel)

array = np.stack(
[_.array for _ in channels], axis=0
) # adds another axis, even if only 1 channel

image_kwargs = tiff_kwargs.copy()
if array.ndim == 3:
# In case the images are 2D
tiff_metadata["axes"] = "CYX"
else:
tiff_metadata["axes"] = "CZYX"

if ome:
channel_dict = get_channel_dict(channel)
if channel_dict is not None:
channel_kwargs["metadata"]["Channel"] = channel_dict
tiff.write(channel.array, **channel_kwargs)
try:
channel_dict = get_ome_channel_dict(*channels)
if channel_dict is not None:
tiff_metadata["Channel"] = channel_dict
except UndefinedValueError as e:
if not allow_missing_channel_info:
logger.error("Failed to write image '%s': %s", output_path, e)
raise
logger.warning(
"Writing without OME Channel metadata due to error: %s", e
)
tiff.write(array, **image_kwargs)
return output_path
Loading

0 comments on commit 2c6ba2e

Please sign in to comment.