Skip to content

Commit

Permalink
Bugfix/plugin merging (#351)
Browse files Browse the repository at this point in the history
* remove target from prediction

* use base image as target if no merging occurs

* convert tensor to uint8 nparray)

* update configs

* precommit

* switch default to zeros

* update merge with inverted mask

* fix mask head

* always cast to float;

* float before numpy

* fix testing transforms

* decrease patch overlap

* lint

* update configs

* change default

* lint

* set default cache path

* add test, workers,cache changes

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
3 people authored May 2, 2024
1 parent d3107f9 commit 251f691
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 80 deletions.
105 changes: 46 additions & 59 deletions configs/data/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ _target_: cyto_dl.datamodules.dataframe.DataframeDatamodule
path:
cache_dir:

num_workers: 0
num_workers: 4
batch_size: 1
pin_memory: True
split_column:
Expand All @@ -19,6 +19,8 @@ transforms:
train:
_target_: monai.transforms.Compose
transforms:
# remove nan keys
- _target_: cyto_dl.datamodules.dataframe.utils.RemoveNaNKeysd
# load
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
Expand All @@ -34,10 +36,19 @@ transforms:
C: 0
- _target_: monai.transforms.LoadImaged
keys: ${target_col2}
allow_missing_keys: True
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 0
- _target_: monai.transforms.ThresholdIntensityd
allow_missing_keys: True
keys:
- ${target_col1}
- ${target_col2}
threshold: 0.1
above: False
cval: 1
# load merging mask - assumed not to exist by default
- _target_: cyto_dl.image.io.PolygonLoaderd
keys:
Expand All @@ -62,68 +73,29 @@ transforms:
- ${target_col2}
base_image_key: ${base_image_col}
output_name: target
#crop
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
- target
- ${exclude_mask_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}

# augmentation
- _target_: monai.transforms.RandRotate90d
keys:
- ${source_col}
- target
- ${exclude_mask_col}
prob: 0.5
spatial_axes: [0, 1]
- _target_: monai.transforms.RandFlipd
keys:
- ${source_col}
- target
- ${exclude_mask_col}
prob: 0.5
spatial_axis: 0
- _target_: monai.transforms.RandFlipd
keys:
- ${source_col}
- target
- ${exclude_mask_col}
prob: 0.5
spatial_axis: 1
- _target_: monai.transforms.RandHistogramShiftd
prob: 0.1
keys:
- ${source_col}
- target
- ${exclude_mask_col}
num_control_points: [90, 500]
- _target_: monai.transforms.RandStdShiftIntensityd
prob: 0.1
keys:
- ${source_col}
- target
- ${exclude_mask_col}
factors: 0.1
- _target_: monai.transforms.RandAdjustContrastd
prob: 0.1
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- target
- ${exclude_mask_col}
gamma: [0.9, 1.5]
- _target_: monai.transforms.ToTensord
dtype: float16

#crop
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
- target
- ${exclude_mask_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: ${data._aux.patch_per_image}
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}

test:
_target_: monai.transforms.Compose
transforms:
# remove nan keys
- _target_: cyto_dl.datamodules.dataframe.utils.RemoveNaNKeysd
# load
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
Expand All @@ -139,6 +111,7 @@ transforms:
C: 0
- _target_: monai.transforms.LoadImaged
keys: ${target_col2}
allow_missing_keys: True
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
Expand All @@ -158,7 +131,6 @@ transforms:
# normalize
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}

channel_wise: True
# merge masks
- _target_: cyto_dl.image.transforms.merge.Merged
Expand All @@ -168,11 +140,13 @@ transforms:
- ${target_col2}
base_image_key: ${base_image_col}
output_name: target

- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- target
- ${exclude_mask_col}
dtype: float16

predict:
_target_: monai.transforms.Compose
Expand All @@ -191,12 +165,12 @@ transforms:
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- target
- ${exclude_mask_col}

valid:
_target_: monai.transforms.Compose
transforms:
# remove nan keys
- _target_: cyto_dl.datamodules.dataframe.utils.RemoveNaNKeysd
# load
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
Expand All @@ -212,11 +186,20 @@ transforms:
C: 0
- _target_: monai.transforms.LoadImaged
keys: ${target_col2}
allow_missing_keys: True
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 0

- _target_: monai.transforms.ThresholdIntensityd
allow_missing_keys: True
keys:
- ${target_col1}
- ${target_col2}
threshold: 0.1
above: False
cval: 1
# load merging mask - assumed not to exist by default
- _target_: cyto_dl.image.io.PolygonLoaderd
keys:
Expand Down Expand Up @@ -244,21 +227,25 @@ transforms:
base_image_key: ${base_image_col}
output_name: target

# #crop
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- target
- ${exclude_mask_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
- _target_: monai.transforms.ToTensord
dtype: float16

# #crop
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
- target
- ${exclude_mask_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: ${data._aux.patch_per_image}
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}

_aux:
patch_per_image: 1
_scales_dict:
- - target
- [1]
Expand Down
10 changes: 10 additions & 0 deletions configs/experiment/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ ckpt_path: null # must override for prediction
experiment_name: experiment_name
run_name: run_name

persist_cache: False
test: False

# manifest columns
source_col: raw
target_col1: seg1
Expand All @@ -41,6 +44,8 @@ trainer:

data:
path: MUST_OVERRIDE # string path to manifest
cache_dir: ${paths.output_dir}/cache # string path to cache_dir (this speeds up data loading)
num_workers: MUST_OVERRIDE # this should be set based on the number of available CPU cores
split_column: null
batch_size: 1
_aux:
Expand All @@ -49,3 +54,8 @@ data:
paths:
output_dir: MUST_OVERRIDE
work_dir: ${paths.output_dir} # it's unclear to me if this is necessary or used

model:
_aux:
filters: MUST_OVERRIDE
overlap: 0
18 changes: 10 additions & 8 deletions configs/model/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ backbone:
spatial_dims: ${spatial_dims}
in_channels: ${raw_im_channels}
out_channels: 1
strides: ${model._aux.strides}
kernel_size: ${model._aux.kernel_size}
upsample_kernel_size: ${model._aux.upsample_kernel_size}
strides: [1, 2, 2]
kernel_size: [3, 3, 3]
upsample_kernel_size: [2, 2]
dropout: 0.0
res_block: True
filters: ${model._aux.filters}

task_heads:
target:
_target_: cyto_dl.nn.BaseHead
_target_: cyto_dl.nn.head.MaskHead
mask_key: ${exclude_mask_col}
loss:
_target_: monai.losses.MaskedDiceLoss
sigmoid: True
Expand All @@ -29,6 +31,7 @@ task_heads:
prediction:
_target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold
method: "threshold_otsu"

save_input: True

optimizer:
Expand All @@ -47,10 +50,9 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: ${model._aux.overlap}
mode: "gaussian"

_aux:
strides:
kernel_size:
upsample_kernel_size:
filters:
overlap:
2 changes: 2 additions & 0 deletions cyto_dl/image/io/polygon_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __call__(self, input_dict):
mask = np.logical_or(mask, polygon2mask(mask_shape, p))
if self.propagate_3d:
mask = np.stack([mask] * input_dict[self.shape_reference_key].shape[1])
# all ones except for regions in polygon
mask = ~mask
input_dict[key] = np.expand_dims(mask > 0, 0)
elif self.missing_key_mode == "raise":
raise KeyError(
Expand Down
24 changes: 16 additions & 8 deletions cyto_dl/image/transforms/merge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Union

from monai.transforms import Transform
Expand Down Expand Up @@ -38,12 +39,6 @@ def __call__(self, input_dict):
input_dict: Dict[str, torch.Tensor]
dict of CZYX tensors/metadata/paths
"""
# no merging mask, return original dict
if self.mask_key not in input_dict or input_dict[self.mask_key] is None:
return input_dict

mask = input_dict[self.mask_key]

if self.base_image_key not in input_dict:
raise KeyError(
f"key `{self.base_image_key}` not available. Available keys are {input_dict.keys()}"
Expand All @@ -55,15 +50,28 @@ def __call__(self, input_dict):
f"Base image name `{base_image_name}` must match provided image keys `{self.image_keys}`"
)

if self.mask_key not in input_dict or input_dict[self.mask_key] is None:
# no merging mask, return original dict
input_dict[self.output_name] = deepcopy(input_dict[base_image_name])
# remove mask key if it exists
input_dict.pop(self.mask_key, None)
return input_dict

mask = input_dict[self.mask_key].astype(bool)
# From polygoan loader, 1 is everything outside of the polygon, 0 inside the polygon.
# For merging we want to inver this
mask = ~mask

for key in self.image_keys:
if key not in input_dict.keys():
raise KeyError(
f"key `{key}` not available. Available keys are {input_dict.keys()}"
)

base_image = input_dict[base_image_name]
self.image_keys.remove(base_image_name)
merge_image = input_dict[self.image_keys[0]]
merge_image = input_dict[
self.image_keys[0] if self.image_keys[1] == base_image_name else self.image_keys[1]
]

input_dict[self.output_name] = (base_image * ~mask) + (merge_image * mask)

Expand Down
5 changes: 4 additions & 1 deletion cyto_dl/models/im2im/utils/postprocessing/auto_thresh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib
from typing import Optional, Union

import numpy as np


class AutoThreshold:
def __init__(self, method: Optional[Union[float, str]] = None):
Expand All @@ -21,6 +23,7 @@ def thresh_func(image):
self.thresh_func = thresh_func

def __call__(self, image):
image = image.detach().cpu().float().numpy()
if self.thresh_func is None:
return image
return image > self.thresh_func(image)
return (image > self.thresh_func(image)).astype(np.uint8)
1 change: 1 addition & 0 deletions cyto_dl/nn/head/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_head import BaseHead
from .gan_head import GANHead
from .gan_head_superres import GANHead_resize
from .mask_head import MaskHead
from .res_blocks_head import ResBlocksHead
5 changes: 1 addition & 4 deletions cyto_dl/nn/head/mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@ def __init__(
save_input=False
Whether to save out example input images during training
"""
super().__init__()
self.loss = loss
self.postprocess = postprocess
super().__init__(loss, postprocess=postprocess, save_input=save_input)
self.mask_key = mask_key

self.model = torch.nn.Sequential(torch.nn.Identity())
self.save_input = save_input

def _calculate_loss(self, y_hat, y, mask):
return self.loss(y_hat, y, mask)
Expand Down

0 comments on commit 251f691

Please sign in to comment.