From 4860c8bdb7a051b4d8a16281646ec0f4671d733d Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:44:50 -0800 Subject: [PATCH] Feature/head cleanup (#336) * only run heads that dont have existing outputs * make gan subclass of multitask model * update saving code * remove conv_proj_layer * make superres head class gan and resblockshead * remove metric calculation * update configs * update readme * remove conv_proj_layer; * precommit --------- Co-authored-by: Benjamin Morris --- README.md | 2 +- configs/data/im2im/gan_superres.yaml | 205 ++++++++++++++++++ configs/data/im2im/segmentation_superres.yaml | 192 ++++++++++++++++ configs/experiment/im2im/gan_superres.yaml | 44 ++++ .../im2im/segmentation_superres.yaml | 36 +++ configs/model/im2im/gan.yaml | 2 +- configs/model/im2im/gan_superres.yaml | 81 +++++++ configs/model/im2im/instance_seg.yaml | 2 +- configs/model/im2im/labelfree.yaml | 2 +- configs/model/im2im/segmentation.yaml | 2 +- configs/model/im2im/segmentation_plugin.yaml | 2 +- .../model/im2im/segmentation_superres.yaml | 62 ++++++ .../model/im2im/vit_segmentation_decoder.yaml | 2 +- cyto_dl/models/im2im/gan.py | 76 ++----- cyto_dl/models/im2im/multi_task.py | 88 +++++--- cyto_dl/nn/__init__.py | 2 +- cyto_dl/nn/head/__init__.py | 1 - cyto_dl/nn/head/base_head.py | 70 +++--- cyto_dl/nn/head/conv_proj_layer.py | 54 ----- cyto_dl/nn/head/gan_head.py | 16 +- cyto_dl/nn/head/gan_head_superres.py | 106 +++------ cyto_dl/nn/head/mae_head.py | 11 +- cyto_dl/nn/head/mask_head.py | 15 +- cyto_dl/nn/head/res_blocks_head.py | 9 +- 24 files changed, 772 insertions(+), 310 deletions(-) create mode 100644 configs/data/im2im/gan_superres.yaml create mode 100644 configs/data/im2im/segmentation_superres.yaml create mode 100644 configs/experiment/im2im/gan_superres.yaml create mode 100644 configs/experiment/im2im/segmentation_superres.yaml create mode 100644 configs/model/im2im/gan_superres.yaml create mode 100644 configs/model/im2im/segmentation_superres.yaml delete mode 100644 cyto_dl/nn/head/conv_proj_layer.py diff --git a/README.md b/README.md index b6c3cb20f..ecb4775c9 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ As part of the [Allen Institute for Cell Science's](allencell.org) mission to un The bulk of `CytoDL`'s underlying structure bases the [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) organization - we highly recommend that you familiarize yourself with their (short) docs for detailed instructions on running training, overrides, etc. -Our currently available code is roughly split into two domains: image-to-image transformations and representation learning. The image-to-image code (denoted im2im) contains configuration files detailing how to train and predict using models for resolution enhancement using conditional GANs (e.g. predicting 100x images from 20x images), semantic and instance segmentation, and label-free prediction. We also provide configs for Masked Autoencoder (MAE) pretraining using a Vision Transformer (ViT) backbone (currently only supported for 3D images) and for training segmentation decoders from these pretrained features. Representation learning code includes a wide variety of Variational Auto Encoder (VAE) architectures. Due to dependency issues, equivariant autoencoders are not currently supported on Windows. +Our currently available code is roughly split into two domains: image-to-image transformations and representation learning. The image-to-image code (denoted im2im) contains configuration files detailing how to train and predict using models for resolution enhancement using conditional GANs (e.g. predicting 100x images from 20x images), semantic and instance segmentation, and label-free prediction. We also provide configs for Masked Autoencoder (MAE) pretraining using a Vision Transformer (ViT) backbone and for training segmentation decoders from these pretrained features. Representation learning code includes a wide variety of Variational Auto Encoder (VAE) architectures. Due to dependency issues, equivariant autoencoders are not currently supported on Windows. As we rely on recent versions of pytorch, users wishing to train and run models on GPU hardware will need up-to-date NVIDIA drivers. Users with older GPUs should not expect code to work out of the box. Similarly, we do not currently support training/predicting on Mac GPUs. In most cases, cpu-based training should work when GPU training fails. diff --git a/configs/data/im2im/gan_superres.yaml b/configs/data/im2im/gan_superres.yaml new file mode 100644 index 000000000..c93816740 --- /dev/null +++ b/configs/data/im2im/gan_superres.yaml @@ -0,0 +1,205 @@ +_target_: cyto_dl.datamodules.dataframe.DataframeDatamodule + +path: +cache_dir: + +num_workers: 0 +batch_size: 1 +pin_memory: True +split_column: +columns: + - ${source_col} + - ${target_col} + +transforms: + train: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + # target is the nuclear dyeimage + - _target_: monai.transforms.LoadImaged + keys: ${target_col} + reader: + - _target_: + cyto_dl.image.io.MonaiBioReader + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + # source image is the segmentation + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${data.columns} + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + - _target_: monai.transforms.ToTensord + keys: ${data.columns} + # GANs use Tanh as final activation, target has to be in range [-1,1] + - _target_: monai.transforms.ScaleIntensityRangePercentilesd + keys: ${target_col} + lower: 0.01 + upper: 99.99 + b_min: -1 + b_max: 1 + clip: True + # input to synthetic image generation model is a semantic segmentation + - _target_: monai.transforms.ThresholdIntensityd + keys: ${source_col} + threshold: 0.1 + above: False + cval: 1 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: ${data.columns} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + + test: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + # target is the nuclear dyeimage + - _target_: monai.transforms.LoadImaged + keys: ${target_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + # source image is the segmentation + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${data.columns} + #crop image so that it is divisible by 4 for downsampling + - _target_: monai.transforms.CenterSpatialCropd + keys: ${data.columns} + roi_size: [64, -1, -1] + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + - _target_: monai.transforms.ToTensord + keys: ${data.columns} + # GANs use Tanh as final activation, target has to be in range [-1,1] + - _target_: monai.transforms.ScaleIntensityRangePercentilesd + keys: ${target_col} + lower: 0.01 + upper: 99.99 + b_min: -1 + b_max: 1 + clip: True + # input to synthetic image generation model is a semantic segmentation + - _target_: monai.transforms.ThresholdIntensityd + keys: ${source_col} + threshold: 0.1 + above: False + cval: 1 + + predict: + _target_: monai.transforms.Compose + transforms: + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + # source image is the segmentation + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${source_col} + #crop image so that it is divisible by 4 for downsampling + - _target_: monai.transforms.CenterSpatialCropd + keys: ${source_col} + roi_size: [64, -1, -1] + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + + - _target_: monai.transforms.ToTensord + keys: ${source_col} + # input to synthetic image generation model is a semantic segmentation + - _target_: monai.transforms.ThresholdIntensityd + keys: ${source_col} + threshold: 0.1 + above: False + cval: 1 + + valid: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + # target is the nuclear dyeimage + - _target_: monai.transforms.LoadImaged + keys: ${target_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + # source image is the segmentation + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${data.columns} + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + + - _target_: monai.transforms.ToTensord + keys: ${data.columns} + # GANs use Tanh as final activation, target has to be in range [-1,1] + - _target_: monai.transforms.ScaleIntensityRangePercentilesd + keys: ${target_col} + lower: 0.01 + upper: 99.99 + b_min: -1 + b_max: 1 + clip: True + # input to synthetic image generation model is a semantic segmentation + - _target_: monai.transforms.ThresholdIntensityd + keys: ${source_col} + threshold: 0.1 + above: False + cval: 1 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: ${data.columns} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + +_aux: + _scales_dict: + - - ${target_col} + - [4] + - - ${source_col} + - [1] diff --git a/configs/data/im2im/segmentation_superres.yaml b/configs/data/im2im/segmentation_superres.yaml new file mode 100644 index 000000000..e75416b1a --- /dev/null +++ b/configs/data/im2im/segmentation_superres.yaml @@ -0,0 +1,192 @@ +_target_: cyto_dl.datamodules.dataframe.DataframeDatamodule + +path: +cache_dir: + +num_workers: 0 +batch_size: 1 +pin_memory: True +split_column: +columns: + - ${source_col} + - ${target_col} + +transforms: + train: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + - _target_: monai.transforms.LoadImaged + keys: ${target_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${data.columns} + # keep that target at full resolution to test superres + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + - _target_: monai.transforms.ToTensord + keys: ${data.columns} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: True + - _target_: monai.transforms.ThresholdIntensityd + keys: ${target_col} + threshold: 0.1 + above: False + cval: 1 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: ${data.columns} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + - _target_: monai.transforms.RandHistogramShiftd + prob: 0.1 + keys: ${source_col} + num_control_points: [90, 500] + + - _target_: monai.transforms.RandStdShiftIntensityd + prob: 0.1 + keys: ${source_col} + factors: 0.1 + + - _target_: monai.transforms.RandAdjustContrastd + prob: 0.1 + keys: ${source_col} + gamma: [0.9, 1.5] + + test: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + - _target_: monai.transforms.LoadImaged + keys: ${target_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${data.columns} + #crop image so that it is divisible by 4 for downsampling + - _target_: monai.transforms.CenterSpatialCropd + keys: ${data.columns} + roi_size: [64, -1, -1] + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + - _target_: monai.transforms.ToTensord + keys: ${data.columns} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: True + - _target_: monai.transforms.ThresholdIntensityd + keys: ${target_col} + threshold: 0.1 + above: False + cval: 1 + + predict: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${source_col} + #crop image so that it is divisible by 4 for downsampling + - _target_: monai.transforms.CenterSpatialCropd + keys: ${source_col} + roi_size: [64, -1, -1] + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + - _target_: monai.transforms.ToTensord + keys: ${source_col} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: True + + valid: + _target_: monai.transforms.Compose + transforms: + # channels are [blank, membrane,blank, structure, blank, nuclear dye, brightfield ] + - _target_: monai.transforms.LoadImaged + keys: ${source_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + # channels are [nucseg, cellseg, nuclear boundary seg, cell boundary seg] + - _target_: monai.transforms.LoadImaged + keys: ${target_col} + reader: + - _target_: cyto_dl.image.io.MonaiBioReader + dimension_order_out: ${eval:'"ZYX" if ${spatial_dims}==3 else "YX"'} + C: 0 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" + keys: ${data.columns} + - _target_: monai.transforms.Zoomd + keys: ${source_col} + zoom: 0.25 + keep_size: False + - _target_: monai.transforms.ToTensord + keys: ${data.columns} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: True + - _target_: monai.transforms.ThresholdIntensityd + keys: ${target_col} + threshold: 0.1 + above: False + cval: 1 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: ${data.columns} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + +_aux: + _scales_dict: + - - ${target_col} + - [4] + - - ${source_col} + - [1] diff --git a/configs/experiment/im2im/gan_superres.yaml b/configs/experiment/im2im/gan_superres.yaml new file mode 100644 index 000000000..0fa6e0adb --- /dev/null +++ b/configs/experiment/im2im/gan_superres.yaml @@ -0,0 +1,44 @@ +# @package _global_ +# to execute this experiment run: +# python train.py experiment=example +defaults: + - override /data: im2im/gan_superres.yaml + - override /model: im2im/gan_superres.yaml + - override /callbacks: default.yaml + - override /trainer: gpu.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: [dev] +seed: 12345 + +experiment_name: YOUR_EXP_NAME +run_name: YOUR_RUN_NAME +# for synthetic image generation, our source is a segmentation and we are predicting a raw image +source_col: seg +target_col: raw +spatial_dims: 3 +raw_im_channels: 1 + +trainer: + max_epochs: 100 + +data: + path: ${paths.data_dir}/example_experiment_data/segmentation + cache_dir: ${paths.data_dir}/example_experiment_data/cache + subsample: + batch_size: 1 + _aux: + # 2D + # patch_shape: [64, 64] + # 3D + patch_shape: [16, 32, 32] + +callbacks: + model_checkpoint: + monitor: val/loss/generator_loss + + early_stopping: + monitor: val/loss/generator_loss diff --git a/configs/experiment/im2im/segmentation_superres.yaml b/configs/experiment/im2im/segmentation_superres.yaml new file mode 100644 index 000000000..7db7dbba9 --- /dev/null +++ b/configs/experiment/im2im/segmentation_superres.yaml @@ -0,0 +1,36 @@ +# @package _global_ +# to execute this experiment run: +# python train.py experiment=example +defaults: + - override /data: im2im/segmentation_superres.yaml + - override /model: im2im/segmentation_superres.yaml + - override /callbacks: default.yaml + - override /trainer: gpu.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["dev"] +seed: 12345 + +experiment_name: YOUR_EXP_NAME +run_name: YOUR_RUN_NAME +source_col: raw +target_col: seg +spatial_dims: 3 +raw_im_channels: 1 + +trainer: + max_epochs: 100 + +data: + path: ${paths.data_dir}/example_experiment_data/segmentation + cache_dir: ${paths.data_dir}/example_experiment_data/cache + subsample: + batch_size: 1 + _aux: + # 2D + # patch_shape: [64, 64] + # 3D + patch_shape: [16, 32, 32] diff --git a/configs/model/im2im/gan.yaml b/configs/model/im2im/gan.yaml index 1c84ba019..5e49ec94f 100644 --- a/configs/model/im2im/gan.yaml +++ b/configs/model/im2im/gan.yaml @@ -66,7 +66,7 @@ _aux: scales: 1 reconstruction_loss: _target_: torch.nn.MSELoss - save_raw: True + save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/configs/model/im2im/gan_superres.yaml b/configs/model/im2im/gan_superres.yaml new file mode 100644 index 000000000..d391a491a --- /dev/null +++ b/configs/model/im2im/gan_superres.yaml @@ -0,0 +1,81 @@ +_target_: cyto_dl.models.im2im.GAN + +save_images_every_n_epochs: 1 +save_dir: ${paths.output_dir} + +x_key: ${source_col} + +backbone: + _target_: monai.networks.nets.DynUNet + spatial_dims: ${spatial_dims} + in_channels: ${raw_im_channels} + out_channels: 1 + strides: [1, 2, 2] + kernel_size: [3, 3, 3] + upsample_kernel_size: [2, 2] + filters: [16, 32, 64] + dropout: 0.0 + res_block: True + +task_heads: ${kv_to_dict:${model._aux._tasks}} + +discriminator: + _target_: cyto_dl.nn.discriminators.MultiScaleDiscriminator + n_scales: 1 + input_nc: 2 #conditioning image+real/fake image + n_layers: 1 + ndf: 16 #number of filters per layer + dim: ${spatial_dims} + +optimizer: + generator: + _partial_: True + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0001 + betas: [0.5, 0.999] + discriminator: + _partial_: True + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0001 + betas: [0.5, 0.999] + +lr_scheduler: + generator: + _partial_: True + _target_: torch.optim.lr_scheduler.ExponentialLR + gamma: 0.998 + discriminator: + _partial_: True + _target_: torch.optim.lr_scheduler.ExponentialLR + gamma: 0.998 + +inference_args: + sw_batch_size: 1 + roi_size: ${data._aux.patch_shape} + overlap: 0.25 + mode: "gaussian" + +_aux: + _tasks: + - - ${target_col} + - _target_: cyto_dl.nn.GANHead_resize + gan_loss: + _target_: cyto_dl.nn.losses.Pix2PixHD + scales: 1 + reconstruction_loss: + _target_: torch.nn.MSELoss + save_input: True + postprocess: + input: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 + prediction: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 + in_channels: 1 + out_channels: 1 + upsample_ratio: 4 + resolution: hr + upsample_method: nontrainable diff --git a/configs/model/im2im/instance_seg.yaml b/configs/model/im2im/instance_seg.yaml index 3d6913ee8..31bc89fb4 100644 --- a/configs/model/im2im/instance_seg.yaml +++ b/configs/model/im2im/instance_seg.yaml @@ -44,7 +44,7 @@ _aux: loss: _target_: cyto_dl.models.im2im.utils.InstanceSegLoss dim: ${spatial_dims} - save_raw: True + save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/configs/model/im2im/labelfree.yaml b/configs/model/im2im/labelfree.yaml index 6713090cc..1f0c15a8f 100644 --- a/configs/model/im2im/labelfree.yaml +++ b/configs/model/im2im/labelfree.yaml @@ -51,4 +51,4 @@ _aux: prediction: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel rescale_dtype: numpy.uint8 - save_raw: True + save_input: True diff --git a/configs/model/im2im/segmentation.yaml b/configs/model/im2im/segmentation.yaml index e36397dce..873bf6950 100644 --- a/configs/model/im2im/segmentation.yaml +++ b/configs/model/im2im/segmentation.yaml @@ -54,4 +54,4 @@ _aux: activation: _target_: torch.nn.Sigmoid rescale_dtype: numpy.uint8 - save_raw: True + save_input: True diff --git a/configs/model/im2im/segmentation_plugin.yaml b/configs/model/im2im/segmentation_plugin.yaml index e50eba839..d5f38b30c 100644 --- a/configs/model/im2im/segmentation_plugin.yaml +++ b/configs/model/im2im/segmentation_plugin.yaml @@ -31,7 +31,7 @@ task_heads: activation: _target_: torch.nn.Sigmoid rescale_dtype: numpy.uint8 - save_raw: True + save_input: True optimizer: generator: diff --git a/configs/model/im2im/segmentation_superres.yaml b/configs/model/im2im/segmentation_superres.yaml new file mode 100644 index 000000000..5b175caee --- /dev/null +++ b/configs/model/im2im/segmentation_superres.yaml @@ -0,0 +1,62 @@ +_target_: cyto_dl.models.im2im.MultiTaskIm2Im + +save_images_every_n_epochs: 1 +save_dir: ${paths.output_dir} + +x_key: ${source_col} + +backbone: + _target_: monai.networks.nets.DynUNet + spatial_dims: ${spatial_dims} + in_channels: ${raw_im_channels} + out_channels: 1 + strides: [1, 2, 2] + kernel_size: [3, 3, 3] + upsample_kernel_size: [2, 2] + filters: [16, 32, 64] + dropout: 0.0 + res_block: True + +task_heads: ${kv_to_dict:${model._aux._tasks}} + +optimizer: + generator: + _partial_: True + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0001 + +lr_scheduler: + generator: + _partial_: True + _target_: torch.optim.lr_scheduler.ExponentialLR + gamma: 0.995 + +inference_args: + sw_batch_size: 1 + roi_size: ${data._aux.patch_shape} + overlap: 0.25 + mode: "gaussian" + +_aux: + _tasks: + - - ${target_col} + - _target_: cyto_dl.nn.ResBlocksHead + loss: + _target_: monai.losses.DiceCELoss + sigmoid: True + postprocess: + input: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 + prediction: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + activation: + _target_: torch.nn.Sigmoid + rescale_dtype: numpy.uint8 + save_input: True + in_channels: 1 + out_channels: 1 + upsample_ratio: 4 + resolution: hr + upsample_method: nontrainable diff --git a/configs/model/im2im/vit_segmentation_decoder.yaml b/configs/model/im2im/vit_segmentation_decoder.yaml index 88b7a39a3..25366e638 100644 --- a/configs/model/im2im/vit_segmentation_decoder.yaml +++ b/configs/model/im2im/vit_segmentation_decoder.yaml @@ -47,7 +47,7 @@ _aux: loss: _target_: cyto_dl.models.im2im.utils.InstanceSegLoss dim: ${spatial_dims} - save_raw: True + save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/cyto_dl/models/im2im/gan.py b/cyto_dl/models/im2im/gan.py index 4cc7cca08..cf5b6a246 100644 --- a/cyto_dl/models/im2im/gan.py +++ b/cyto_dl/models/im2im/gan.py @@ -8,7 +8,7 @@ from monai.inferers import sliding_window_inference from torchmetrics import MeanMetric -from cyto_dl.models.base_model import BaseModel +from cyto_dl.models.im2im.multi_task import MultiTaskIm2Im _DEFAULT_METRICS = { "train/loss/discriminator_loss": MeanMetric(), @@ -23,7 +23,7 @@ } -class GAN(BaseModel): +class GAN(MultiTaskIm2Im): """Basic GAN model.""" def __init__( @@ -64,23 +64,19 @@ def __init__( """ metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS) - super().__init__(metrics=metrics, **base_kwargs) + super().__init__( + metrics=metrics, backbone=backbone, task_heads=task_heads, x_key=x_key, **base_kwargs + ) self.automatic_optimization = False - for stage in ("train", "val", "test", "predict"): - (Path(save_dir) / f"{stage}_images").mkdir(exist_ok=True, parents=True) if compile is True and not sys.platform.startswith("win"): - self.backbone = torch.compile(backbone) self.discriminator = torch.compile(discriminator) - self.task_heads = torch.nn.ModuleDict( - {k: torch.compile(v) for k, v in task_heads.items()} - ) else: - self.backbone = backbone self.discriminator = discriminator - self.task_heads = torch.nn.ModuleDict(task_heads) assert len(self.task_heads.keys()) == 1, "Only single-head GANs are supported currently." + self.inference_heads = list(self.task_heads.keys()) + for k, head in self.task_heads.items(): head.update_params({"head_name": k, "x_key": x_key, "save_dir": save_dir}) @@ -111,10 +107,6 @@ def _train_forward(self, batch, stage, save_image, run_heads): for task in run_heads } - def forward(self, x, run_heads): - z = self.backbone(x) - return {task: self.task_heads[task](z) for task in run_heads} - def _inference_forward(self, batch, stage, save_image, run_heads): """during inference, we need to calculate per-fov loss/metrics/postprocessing. @@ -143,31 +135,6 @@ def _inference_forward(self, batch, stage, save_image, run_heads): for head_name, head in self.task_heads.items() } - def run_forward(self, batch, stage, save_image, run_heads): - if stage in ("train", "val"): - return self._train_forward(batch, stage, save_image, run_heads) - return self._inference_forward(batch, stage, save_image, run_heads) - - def should_save_image(self, batch_idx, stage): - return stage in ("test", "predict") or ( - batch_idx == 0 # noqa: FURB124 - and (self.current_epoch + 1) % self.hparams.save_images_every_n_epochs == 0 - ) - - def _sum_losses(self, losses): - summ = 0 - for k, v in losses.items(): - summ += v - losses["loss"] = summ - return losses - - def _get_run_heads(self, batch, stage): - if stage not in ("test", "predict"): - run_heads = [key for key in self.task_heads.keys() if key in batch] - else: - run_heads = list(self.task_heads.keys()) - return run_heads - def _extract_loss(self, outs, loss_type): loss = { f"{head_name}_{loss_type}": head_result[loss_type] @@ -176,13 +143,8 @@ def _extract_loss(self, outs, loss_type): return self._sum_losses(loss) def model_step(self, stage, batch, batch_idx): - batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] - # convert monai metatensors to tensors - for k, v in batch.items(): - if isinstance(v, MetaTensor): - batch[k] = v.as_tensor() - - run_heads = self._get_run_heads(batch, stage) + run_heads, _ = self._get_run_heads(batch, stage, batch_idx) + batch = self._to_tensor(batch) outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) loss_D = self._extract_loss(outs, "loss_D") @@ -213,20 +175,10 @@ def model_step(self, stage, batch, batch_idx): return loss_dict, None, None def predict_step(self, batch, batch_idx): - batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] - # convert monai metatensors to tensors - for k, v in batch.items(): - if isinstance(v, MetaTensor): - batch[k] = v.as_tensor() stage = "predict" - run_heads = self._get_run_heads(batch, stage) - outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - # create input-> per head output mapping - io_map = {} - for head, output in outs.items(): - head_io_map = output["save_path"] - for in_file, out_file in zip(head_io_map["input"], head_io_map["output"]): - if in_file not in io_map: - io_map[in_file] = {} - io_map[in_file][head] = out_file + run_heads, io_map = self._get_run_heads(batch, stage, batch_idx) + if len(run_heads) > 0: + batch = self._to_tensor(batch) + save_image = self.should_save_image(batch_idx, stage) + self.run_forward(batch, stage, save_image, run_heads) return io_map diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index ada16b853..860cbd7c1 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -153,48 +153,72 @@ def should_save_image(self, batch_idx, stage): ) def _sum_losses(self, losses): - summ = 0 - for k, v in losses.items(): - summ += v - losses["loss"] = summ + losses["loss"] = torch.sum(torch.stack(list(losses.values()))) return losses - def _get_run_heads(self, batch, stage): - if stage not in ("test", "predict"): + def _get_unrun_heads(self, io_map): + """returns heads that don't have outputs yet.""" + updated_run_heads = [] + # check that all output files exist for each head + for head, head_io_map in io_map.items(): + for fn in head_io_map["output"]: + if not Path(fn).exists(): + updated_run_heads.append(head) + break + return updated_run_heads + + def _combine_io_maps(self, io_maps): + """aggregate io_maps from per-head to per-input image.""" + io_map = {} + # create input-> per head output mapping + for head, head_io_map in io_maps.items(): + for in_file, out_file in zip(head_io_map["input"], head_io_map["output"]): + if in_file not in io_map: + io_map[in_file] = {} + io_map[in_file][head] = out_file + return io_map + + def _get_run_heads(self, batch, stage, batch_idx): + """Get heads that are either specified as inference heads and don't have outputs yet or all + heads.""" + run_heads = self.inference_heads + if stage in ("train", "val"): run_heads = [key for key in self.task_heads.keys() if key in batch] - else: - run_heads = self.inference_heads - return run_heads - def model_step(self, stage, batch, batch_idx): - batch["filenames"] = batch[self.hparams.x_key].meta.get("filename_or_obj", batch_idx) - # convert monai metatensors to tensors + io_map = { + h: self.task_heads[h].generate_io_map( + batch[self.hparams.x_key].meta, stage, batch_idx, self.global_step + ) + for h in run_heads + } + + if stage == "predict": + # only run heads that don't have outputs yet for prediction + run_heads = self._get_unrun_heads(io_map) + io_map = self._combine_io_maps(io_map) + + return run_heads, io_map + + def _to_tensor(self, batch): + """convert monai metatensors to tensors.""" for k, v in batch.items(): if isinstance(v, MetaTensor): batch[k] = v.as_tensor() + return batch - run_heads = self._get_run_heads(batch, stage) - outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - + def model_step(self, stage, batch, batch_idx): + run_heads, _ = self._get_run_heads(batch, stage, batch_idx) + batch = self._to_tensor(batch) + save_image = self.should_save_image(batch_idx, stage) + outs = self.run_forward(batch, stage, save_image, run_heads) losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()} - losses = self._sum_losses(losses) - return losses, None, None + return self._sum_losses(losses), None, None def predict_step(self, batch, batch_idx): stage = "predict" - batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] - # convert monai metatensors to tensors - for k, v in batch.items(): - if isinstance(v, MetaTensor): - batch[k] = v.as_tensor() - run_heads = self._get_run_heads(batch, stage) - outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - # create input-> per head output mapping - io_map = {} - for head, output in outs.items(): - head_io_map = output["save_path"] - for in_file, out_file in zip(head_io_map["input"], head_io_map["output"]): - if in_file not in io_map: - io_map[in_file] = {} - io_map[in_file][head] = out_file + run_heads, io_map = self._get_run_heads(batch, stage, batch_idx) + if len(run_heads) > 0: + batch = self._to_tensor(batch) + save_image = self.should_save_image(batch_idx, stage) + self.run_forward(batch, stage, save_image, run_heads) return io_map diff --git a/cyto_dl/nn/__init__.py b/cyto_dl/nn/__init__.py index 24b3cce63..3b5fb03f7 100644 --- a/cyto_dl/nn/__init__.py +++ b/cyto_dl/nn/__init__.py @@ -1,4 +1,4 @@ -from .head import BaseHead, ConvProjectionLayer, GANHead, ResBlocksHead +from .head import BaseHead, GANHead, GANHead_resize, ResBlocksHead from .hr_skip import HRSkip from .losses import ( AdversarialLoss, diff --git a/cyto_dl/nn/head/__init__.py b/cyto_dl/nn/head/__init__.py index 8933f551f..07061d52e 100644 --- a/cyto_dl/nn/head/__init__.py +++ b/cyto_dl/nn/head/__init__.py @@ -1,5 +1,4 @@ from .base_head import BaseHead -from .conv_proj_layer import ConvProjectionLayer from .gan_head import GANHead from .gan_head_superres import GANHead_resize from .res_blocks_head import ResBlocksHead diff --git a/cyto_dl/nn/head/base_head.py b/cyto_dl/nn/head/base_head.py index 24b250615..73577fe63 100644 --- a/cyto_dl/nn/head/base_head.py +++ b/cyto_dl/nn/head/base_head.py @@ -15,7 +15,7 @@ def __init__( loss, postprocess={"input": detach, "prediction": detach}, calculate_metric=False, - save_raw=False, + save_input=False, ): """ Parameters @@ -26,7 +26,7 @@ def __init__( Postprocessing for `input` and `predictions` of head calculate_metric=False Whether to calculate a metric during training. Not used by GAN head. - save_raw=False + save_input=False Whether to save out example input images during training """ super().__init__() @@ -35,7 +35,7 @@ def __init__( self.calculate_metric = calculate_metric self.model = torch.nn.Sequential(torch.nn.Identity()) - self.save_raw = save_raw + self.save_input = save_input def update_params(self, params): for k, v in params.items(): @@ -47,45 +47,36 @@ def _calculate_loss(self, y_hat, y): def _postprocess(self, img, img_type): return [self.postprocess[img_type](img[i]) for i in range(img.shape[0])] - def _save(self, fn, img, stage): + def generate_io_map(self, meta, stage, batch_idx, step): + """generates map between input files and output files for a head.""" + # filename is determined by step in training during train/val and by its source filename for prediction/testing + filename_map = {"input": meta.get("filename_or_obj", [batch_idx])} if stage in ("train", "val", "test"): - (Path(self.save_dir) / f"{stage}_images").mkdir(exist_ok=True, parents=True) - out_path = Path(self.save_dir) / f"{stage}_images" / fn + out_paths = [Path(self.save_dir) / f"{stage}_images" / f"{step}_{self.head_name}.tif"] else: - out_path = Path(self.save_dir) / fn - OmeTiffWriter().save( - uri=out_path, - data=img.squeeze(), - dims_order="STCZYX"[-len(img.shape)], - ) - return out_path + out_paths = [ + Path(self.save_dir) / self.head_name / f"{Path(fn).stem}.tif" + for fn in filename_map["input"] + ] + # create output directory if it doesn't exist + out_paths[0].parent.mkdir(exist_ok=True, parents=True) - def _calculate_metric(self, y_hat, y): - raise NotImplementedError + filename_map["output"] = out_paths + self.filename_map = filename_map + return filename_map def save_image(self, y_hat, batch, stage, global_step): y_hat_out = self._postprocess(y_hat, img_type="prediction") - y_out, raw_out = None, None - filename_map = {"input": [], "output": []} - # filename is determined by step in training during train/val and by its source filename for prediction/testing - if stage in ("train", "val"): - y_out = self._postprocess(batch[self.head_name], img_type="input") - if self.save_raw: - raw_out = self._postprocess(batch[self.x_key], img_type="input") - save_name = [f"{global_step}_{self.head_name}.tif"] - else: - filename_map["input"] = batch["filenames"] - save_name = [f"{Path(fn).stem}_{self.head_name}.tif" for fn in batch["filenames"]] - n_save = len(y_hat_out) if stage in ("test", "predict") else 1 - for i in range(n_save): - out_path = self._save(save_name[i].replace(".tif", "_pred.tif"), y_hat_out[i], stage) + y_out = None + for i, out_path in enumerate(self.filename_map["output"]): + OmeTiffWriter.save(data=y_hat_out[i], uri=out_path) if stage in ("train", "val"): - self._save(save_name[i], y_out[i], stage) - if self.save_raw: - self._save(save_name[i].replace(".tif", "_raw.tif"), raw_out[i], stage) - else: - filename_map["output"].append(out_path) - return y_hat_out, y_out, filename_map + y_out = self._postprocess(batch[self.head_name], img_type="input") + OmeTiffWriter.save(data=y_out[i], uri=str(out_path).replace(".t", "_label.t")) + if self.save_input: + raw_out = self._postprocess(batch[self.x_key][i : i + 1], img_type="input") + OmeTiffWriter.save(data=raw_out, uri=str(out_path).replace(".t", "_input.t")) + return y_hat_out, y_out def forward(self, x): return self.model(x) @@ -112,17 +103,12 @@ def run_head( if stage != "predict": loss = self._calculate_loss(y_hat, batch[self.head_name]) - y_hat_out, y_out, out_paths = None, None, None + y_hat_out, y_out = None, None if save_image: - y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step) + y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step) - metric = None - if self.calculate_metric and stage in ("val", "test"): - metric = self._calculate_metric(y_hat, batch[self.head_name]) return { "loss": loss, - "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out, - "save_path": out_paths, } diff --git a/cyto_dl/nn/head/conv_proj_layer.py b/cyto_dl/nn/head/conv_proj_layer.py deleted file mode 100644 index c0658c157..000000000 --- a/cyto_dl/nn/head/conv_proj_layer.py +++ /dev/null @@ -1,54 +0,0 @@ -import math - -import numpy as np -import torch -from monai.networks.blocks import Convolution - - -class ConvProjectionLayer(torch.nn.Module): - """Layer for projecting e.g. 3D->2D image.""" - - def __init__(self, dim, pool_size: int, in_channels: int, out_channels: int): - """ - Parameters - --------- - dim - Dimension to project, e.g. 2 for projecting NCZYX -> NCYX - pool_size:int - Size of convolutional kernel for downsampling - in_channels:int - number of input channels - out_channels:int - number of output channels - """ - super().__init__() - self.dim = dim - n_downs = math.floor(np.log2(pool_size)) - modules = [] - for _ in range(n_downs): - modules.append( - Convolution( - spatial_dims=3, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=[2, 1, 1], - strides=[2, 1, 1], - padding=[0, 0, 0], - ) - ) - remainder = pool_size - 2**n_downs - if remainder != 0: - modules.append( - Convolution( - spatial_dims=3, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=[remainder, 1, 1], - strides=[remainder, 1, 1], - padding=[0, 0, 0], - ) - ) - self.model = torch.nn.Sequential(*modules) - - def __call__(self, x): - return self.model(x).squeeze(self.dim) diff --git a/cyto_dl/nn/head/gan_head.py b/cyto_dl/nn/head/gan_head.py index c241a3b03..b94921101 100644 --- a/cyto_dl/nn/head/gan_head.py +++ b/cyto_dl/nn/head/gan_head.py @@ -19,7 +19,7 @@ def __init__( reconstruction_loss_weight=100, postprocess={"input": detach, "prediction": detach}, calculate_metric=False, - save_raw=False, + save_input=False, ): """ Parameters @@ -34,10 +34,10 @@ def __init__( Postprocessing for `input` and `predictions` of head calculate_metric=False Whether to calculate a metric during training. Not used by GAN head. - save_raw=False + save_input=False Whether to save out example input images during training """ - super().__init__(None, postprocess, calculate_metric, save_raw) + super().__init__(None, postprocess, calculate_metric, save_input) self.gan_loss = gan_loss self.reconstruction_loss = reconstruction_loss self.reconstruction_loss_weight = reconstruction_loss_weight @@ -81,19 +81,13 @@ def run_head( ) loss_D, loss_G = self._calculate_loss(y_hat, batch, discriminator) - y_hat_out, y_out, out_paths = None, None, None + y_hat_out, y_out = None, None if save_image: - y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step) - - metric = None - if self.calculate_metric and stage in ("val", "test"): - metric = self._calculate_metric(y_hat, batch[self.head_name]) + y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step) return { "loss_D": loss_D, "loss_G": loss_G, - "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out, - "save_path": out_paths, } diff --git a/cyto_dl/nn/head/gan_head_superres.py b/cyto_dl/nn/head/gan_head_superres.py index 663d11891..31989ac5d 100644 --- a/cyto_dl/nn/head/gan_head_superres.py +++ b/cyto_dl/nn/head/gan_head_superres.py @@ -9,10 +9,11 @@ from cyto_dl.nn.losses import Pix2PixHD from .gan_head import GANHead +from .res_blocks_head import ResBlocksHead -class GANHead_resize(GANHead): - """GAN Task head with upsampling.""" +class GANHead_resize(GANHead, ResBlocksHead): + """Inherit run_head from GANHead, use __init__ and forward of ResBlocksHead.""" def __init__( self, @@ -23,7 +24,7 @@ def __init__( reconstruction_loss_weight=100, postprocess={"input": detach, "prediction": detach}, calculate_metric=False, - save_raw=False, + save_input=False, final_act: Callable = torch.nn.Identity(), resolution="lr", spatial_dims=3, @@ -47,69 +48,30 @@ def __init__( Postprocessing for `input` and `predictions` of head calculate_metric=False Whether to calculate a metric during training. Not used by GAN head. - save_raw=False + save_input=False Whether to save out example input images during training """ - super().__init__( - gan_loss, - reconstruction_loss, - reconstruction_loss_weight, - postprocess, - calculate_metric, - save_raw, - ) - - self.resolution = resolution - conv_input_channels = in_channels - modules = [first_layer] - upsample = torch.nn.Identity() - - upsample_ratio = upsample_ratio or [2] * spatial_dims - - if resolution == "hr": - if upsample_method == "pixelshuffle": - conv_input_channels //= 2**spatial_dims - assert len(upsample_ratio) == spatial_dims - upsample = UpSample( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=conv_input_channels, - scale_factor=upsample_ratio, - mode=upsample_method, - ) - for i in range(n_convs): - in_channels = conv_input_channels - if dense: - in_channels = (i + 1) * conv_input_channels - modules.append( - UnetResBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=conv_input_channels, - stride=1, - kernel_size=3, - norm_name="INSTANCE", - dropout=dropout, - ) - ) - if dense: - # dense convolutions - modules = [modules[0]] + [DenseBlock(modules[1:])] - conv_input_channels *= n_convs + 1 - modules.extend( - ( - UnetOutBlock( - spatial_dims=spatial_dims, - in_channels=conv_input_channels, - out_channels=out_channels, - dropout=dropout, - ), - final_act, - ) - ) - self.model = torch.nn.ModuleDict( - {"upsample": upsample, "model": torch.nn.Sequential(*modules)} + ResBlocksHead.__init__( + self, + loss=None, + in_channels=in_channels, + out_channels=out_channels, + final_act=final_act, + postprocess=postprocess, + calculate_metric=calculate_metric, + save_input=save_input, + resolution=resolution, + spatial_dims=spatial_dims, + n_convs=n_convs, + dropout=dropout, + upsample_method=upsample_method, + upsample_ratio=upsample_ratio, + first_layer=first_layer, + dense=dense, ) + self.gan_loss = gan_loss + self.reconstruction_loss = reconstruction_loss + self.reconstruction_loss_weight = reconstruction_loss_weight def _ensure_same_shape(self, x, y): min_shape = np.minimum(x.shape, y.shape) @@ -118,20 +80,8 @@ def _ensure_same_shape(self, x, y): return x, y def _calculate_loss(self, y_hat, batch, discriminator): - # extract intermediate activations from discriminator for real and predicted images - y, y_hat = self._ensure_same_shape(batch[self.head_name], y_hat) - - features_discriminator = discriminator(batch[self.x_key], y, y_hat.detach()) - loss_D = self.gan_loss(features_discriminator, "discriminator") - - # passability of generated images - features_generator = discriminator(batch[self.x_key], y, y_hat) - loss_G = self.gan_loss(features_generator, "generator") - # image reconstruction quality - loss_reconstruction = self.reconstruction_loss(y, y_hat) - return loss_D, loss_G + loss_reconstruction * self.reconstruction_loss_weight + batch[self.head_name], y_hat = self._ensure_same_shape(batch[self.head_name], y_hat) + return GANHead._calculate_loss(self, y_hat, batch, discriminator) def forward(self, x): - if self.resolution == "hr": - x = self.model["upsample"](x) - return self.model["model"](x) + return ResBlocksHead.forward(self, x) diff --git a/cyto_dl/nn/head/mae_head.py b/cyto_dl/nn/head/mae_head.py index e8dbcddf2..0a7e73be8 100644 --- a/cyto_dl/nn/head/mae_head.py +++ b/cyto_dl/nn/head/mae_head.py @@ -18,24 +18,19 @@ def run_head( y_hat, mask = backbone_features else: raise ValueError("MAE head is only intended for use during training.") + loss = (batch[self.head_name] - y_hat) ** 2 if mask.sum() > 0: loss = loss[mask.bool()].mean() else: loss = loss.mean() - y_hat_out, y_out, out_paths = None, None, None + y_hat_out, y_out = None, None if save_image: - y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step) - - metric = None - if self.calculate_metric and stage in ("val", "test"): - metric = self._calculate_metric(y_hat, batch[self.head_name]) + y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step) return { "loss": loss, - "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out, - "save_path": out_paths, } diff --git a/cyto_dl/nn/head/mask_head.py b/cyto_dl/nn/head/mask_head.py index 75756fb95..05cc83a64 100644 --- a/cyto_dl/nn/head/mask_head.py +++ b/cyto_dl/nn/head/mask_head.py @@ -13,7 +13,7 @@ def __init__( mask_key: str = "mask", postprocess={"input": detach, "prediction": detach}, calculate_metric=False, - save_raw=False, + save_input=False, ): """ Parameters @@ -24,7 +24,7 @@ def __init__( Postprocessing for `input` and `predictions` of head calculate_metric=False Whether to calculate a metric during training. Not used by GAN head. - save_raw=False + save_input=False Whether to save out example input images during training """ super().__init__() @@ -34,7 +34,7 @@ def __init__( self.mask_key = mask_key self.model = torch.nn.Sequential(torch.nn.Identity()) - self.save_raw = save_raw + self.save_input = save_input def _calculate_loss(self, y_hat, y, mask): return self.loss(y_hat, y, mask) @@ -61,17 +61,12 @@ def run_head( if stage != "predict": loss = self._calculate_loss(y_hat, batch[self.head_name], batch[self.mask_key]) - y_hat_out, y_out, out_paths = None, None, None + y_hat_out, y_out = None, None if save_image: - y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step) + y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step) - metric = None - if self.calculate_metric and stage in ("val", "test"): - metric = self._calculate_metric(y_hat, batch[self.head_name]) return { "loss": loss, - "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out, - "save_path": out_paths, } diff --git a/cyto_dl/nn/head/res_blocks_head.py b/cyto_dl/nn/head/res_blocks_head.py index 3ea4d6891..eaf399613 100644 --- a/cyto_dl/nn/head/res_blocks_head.py +++ b/cyto_dl/nn/head/res_blocks_head.py @@ -21,7 +21,7 @@ def __init__( final_act: Callable = torch.nn.Identity(), postprocess={"input": detach, "prediction": detach}, calculate_metric=False, - save_raw=False, + save_input=False, resolution="lr", spatial_dims=3, n_convs=1, @@ -44,7 +44,7 @@ def __init__( Postprocessing functions for ground truth and model predictions calculate_metric=False Whether to calculate a metric. Currently not implemented - save_raw=False + save_input=False Whether to save raw image examples during training resolution="lr" Resolution of output image. If `lr`, no upsampling is done. If `hr`, `upsample_method` and `upsample_ratio` are used @@ -64,14 +64,15 @@ def __init__( dense=False Whether to use dense connections between convolutional layers """ - super().__init__(loss, postprocess, calculate_metric, save_raw) + super().__init__(loss, postprocess, calculate_metric, save_input) self.resolution = resolution conv_input_channels = in_channels modules = [first_layer] upsample = torch.nn.Identity() - upsample_ratio = upsample_ratio or [2] * spatial_dims + if isinstance(upsample_ratio, int): + upsample_ratio = [upsample_ratio] * spatial_dims if resolution == "hr": if upsample_method == "pixelshuffle":