diff --git a/Makefile b/Makefile index 3eec088e8..af147b5bc 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,6 @@ sync-reqs-files: requirements/$(PLATFORM)/requirements.txt \ requirements/$(PLATFORM)/torchserve-requirements.txt \ requirements/$(PLATFORM)/equiv-requirements.txt \ requirements/$(PLATFORM)/spharm-requirements.txt \ - requirements/$(PLATFORM)/omnipose-requirements.txt \ requirements/$(PLATFORM)/all-requirements.txt \ requirements/$(PLATFORM)/test-requirements.txt \ requirements/$(PLATFORM)/docs-requirements.txt diff --git a/configs/data/im2im/skoots.yaml b/configs/data/im2im/instance_seg.yaml similarity index 96% rename from configs/data/im2im/skoots.yaml rename to configs/data/im2im/instance_seg.yaml index 4b916b116..85c5848f1 100644 --- a/configs/data/im2im/skoots.yaml +++ b/configs/data/im2im/instance_seg.yaml @@ -35,7 +35,7 @@ transforms: - _target_: monai.transforms.Zoomd keys: ${data.columns} zoom: 0.25 - - _target_: cyto_dl.models.im2im.utils.SkootsPreprocessd + - _target_: cyto_dl.models.im2im.utils.InstanceSegPreprocessd label_keys: ${target_col} dim: ${spatial_dims} - _target_: monai.transforms.ToTensord @@ -86,7 +86,7 @@ transforms: - _target_: monai.transforms.Zoomd keys: ${data.columns} zoom: 0.25 - - _target_: cyto_dl.models.im2im.utils.SkootsPreprocessd + - _target_: cyto_dl.models.im2im.utils.InstanceSegPreprocessd label_keys: ${target_col} dim: ${spatial_dims} - _target_: monai.transforms.ToTensord @@ -141,7 +141,7 @@ transforms: - _target_: monai.transforms.Zoomd keys: ${data.columns} zoom: 0.25 - - _target_: cyto_dl.models.im2im.utils.SkootsPreprocessd + - _target_: cyto_dl.models.im2im.utils.InstanceSegPreprocessd label_keys: ${target_col} dim: ${spatial_dims} - _target_: monai.transforms.ToTensord diff --git a/configs/data/im2im/omnipose.yaml b/configs/data/im2im/omnipose.yaml deleted file mode 100644 index 8000e53df..000000000 --- a/configs/data/im2im/omnipose.yaml +++ /dev/null @@ -1,163 +0,0 @@ -_target_: cyto_dl.datamodules.dataframe.DataframeDatamodule - -path: -cache_dir: - -num_workers: 0 -batch_size: 1 -pin_memory: True -split_column: -columns: - - ${source_col} - - ${target_col} - -transforms: - train: - _target_: monai.transforms.Compose - transforms: - # channels are [blank, membrane,blank, structure, nuclear dye, brightfield ] - - _target_: monai.transforms.LoadImaged - keys: ${source_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 4 - # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] - - _target_: monai.transforms.LoadImaged - keys: ${target_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 0 - - _target_: monai.transforms.EnsureChannelFirstd - channel_dim: "no_channel" - keys: ${data.columns} - - _target_: monai.transforms.Zoomd - keys: ${data.columns} - zoom: 0.25 - - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposePreprocessd - label_keys: ${target_col} - dim: ${spatial_dims} - - _target_: monai.transforms.ToTensord - keys: ${data.columns} - - _target_: monai.transforms.NormalizeIntensityd - keys: ${source_col} - channel_wise: True - - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd - keys: ${data.columns} - patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 - scales_dict: ${kv_to_dict:${data._aux._scales_dict}} - - _target_: monai.transforms.RandHistogramShiftd - prob: 0.1 - keys: ${source_col} - num_control_points: [90, 500] - - - _target_: monai.transforms.RandStdShiftIntensityd - prob: 0.1 - keys: ${source_col} - factors: 0.1 - - - _target_: monai.transforms.RandAdjustContrastd - prob: 0.1 - keys: ${source_col} - gamma: [0.9, 1.5] - - test: - _target_: monai.transforms.Compose - transforms: - # channels are [blank, membrane,blank, structure, nuclear dye, brightfield ] - - _target_: monai.transforms.LoadImaged - keys: ${source_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 5 - # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] - - _target_: monai.transforms.LoadImaged - keys: ${target_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 0 - - _target_: monai.transforms.EnsureChannelFirstd - channel_dim: "no_channel" - keys: ${data.columns} - - _target_: monai.transforms.Zoomd - keys: ${data.columns} - zoom: 0.25 - - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposePreprocessd - label_keys: ${target_col} - dim: ${spatial_dims} - - _target_: monai.transforms.ToTensord - keys: ${data.columns} - - _target_: monai.transforms.NormalizeIntensityd - keys: ${source_col} - channel_wise: True - - predict: - _target_: monai.transforms.Compose - transforms: - # channels are [blank, membrane,blank, structure, nuclear dye, brightfield ] - - _target_: monai.transforms.LoadImaged - keys: ${source_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 5 - - _target_: monai.transforms.EnsureChannelFirstd - channel_dim: "no_channel" - keys: ${data.columns} - - _target_: monai.transforms.Zoomd - keys: ${data.columns} - zoom: 0.25 - - _target_: monai.transforms.ToTensord - keys: ${source_col} - - _target_: monai.transforms.NormalizeIntensityd - keys: ${source_col} - channel_wise: True - - valid: - _target_: monai.transforms.Compose - transforms: - # channels are [blank, membrane,blank, structure, nuclear dye, brightfield ] - - _target_: monai.transforms.LoadImaged - keys: ${source_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 5 - expand_user: False - # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] - - _target_: monai.transforms.LoadImaged - keys: ${target_col} - reader: - - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: "ZYX" - C: 0 - - _target_: monai.transforms.EnsureChannelFirstd - channel_dim: "no_channel" - keys: ${data.columns} - - _target_: monai.transforms.Zoomd - keys: ${data.columns} - zoom: 0.25 - - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposePreprocessd - label_keys: ${target_col} - dim: ${spatial_dims} - - _target_: monai.transforms.ToTensord - keys: ${data.columns} - - _target_: monai.transforms.NormalizeIntensityd - keys: ${source_col} - channel_wise: True - - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd - keys: ${data.columns} - patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 - scales_dict: ${kv_to_dict:${data._aux._scales_dict}} - -_aux: - _scales_dict: - - - ${target_col} - - [1] - - - ${source_col} - - [1] diff --git a/configs/experiment/im2im/omnipose.yaml b/configs/experiment/im2im/instance_seg.yaml similarity index 89% rename from configs/experiment/im2im/omnipose.yaml rename to configs/experiment/im2im/instance_seg.yaml index 1929fcb12..00c9be320 100644 --- a/configs/experiment/im2im/omnipose.yaml +++ b/configs/experiment/im2im/instance_seg.yaml @@ -4,8 +4,8 @@ # python train.py experiment=example defaults: - - override /data: im2im/omnipose.yaml - - override /model: im2im/omnipose.yaml + - override /data: im2im/instance_seg.yaml + - override /model: im2im/instance_seg.yaml - override /callbacks: default.yaml - override /trainer: gpu.yaml - override /logger: mlflow.yaml @@ -30,6 +30,5 @@ data: path: ${paths.data_dir}/example_experiment_data/segmentation cache_dir: ${paths.data_dir}/example_experiment_data/cache batch_size: 1 - _aux: patch_shape: [16, 32, 32] diff --git a/configs/experiment/im2im/skoots.yaml b/configs/experiment/im2im/skoots.yaml deleted file mode 100644 index 263da1ee3..000000000 --- a/configs/experiment/im2im/skoots.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: im2im/skoots.yaml - - override /model: im2im/skoots.yaml - - override /callbacks: default.yaml - - override /trainer: gpu.yaml - - override /logger: mlflow.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -tags: ["dev"] -seed: 12345 - -experiment_name: YOUR_EXPERIMENT_NAME_HERE -run_name: YOUR_RUN_NAME_HERE -source_col: raw -target_col: seg -spatial_dims: 3 -raw_im_channels: 1 - -trainer: - max_epochs: 100 - precision: 16 - -data: - path: ${paths.data_dir}/example_experiment_data/segmentation - cache_dir: ${paths.data_dir}/example_experiment_data/cache - batch_size: 1 - - _aux: - patch_shape: [16, 32, 32] diff --git a/configs/model/im2im/skoots.yaml b/configs/model/im2im/instance_seg.yaml similarity index 94% rename from configs/model/im2im/skoots.yaml rename to configs/model/im2im/instance_seg.yaml index d5a686bbe..76462b6b9 100644 --- a/configs/model/im2im/skoots.yaml +++ b/configs/model/im2im/instance_seg.yaml @@ -39,7 +39,7 @@ _aux: - - ${target_col} - _target_: cyto_dl.nn.BaseHead loss: - _target_: cyto_dl.models.im2im.utils.SkootsLoss + _target_: cyto_dl.models.im2im.utils.InstanceSegLoss dim: ${spatial_dims} save_raw: True postprocess: diff --git a/configs/model/im2im/omnipose.yaml b/configs/model/im2im/omnipose.yaml deleted file mode 100644 index 57459397f..000000000 --- a/configs/model/im2im/omnipose.yaml +++ /dev/null @@ -1,51 +0,0 @@ -_target_: cyto_dl.models.im2im.MultiTaskIm2Im - -save_images_every_n_epochs: 1 -x_key: ${source_col} -save_dir: ${paths.output_dir} - -backbone: - _target_: monai.networks.nets.DynUNet - spatial_dims: ${spatial_dims} - in_channels: ${raw_im_channels} - out_channels: 5 - strides: [1, 2, 2, 2] - kernel_size: [3, 3, 3, 3] - upsample_kernel_size: [2, 2, 2] - dropout: 0.0 - res_block: True - -task_heads: ${kv_to_dict:${model._aux._tasks}} - -optimizer: - generator: - _partial_: True - _target_: torch.optim.AdamW - lr: 0.0001 - weight_decay: 0.001 - -lr_scheduler: - generator: - _partial_: True - _target_: torch.optim.lr_scheduler.ExponentialLR - gamma: 0.998 - -inference_args: - sw_batch_size: 1 - roi_size: ${data._aux.patch_shape} - -_aux: - _tasks: - - - ${target_col} - - _target_: cyto_dl.nn.BaseHead - loss: - _target_: cyto_dl.models.im2im.utils.OmniposeLoss - dim: ${spatial_dims} - save_raw: True - postprocess: - input: - _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel - dtype: numpy.float32 - prediction: - _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel - dtype: numpy.float32 diff --git a/configs/model/test/omnipose.yaml b/configs/model/test/omnipose.yaml deleted file mode 100644 index f0a1ec31f..000000000 --- a/configs/model/test/omnipose.yaml +++ /dev/null @@ -1,21 +0,0 @@ -defaults: - - test/base.yaml@_here_ - -task_heads: ${kv_to_dict:${model._aux._tasks}} -inference_args: - sw_batch_size: 1 - roi_size: ${data._aux.patch_shape} -_aux: - _tasks: - - - ${target_col} - - _target_: cyto_dl.nn.BaseHead - loss: - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposeLoss - dim: 3 - postprocess: - input: - _target_: cyto_dl.models.im2im.utils.postprocessing.detach - _partial_: True - prediction: - _target_: cyto_dl.models.im2im.utils.postprocessing.detach - _partial_: True diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index 41bf720f9..a5002f065 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -7,9 +7,8 @@ max_epochs: 10 accelerator: cpu devices: 1 - # mixed precision for extra speed-up -# precision: 16 +precision: 16 # perform a validation loop every N training epochs check_val_every_n_epoch: 1 diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml index 4cf3b353b..d6f5786b3 100644 --- a/configs/trainer/gpu.yaml +++ b/configs/trainer/gpu.yaml @@ -1,5 +1,4 @@ defaults: - default.yaml -precision: 16 accelerator: gpu devices: 1 diff --git a/cyto_dl/models/im2im/utils/__init__.py b/cyto_dl/models/im2im/utils/__init__.py index c4c31478d..e59b5090c 100644 --- a/cyto_dl/models/im2im/utils/__init__.py +++ b/cyto_dl/models/im2im/utils/__init__.py @@ -1,15 +1,12 @@ try: - from .omnipose import OmniposeClustering, OmniposeLoss, OmniposePreprocessd + from .instance_seg import ( + InstanceSegCluster, + InstanceSegLoss, + InstanceSegPreprocessd, + ) except (ModuleNotFoundError, ImportError): - OmniposeClustering = None - OmniposeLoss = None - OmniposePreprocessd = None - -try: - from .skoots import SkootsCluster, SkootsLoss, SkootsPreprocessd -except (ModuleNotFoundError, ImportError): - SkootsCluster = None - SkootsLoss = None - SkootsPreprocessd = None + InstanceSegCluster = None + InstanceSegLoss = None + InstanceSegPreprocessd = None from .postprocessing import ActThreshLabel, DictToIm, detach diff --git a/cyto_dl/models/im2im/utils/skoots.py b/cyto_dl/models/im2im/utils/instance_seg.py similarity index 93% rename from cyto_dl/models/im2im/utils/skoots.py rename to cyto_dl/models/im2im/utils/instance_seg.py index cb673066e..1b451c851 100644 --- a/cyto_dl/models/im2im/utils/skoots.py +++ b/cyto_dl/models/im2im/utils/instance_seg.py @@ -27,7 +27,7 @@ from cyto_dl.nn.losses.loss_wrapper import CMAP_loss -class SkootsPreprocessd(Transform): +class InstanceSegPreprocessd(Transform): def __init__( self, label_keys: Union[Sequence[str], str], @@ -41,7 +41,7 @@ def __init__( Parameters ---------- label_keys: Union[Sequence[str], str] - Keys of instance segmentations in input dictionary to convert to Skoots ground truth images. + Keys of instance segmentations in input dictionary to convert to InstanceSeg ground truth images. kernel_size: int=3 Size of kernel for gaussian smoothing of flows thin: int=5 @@ -157,7 +157,7 @@ def embed_from_skel(self, skel, iseg): skel_points = skel_boundary.eq(i).nonzero() if skel_points.numel() == 0: continue - # distances should take into account z anisotropy + # distances should take into account z anisotropy point_embeddings = self._get_point_embeddings( object_points.mul(self.anisotropy), skel_points.mul(self.anisotropy) ) @@ -215,9 +215,6 @@ def __call__(self, image_dict): continue im = image_dict.pop(key) im = im.as_tensor() if isinstance(im, MetaTensor) else im - import time - - t0 = time.time() im_numpy = im.numpy().astype(int).squeeze() skel = self.shrink(im_numpy) skel_edt = torch.from_numpy(edt.edt(skel > 0)).unsqueeze(0) @@ -226,18 +223,15 @@ def __call__(self, image_dict): embed = self.embed_from_skel(skel, im.squeeze(0).clone()) cmap = self._get_cmap(skel_edt.squeeze(), im) bound = torch.from_numpy(find_boundaries(im_numpy)).unsqueeze(0) - # image_dict[key]= torch.cat([(skel>0).unsqueeze(0),im>0, embed, bound, cmap]) image_dict[key] = torch.cat([skel_edt, im > 0, embed, bound, cmap]).float() - - print(time.time() - t0) return image_dict -class SkootsRandFlipd(RandomizableTransform): - """Flipping Augmentation for Skoots training. +class InstanceSegRandFlipd(RandomizableTransform): + """Flipping Augmentation for InstanceSeg training. - When flipping ground truths generated by `SkootsPreprocessD`, the sign of gradients have to be - changed after flipping. + When flipping ground truths generated by `InstanceSegPreprocessD`, the sign of gradients have + to be changed after flipping. """ def __init__( @@ -255,9 +249,9 @@ def __init__( spatial_axis:int axis to flip across label_keys:Union[str, Sequence[str]]=[] - key or list of keys generated by SkootsPreprocessD to flip + key or list of keys generated by InstanceSegPreprocessD to flip image_keys:Union[str, Sequence[str]]=[] - key or list of keys NOT generated by SkootsPreprocessd to flip + key or list of keys NOT generated by InstanceSegPreprocessd to flip prob:float=0.1 probability of flipping dim:int=3 @@ -284,7 +278,7 @@ def _flip(self, img, is_label): if is_label: assert ( img.shape[0] == 4 + self.dim - ), f"Expected generated skoots ground truth to have {4+self.dim} channels, got {img.shape[0]}" + ), f"Expected generated InstanceSeg ground truth to have {4+self.dim} channels, got {img.shape[0]}" flipped_flows = img[2 : 2 + self.dim] flipped_flows[self.spatial_axis] *= -1 img[2 : 2 + self.dim] = flipped_flows @@ -303,8 +297,8 @@ def __call__(self, image_dict): return image_dict -class SkootsLoss: - """Loss function for Skoots.""" +class InstanceSegLoss: + """Loss function for InstanceSeg.""" def __init__(self, dim: int = 3): """ @@ -315,7 +309,6 @@ def __init__(self, dim: int = 3): """ self.dim = dim - # self.skeleton_loss = CMAP_loss(torch.nn.BCEWithLogitsLoss(reduction='none')) self.skeleton_loss = CMAP_loss(torch.nn.MSELoss(reduction="none")) self.vector_loss = CMAP_loss(torch.nn.MSELoss(reduction="none")) self.boundary_loss = CMAP_loss(torch.nn.BCEWithLogitsLoss(reduction="none")) @@ -347,9 +340,9 @@ def __call__(self, y_hat, y): return vector_loss + skeleton_loss + semantic_loss + boundary_loss -class SkootsCluster: +class InstanceSegCluster: """ - Clustering for SKOOTS - finds skeletons and assigns semantic points to skeleton based on spatial embedding and nearest neighbor distances. + Clustering for InstanceSeg - finds skeletons and assigns semantic points to skeleton based on spatial embedding and nearest neighbor distances. """ def __init__( diff --git a/cyto_dl/models/im2im/utils/omnipose.py b/cyto_dl/models/im2im/utils/omnipose.py deleted file mode 100644 index 25a3fe62a..000000000 --- a/cyto_dl/models/im2im/utils/omnipose.py +++ /dev/null @@ -1,432 +0,0 @@ -# modified from https://github.com/kevinjohncutler/omnipose/blob/main/omnipose/core.py -import warnings -from typing import Sequence, Union - -import dask -import edt -import numpy as np -import torch -from cellpose_omni.core import ( - ArcCosDotLoss, - DerivativeLoss, - DivergenceLoss, - NormLoss, - WeightedLoss, -) -from monai.data import MetaTensor -from monai.transforms import Flip, RandomizableTransform, Transform -from omegaconf import ListConfig -from omnipose.core import compute_masks, diameters, masks_to_flows -from scipy.ndimage import find_objects -from scipy.spatial import ConvexHull -from skimage.filters import apply_hysteresis_threshold, gaussian -from skimage.measure import label -from skimage.morphology import binary_dilation, remove_small_holes -from skimage.segmentation import expand_labels -from skimage.transform import rescale, resize - - -class OmniposePreprocessd(Transform): - """Wrapper of core functions from [Omnipose](https://github.com/kevinjohncutler/omnipose) to - create an 5/6 channel Omnipose ground truth consisting of boundary, weight mask, flow, and - smooth distance images from an input instance segmentation.""" - - def __init__( - self, - label_keys: Union[Sequence[str], str], - dim: int = 3, - allow_missing_keys: bool = False, - ): - """ - Parameters - ---------- - label_keys: Union[Sequence[str], str] - Keys of instance segmentations in input dictionary to convert to omnipose ground truth images. - dim:int=3 - Spatial dimension of images - allow_missing_keys:bool=False - Whether to raise error if key in `label_keys` is not present - """ - super().__init__() - self.label_keys = ( - label_keys if isinstance(label_keys, (list, ListConfig)) else [label_keys] - ) - self.dim = dim - self.allow_missing_keys = allow_missing_keys - - def __call__(self, image_dict): - warnings.warn( - "OmniPose preprocessing is slow, consider setting `persist_cache: True` in your experiment config" - ) - for key in self.label_keys: - if key not in image_dict: - if not self.allow_missing_keys: - raise KeyError( - f"Key {key} not found in data. Available keys are {image_dict.keys()}" - ) - continue - - im = image_dict[key] - im = im.as_tensor() if isinstance(im, MetaTensor) else im - numpy_im = im.numpy().squeeze() - - if np.max(numpy_im) <= 0: - raise ValueError("Ground truth images for Omnipose must have at least 1 label") - - out_im = np.zeros([4 + self.dim] + list(numpy_im.shape)) - - ( - instance_seg, - rough_distance, - boundaries, - smooth_distance, - flows, - ) = masks_to_flows(numpy_im, omni=True, dim=self.dim, use_gpu=True, device=im.device) - cutoff = diameters(instance_seg, rough_distance) / 2 - smooth_distance[rough_distance <= 0] = -cutoff - - bg_edt = edt.edt(numpy_im < 0.5, black_border=True) - boundary_weighted_mask = gaussian(1 - np.clip(bg_edt, 0, cutoff) / cutoff, 1) + 0.5 - out_im[0] = boundaries - out_im[1] = boundary_weighted_mask - out_im[2] = instance_seg - out_im[3 : 3 + self.dim] = flows * 5.0 # weighted for loss function? - out_im[3 + self.dim] = smooth_distance - image_dict[key] = out_im - return image_dict - - -class OmniposeRandFlipd(RandomizableTransform): - """Flipping Augmentation for Omnipose training. - - When flipping ground truths generated by `OmniposePreprocessD, the sign of gradients have to be - changed after flipping. - """ - - def __init__( - self, - spatial_axis: int, - label_keys: Union[str, Sequence[str]] = [], - image_keys: Union[str, Sequence[str]] = [], - prob: float = 0.1, - dim: int = 3, - allow_missing_keys: bool = False, - ): - """ - Parameters - -------------- - spatial_axis:int - axis to flip across - label_keys:Union[str, Sequence[str]]=[] - key or list of keys generated by OmniposePreprocessD to flip - image_keys:Union[str, Sequence[str]]=[] - key or list of keys NOT generated by OmniposePreprocessd to flip - prob:float=0.1 - probability of flipping - dim:int=3 - spatial dimensions of images - allow_missing_keys:bool=False - Whether to raise error if a provided key is missing - """ - super().__init__() - self.image_keys = ( - image_keys if isinstance(image_keys, (list, ListConfig)) else [image_keys] - ) - self.label_keys = ( - label_keys if isinstance(label_keys, (list, ListConfig)) else [label_keys] - ) - self.dim = dim - self.allow_missing_keys = allow_missing_keys - self.flipper = Flip(spatial_axis) - self.prob = prob - self.spatial_axis = spatial_axis - - def _flip(self, img, is_label): - img = self.flipper(img) - - if is_label: - assert ( - img.shape[0] == 4 + self.dim - ), f"Expected generated omnipose ground truth to have {4+self.dim} channels, got {img.shape[0]}" - flipped_flows = img[3 : 3 + self.dim] - flipped_flows[self.spatial_axis] *= -1 - img[3 : 3 + self.dim] = flipped_flows - return img - - def __call__(self, image_dict): - do_flip = self.R.rand() < self.prob - if do_flip: - for key in self.label_keys + self.image_keys: - if key not in image_dict: - if not self.allow_missing_keys: - raise KeyError( - f"Key {key} not found in data. Available keys are {image_dict.keys()}" - ) - continue - image_dict[key] = self._flip(image_dict[key], key in self.label_keys) - return image_dict - - -class OmniposeLoss: - """Loss function for Omnipose.""" - - def __init__(self, dim: int = 3): - """ - Parameters - -------------- - dim:int=3 - Spatial dimension of input images. - """ - - self.dim = dim - self.weighted_flow_MSE = WeightedLoss() - self.angular_flow_loss = ArcCosDotLoss() - self.DerivativeLoss = DerivativeLoss() - self.boundary_seg_loss = torch.nn.BCEWithLogitsLoss(reduction="mean") - self.NormLoss = NormLoss() - self.distance_field_mse = WeightedLoss() - self.criterion11 = DerivativeLoss() - self.criterion16 = DivergenceLoss() - - def __call__(self, y_hat, y): - """ - Parameters - -------------- - y: ND-array, float - transformed labels in array [nimg x nchan x xy[0] x xy[1]] - y[:,0] boundary field - y[:,1] boundary-emphasized weights - y[:,2] cell masks - y[:,3:3+self.dim] flow components - y[:,3+self.dim] smooth distance field - - y_hat: ND-tensor, float - network predictions, with dimension D, these are: - y_hat[:,:D] flow field components at 0,1,...,D-1 - y_hat[:,D] distance fields at D - y_hat[:,D+1] boundary fields at D+1 - - """ - boundary = y[:, 0] - w = y[:, 1] - cellmask = (y[:, 2] > 0).bool() # acts as a mask now, not output - - # calculat loss on entire patch if no cells present - this helps - # remove background artifacts - for img_id in range(cellmask.shape[0]): - if torch.sum(cellmask[img_id]) == 0: - cellmask[img_id] = True - veci = y[:, -(self.dim + 1) : -1] - dist = y[:, -1] # now distance transform replaces probability - - # prediction - flow = y_hat[:, : self.dim] # 0,1,...self.dim-1 - dt = y_hat[:, self.dim] - bd = y_hat[:, self.dim + 1] - a = 10.0 - - # stacked versions for weighting vector fields with scalars - wt = torch.stack([w] * self.dim, dim=1) - ct = torch.stack([cellmask] * self.dim, dim=1) - - # luckily, torch.gradient did exist after all and derivative loss was easy to implement. Could also fix divergenceloss, but I have not been using it. - # the rest seem good to go. - - loss1 = 10.0 * self.weighted_flow_MSE(flow, veci, wt) # weighted MSE - loss2 = self.angular_flow_loss(flow, veci, w, cellmask) # ArcCosDotLoss - loss3 = self.DerivativeLoss(flow, veci, wt, ct) / a # DerivativeLoss - loss4 = 2.0 * self.boundary_seg_loss(bd, boundary) # BCElogits - loss5 = 2.0 * self.NormLoss(flow, veci, w, cellmask) # loss on norm - loss6 = 2.0 * self.distance_field_mse(dt, dist, w) # weighted MSE - loss7 = ( - self.criterion11( - dt.unsqueeze(1), - dist.unsqueeze(1), - w.unsqueeze(1), - cellmask.unsqueeze(1), - ) - / a - ) - loss8 = self.criterion16(flow, veci, cellmask) # divergence loss - loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 - return loss - - -# setting a high flow threshold avoids erroneous removal of masks that are fine. -# debugging whether this is a training issue... -class OmniposeClustering: - """Run clustering on downsampled version of flows, then use original resolution distance field - to mask instance segmentations.""" - - def __init__( - self, - mask_threshold=0, - rescale_factor=1.0, - min_object_size=100, - hole_size=0, - flow_threshold=1e8, - spatial_dim=3, - boundary_seg=True, - naive_label=False, - fine_threshold=True, - convex_ratio_threshold=1.2, - ): - """ - Parameters - -------------- - mask_threshold: float - Threshold to binarize distance transform - rescale_factor: float - Rescaling before omnipose clustering - min_object_size: int - Minimum object size to include in final segmentation - hole_size: int - Maximum hole size to include in final segmentation - flow_threshold: float - Remove masks with anomalous flows below threshold. NOTE, currently must be set high - to avoid removal of most masks - spatial_dim: int - Spatial dimensions of input data - boundary_seg: bool - whether to use omnipose's boundary_seg clustering. If False, uses standard euler integration - naive_label: bool - Whether to attempt to label objects and only cluster objects that look like merged segmentations - based on `convex_ratio_threshold`. Much faster, but can lead to worse segmentations. - fine_threshold: bool - Whether to use hysteresis threshold for finer detail in thin segmentation structures - convex_ratio_threshold: float - Anomaly threshold for running omnipose clustering if `naive_label==True`. - """ - - assert ( - 0 < rescale_factor <= 1.0 - ), f"Rescale factor must be in range [0,1], got {rescale_factor}" - self.mask_threshold = mask_threshold - self.rescale_factor = rescale_factor - self.min_object_size = min_object_size - self.hole_size = hole_size - self.flow_threshold = flow_threshold - self.spatial_dim = spatial_dim - self.boundary_seg = boundary_seg - self.naive_label = naive_label - self.clustering_function = self.do_naive_labeling if naive_label else self.get_mask - self.fine_threshold = fine_threshold - self.convex_ratio_threshold = convex_ratio_threshold - - def rescale_instance(self, im, seg): - seg = resize(seg, im.shape, order=0, anti_aliasing=False, preserve_range=True) - seg = expand_labels(seg, distance=3) - seg[im == 0] = 0 - return seg - - def get_mask(self, flow, dist, device): - flow = rescale( - flow, - [1] + [self.rescale_factor] * self.spatial_dim, - order=3, - preserve_range=True, - anti_aliasing=False, - ) - rescale_dist = rescale( - dist, self.rescale_factor, order=3, preserve_range=True, anti_aliasing=False - ) - mask, p, tr, bounds, affinity_graph = compute_masks( - flow, - rescale_dist, - nclasses=4, - dim=self.spatial_dim, - use_gpu=True, - device=device, - min_size=self.min_object_size, - flow_threshold=self.flow_threshold, - boundary_seg=self.boundary_seg, - mask_threshold=self.mask_threshold, - do_3D=True, - ) - mask = self.rescale_instance(dist > self.mask_threshold, mask) - return mask - - def pad_slice(self, s, padding, constraints): - new_slice = [slice(None, None, None)] - for slice_part, c in zip(s, constraints): - start = max(0, slice_part.start - padding) - stop = min(c, slice_part.stop + padding) - new_slice.append(slice(start, stop, None)) - return new_slice - - def is_merged_segmentation(self, mask_crop, area): - mask_points = np.asarray(list(zip(*np.where(mask_crop)))) - - # look for self.convex_ratio_threshold - - @dask.delayed - def get_separated_masks(self, flow_crop, mask_crop, dist_crop, device, crop): - area = np.sum(mask_crop) - if area < self.min_object_size: - return - if self.is_merged_segmentation(mask_crop, area): - flow_crop[:, ~mask_crop] = 0 - dist_crop[~mask_crop] = dist_crop.min() - mask = self.get_mask(flow_crop, dist_crop, device) - return {"slice": tuple(crop[1:]), "mask": mask} - - return { - "slice": tuple(crop[1:]), - "mask": remove_small_holes( - mask_crop > 0, - area_threshold=self.hole_size, - connectivity=self.spatial_dim, - ), - } - - def do_naive_labeling(self, flow, dist, device): - """label thresholded distance transform to get objects, then run clustering only on objects - that seem to be merged segmentations. - - Useful for well-separated, round objects like nuclei - """ - if self.fine_threshold: - cellmask = apply_hysteresis_threshold( - dist, low=self.mask_threshold - 1, high=self.mask_threshold - ) - else: - cellmask = dist > self.mask_threshold - naive_labeling = label(cellmask) - out_image = np.zeros_like(naive_labeling, dtype=np.uint16) - regions = find_objects(naive_labeling) - results = [] - for val, region in enumerate(regions, start=1): - padded_crop = self.pad_slice(region, 5, naive_labeling.shape) - results.append( - self.get_separated_masks( - flow[tuple(padded_crop)].copy(), - naive_labeling[tuple(padded_crop[1:])] == val, - dist[tuple(padded_crop[1:])].copy(), - device, - padded_crop, - ) - ) - results = dask.compute(*results) - highest_cell_idx = 0 - for r in results: - if r is None: - continue - mask = r["mask"].astype(np.uint16) - mask[mask > 0] += highest_cell_idx - out_image[r["slice"]] += mask - highest_cell_idx += np.max(r["mask"]) - return out_image - - def __call__(self, im): - device = im.device - im = im.detach().cpu().numpy() - flow = im[: self.spatial_dim] - dist = im[self.spatial_dim] - mask = self.clustering_function(flow, dist, device) - return mask diff --git a/cyto_dl/nn/losses/__init__.py b/cyto_dl/nn/losses/__init__.py index 9857549c7..4d974f9be 100644 --- a/cyto_dl/nn/losses/__init__.py +++ b/cyto_dl/nn/losses/__init__.py @@ -4,7 +4,6 @@ from .cosine_loss import CosineLoss from .gan_loss import GANLoss, Pix2PixHD from .gaussian_nll_loss import GaussianNLLLoss -from .geomloss import GeomLoss from .threshold_loss import ThresholdLoss from .weibull import WeibullLogLoss from .weighted_mse_loss import WeightedMSELoss @@ -13,3 +12,8 @@ from .spharm_loss import SpharmLoss except (ModuleNotFoundError, ImportError): SpharmLoss = None + +try: + from .geomloss import GeomLoss +except (ModuleNotFoundError, ImportError): + GeomLoss = None diff --git a/docs/conf.py b/docs/conf.py old mode 100755 new mode 100644 diff --git a/docs/using_examples.rst b/docs/using_examples.rst index 160675cd7..b494c0d76 100644 --- a/docs/using_examples.rst +++ b/docs/using_examples.rst @@ -23,7 +23,7 @@ Our data configs all follow the same structure - image loading, image normalizat c. Image augmentation Targeted image augmentation can increase model robustness. Again, monai provides excellent options for `intensity `_ and `spatial `_ augmentations. For spatial augmentations, ensure that your input and ground truth images are both passed to the transformation, while for intensity augmentations ensure that only the input image is changed. - **Note** For Omnipose, use Omnipose-specific spatial transforms. Naive implementations of flipping/rotation/ other spatial transforms will make augmented vector fields incorrect. + **Note** For instance segmentation, use instance segmentation-specific spatial transforms. Naive implementations of flipping/rotation/ other spatial transforms will make augmented vector fields incorrect. 2. Changes to the `model` config @@ -32,7 +32,7 @@ The model config specifies neural network architecture and optimization paramete `monai `_ provides many cutting edge networks. Crucial parameters to change are the `spatial_dims` if you are changing from a 3D to 2D task, `in_channels` if you want to provide multi-channel images to the network, and `out_channels`. For multi-task learning, it is important to increase the number of `out_channels` so that the task heads are not bottlenecked by the number of `out_channels` in the backbone. b. Modifying the `task_heads` - `task_heads` can be modified by changing their loss function (suggested if you are changing e.g. from labelfree to segmentation), postprocessing (if you are changing from segmentation to omnipose), and `task-head` type (if you are changing from a segmentation network to a GAN). + `task_heads` can be modified by changing their loss function (suggested if you are changing e.g. from labelfree to segmentation), postprocessing (if you are changing from segmentation to instance segmentation), and `task-head` type (if you are changing from a segmentation network to a GAN). `torch `_ and `monai `_ provide many loss functions. We provide basic [postprocessing](cyto_dl/models/utils/postprocessing) functions. Additional `task_heads` can be added for multi-task learning. The name of each `task_head` should line up with the name of an image in your training batch. For example, if our batch looks like `{'raw':torch.Tensor, 'segmentation':torch.Tensor, 'distance':torch.Tensor}` and `raw` is our input image, we should provide `task_heads` for `segmentation` and `distance` that predict a segmentation and distance map respectively. diff --git a/pyproject.toml b/pyproject.toml index 817c0f292..b3e650c32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,9 +40,11 @@ dependencies = [ "tifffile~=2023.2", "tqdm~=4.64", "protobuf<=3.20.1", - "lightning>=2.0.1.post0", + "lightning~=2.0", "ostat>=0.2", "einops>=0.6.1", + "edt>=2.3.1", + "astropy~=5.2", ] requires-python = ">=3.8,<3.11" @@ -58,11 +60,6 @@ spharm = [ "aicscytoparam~=0.1", "pyshtools~=4.9", ] -omnipose = [ - "edt>=0.56.4", - "torchvf~=0.1.2", - "omnipose @ git+https://github.com/kevinjohncutler/omnipose.git@ce2e9aa", -] s3 = [ "boto3>=1.23.5,<1.24.5", "s3fs~=2023.1" @@ -82,7 +79,7 @@ pcloud = [ "torchio>=0.19.1", ] all = [ - "cyto-dl[equiv,spharm,omnipose,s3,torchserve,pcloud]", + "cyto-dl[equiv,spharm,s3,torchserve,pcloud]", ] test = [ "cyto-dl[all]", diff --git a/tests/conftest.py b/tests/conftest.py index e9daebb52..8f92c5f83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ OmegaConf.register_new_resolver("eval", eval) # Experiment configs to test -experiment_types = ["gan", "segmentation", "omnipose", "labelfree", "skoots"] +experiment_types = ["instance_segmentation", "gan", "segmentation", "labelfree"] @pytest.fixture(scope="package", params=experiment_types) diff --git a/tests/test_gradient_flip.py b/tests/test_gradient_flip.py deleted file mode 100644 index fffafcc78..000000000 --- a/tests/test_gradient_flip.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -import torch -from monai.transforms import Flip - -from cyto_dl.models.im2im.utils.omnipose import OmniposeRandFlipd - - -@pytest.mark.parametrize("spatial_axis", [0, 1, 2]) -def test_gradient_flip(spatial_axis): - # flip image and compute gradient - img = torch.rand((30, 30, 30)) - flipper = Flip(spatial_axis=spatial_axis) - # transforms expects CZYX tensor - img_flip = flipper(img.unsqueeze(0)) - img_flip_grad = torch.stack(torch.gradient(img_flip.squeeze(0))) - - # compute gradient and flip image - grad = torch.stack(torch.gradient(img)) - # put grad image into fake omnipose-generated gt - omnipose_im = torch.ones((7, 30, 30, 30)) - omnipose_im[3:6] = grad - grad_flipper = OmniposeRandFlipd(label_keys=["im"], spatial_axis=spatial_axis, prob=1.0) - flip_im = grad_flipper({"im": omnipose_im})["im"] - # extract gradient - flip_grad = flip_im[3:6] - - assert torch.equal(img_flip_grad, flip_grad)