Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/simplify multidim dataloading (#448) #450

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 24 additions & 154 deletions cyto_dl/datamodules/multidim_image.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Union

import numpy as np
import pandas as pd
import torch
from bioio import BioImage
from monai.data import DataLoader, Dataset, MetaTensor
from monai.transforms import Compose, ToTensor, apply_transform
from omegaconf import ListConfig
from monai.data import CacheDataset
from omegaconf import OmegaConf


class MultiDimImageDataset(Dataset):
"""Dataset converting a `.csv` file listing multi dimensional (timelapse or multi-scene) files
and some metadata into batches of single- scene, single-timepoint, single-channel images."""
class MultiDimImageDataset(CacheDataset):
"""Dataset converting a `.csv` file or dictionary listing multi dimensional (timelapse or
multi-scene) files and some metadata into batches of metadata intended for the
AICSImageLoaderd class."""

def __init__(
self,
csv_path: Union[Path, str],
img_path_column: str,
channel_column: str,
out_key: str,
csv_path: Optional[Union[Path, str]] = None,
img_path_column: str = "path",
channel_column: str = "channel",
spatial_dims: int = 3,
scene_column: str = "scene",
time_start_column: str = "start",
time_stop_column: str = "stop",
time_step_column: str = "step",
dict_meta: Optional[Dict] = None,
transform: Optional[Callable] = None,
dask_load: bool = True,
transform: Optional[Union[Callable, Sequence[Callable]]] = [],
**cache_kwargs,
):
"""
Parameters
Expand All @@ -38,8 +35,6 @@ def __init__(
column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file
channel_column:str
Column in `csv_path` that contains which channel to extract from multi dimensional (timelapse or multi-scene) file. Should be an integer.
out_key:str
Key where single-scene/timepoint/channel is saved in output dictionary
spatial_dims:int=3
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX
scene_column:str="scene",
Expand All @@ -56,27 +51,27 @@ def __init__(
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted.
dict_meta: Optional[Dict]
Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`.
transform: Optional[Callable] = None
Callable to that accepts numpy array. For example, image normalization functions could be passed here.
dask_load: bool = True
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints.
transform: Optional[Callable] = []
List (or Compose Object) or Monai dictionary-style transforms to apply to the image metadata. Typically, the first transform should be AICSImageLoaderd.
cache_kwargs:
Additional keyword arguments to pass to `CacheDataset`. To skip the caching mechanism, set `cache_num` to 0.
"""
super().__init__(None, transform)
df = pd.read_csv(csv_path) if csv_path is not None else pd.DataFrame([dict_meta])

df = (
pd.read_csv(csv_path)
if csv_path is not None
else pd.DataFrame(OmegaConf.to_container(dict_meta))
)
self.img_path_column = img_path_column
self.channel_column = channel_column
self.scene_column = scene_column
self.time_start_column = time_start_column
self.time_stop_column = time_stop_column
self.time_step_column = time_step_column
self.out_key = out_key
if spatial_dims not in (2, 3):
raise ValueError(f"`spatial_dims` must be 2 or 3, got {spatial_dims}")
self.spatial_dims = spatial_dims
self.dask_load = dask_load

self.img_data = self.get_per_file_args(df)
data = self.get_per_file_args(df)
super().__init__(data, transform, **cache_kwargs)

def _get_scenes(self, row, img):
scenes = row.get(self.scene_column, -1)
Expand Down Expand Up @@ -109,136 +104,11 @@ def get_per_file_args(self, df):
for timepoint in timepoints:
img_data.append(
{
"img": img,
"dimension_order_out": "ZYX"[-self.spatial_dims :],
"dimension_order_out": "C" + "ZYX"[-self.spatial_dims :],
"C": row[self.channel_column],
"scene": scene,
"T": timepoint,
"original_path": row[self.img_path_column],
}
)
img_data.reverse()
return img_data

def _metadata_to_str(self, metadata):
return "_".join([] + [f"{k}={v}" for k, v in metadata.items()])

def _ensure_channel_first(self, img):
while len(img.shape) < self.spatial_dims + 1:
img = np.expand_dims(img, 0)
return img

def create_metatensor(self, img, meta):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img.astype(float))
if isinstance(img, MetaTensor):
img.meta.update(meta)
return img
elif isinstance(img, torch.Tensor):
return MetaTensor(
img,
meta=meta,
)
raise ValueError(f"Expected img to be MetaTensor or torch.Tensor, got {type(img)}")

def is_batch(self, x):
return isinstance(x, list) or len(x.shape) == self.spatial_dims + 2

def _transform(self, index: int):
img_data = self.img_data.pop()
img = img_data.pop("img")
original_path = img_data.pop("original_path")
scene = img_data.pop("scene")
img.set_scene(scene)

if self.dask_load:
data_i = img.get_image_dask_data(**img_data).compute()
else:
data_i = img.get_image_data(**img_data)
# add scene and path information back to metadata
img_data["scene"] = scene
img_data["original_path"] = original_path
data_i = self._ensure_channel_first(data_i)
data_i = self.create_metatensor(data_i, img_data)

output_img = (
apply_transform(self.transform, data_i) if self.transform is not None else data_i
)
# some monai transforms return a batch. When collated, the batch dimension gets moved to the channel dimension
if self.is_batch(output_img):
return [{self.out_key: img} for img in output_img]
return {self.out_key: output_img}

def __len__(self):
return len(self.img_data)


def make_multidim_image_dataloader(
csv_path: Optional[Union[Path, str]] = None,
img_path_column: str = "path",
channel_column: str = "channel",
out_key: str = "image",
spatial_dims: int = 3,
scene_column: str = "scene",
time_start_column: str = "start",
time_stop_column: str = "stop",
time_step_column: str = "step",
dict_meta: Optional[Dict] = None,
transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] = None,
**dataloader_kwargs,
) -> DataLoader:
"""Function to create a MultiDimImage DataLoader. Currently, this dataset is only useful during
prediction and cannot be used for training or testing.

Parameters
----------
csv_path: Optional[Union[Path, str]]
path to csv
img_path_column: str
column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file
channel_column: str
Column in `csv_path` that contains which channel to extract from multi dim image file. Should be an integer.
out_key: str
Key where single-scene/timepoint/channel is saved in output dictionary
spatial_dims: int
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX
scene_column: str
Column in `csv_path` that contains scenes to extract from multiscene file. If not specified, all scenes will
be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`)
time_start_column: str
Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column`
are not specified, all timepoints are extracted.
time_stop_column: str
Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column`
are not specified, all timepoints are extracted.
time_step_column: str
Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run.
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted.
dict_meta: Optional[Dict]
Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`.
transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]]
Callable or list of callables that accept numpy array. For example, image normalization functions could be passed here. Dataloading is already handled by the dataset.

Returns
-------
DataLoader
The DataLoader object for the MultiDimIMage dataset.
"""
if isinstance(transforms, (list, tuple, ListConfig)):
transforms = Compose(transforms)
dataset = MultiDimImageDataset(
csv_path,
img_path_column,
channel_column,
out_key,
spatial_dims,
scene_column=scene_column,
time_start_column=time_start_column,
time_stop_column=time_stop_column,
time_step_column=time_step_column,
dict_meta=dict_meta,
transform=transforms,
)
# currently only supports a 0/1 workers
num_workers = min(dataloader_kwargs.pop("num_workers"), 1)
return DataLoader(dataset, num_workers=num_workers, **dataloader_kwargs)
1 change: 1 addition & 0 deletions cyto_dl/image/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bioio_loader import BioIOImageLoaderd
from .monai_bio_reader import MonaiBioReader
from .numpy_reader import ReadNumpyFile
from .ome_zarr_reader import OmeZarrReader
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import re
from typing import List

import numpy as np
from bioio import BioImage
from monai.data import MetaTensor
from monai.transforms import Transform

from cyto_dl.models.im2im.utils.postprocessing.arg_checking import get_dtype

class AICSImageLoaderd(Transform):
"""Enumerates scenes and timepoints for dictionary with format.

{path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. Differs
from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at
initialization.
class BioIOImageLoaderd(Transform):
"""Enumerates scenes and timepoints for dictionary with format.
{path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}.
Differs from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at
initialization. The filepath will be saved in the dictionary as 'filename_or_obj' (with or without metadata depending on `include_meta_in_filename`).
"""

def __init__(
Expand All @@ -23,6 +25,7 @@ def __init__(
allow_missing_keys=False,
dtype: np.dtype = np.float16,
dask_load: bool = True,
include_meta_in_filename: bool = False,
):
"""
Parameters
Expand All @@ -37,23 +40,35 @@ def __init__(
Key for the output image
allow_missing_keys : bool = False
Whether to allow missing keys in the data dictionary
dtype : np.dtype = np.float16
Data type to cast the image to
dask_load: bool = True
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints.
include_meta_in_filename: bool = False
Whether to include metadata in the filename. Useful when loading multi-dimensional images with different kwargs.
"""
super().__init__()
self.path_key = path_key
self.kwargs_keys = kwargs_keys
self.allow_missing_keys = allow_missing_keys
self.out_key = out_key
self.scene_key = scene_key
self.dtype = dtype
self.dtype = get_dtype(dtype)
self.dask_load = dask_load
self.include_meta_in_filename = include_meta_in_filename

def split_args(self, arg):
if "," in str(arg):
if isinstance(arg, str) and "," in arg:
return list(map(int, arg.split(",")))
return arg

def _get_filename(self, path, kwargs):
if self.include_meta_in_filename:
path = path.split(".")[0] + "_" + "_".join([f"{k}_{v}" for k, v in kwargs.items()])
# remove illegal characters from filename
path = re.sub(r'[<>:"|?*]', "", path)
return path

def __call__(self, data):
# copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time
data = data.copy()
Expand All @@ -69,6 +84,6 @@ def __call__(self, data):
else:
img = img.get_image_data(**kwargs)
img = img.astype(self.dtype)
data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs})

kwargs.update({"filename_or_obj": self._get_filename(path, kwargs)})
data[self.out_key] = MetaTensor(img, meta=kwargs)
return data
33 changes: 33 additions & 0 deletions cyto_dl/image/transforms/add_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from monai.data import MetaTensor
from monai.transforms import Transform
from omegaconf import ListConfig


class AddMeta(Transform):
Expand Down Expand Up @@ -33,3 +34,35 @@ def __call__(self, data):
del data[k]
data[self.image_key].meta.update(new_meta)
return data


class MetaToKey(Transform):
"""Transform to add metadata from image key to the batch dictionary."""

def __init__(self, image_key: str, meta_keys: Sequence[str], replace: bool = False):
"""
Parameters
----------
image_key: str
Key in batch dictionary for image data. Must be a MetaTensor
meta_keys: Sequence[str]
List of keys to add to batch dictionary
replace: bool
If True, replace meta_keys in batch dictionary with those from image metadata if they already exist
"""
self.meta_keys = meta_keys if isinstance(meta_keys, (list, ListConfig)) else [meta_keys]
self.image_key = image_key
self.replace = replace

def __call__(self, data):
for k in self.meta_keys:
if not isinstance(data[self.image_key], MetaTensor):
raise ValueError(
f"Image key {self.image_key} must be a MetaTensor, got {type(data[self.image_key])}"
)
if k in data and not self.replace:
raise ValueError(f"Key {k} already exists in batch dictionary")
if k not in data[self.image_key].meta:
raise ValueError(f"Key {k} not found in image metadata")
data[k] = data[self.image_key].meta[k]
return data
16 changes: 4 additions & 12 deletions cyto_dl/models/im2im/utils/postprocessing/act_thresh_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from skimage.exposure import rescale_intensity
from skimage.measure import label

from cyto_dl.models.im2im.utils.postprocessing.arg_checking import get_dtype


class ActThreshLabel:
"""General-purpose postprocessing transform for applying any of an activation, threshold,
Expand Down Expand Up @@ -40,22 +42,12 @@ def __init__(
self.activation = activation
self.threshold = threshold
self.label = label
self.dtype = self._get_dtype(dtype)
self.dtype = get_dtype(dtype)
self.ch = ch
self.rescale_dtype = self._get_dtype(rescale_dtype)
self.rescale_dtype = get_dtype(rescale_dtype)
if self.rescale_dtype is not None:
self.dtype = self.rescale_dtype

def _get_dtype(self, dtype: DTypeLike) -> DTypeLike:
if isinstance(dtype, str):
return get_class(dtype)
elif dtype is None:
return dtype
elif isinstance(dtype, type):
return dtype
else:
raise ValueError(f"Expected dtype to be DtypeLike, string, or None, got {type(dtype)}")

def __call__(self, img: torch.Tensor) -> np.ndarray:
if self.ch > 0:
img = img[self.ch]
Expand Down
Loading
Loading