diff --git a/cyto_dl/datamodules/multidim_image.py b/cyto_dl/datamodules/multidim_image.py index 14643061..ffa2842e 100644 --- a/cyto_dl/datamodules/multidim_image.py +++ b/cyto_dl/datamodules/multidim_image.py @@ -1,33 +1,30 @@ from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Union -import numpy as np import pandas as pd -import torch from bioio import BioImage -from monai.data import DataLoader, Dataset, MetaTensor -from monai.transforms import Compose, ToTensor, apply_transform -from omegaconf import ListConfig +from monai.data import CacheDataset +from omegaconf import OmegaConf -class MultiDimImageDataset(Dataset): - """Dataset converting a `.csv` file listing multi dimensional (timelapse or multi-scene) files - and some metadata into batches of single- scene, single-timepoint, single-channel images.""" +class MultiDimImageDataset(CacheDataset): + """Dataset converting a `.csv` file or dictionary listing multi dimensional (timelapse or + multi-scene) files and some metadata into batches of metadata intended for the + AICSImageLoaderd class.""" def __init__( self, - csv_path: Union[Path, str], - img_path_column: str, - channel_column: str, - out_key: str, + csv_path: Optional[Union[Path, str]] = None, + img_path_column: str = "path", + channel_column: str = "channel", spatial_dims: int = 3, scene_column: str = "scene", time_start_column: str = "start", time_stop_column: str = "stop", time_step_column: str = "step", dict_meta: Optional[Dict] = None, - transform: Optional[Callable] = None, - dask_load: bool = True, + transform: Optional[Union[Callable, Sequence[Callable]]] = [], + **cache_kwargs, ): """ Parameters @@ -38,8 +35,6 @@ def __init__( column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file channel_column:str Column in `csv_path` that contains which channel to extract from multi dimensional (timelapse or multi-scene) file. Should be an integer. - out_key:str - Key where single-scene/timepoint/channel is saved in output dictionary spatial_dims:int=3 Spatial dimension of output image. Must be 2 for YX or 3 for ZYX scene_column:str="scene", @@ -56,27 +51,27 @@ def __init__( If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. dict_meta: Optional[Dict] Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`. - transform: Optional[Callable] = None - Callable to that accepts numpy array. For example, image normalization functions could be passed here. - dask_load: bool = True - Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. + transform: Optional[Callable] = [] + List (or Compose Object) or Monai dictionary-style transforms to apply to the image metadata. Typically, the first transform should be AICSImageLoaderd. + cache_kwargs: + Additional keyword arguments to pass to `CacheDataset`. To skip the caching mechanism, set `cache_num` to 0. """ - super().__init__(None, transform) - df = pd.read_csv(csv_path) if csv_path is not None else pd.DataFrame([dict_meta]) - + df = ( + pd.read_csv(csv_path) + if csv_path is not None + else pd.DataFrame(OmegaConf.to_container(dict_meta)) + ) self.img_path_column = img_path_column self.channel_column = channel_column self.scene_column = scene_column self.time_start_column = time_start_column self.time_stop_column = time_stop_column self.time_step_column = time_step_column - self.out_key = out_key if spatial_dims not in (2, 3): raise ValueError(f"`spatial_dims` must be 2 or 3, got {spatial_dims}") self.spatial_dims = spatial_dims - self.dask_load = dask_load - - self.img_data = self.get_per_file_args(df) + data = self.get_per_file_args(df) + super().__init__(data, transform, **cache_kwargs) def _get_scenes(self, row, img): scenes = row.get(self.scene_column, -1) @@ -109,136 +104,11 @@ def get_per_file_args(self, df): for timepoint in timepoints: img_data.append( { - "img": img, - "dimension_order_out": "ZYX"[-self.spatial_dims :], + "dimension_order_out": "C" + "ZYX"[-self.spatial_dims :], "C": row[self.channel_column], "scene": scene, "T": timepoint, "original_path": row[self.img_path_column], } ) - img_data.reverse() return img_data - - def _metadata_to_str(self, metadata): - return "_".join([] + [f"{k}={v}" for k, v in metadata.items()]) - - def _ensure_channel_first(self, img): - while len(img.shape) < self.spatial_dims + 1: - img = np.expand_dims(img, 0) - return img - - def create_metatensor(self, img, meta): - if isinstance(img, np.ndarray): - img = torch.from_numpy(img.astype(float)) - if isinstance(img, MetaTensor): - img.meta.update(meta) - return img - elif isinstance(img, torch.Tensor): - return MetaTensor( - img, - meta=meta, - ) - raise ValueError(f"Expected img to be MetaTensor or torch.Tensor, got {type(img)}") - - def is_batch(self, x): - return isinstance(x, list) or len(x.shape) == self.spatial_dims + 2 - - def _transform(self, index: int): - img_data = self.img_data.pop() - img = img_data.pop("img") - original_path = img_data.pop("original_path") - scene = img_data.pop("scene") - img.set_scene(scene) - - if self.dask_load: - data_i = img.get_image_dask_data(**img_data).compute() - else: - data_i = img.get_image_data(**img_data) - # add scene and path information back to metadata - img_data["scene"] = scene - img_data["original_path"] = original_path - data_i = self._ensure_channel_first(data_i) - data_i = self.create_metatensor(data_i, img_data) - - output_img = ( - apply_transform(self.transform, data_i) if self.transform is not None else data_i - ) - # some monai transforms return a batch. When collated, the batch dimension gets moved to the channel dimension - if self.is_batch(output_img): - return [{self.out_key: img} for img in output_img] - return {self.out_key: output_img} - - def __len__(self): - return len(self.img_data) - - -def make_multidim_image_dataloader( - csv_path: Optional[Union[Path, str]] = None, - img_path_column: str = "path", - channel_column: str = "channel", - out_key: str = "image", - spatial_dims: int = 3, - scene_column: str = "scene", - time_start_column: str = "start", - time_stop_column: str = "stop", - time_step_column: str = "step", - dict_meta: Optional[Dict] = None, - transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] = None, - **dataloader_kwargs, -) -> DataLoader: - """Function to create a MultiDimImage DataLoader. Currently, this dataset is only useful during - prediction and cannot be used for training or testing. - - Parameters - ---------- - csv_path: Optional[Union[Path, str]] - path to csv - img_path_column: str - column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file - channel_column: str - Column in `csv_path` that contains which channel to extract from multi dim image file. Should be an integer. - out_key: str - Key where single-scene/timepoint/channel is saved in output dictionary - spatial_dims: int - Spatial dimension of output image. Must be 2 for YX or 3 for ZYX - scene_column: str - Column in `csv_path` that contains scenes to extract from multiscene file. If not specified, all scenes will - be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`) - time_start_column: str - Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column` - are not specified, all timepoints are extracted. - time_stop_column: str - Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column` - are not specified, all timepoints are extracted. - time_step_column: str - Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run. - If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. - dict_meta: Optional[Dict] - Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`. - transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] - Callable or list of callables that accept numpy array. For example, image normalization functions could be passed here. Dataloading is already handled by the dataset. - - Returns - ------- - DataLoader - The DataLoader object for the MultiDimIMage dataset. - """ - if isinstance(transforms, (list, tuple, ListConfig)): - transforms = Compose(transforms) - dataset = MultiDimImageDataset( - csv_path, - img_path_column, - channel_column, - out_key, - spatial_dims, - scene_column=scene_column, - time_start_column=time_start_column, - time_stop_column=time_stop_column, - time_step_column=time_step_column, - dict_meta=dict_meta, - transform=transforms, - ) - # currently only supports a 0/1 workers - num_workers = min(dataloader_kwargs.pop("num_workers"), 1) - return DataLoader(dataset, num_workers=num_workers, **dataloader_kwargs) diff --git a/cyto_dl/image/io/__init__.py b/cyto_dl/image/io/__init__.py index fa360bb9..cf20e25c 100644 --- a/cyto_dl/image/io/__init__.py +++ b/cyto_dl/image/io/__init__.py @@ -1,3 +1,4 @@ +from .bioio_loader import BioIOImageLoaderd from .monai_bio_reader import MonaiBioReader from .numpy_reader import ReadNumpyFile from .ome_zarr_reader import OmeZarrReader diff --git a/cyto_dl/image/io/aicsimage_loader.py b/cyto_dl/image/io/bioio_loader.py similarity index 65% rename from cyto_dl/image/io/aicsimage_loader.py rename to cyto_dl/image/io/bioio_loader.py index 4c7b5a22..8a137f88 100644 --- a/cyto_dl/image/io/aicsimage_loader.py +++ b/cyto_dl/image/io/bioio_loader.py @@ -1,3 +1,4 @@ +import re from typing import List import numpy as np @@ -5,13 +6,14 @@ from monai.data import MetaTensor from monai.transforms import Transform +from cyto_dl.models.im2im.utils.postprocessing.arg_checking import get_dtype -class AICSImageLoaderd(Transform): - """Enumerates scenes and timepoints for dictionary with format. - {path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. Differs - from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at - initialization. +class BioIOImageLoaderd(Transform): + """Enumerates scenes and timepoints for dictionary with format. + {path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. + Differs from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at + initialization. The filepath will be saved in the dictionary as 'filename_or_obj' (with or without metadata depending on `include_meta_in_filename`). """ def __init__( @@ -23,6 +25,7 @@ def __init__( allow_missing_keys=False, dtype: np.dtype = np.float16, dask_load: bool = True, + include_meta_in_filename: bool = False, ): """ Parameters @@ -37,8 +40,12 @@ def __init__( Key for the output image allow_missing_keys : bool = False Whether to allow missing keys in the data dictionary + dtype : np.dtype = np.float16 + Data type to cast the image to dask_load: bool = True Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. + include_meta_in_filename: bool = False + Whether to include metadata in the filename. Useful when loading multi-dimensional images with different kwargs. """ super().__init__() self.path_key = path_key @@ -46,14 +53,22 @@ def __init__( self.allow_missing_keys = allow_missing_keys self.out_key = out_key self.scene_key = scene_key - self.dtype = dtype + self.dtype = get_dtype(dtype) self.dask_load = dask_load + self.include_meta_in_filename = include_meta_in_filename def split_args(self, arg): - if "," in str(arg): + if isinstance(arg, str) and "," in arg: return list(map(int, arg.split(","))) return arg + def _get_filename(self, path, kwargs): + if self.include_meta_in_filename: + path = path.split(".")[0] + "_" + "_".join([f"{k}_{v}" for k, v in kwargs.items()]) + # remove illegal characters from filename + path = re.sub(r'[<>:"|?*]', "", path) + return path + def __call__(self, data): # copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time data = data.copy() @@ -69,6 +84,6 @@ def __call__(self, data): else: img = img.get_image_data(**kwargs) img = img.astype(self.dtype) - data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs}) - + kwargs.update({"filename_or_obj": self._get_filename(path, kwargs)}) + data[self.out_key] = MetaTensor(img, meta=kwargs) return data diff --git a/cyto_dl/image/transforms/add_meta.py b/cyto_dl/image/transforms/add_meta.py index e2fa267d..5ba06bf7 100644 --- a/cyto_dl/image/transforms/add_meta.py +++ b/cyto_dl/image/transforms/add_meta.py @@ -2,6 +2,7 @@ from monai.data import MetaTensor from monai.transforms import Transform +from omegaconf import ListConfig class AddMeta(Transform): @@ -33,3 +34,35 @@ def __call__(self, data): del data[k] data[self.image_key].meta.update(new_meta) return data + + +class MetaToKey(Transform): + """Transform to add metadata from image key to the batch dictionary.""" + + def __init__(self, image_key: str, meta_keys: Sequence[str], replace: bool = False): + """ + Parameters + ---------- + image_key: str + Key in batch dictionary for image data. Must be a MetaTensor + meta_keys: Sequence[str] + List of keys to add to batch dictionary + replace: bool + If True, replace meta_keys in batch dictionary with those from image metadata if they already exist + """ + self.meta_keys = meta_keys if isinstance(meta_keys, (list, ListConfig)) else [meta_keys] + self.image_key = image_key + self.replace = replace + + def __call__(self, data): + for k in self.meta_keys: + if not isinstance(data[self.image_key], MetaTensor): + raise ValueError( + f"Image key {self.image_key} must be a MetaTensor, got {type(data[self.image_key])}" + ) + if k in data and not self.replace: + raise ValueError(f"Key {k} already exists in batch dictionary") + if k not in data[self.image_key].meta: + raise ValueError(f"Key {k} not found in image metadata") + data[k] = data[self.image_key].meta[k] + return data diff --git a/cyto_dl/models/im2im/utils/postprocessing/act_thresh_label.py b/cyto_dl/models/im2im/utils/postprocessing/act_thresh_label.py index 9ceb1412..206ff445 100644 --- a/cyto_dl/models/im2im/utils/postprocessing/act_thresh_label.py +++ b/cyto_dl/models/im2im/utils/postprocessing/act_thresh_label.py @@ -7,6 +7,8 @@ from skimage.exposure import rescale_intensity from skimage.measure import label +from cyto_dl.models.im2im.utils.postprocessing.arg_checking import get_dtype + class ActThreshLabel: """General-purpose postprocessing transform for applying any of an activation, threshold, @@ -40,22 +42,12 @@ def __init__( self.activation = activation self.threshold = threshold self.label = label - self.dtype = self._get_dtype(dtype) + self.dtype = get_dtype(dtype) self.ch = ch - self.rescale_dtype = self._get_dtype(rescale_dtype) + self.rescale_dtype = get_dtype(rescale_dtype) if self.rescale_dtype is not None: self.dtype = self.rescale_dtype - def _get_dtype(self, dtype: DTypeLike) -> DTypeLike: - if isinstance(dtype, str): - return get_class(dtype) - elif dtype is None: - return dtype - elif isinstance(dtype, type): - return dtype - else: - raise ValueError(f"Expected dtype to be DtypeLike, string, or None, got {type(dtype)}") - def __call__(self, img: torch.Tensor) -> np.ndarray: if self.ch > 0: img = img[self.ch] diff --git a/cyto_dl/models/im2im/utils/postprocessing/arg_checking.py b/cyto_dl/models/im2im/utils/postprocessing/arg_checking.py new file mode 100644 index 00000000..78e4fea1 --- /dev/null +++ b/cyto_dl/models/im2im/utils/postprocessing/arg_checking.py @@ -0,0 +1,11 @@ +from numpy.typing import DTypeLike + +def get_dtype(dtype: DTypeLike) -> DTypeLike: + if isinstance(dtype, str): + return get_class(dtype) + elif dtype is None: + return dtype + elif isinstance(dtype, type): + return dtype + else: + raise ValueError(f"Expected dtype to be DtypeLike, string, or None, got {type(dtype)}") \ No newline at end of file