Skip to content

Commit

Permalink
Feature/jepa (#409)
Browse files Browse the repository at this point in the history
* Bump version: 0.1.5 → 0.1.6

* add datamodules

* add tile crop transform

* add contrastive models

* add vic-reg specific head and loss

* rename encoder decoder

* remove gan version

* allow multi-ch reading from comma separatedlist

* add non-overlapping tile cropper

* save out imagesduring inference

* forward is in baseclass;

* add vic reg loss to init

* add heatmap

* update predict step

* transform for generating block-style masks with a fixed number of masked patches

* add jepa models

* patchify speedup

* update jepa

* add guard rails

* update model

* remove non-jepa changes

* oops

* add inference

* add 2d support

* swap to jepa

* split up jepa into base, ijepa, and iwm

* clean up mask generator transform;

* lint

* remove jepaseg class

* combine predictor classes

* generalize source and target, update predict step

* use patchify arg

* update docs

* sanitize patch args

* add ijepa configs

* update iwm configs

* restructure csv saver

* simplify iwm model

* fix domain embedding shape

* add struct to test data

* add iwm configs

* update predict args

* update test/predict transforms

* switch to csv logger

* precommit

* account for grid patch transform

* add pretrain transformers to tests

* update mae head return format

* update resume test

* update pos embed

* precommit

* remove encoder_decoder

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
4 people authored Aug 14, 2024
1 parent 4b94064 commit 26b77af
Show file tree
Hide file tree
Showing 22 changed files with 1,177 additions and 32 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ As part of the [Allen Institute for Cell Science's](https://allencell.org) missi

The bulk of `CytoDL`'s underlying structure bases the [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) organization - we highly recommend that you familiarize yourself with their (short) docs for detailed instructions on running training, overrides, etc.

Our currently available code is roughly split into two domains: image-to-image transformations and representation learning. The image-to-image code (denoted im2im) contains configuration files detailing how to train and predict using models for resolution enhancement using conditional GANs (e.g. predicting 100x images from 20x images), semantic and instance segmentation, and label-free prediction. We also provide configs for Masked Autoencoder (MAE) pretraining using a Vision Transformer (ViT) backbone and for training segmentation decoders from these pretrained features. Representation learning code includes a wide variety of Variational Auto Encoder (VAE) architectures and contrastive learning methods such as [VICReg](https://github.com/facebookresearch/vicreg). Due to dependency issues, equivariant autoencoders are not currently supported on Windows.
Our currently available code is roughly split into two domains: image-to-image transformations and representation learning. The image-to-image code (denoted im2im) contains configuration files detailing how to train and predict using models for resolution enhancement using conditional GANs (e.g. predicting 100x images from 20x images), semantic and instance segmentation, and label-free prediction. We also provide configs for Masked Autoencoder (MAE) and Joint Embedding Prediction Architecture ([JEPA](https://github.com/facebookresearch/jepa)) pretraining on 2D and 3D images using a Vision Transformer (ViT) backbone and for training segmentation decoders from these pretrained features. Representation learning code includes a wide variety of Variational Auto Encoder (VAE) architectures and contrastive learning methods such as [VICReg](https://github.com/facebookresearch/vicreg). Due to dependency issues, equivariant autoencoders are not currently supported on Windows.

As we rely on recent versions of pytorch, users wishing to train and run models on GPU hardware will need up-to-date NVIDIA drivers. Users with older GPUs should not expect code to work out of the box. Similarly, we do not currently support training/predicting on Mac GPUs. In most cases, cpu-based training should work when GPU training fails.

Expand Down
134 changes: 134 additions & 0 deletions configs/data/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
_target_: cyto_dl.datamodules.dataframe.DataframeDatamodule

path:
cache_dir:

num_workers: 0
shuffle: True
batch_size: 1
pin_memory: True

transforms:
train:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
- _target_: monai.transforms.ToTensord
keys: ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}

test:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
- _target_: monai.transforms.ToTensord
keys: ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
# extract out all patches
- _target_: monai.transforms.GridPatchd
keys:
- ${source_col}
patch_size: ${data._aux.patch_shape}
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}

predict:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
- _target_: monai.transforms.ToTensord
keys: ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
# extract out all patches
- _target_: monai.transforms.GridPatchd
keys:
- ${source_col}
patch_size: ${data._aux.patch_shape}

valid:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
- _target_: monai.transforms.ToTensord
keys: ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}

_aux:
_scales_dict:
- - ${source_col}
- [1]
180 changes: 180 additions & 0 deletions configs/data/im2im/iwm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
_target_: cyto_dl.datamodules.dataframe.DataframeDatamodule

path:
cache_dir:

num_workers: 0
shuffle: True
batch_size: 1
pin_memory: True

columns:
- ${source_col}
- struct

transforms:
train:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
# structure and brightfield channels
C: [3, 5]
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
# split two channel image into two separate keys called `source_col`_struct and `source_col`_brightfield
- _target_: monai.transforms.SplitDimd
keys: ${source_col}
output_postfixes:
- struct
- brightfield
# delete original key
- _target_: monai.transforms.DeleteItemsd
keys: ${source_col}
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}_struct
- ${source_col}_brightfield
- _target_: monai.transforms.NormalizeIntensityd
keys:
- ${source_col}_struct
- ${source_col}_brightfield
channel_wise: True
- _target_: monai.transforms.RandSpatialCropSamplesd
keys:
- ${source_col}_struct
- ${source_col}_brightfield
roi_size: ${data._aux.patch_shape}
num_samples: 1
random_size: False
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}

test:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
# structure and brightfield channels
C: [3, 5]
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
# split two channel image into two separate keys called `source_col`_struct and `source_col`_brightfield
- _target_: monai.transforms.SplitDimd
keys: ${source_col}
output_postfixes:
- struct
- brightfield
# delete original key
- _target_: monai.transforms.DeleteItemsd
keys: ${source_col}
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}_struct
- ${source_col}_brightfield
- _target_: monai.transforms.NormalizeIntensityd
keys:
- ${source_col}_struct
- ${source_col}_brightfield
channel_wise: True
# extract out all patches
- _target_: monai.transforms.GridPatchd
keys:
- ${source_col}_struct
- ${source_col}_brightfield
patch_size: ${data._aux.patch_shape}

predict:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
# structure and brightfield channels
C: 3
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys:
- ${source_col}
channel_wise: True
# extract out all patches
- _target_: monai.transforms.GridPatchd
keys:
- ${source_col}
patch_size: ${data._aux.patch_shape}

valid:
_target_: monai.transforms.Compose
transforms:
# channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ]
- _target_: monai.transforms.LoadImaged
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
# structure and brightfield channels
C: [3, 5]
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
keep_size: False
# split two channel image into two separate keys called `source_col`_struct and `source_col`_brightfield
- _target_: monai.transforms.SplitDimd
keys: ${source_col}
output_postfixes:
- struct
- brightfield
# delete original key
- _target_: monai.transforms.DeleteItemsd
keys: ${source_col}
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}_struct
- ${source_col}_brightfield
- _target_: monai.transforms.NormalizeIntensityd
keys:
- ${source_col}_struct
- ${source_col}_brightfield
channel_wise: True
- _target_: monai.transforms.RandSpatialCropSamplesd
keys:
- ${source_col}_struct
- ${source_col}_brightfield
roi_size: ${data._aux.patch_shape}
num_samples: 1
random_size: False
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}

_aux:
46 changes: 46 additions & 0 deletions configs/experiment/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: im2im/ijepa.yaml
- override /model: im2im/ijepa.yaml
- override /callbacks: default.yaml
- override /trainer: gpu.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["dev"]
seed: 12345

experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME

# only source_col is needed for ijepa
source_col: raw
spatial_dims: 3
raw_im_channels: 1

trainer:
max_epochs: 100
gradient_clip_val: 10

data:
path: ${paths.data_dir}/example_experiment_data/segmentation
cache_dir: ${paths.data_dir}/example_experiment_data/cache
batch_size: 1
_aux:
# 2D
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

model:
_aux:
# 3D
num_patches: [8, 8, 8]
# 2d
# num_patches: [8, 8]
Loading

0 comments on commit 26b77af

Please sign in to comment.