From bbde8f277b32b6e62ef84aee3e0850b7e5f84141 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 11 Mar 2024 15:04:46 -0400 Subject: [PATCH 001/108] created script to build msd dataset for monai nnunet model training --- monai/nnunet/1_create_msd_data.py | 153 ++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 monai/nnunet/1_create_msd_data.py diff --git a/monai/nnunet/1_create_msd_data.py b/monai/nnunet/1_create_msd_data.py new file mode 100644 index 0000000..0e6566a --- /dev/null +++ b/monai/nnunet/1_create_msd_data.py @@ -0,0 +1,153 @@ +""" +This file creates the MSD-style JSON datalist to train an nnunet model using monai. +The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large. + +Arguments: + -pd, --path-data: Path to the data set directory + -pj, --path-joblib: Path to joblib file from ivadomed containing the dataset splits. + -po, --path-out: Path to the output directory where dataset json is saved + --contrast: Contrast to use for training + --label-type: Type of labels to use for training + --seed: Seed for reproducibility + +Example: + python create_msd_data.py ... + +TO DO: + * + +Pierre-Louis Benveniste +""" + +import os +import json +from tqdm import tqdm +import yaml +import argparse +from loguru import logger +from sklearn.model_selection import train_test_split +from datetime import date + +# root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" + +parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') + +parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') +parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') +parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") +args = parser.parse_args() + + +root = args.path_data +seed = args.seed + +# Get all subjects +canproco_path = os.path.join(root, "canproco") +basel_path = os.path.join(root, "basel-mp2rage") +bavaria_path = os.path.join(root, "bavaria-quebec-spine-ms") +sct_testing_path = os.path.join(root, "sct-testing-large") + +subjects_canproco = list(canproco_path.rglob('*_PSIR.nii.gz')) + list(canproco_path.rglob('*STIR.nii.gz')) +subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) +subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) +subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) + +subjects = subjects_canproco + subjects_basel + subjects_sct + subjects_bavaria +logger.info(f"Total number of subjects in the root directory: {subjects}") + +# create one json file with 60-20-20 train-val-test split +train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 +train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) +# Use the training split to further split into training and validation splits +train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) +# sort the subjects +train_subjects = sorted(train_subjects) +val_subjects = sorted(val_subjects) +test_subjects = sorted(test_subjects) + +logger.info(f"Number of training subjects: {len(train_subjects)}") +logger.info(f"Number of validation subjects: {len(val_subjects)}") +logger.info(f"Number of testing subjects: {len(test_subjects)}") + +# dump train/val/test splits into a yaml file +with open(f"data_split_{date}_seed{seed}.yaml", 'w') as file: + yaml.dump({'train': train_subjects, 'val': val_subjects, 'test': test_subjects}, file, indent=2, sort_keys=True) + +# keys to be defined in the dataset_0.json +params = {} +params["description"] = "ms-lesion-agnostic" +params["labels"] = { + "0": "background", + "1": "ms-lesion-seg" + } +params["license"] = "plb" +params["modality"] = { + "0": "MRI" + } +params["name"] = "ms-lesion-agnostic" +params["numTest"] = len(test_subjects) +params["numTraining"] = len(train_subjects) +params["numValidation"] = len(val_subjects) +params["seed"] = args.seed +params["reference"] = "NeuroPoly" +params["tensorImageSize"] = "3D" + +train_subjects_dict = {"train": train_subjects} +val_subjects_dict = {"validation": val_subjects} +test_subjects_dict = {"test": test_subjects} +all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] + +for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): + + for name, subs_list in subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(subs_list): + + temp_data_canproco = {} + temp_data_basel = {} + temp_data_sct = {} + temp_data_bavaria = {} + + # Canproco + if 'canproco' in str(subject): + temp_data_canproco["label"] = subject + temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): + temp_list.append(temp_data_canproco) + + # Basel + elif 'basel-mp2rage' in str(subject): + relative_path = subject.relative_to(bavaria_path).parent + temp_data_basel["image"] = subject + temp_data_basel["image"] = bavaria_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz') + if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + temp_list.append(temp_data_basel) + + # sct-testing-large + elif 'sct-testing-large' in str(subject): + temp_data_sct["label"] = subject + temp_data_sct["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): + temp_list.append(temp_data_sct) + + + # Bavaria-quebec + elif 'bavaria-quebec-spine-ms' in str(subject): + relative_path = subject.relative_to(bavaria_path).parent + temp_data_bavaria["label"] = subject + temp_data_bavaria["image"] = bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz') + if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + temp_list.append(temp_data_bavaria) + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + +final_json = json.dumps(params, indent=4, sort_keys=True) +if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + +jsonFile = open(args.path_out + "/" + f"dataset_{date}_seed{seed}.json", "w") +jsonFile.write(final_json) +jsonFile.close() From 081276a419bd90a1ef83e34d7ec5087bec2dc1c6 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 11 Mar 2024 15:21:44 -0400 Subject: [PATCH 002/108] added requirements script --- monai/nnunet/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 monai/nnunet/requirements.txt diff --git a/monai/nnunet/requirements.txt b/monai/nnunet/requirements.txt new file mode 100644 index 0000000..ed959c6 --- /dev/null +++ b/monai/nnunet/requirements.txt @@ -0,0 +1,2 @@ +yaml +scikit-learn \ No newline at end of file From 0ef3f0a510beec08ec78c119efea383b3009344c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 11 Mar 2024 15:26:13 -0400 Subject: [PATCH 003/108] removed yaml from requirements --- monai/nnunet/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/nnunet/requirements.txt b/monai/nnunet/requirements.txt index ed959c6..ff88936 100644 --- a/monai/nnunet/requirements.txt +++ b/monai/nnunet/requirements.txt @@ -1,2 +1 @@ -yaml scikit-learn \ No newline at end of file From 451ca30289fcff6097906b75d416d24df5305333 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 10:42:44 -0400 Subject: [PATCH 004/108] changed output file names --- monai/nnunet/1_create_msd_data.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/monai/nnunet/1_create_msd_data.py b/monai/nnunet/1_create_msd_data.py index 0e6566a..13077ff 100644 --- a/monai/nnunet/1_create_msd_data.py +++ b/monai/nnunet/1_create_msd_data.py @@ -27,6 +27,7 @@ from loguru import logger from sklearn.model_selection import train_test_split from datetime import date +from pathlib import Path # root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" @@ -42,10 +43,10 @@ seed = args.seed # Get all subjects -canproco_path = os.path.join(root, "canproco") -basel_path = os.path.join(root, "basel-mp2rage") -bavaria_path = os.path.join(root, "bavaria-quebec-spine-ms") -sct_testing_path = os.path.join(root, "sct-testing-large") +canproco_path = Path(os.path.join(root, "canproco")) +basel_path = Path(os.path.join(root, "basel-mp2rage")) +bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms")) +sct_testing_path = Path(os.path.join(root, "sct-testing-large")) subjects_canproco = list(canproco_path.rglob('*_PSIR.nii.gz')) + list(canproco_path.rglob('*STIR.nii.gz')) subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) @@ -53,7 +54,7 @@ subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) subjects = subjects_canproco + subjects_basel + subjects_sct + subjects_bavaria -logger.info(f"Total number of subjects in the root directory: {subjects}") +logger.info(f"Total number of subjects in the root directory: {len(subjects)}") # create one json file with 60-20-20 train-val-test split train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 @@ -71,7 +72,7 @@ logger.info(f"Number of testing subjects: {len(test_subjects)}") # dump train/val/test splits into a yaml file -with open(f"data_split_{date}_seed{seed}.yaml", 'w') as file: +with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: yaml.dump({'train': train_subjects, 'val': val_subjects, 'test': test_subjects}, file, indent=2, sort_keys=True) # keys to be defined in the dataset_0.json @@ -112,22 +113,22 @@ # Canproco if 'canproco' in str(subject): - temp_data_canproco["label"] = subject + temp_data_canproco["label"] = str(subject) temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): temp_list.append(temp_data_canproco) # Basel elif 'basel-mp2rage' in str(subject): - relative_path = subject.relative_to(bavaria_path).parent - temp_data_basel["image"] = subject - temp_data_basel["image"] = bavaria_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz') + relative_path = subject.relative_to(basel_path).parent + temp_data_basel["image"] = str(subject) + temp_data_basel["label"] = str(basel_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz')) if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): temp_list.append(temp_data_basel) # sct-testing-large elif 'sct-testing-large' in str(subject): - temp_data_sct["label"] = subject + temp_data_sct["label"] = str(subject) temp_data_sct["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): temp_list.append(temp_data_sct) @@ -136,8 +137,8 @@ # Bavaria-quebec elif 'bavaria-quebec-spine-ms' in str(subject): relative_path = subject.relative_to(bavaria_path).parent - temp_data_bavaria["label"] = subject - temp_data_bavaria["image"] = bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz') + temp_data_bavaria["image"] = str(subject) + temp_data_bavaria["label"] = str(bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz')) if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): temp_list.append(temp_data_bavaria) @@ -148,6 +149,6 @@ if not os.path.exists(args.path_out): os.makedirs(args.path_out, exist_ok=True) -jsonFile = open(args.path_out + "/" + f"dataset_{date}_seed{seed}.json", "w") +jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") jsonFile.write(final_json) jsonFile.close() From 1dbd553e4eec44b37fdbf9b76582d7193c2fcc29 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 10:42:59 -0400 Subject: [PATCH 005/108] added missing requirements --- monai/nnunet/requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/nnunet/requirements.txt b/monai/nnunet/requirements.txt index ff88936..7503a26 100644 --- a/monai/nnunet/requirements.txt +++ b/monai/nnunet/requirements.txt @@ -1 +1,3 @@ -scikit-learn \ No newline at end of file +scikit-learn +tqdm +loguru From 91bac326d2dd3914c737a40364b9ebdcffb733cb Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 15:38:11 -0400 Subject: [PATCH 006/108] initialised config.yml file example --- monai/nnunet/config.yml | 71 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 monai/nnunet/config.yml diff --git a/monai/nnunet/config.yml b/monai/nnunet/config.yml new file mode 100644 index 0000000..4e4a488 --- /dev/null +++ b/monai/nnunet/config.yml @@ -0,0 +1,71 @@ +seed: 15 +save_test_preds: True + +directories: + # Path to the saved models directory + models_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models/followup + # Path to the saved results directory + results_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/results/models_followup + # Path to the saved wandb logs directory + # if None, starts training from scratch. Otherwise, resumes training from the specified wandb run folder + wandb_run_folder: None + +dataset: + # Dataset name (will be used as "group_name" for wandb logging) + name: spine-generic + # Path to the dataset directory containing all datalists (.json files) + root_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/spine-generic/seed15 + # Type of contrast to be used for training. "all" corresponds to training on all contrasts + contrast: all # choices: ["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"] + # Type of label to be used for training. + label_type: soft_bin # choices: ["hard", "soft", "soft_bin"] + +preprocessing: + # Online resampling of images to the specified spacing. + spacing: [1.0, 1.0, 1.0] + # Center crop/pad images to the specified size. (NOTE: done after resampling) + # values correspond to R-L, A-P, I-S axes of the image after 1mm isotropic resampling. + crop_pad_size: [64, 192, 320] + +opt: + name: adam + lr: 0.001 + max_epochs: 200 + batch_size: 2 + # Interval between validation checks in epochs + check_val_every_n_epochs: 5 + # Early stopping patience (this is until patience * check_val_every_n_epochs) + early_stopping_patience: 20 + + +model: + # Model architecture to be used for training (also to be specified as args in the command line) + nnunet: + # NOTE: these info are typically taken from nnUNetPlans.json (if an nnUNet model is trained) + base_num_features: 32 + max_num_features: 320 + n_conv_per_stage_encoder: [2, 2, 2, 2, 2, 2] + n_conv_per_stage_decoder: [2, 2, 2, 2, 2] + pool_op_kernel_sizes: [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ] + enable_deep_supervision: True + + mednext: + num_input_channels: 1 + base_num_features: 32 + num_classes: 1 + kernel_size: 3 # 3x3x3 and 5x5x5 were tested in publication + block_counts: [2,2,2,2,1,1,1,1,1] # number of blocks in each layer + enable_deep_supervision: True + + swinunetr: + spatial_dims: 3 + depths: [2, 2, 2, 2] + num_heads: [3, 6, 12, 24] # number of heads in multi-head Attention + feature_size: 36 \ No newline at end of file From 48f109c472ee8150690c6fd937d7795c00e62948 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 15:39:06 -0400 Subject: [PATCH 007/108] initialised main file --- monai/nnunet/main.py | 692 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 692 insertions(+) create mode 100644 monai/nnunet/main.py diff --git a/monai/nnunet/main.py b/monai/nnunet/main.py new file mode 100644 index 0000000..058fb85 --- /dev/null +++ b/monai/nnunet/main.py @@ -0,0 +1,692 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from utils import dice_score, PolyLRScheduler, plot_slices, check_empty_patch +from losses import AdapWingLoss +from transforms import train_transforms, val_transforms +from models import create_nnunet_from_plans + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNETR, SwinUNETR +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) + +# mednext +# from nnunet_mednext import MedNeXt + +def get_args(): + parser = argparse.ArgumentParser(description='Script for training contrast-agnositc SC segmentation model.') + + # arguments for model + parser.add_argument('-m', '--model', choices=['nnunet', 'mednext', 'swinunetr'], + default='nnunet', type=str, + help='Model type to be used. Options: nnunet, mednext, swinunetr.') + # path to the config file + parser.add_argument("--config", type=str, default="./config.json", + help="Path to the config file containing all training details.") + # saving + parser.add_argument('--debug', default=False, action='store_true', help='if true, results are not logged to wandb') + parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', + help='Load model from checkpoint and continue training') + args = parser.parse_args() + + return args + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + + self.root = data_root + self.net = net + self.lr = config["opt"]["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["preprocessing"]["spacing"] + self.voxel_cropping_size = self.inference_roi_size = config["preprocessing"]["crop_pad_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + transforms_train = train_transforms( + crop_size=self.voxel_cropping_size, + lbl_key='label' + ) + transforms_val = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') + + # load the dataset + logger.info(f"Training with {self.cfg['dataset']['label_type']} labels ...") + dataset = os.path.join(self.root, + f"dataset_{self.cfg['dataset']['contrast']}_{self.cfg['dataset']['label_type']}_seed{self.cfg['seed']}.json" + ) + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + if args.debug: + train_files = train_files[:15] + val_files = val_files[:15] + test_files = test_files[:6] + + train_cache_rate = 0.25 if args.debug else 0.5 + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) + + # define test transforms + transforms_test = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["opt"]["batch_size"], shuffle=True, num_workers=16, + pin_memory=True, persistent_workers=True) + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + persistent_workers=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + if self.cfg["opt"]["name"] == "sgd": + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=3e-5, nesterov=True) + else: + optimizer = self.optimizer_class(self.parameters(), lr=self.lr) + # scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) + # NOTE: ivadomed using CosineAnnealingLR with T_max = 200 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["opt"]["max_epochs"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + + # calculate dice loss for each output + loss, train_soft_dice = 0.0, 0.0 + for i in range(len(output)): + # give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + # NOTE: outputs[0] is the final pred, outputs[-1] is the lowest resolution pred (at the bottleneck) + # we're downsampling the GT to the resolution of each deepsupervision feature map output + # (instead of upsampling each deepsupervision feature map output to the final resolution) + downsampled_gt = F.interpolate(labels, size=output[i].shape[-3:], mode='trilinear', align_corners=False) + # print(f"downsampled_gt.shape: {downsampled_gt.shape} \t output[i].shape: {output[i].shape}") + loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) + + # get probabilities from logits + out = F.relu(output[i]) / F.relu(output[i]).max() if bool(F.relu(output[i]).max()) else F.relu(output[i]) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice += self.soft_dice_metric(out, downsampled_gt) + + # average dice loss across the outputs + loss /= len(output) + train_soft_dice /= len(output) + + else: + # calculate training loss + loss = self.loss_function(output, labels) + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + # "train_image": inputs[0].detach().cpu().squeeze(), + # "train_gt": labels[0].detach().cpu().squeeze(), + # "train_pred": output[0].detach().cpu().squeeze() + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + } + self.log_dict(wandb_logs) + + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"], + # debug=args.debug) + # wandb.log({"training images": wandb.Image(fig)}) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + # outputs shape: (B, C, ) + if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + # we only need the output with the highest resolution + outputs = outputs[0] + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + # "val_image": inputs[0].detach().cpu().squeeze(), + # "val_gt": labels[0].detach().cpu().squeeze(), + # "val_pred": post_outputs[0].detach().cpu().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + } + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation CSA loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" + f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + + # log on to wandb + self.log_dict(wandb_logs) + + # # plot the validation images + # fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + # gt=self.val_step_outputs[0]["val_gt"], + # pred=self.val_step_outputs[0]["val_pred"],) + # wandb.log({"validation images": wandb.Image(fig)}) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + # return {"log": wandb_logs} + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + # we only need the output with the highest resolution + batch["pred"] = batch["pred"][0] + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # save the prediction and label + if self.cfg["save_test_preds"]: + + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + logger.info(f"Saving subject: {subject_name}") + + # image saver class + save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) + pred_saver = SaveImage( + output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False, resample=True) + # save the prediction + pred_saver(pred) + + # label_saver = SaveImage( + # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", + # separate_folder=False, print_log=False, resample=True) + # # save the label + # label_saver(label) + + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(args): + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["dataset"]["root_dir"] + + # define optimizer + if config["opt"]["name"] == "adam": + optimizer_class = torch.optim.Adam + elif config["opt"]["name"] == "sgd": + optimizer_class = torch.optim.SGD + + + if config["model"]["nnunet"]["enable_deep_supervision"]: + logger.info(f"Using nnUNet model WITH deep supervision ...") + else: + logger.info(f"Using nnUNet model WITHOUT deep supervision ...") + + logger.info("Defining plans for nnUNet model ...") + # ========================================================================================= + # Define plans json taken from nnUNet_preprocessed folder + # ========================================================================================= + nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": config["model"]["nnunet"]["base_num_features"], + "n_conv_per_stage_encoder": config["model"]["nnunet"]["n_conv_per_stage_encoder"], + "n_conv_per_stage_decoder": config["model"]["nnunet"]["n_conv_per_stage_decoder"], + "pool_op_kernel_sizes": config["model"]["nnunet"]["pool_op_kernel_sizes"], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": config["model"]["nnunet"]["max_num_features"], + } + + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, + deep_supervision=config["model"]["nnunet"]["enable_deep_supervision"]) + # variable for saving patch size in the experiment id (same as crop_pad_size) + patch_size = f"{config['preprocessing']['crop_pad_size'][0]}x" \ + f"{config['preprocessing']['crop_pad_size'][1]}x" \ + f"{config['preprocessing']['crop_pad_size'][2]}" + # save experiment id + save_exp_id = f"{args.model}_seed={config['seed']}_" \ + f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ + f"nf={config['model']['nnunet']['base_num_features']}_" \ + f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ + f"bs={config['opt']['batch_size']}_{patch_size}" \ + + if args.debug: + save_exp_id = f"DEBUG_{save_exp_id}" + + + timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format + save_exp_id = f"{save_exp_id}_{timestamp}" + + # save output to a log file + logger.add(os.path.join(config["directories"]["models_dir"], f"{save_exp_id}", "logs.txt"), rotation="10 MB", level="INFO") + + # save config file to the output folder + with open(os.path.join(config["directories"]["models_dir"], f"{save_exp_id}", "config.yaml"), "w") as f: + yaml.dump(config, f) + + # define loss function + loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["opt"]["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + # training from scratch + if not args.continue_from_checkpoint: + # to save the best model on validation + save_path = os.path.join(config["directories"]["models_dir"], f"{save_exp_id}") + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + # to save the results/model predictions + results_path = os.path.join(config["directories"]["results_dir"], f"{save_exp_id}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation loss + logger.info(f"Saving best model to {save_path}!") + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=False) + + # # saving the best model based on soft validation dice score + # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + # dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + # save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name=save_exp_id, + save_dir="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", + group=config["dataset"]["name"], + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=config) + + # Saving training script to wandb + wandb.save("main.py") + wandb.save("transforms.py") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], + max_epochs=config["opt"]["max_epochs"], + precision=32, + # deterministic=True, + enable_progress_bar=False) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + else: + logger.info(f" Resuming training from the latest checkpoint! ") + + # check if wandb run folder is provided to resume using the same run + if config["directories"]["wandb_run_folder"] is None: + raise ValueError("Please provide the wandb run folder to resume training using the same run on WandB!") + else: + wandb_run_folder = os.path.basename(config["directories"]["wandb_run_folder"]) + wandb_run_id = wandb_run_folder.split("-")[-1] + + save_exp_id = config["directories"]["models_dir"] + save_path = os.path.dirname(config["directories"]["models_dir"]) + logger.info(f"save_path: {save_path}") + results_path = config["directories"]["results_dir"] + + # i.e. train by loading existing weights + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation CSA loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_exp_id, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + # # saving the best model based on soft validation dice score + # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + # dirpath=save_exp_id, filename='best_model_dice', monitor='val_soft_dice', + # save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + # wandb logger + exp_logger = pl.loggers.WandbLogger( + save_dir=save_path, + group=config["dataset"]["name"], + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=args, + id=wandb_run_id, resume='must') + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], + max_epochs=config["opt"]["max_epochs"], + precision=32, + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model, ckpt_path=os.path.join(save_exp_id, "last.ckpt"),) + logger.info(f" Training Done!") + + # Test! + trainer.test(pl_model) + logger.info(f"TESTING DONE!") + + # closing the current wandb instance so that a new one is created for the next fold + wandb.finish() + + # TODO: Figure out saving test metrics to a file + with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: + print('\n-------------- Test Metrics ----------------', file=f) + print(f"{args.model}_seed={config['seed']}_" \ + f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ + f"nf={config['model']['nnunet']['base_num_features']}_" \ + f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ + f"bs={config['opt']['batch_size']}_{patch_size}" \ + f"_{timestamp}", file=f) + + print('\n-------------- Test Hard Dice Scores ----------------', file=f) + print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) + + print('\n-------------- Test Soft Dice Scores ----------------', file=f) + print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) + + print('-------------------------------------------------------', file=f) + + +if __name__ == "__main__": + args = get_args() + main(args) \ No newline at end of file From 5aa2eaeb09e5f9ab090596a1c3abfc94a84d89f0 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 15:39:23 -0400 Subject: [PATCH 008/108] initialised models file --- monai/nnunet/models.py | 90 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 monai/nnunet/models.py diff --git a/monai/nnunet/models.py b/monai/nnunet/models.py new file mode 100644 index 0000000..55f8a91 --- /dev/null +++ b/monai/nnunet/models.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# ---------------------------- Imports for nnUNet's Model ----------------------------- +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + + +# ====================================================================================================== +# Utils for nnUNet's Model +# ==================================================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +# ====================================================================================================== +# Define the network based on plans json +# ==================================================================================================== +def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model \ No newline at end of file From 5b233d2ec5205cf1e891e0a4b55aea517aab767c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 15:39:46 -0400 Subject: [PATCH 009/108] initalised transforms file --- monai/nnunet/transforms.py | 59 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 monai/nnunet/transforms.py diff --git a/monai/nnunet/transforms.py b/monai/nnunet/transforms.py new file mode 100644 index 0000000..7f89cc1 --- /dev/null +++ b/monai/nnunet/transforms.py @@ -0,0 +1,59 @@ + +import numpy as np +from monai.transforms import (Compose, CropForegroundd, LoadImaged, RandFlipd, + Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, + DivisiblePadd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, + RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, + RandSimulateLowResolutiond, ResizeWithPadOrCropd) + + +def train_transforms(crop_size, lbl_key="label"): + + monai_transforms = [ + # pre-processing + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + # NOTE: spine interpolation with order=2 is spline, order=1 is linear + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + # data-augmentation + RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.9, + rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians + scale_range=(-0.2, 0.2), + translate_range=(-0.1, 0.1)), + Rand3DElasticd(keys=["image", lbl_key], prob=0.5, + sigma_range=(3.5, 5.5), + magnitude_range=(25., 35.)), + RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma + RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), + RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform + RandFlipd(keys=["image", lbl_key], prob=0.3,), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ] + + return Compose(monai_transforms) + +def inference_transforms(crop_size, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key], image_only=False), + EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + DivisiblePadd(keys=["image", lbl_key], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + +def val_transforms(crop_size, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key], image_only=False), + EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) \ No newline at end of file From 794430a20d5512680eea8005fb0ecc8b8b0b4ad3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 16:09:53 -0400 Subject: [PATCH 010/108] modified to fit our training parameters --- monai/nnunet/config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/nnunet/config.yml b/monai/nnunet/config.yml index 4e4a488..7ee1dd4 100644 --- a/monai/nnunet/config.yml +++ b/monai/nnunet/config.yml @@ -3,18 +3,18 @@ save_test_preds: True directories: # Path to the saved models directory - models_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models/followup + models_dir: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/models # Path to the saved results directory - results_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/results/models_followup + results_dir: /home/GRAMES.POLYMTL.CAp119007/ms_lesion_agnostic/results/models_followup # Path to the saved wandb logs directory # if None, starts training from scratch. Otherwise, resumes training from the specified wandb run folder wandb_run_folder: None dataset: # Dataset name (will be used as "group_name" for wandb logging) - name: spine-generic + name: ms-lesion-agnostic # Path to the dataset directory containing all datalists (.json files) - root_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/spine-generic/seed15 + root_dir: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-12_seed42.json # Type of contrast to be used for training. "all" corresponds to training on all contrasts contrast: all # choices: ["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"] # Type of label to be used for training. @@ -51,7 +51,7 @@ model: [2, 2, 2], [2, 2, 2], [2, 2, 2], - [2, 2, 2], + [1, 2, 2], [1, 2, 2] ] enable_deep_supervision: True From 06199803f2174fd5e592dd2a8e19d9a108147730 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 12 Mar 2024 16:10:29 -0400 Subject: [PATCH 011/108] simplied main training script --- monai/nnunet/main.py | 174 ++++++++++++++----------------------------- 1 file changed, 56 insertions(+), 118 deletions(-) diff --git a/monai/nnunet/main.py b/monai/nnunet/main.py index 058fb85..1d0f163 100644 --- a/monai/nnunet/main.py +++ b/monai/nnunet/main.py @@ -475,9 +475,6 @@ def main(args): # define optimizer if config["opt"]["name"] == "adam": optimizer_class = torch.optim.Adam - elif config["opt"]["name"] == "sgd": - optimizer_class = torch.optim.SGD - if config["model"]["nnunet"]["enable_deep_supervision"]: logger.info(f"Using nnUNet model WITH deep supervision ...") @@ -518,9 +515,6 @@ def main(args): f"nf={config['model']['nnunet']['base_num_features']}_" \ f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ f"bs={config['opt']['batch_size']}_{patch_size}" \ - - if args.debug: - save_exp_id = f"DEBUG_{save_exp_id}" timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format @@ -548,118 +542,62 @@ def main(args): lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') # training from scratch - if not args.continue_from_checkpoint: - # to save the best model on validation - save_path = os.path.join(config["directories"]["models_dir"], f"{save_exp_id}") - if not os.path.exists(save_path): - os.makedirs(save_path, exist_ok=True) - - # to save the results/model predictions - results_path = os.path.join(config["directories"]["results_dir"], f"{save_exp_id}") - if not os.path.exists(results_path): - os.makedirs(results_path, exist_ok=True) - - # i.e. train by loading weights from scratch - pl_model = Model(config, data_root=dataset_root, - optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id=save_exp_id, results_path=results_path) - - # saving the best model based on validation loss - logger.info(f"Saving best model to {save_path}!") - checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=True, save_weights_only=False) - - # # saving the best model based on soft validation dice score - # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( - # dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', - # save_top_k=1, mode="max", save_last=False, save_weights_only=True) - - logger.info(f"Starting training from scratch ...") - # wandb logger - exp_logger = pl.loggers.WandbLogger( - name=save_exp_id, - save_dir="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", - group=config["dataset"]["name"], - log_model=True, # save best model using checkpoint callback - project='contrast-agnostic', - entity='naga-karthik', - config=config) - - # Saving training script to wandb - wandb.save("main.py") - wandb.save("transforms.py") - - # initialise Lightning's trainer. - trainer = pl.Trainer( - devices=1, accelerator="gpu", - logger=exp_logger, - callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], - check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], - max_epochs=config["opt"]["max_epochs"], - precision=32, - # deterministic=True, - enable_progress_bar=False) - # profiler="simple",) # to profile the training time taken for each step - - # Train! - trainer.fit(pl_model) - logger.info(f" Training Done!") - - else: - logger.info(f" Resuming training from the latest checkpoint! ") - - # check if wandb run folder is provided to resume using the same run - if config["directories"]["wandb_run_folder"] is None: - raise ValueError("Please provide the wandb run folder to resume training using the same run on WandB!") - else: - wandb_run_folder = os.path.basename(config["directories"]["wandb_run_folder"]) - wandb_run_id = wandb_run_folder.split("-")[-1] - - save_exp_id = config["directories"]["models_dir"] - save_path = os.path.dirname(config["directories"]["models_dir"]) - logger.info(f"save_path: {save_path}") - results_path = config["directories"]["results_dir"] - - # i.e. train by loading existing weights - pl_model = Model(config, data_root=dataset_root, - optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id=save_exp_id, results_path=results_path) - - # saving the best model based on validation CSA loss - checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=save_exp_id, filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=True, save_weights_only=True) - - # # saving the best model based on soft validation dice score - # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( - # dirpath=save_exp_id, filename='best_model_dice', monitor='val_soft_dice', - # save_top_k=1, mode="max", save_last=False, save_weights_only=True) - - # wandb logger - exp_logger = pl.loggers.WandbLogger( - save_dir=save_path, - group=config["dataset"]["name"], - log_model=True, # save best model using checkpoint callback - project='contrast-agnostic', - entity='naga-karthik', - config=args, - id=wandb_run_id, resume='must') - - # initialise Lightning's trainer. - trainer = pl.Trainer( - devices=1, accelerator="gpu", - logger=exp_logger, - callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], - check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], - max_epochs=config["opt"]["max_epochs"], - precision=32, - enable_progress_bar=True) - # profiler="simple",) # to profile the training time taken for each step - - # Train! - trainer.fit(pl_model, ckpt_path=os.path.join(save_exp_id, "last.ckpt"),) - logger.info(f" Training Done!") + # to save the best model on validation + save_path = os.path.join(config["directories"]["models_dir"], f"{save_exp_id}") + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + # to save the results/model predictions + results_path = os.path.join(config["directories"]["results_dir"], f"{save_exp_id}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation loss + logger.info(f"Saving best model to {save_path}!") + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=False) + + # # saving the best model based on soft validation dice score + # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + # dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + # save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name=save_exp_id, + save_dir="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", + group=config["dataset"]["name"], + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=config) + + # Saving training script to wandb + wandb.save("main.py") + wandb.save("transforms.py") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], + max_epochs=config["opt"]["max_epochs"], + precision=32, + # deterministic=True, + enable_progress_bar=False) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") # Test! trainer.test(pl_model) From e04c423aceee1f6a220667fb8b9688f3ba7b8010 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 17:25:13 -0400 Subject: [PATCH 012/108] fixed canproco problem (having both img and label with same link) --- monai/nnunet/1_create_msd_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/nnunet/1_create_msd_data.py b/monai/nnunet/1_create_msd_data.py index 13077ff..451f67b 100644 --- a/monai/nnunet/1_create_msd_data.py +++ b/monai/nnunet/1_create_msd_data.py @@ -48,7 +48,7 @@ bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms")) sct_testing_path = Path(os.path.join(root, "sct-testing-large")) -subjects_canproco = list(canproco_path.rglob('*_PSIR.nii.gz')) + list(canproco_path.rglob('*STIR.nii.gz')) +subjects_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) From 26629c8388a965e0d393d0b48a4eee34b126d461 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 17:26:11 -0400 Subject: [PATCH 013/108] added training script for a monai trained UNETR --- monai/nnunet/train_monai_UNETR.py | 316 ++++++++++++++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 monai/nnunet/train_monai_UNETR.py diff --git a/monai/nnunet/train_monai_UNETR.py b/monai/nnunet/train_monai_UNETR.py new file mode 100644 index 0000000..ace0382 --- /dev/null +++ b/monai/nnunet/train_monai_UNETR.py @@ -0,0 +1,316 @@ +""" +This script is used to train a UNETR model. + +It takes as input the config file (a JSON file) + +This script is inspired from : https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb + +Args: + -c: path to the config file + +Example: + python train_monai_nnunet.py -c /path/to/nnunet/config.json + +Pierre-Louis Benveniste +""" + +import argparse +import json +import os +import sys +import monai +from tqdm import tqdm +import matplotlib.pyplot as plt +import yaml + + +#Transforms import +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd + ) + +# Dataset import +from monai.data import DataLoader, CacheDataset, load_decathlon_datalist, Dataset + + +# model import +import torch +from monai.networks.nets import UNETR +from monai.losses import DiceCELoss + +# For training and validation +from monai.data import decollate_batch +from monai.inferers import sliding_window_inference +from monai.metrics import DiceMetric +from monai.transforms import AsDiscrete + + + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +def validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step): + model.eval() + with torch.no_grad(): + for batch in epoch_iterator_val: + val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) + val_outputs = sliding_window_inference(val_inputs, config["spatial_size"], 4, model) + val_labels_list = decollate_batch(val_labels) + val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] + val_outputs_list = decollate_batch(val_outputs) + val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] + dice_metric(y_pred=val_output_convert, y=val_labels_convert) + epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0)) # noqa: B038 + mean_dice_val = dice_metric.aggregate().item() + dice_metric.reset() + return mean_dice_val + +def train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader): + model.train() + epoch_loss = 0 + step = 0 + epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) + for step, batch in enumerate(epoch_iterator): + step += 1 + x, y = (batch["image"].cuda(), batch["label"].cuda()) + logit_map = model(x) + print(logit_map) + loss = loss_function(logit_map, y) + loss.backward() + epoch_loss += loss.item() + optimizer.step() + optimizer.zero_grad() + epoch_iterator.set_description( # noqa: B038 + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, config["max_iterations"], loss) + ) + if (global_step % config["eval_num"] == 0 and global_step != 0) or global_step == config["max_iterations"]: + epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) + dice_val = validation(epoch_iterator_val) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + metric_values.append(dice_val) + if dice_val > dice_val_best: + dice_val_best = dice_val + global_step_best = global_step + torch.save(model.state_dict(), config["best_model_path"]) + print( + "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val) + ) + else: + print( + "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format( + dice_val_best, dice_val + ) + ) + global_step += 1 + return global_step, dice_val_best, global_step_best + + +def main(): + """ + Main function of the script. + """ + + # We get the parser and parse the arguments + parser = get_parser() + args = parser.parse_args() + + # We load the config file (a yml file) + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + + ##### ------------------ + # Monai should be installed with pip install monai[all] (to get all readers) + # We define the trasnformations for training and validation + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RSP"), + Spacingd( + keys=["image", "label"], + pixdim=config["pixdim"], + mode=("bilinear", "nearest"), + ), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=config["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=0.2, + ), + RandAdjustContrastd( + keys=["image"], + prob=0.2, + gamma=(0.5, 4.5), + invert_image=True, + ), + NormalizeIntensityd( + keys=["image", "label"], + nonzero=False, + channel_wise=False + ), + AsDiscreted( + keys=["label"], + num_classes=2, + threshold_values=True, + logit_thresh=0.2, + ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RSP"), + Spacingd( + keys=["image", "label"], + pixdim=(1., 1., 1.0), + mode=("bilinear", "nearest"), + ), + NormalizeIntensityd( + keys=["image", "label"], + nonzero=False, + channel_wise=False + ), + AsDiscreted( + keys=["label"], + num_classes=2, + threshold_values=True, + logit_thresh=0.2, + ) + ] + ) + + # Path to data split (JSON file) + data_split_json_path = config["data"] + # We load the data lists + with open(data_split_json_path, "r") as f: + data = json.load(f) + train_list = data["train"] + val_list = data["validation"] + + # Path to the output directory + output_dir = config["output_dir"] + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # We load the train and validation data + print("Loading the training and validation data...") + # train_files = load_decathlon_datalist(data, True, "train") + train_ds = CacheDataset( + data=train_list, + transform=train_transforms, + cache_rate=0.25, + num_workers=4 + ) + train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) + # val_files = load_decathlon_datalist(data, True, "validation") + val_ds = CacheDataset( + data=val_list, + transform=val_transforms, + cache_rate=0.25, + num_workers=4, + ) + val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + + # plot 3 image and save them + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + for i, ax in enumerate(axes): + img = train_ds[i][0]['image'] + ax.imshow(img[0, 7, :, :], cmap="gray") + ax.set_title(f"Image {i+1}") + ax.axis('on') + plt.savefig(os.path.join(output_dir, "image.png")) + + + print("Preparing the UNETR model...") + # we define the device to use + device = torch.device("cuda") + + model = UNETR( + in_channels=1, + out_channels=2, + img_size=config["spatial_size"], + feature_size=config["feature_size"], + hidden_size=config["hidden_size"], + mlp_dim=config["mlp_dim"], + num_heads=config["num_heads"], + proj_type="perceptron", + norm_name="instance", + res_block=True, + dropout_rate=0.0, + ).to(device) + + loss_function = DiceCELoss(to_onehot_y=True, softmax=True) + torch.backends.cudnn.benchmark = True + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) + + # We then train the model + post_label = AsDiscrete(to_onehot=2) + post_pred = AsDiscrete(argmax=True, to_onehot=2) + dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) + global_step = 0 + dice_val_best = 0.0 + global_step_best = 0 + epoch_loss_values = [] + metric_values = [] + while global_step < config["max_iterations"]: + global_step, dice_val_best, global_step_best = train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader) + model.load_state_dict(torch.load(config["best_model_path"])) + + print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}") + + +if __name__ == "__main__": + main() + + + + + + + From e815e14a8ab5de06bd4b5b3a77f1377aa17e4ab4 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 17:26:40 -0400 Subject: [PATCH 014/108] added fake config file (for debugging) --- monai/nnunet/config_fake.yml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 monai/nnunet/config_fake.yml diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml new file mode 100644 index 0000000..4cc5e99 --- /dev/null +++ b/monai/nnunet/config_fake.yml @@ -0,0 +1,26 @@ +# Description: Configuration file for the UNETR model +# Path to the data json file +data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json +# Path to the output directory +output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ +# Resampling resolution +pixdim : [1.0, 1.0, 1.0] +# Spatial size of the input data +spatial_size : [16, 224, 224] + +# UNETR model parameters +feature_size : 16 +hidden_size : 768 +mlp_dim : 3072 +num_heads : 12 + +# Optimizer parameters +lr : 0.0001 +weight_decay: 0.00001 + +# Training parameters +max_iterations : 25000 +eval_num : 500 + +# Model saving +best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth \ No newline at end of file From 5aaf6db9af1d302bdc4245d9c038b4c5432bf0ff Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 17:27:23 -0400 Subject: [PATCH 015/108] updated requirements to add monai[all] for problem with data loading --- monai/nnunet/requirements.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/nnunet/requirements.txt b/monai/nnunet/requirements.txt index 7503a26..e61d982 100644 --- a/monai/nnunet/requirements.txt +++ b/monai/nnunet/requirements.txt @@ -1,3 +1,5 @@ -scikit-learn tqdm -loguru +monai[all] +torch +torchvision +matplotlib From 8eafefffbf7dc6152b40e8c34f18736fb1dc6ce5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 17:28:58 -0400 Subject: [PATCH 016/108] removed old config file --- monai/nnunet/config.yml | 71 ----------------------------------------- 1 file changed, 71 deletions(-) delete mode 100644 monai/nnunet/config.yml diff --git a/monai/nnunet/config.yml b/monai/nnunet/config.yml deleted file mode 100644 index 7ee1dd4..0000000 --- a/monai/nnunet/config.yml +++ /dev/null @@ -1,71 +0,0 @@ -seed: 15 -save_test_preds: True - -directories: - # Path to the saved models directory - models_dir: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/models - # Path to the saved results directory - results_dir: /home/GRAMES.POLYMTL.CAp119007/ms_lesion_agnostic/results/models_followup - # Path to the saved wandb logs directory - # if None, starts training from scratch. Otherwise, resumes training from the specified wandb run folder - wandb_run_folder: None - -dataset: - # Dataset name (will be used as "group_name" for wandb logging) - name: ms-lesion-agnostic - # Path to the dataset directory containing all datalists (.json files) - root_dir: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-12_seed42.json - # Type of contrast to be used for training. "all" corresponds to training on all contrasts - contrast: all # choices: ["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"] - # Type of label to be used for training. - label_type: soft_bin # choices: ["hard", "soft", "soft_bin"] - -preprocessing: - # Online resampling of images to the specified spacing. - spacing: [1.0, 1.0, 1.0] - # Center crop/pad images to the specified size. (NOTE: done after resampling) - # values correspond to R-L, A-P, I-S axes of the image after 1mm isotropic resampling. - crop_pad_size: [64, 192, 320] - -opt: - name: adam - lr: 0.001 - max_epochs: 200 - batch_size: 2 - # Interval between validation checks in epochs - check_val_every_n_epochs: 5 - # Early stopping patience (this is until patience * check_val_every_n_epochs) - early_stopping_patience: 20 - - -model: - # Model architecture to be used for training (also to be specified as args in the command line) - nnunet: - # NOTE: these info are typically taken from nnUNetPlans.json (if an nnUNet model is trained) - base_num_features: 32 - max_num_features: 320 - n_conv_per_stage_encoder: [2, 2, 2, 2, 2, 2] - n_conv_per_stage_decoder: [2, 2, 2, 2, 2] - pool_op_kernel_sizes: [ - [1, 1, 1], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [1, 2, 2], - [1, 2, 2] - ] - enable_deep_supervision: True - - mednext: - num_input_channels: 1 - base_num_features: 32 - num_classes: 1 - kernel_size: 3 # 3x3x3 and 5x5x5 were tested in publication - block_counts: [2,2,2,2,1,1,1,1,1] # number of blocks in each layer - enable_deep_supervision: True - - swinunetr: - spatial_dims: 3 - depths: [2, 2, 2, 2] - num_heads: [3, 6, 12, 24] # number of heads in multi-head Attention - feature_size: 36 \ No newline at end of file From 8f43eacd874eddb1da004df05a390b86d4bba0a2 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 19:18:46 -0400 Subject: [PATCH 017/108] changed links to dataset --- monai/nnunet/config_fake.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index 4cc5e99..9cd1d07 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -1,15 +1,16 @@ # Description: Configuration file for the UNETR model # Path to the data json file data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json +#data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution pixdim : [1.0, 1.0, 1.0] # Spatial size of the input data -spatial_size : [16, 224, 224] +spatial_size : [32, 176, 176] # UNETR model parameters -feature_size : 16 +feature_size : 8 hidden_size : 768 mlp_dim : 3072 num_heads : 12 @@ -20,7 +21,7 @@ weight_decay: 0.00001 # Training parameters max_iterations : 25000 -eval_num : 500 +eval_num : 5 # Model saving best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth \ No newline at end of file From 86e1a32cc18083107d975ecae5ff1fb5aa6d62af Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 19:19:19 -0400 Subject: [PATCH 018/108] fixed training. Still need to fix validation params --- monai/nnunet/train_monai_UNETR.py | 41 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/monai/nnunet/train_monai_UNETR.py b/monai/nnunet/train_monai_UNETR.py index ace0382..c8d955c 100644 --- a/monai/nnunet/train_monai_UNETR.py +++ b/monai/nnunet/train_monai_UNETR.py @@ -39,7 +39,8 @@ BatchInverseTransform, RandAdjustContrastd, AsDiscreted, - RandHistogramShiftd + RandHistogramShiftd, + ResizeWithPadOrCropd ) # Dataset import @@ -49,7 +50,7 @@ # model import import torch from monai.networks.nets import UNETR -from monai.losses import DiceCELoss +from monai.losses import DiceLoss # For training and validation from monai.data import decollate_batch @@ -85,7 +86,7 @@ def validation(model, epoch_iterator_val, config, post_label, post_pred, dice_me dice_metric.reset() return mean_dice_val -def train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader): +def train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric): model.train() epoch_loss = 0 step = 0 @@ -94,7 +95,6 @@ def train(model, config, global_step, train_loader, dice_val_best, global_step_b step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) - print(logit_map) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() @@ -105,7 +105,7 @@ def train(model, config, global_step, train_loader, dice_val_best, global_step_b ) if (global_step % config["eval_num"] == 0 and global_step != 0) or global_step == config["max_iterations"]: epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) - dice_val = validation(epoch_iterator_val) + dice_val = validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) @@ -152,8 +152,9 @@ def main(): Spacingd( keys=["image", "label"], pixdim=config["pixdim"], - mode=("bilinear", "nearest"), + mode=(2, 1), ), + ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=config["spatial_size"],), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", @@ -203,13 +204,24 @@ def main(): ) val_transforms = Compose( [ - LoadImaged(keys=["image", "label"]), + LoadImaged(keys=["image", "label"], reader="NibabelReader"), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RSP"), Spacingd( keys=["image", "label"], - pixdim=(1., 1., 1.0), - mode=("bilinear", "nearest"), + pixdim=config["pixdim"], + mode=(2, 1), + ), + ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=config["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=config["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, ), NormalizeIntensityd( keys=["image", "label"], @@ -273,7 +285,7 @@ def main(): model = UNETR( in_channels=1, - out_channels=2, + out_channels=1, img_size=config["spatial_size"], feature_size=config["feature_size"], hidden_size=config["hidden_size"], @@ -285,13 +297,13 @@ def main(): dropout_rate=0.0, ).to(device) - loss_function = DiceCELoss(to_onehot_y=True, softmax=True) + loss_function = DiceLoss(softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) # We then train the model - post_label = AsDiscrete(to_onehot=2) - post_pred = AsDiscrete(argmax=True, to_onehot=2) + post_label = AsDiscrete(to_onehot=1) + post_pred = AsDiscrete(argmax=True, to_onehot=1) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 @@ -299,7 +311,8 @@ def main(): epoch_loss_values = [] metric_values = [] while global_step < config["max_iterations"]: - global_step, dice_val_best, global_step_best = train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader) + global_step, dice_val_best, global_step_best = train(model, config, global_step, train_loader, dice_val_best, global_step_best, + loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric) model.load_state_dict(torch.load(config["best_model_path"])) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}") From 32ca1cddaaad0963327bbd719dcb1a2aed538070 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 13 Mar 2024 19:20:35 -0400 Subject: [PATCH 019/108] removed files using pytorch ligthning training --- monai/nnunet/main.py | 630 ------------------------------------- monai/nnunet/models.py | 90 ------ monai/nnunet/transforms.py | 59 ---- 3 files changed, 779 deletions(-) delete mode 100644 monai/nnunet/main.py delete mode 100644 monai/nnunet/models.py delete mode 100644 monai/nnunet/transforms.py diff --git a/monai/nnunet/main.py b/monai/nnunet/main.py deleted file mode 100644 index 1d0f163..0000000 --- a/monai/nnunet/main.py +++ /dev/null @@ -1,630 +0,0 @@ -import os -import argparse -from datetime import datetime -from loguru import logger -import yaml - -import numpy as np -import wandb -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -import matplotlib.pyplot as plt - -from utils import dice_score, PolyLRScheduler, plot_slices, check_empty_patch -from losses import AdapWingLoss -from transforms import train_transforms, val_transforms -from models import create_nnunet_from_plans - -from monai.utils import set_determinism -from monai.inferers import sliding_window_inference -from monai.networks.nets import UNETR, SwinUNETR -from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) - -# mednext -# from nnunet_mednext import MedNeXt - -def get_args(): - parser = argparse.ArgumentParser(description='Script for training contrast-agnositc SC segmentation model.') - - # arguments for model - parser.add_argument('-m', '--model', choices=['nnunet', 'mednext', 'swinunetr'], - default='nnunet', type=str, - help='Model type to be used. Options: nnunet, mednext, swinunetr.') - # path to the config file - parser.add_argument("--config", type=str, default="./config.json", - help="Path to the config file containing all training details.") - # saving - parser.add_argument('--debug', default=False, action='store_true', help='if true, results are not logged to wandb') - parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', - help='Load model from checkpoint and continue training') - args = parser.parse_args() - - return args - - -# create a "model"-agnostic class with PL to use different models -class Model(pl.LightningModule): - def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): - super().__init__() - self.cfg = config - self.save_hyperparameters(ignore=['net', 'loss_function']) - - self.root = data_root - self.net = net - self.lr = config["opt"]["lr"] - self.loss_function = loss_function - self.optimizer_class = optimizer_class - self.save_exp_id = exp_id - self.results_path = results_path - - self.best_val_dice, self.best_val_epoch = 0, 0 - self.best_val_loss = float("inf") - - # define cropping and padding dimensions - # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference - # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches - # which could be sub-optimal. - # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that - # only the SC is in context. - self.spacing = config["preprocessing"]["spacing"] - self.voxel_cropping_size = self.inference_roi_size = config["preprocessing"]["crop_pad_size"] - - # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) - self.val_post_pred = self.val_post_label = Compose([EnsureType()]) - - # define evaluation metric - self.soft_dice_metric = dice_score - - # temp lists for storing outputs from training, validation, and testing - self.train_step_outputs = [] - self.val_step_outputs = [] - self.test_step_outputs = [] - - - # -------------------------------- - # FORWARD PASS - # -------------------------------- - def forward(self, x): - - out = self.net(x) - # # NOTE: MONAI's models only output the logits, not the output after the final activation function - # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the - # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) - # # as the final block applied to the input, which is just a convolutional layer with no activation function - # # Hence, we are used Normalized ReLU to normalize the logits to the final output - # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) - - return out # returns logits - - - # -------------------------------- - # DATA PREPARATION - # -------------------------------- - def prepare_data(self): - # set deterministic training for reproducibility - set_determinism(seed=self.cfg["seed"]) - - # define training and validation transforms - transforms_train = train_transforms( - crop_size=self.voxel_cropping_size, - lbl_key='label' - ) - transforms_val = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') - - # load the dataset - logger.info(f"Training with {self.cfg['dataset']['label_type']} labels ...") - dataset = os.path.join(self.root, - f"dataset_{self.cfg['dataset']['contrast']}_{self.cfg['dataset']['label_type']}_seed{self.cfg['seed']}.json" - ) - logger.info(f"Loading dataset: {dataset}") - train_files = load_decathlon_datalist(dataset, True, "train") - val_files = load_decathlon_datalist(dataset, True, "validation") - test_files = load_decathlon_datalist(dataset, True, "test") - - if args.debug: - train_files = train_files[:15] - val_files = val_files[:15] - test_files = test_files[:6] - - train_cache_rate = 0.25 if args.debug else 0.5 - self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=4) - self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) - - # define test transforms - transforms_test = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') - - # define post-processing transforms for testing; taken (with explanations) from - # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - self.test_post_pred = Compose([ - EnsureTyped(keys=["pred", "label"]), - Invertd(keys=["pred", "label"], transform=transforms_test, - orig_keys=["image", "label"], - meta_keys=["pred_meta_dict", "label_meta_dict"], - nearest_interp=False, to_tensor=True), - ]) - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) - - - # -------------------------------- - # DATA LOADERS - # -------------------------------- - def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.cfg["opt"]["batch_size"], shuffle=True, num_workers=16, - pin_memory=True, persistent_workers=True) - - def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, - persistent_workers=True) - - def test_dataloader(self): - return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) - - - # -------------------------------- - # OPTIMIZATION - # -------------------------------- - def configure_optimizers(self): - if self.cfg["opt"]["name"] == "sgd": - optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=3e-5, nesterov=True) - else: - optimizer = self.optimizer_class(self.parameters(), lr=self.lr) - # scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) - # NOTE: ivadomed using CosineAnnealingLR with T_max = 200 - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["opt"]["max_epochs"]) - return [optimizer], [scheduler] - - - # -------------------------------- - # TRAINING - # -------------------------------- - def training_step(self, batch, batch_idx): - - inputs, labels = batch["image"], batch["label"] - - # check if any label image patch is empty in the batch - if check_empty_patch(labels) is None: - # print(f"Empty label patch found. Skipping training step ...") - return None - - output = self.forward(inputs) # logits - # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - - if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: - - # calculate dice loss for each output - loss, train_soft_dice = 0.0, 0.0 - for i in range(len(output)): - # give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - # NOTE: outputs[0] is the final pred, outputs[-1] is the lowest resolution pred (at the bottleneck) - # we're downsampling the GT to the resolution of each deepsupervision feature map output - # (instead of upsampling each deepsupervision feature map output to the final resolution) - downsampled_gt = F.interpolate(labels, size=output[i].shape[-3:], mode='trilinear', align_corners=False) - # print(f"downsampled_gt.shape: {downsampled_gt.shape} \t output[i].shape: {output[i].shape}") - loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) - - # get probabilities from logits - out = F.relu(output[i]) / F.relu(output[i]).max() if bool(F.relu(output[i]).max()) else F.relu(output[i]) - - # calculate train dice - # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here - # So, take this dice score with a lot of salt - train_soft_dice += self.soft_dice_metric(out, downsampled_gt) - - # average dice loss across the outputs - loss /= len(output) - train_soft_dice /= len(output) - - else: - # calculate training loss - loss = self.loss_function(output, labels) - - # get probabilities from logits - output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) - - # calculate train dice - # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here - # So, take this dice score with a lot of salt - train_soft_dice = self.soft_dice_metric(output, labels) - - metrics_dict = { - "loss": loss.cpu(), - "train_soft_dice": train_soft_dice.detach().cpu(), - "train_number": len(inputs), - # "train_image": inputs[0].detach().cpu().squeeze(), - # "train_gt": labels[0].detach().cpu().squeeze(), - # "train_pred": output[0].detach().cpu().squeeze() - } - self.train_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_train_epoch_end(self): - - if self.train_step_outputs == []: - # means the training step was skipped because of empty input patch - return None - else: - train_loss, train_soft_dice = 0, 0 - num_items = len(self.train_step_outputs) - for output in self.train_step_outputs: - train_loss += output["loss"].item() - train_soft_dice += output["train_soft_dice"].item() - - mean_train_loss = (train_loss / num_items) - mean_train_soft_dice = (train_soft_dice / num_items) - - wandb_logs = { - "train_soft_dice": mean_train_soft_dice, - "train_loss": mean_train_loss, - } - self.log_dict(wandb_logs) - - # # plot the training images - # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], - # gt=self.train_step_outputs[0]["train_gt"], - # pred=self.train_step_outputs[0]["train_pred"], - # debug=args.debug) - # wandb.log({"training images": wandb.Image(fig)}) - - # free up memory - self.train_step_outputs.clear() - wandb_logs.clear() - # plt.close(fig) - - - # -------------------------------- - # VALIDATION - # -------------------------------- - def validation_step(self, batch, batch_idx): - - inputs, labels = batch["image"], batch["label"] - - # NOTE: this calculates the loss on the entire image after sliding window - outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", - sw_batch_size=4, predictor=self.forward, overlap=0.5,) - # outputs shape: (B, C, ) - if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: - # we only need the output with the highest resolution - outputs = outputs[0] - - # calculate validation loss - loss = self.loss_function(outputs, labels) - - # get probabilities from logits - outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) - - # post-process for calculating the evaluation metric - post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] - post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] - val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) - - hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() - val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) - - # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, - # using .detach() to avoid storing the whole computation graph - # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 - metrics_dict = { - "val_loss": loss.detach().cpu(), - "val_soft_dice": val_soft_dice.detach().cpu(), - "val_hard_dice": val_hard_dice.detach().cpu(), - "val_number": len(post_outputs), - # "val_image": inputs[0].detach().cpu().squeeze(), - # "val_gt": labels[0].detach().cpu().squeeze(), - # "val_pred": post_outputs[0].detach().cpu().squeeze(), - } - self.val_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_validation_epoch_end(self): - - val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 - for output in self.val_step_outputs: - val_loss += output["val_loss"].sum().item() - val_soft_dice += output["val_soft_dice"].sum().item() - val_hard_dice += output["val_hard_dice"].sum().item() - num_items += output["val_number"] - - mean_val_loss = (val_loss / num_items) - mean_val_soft_dice = (val_soft_dice / num_items) - mean_val_hard_dice = (val_hard_dice / num_items) - - wandb_logs = { - "val_soft_dice": mean_val_soft_dice, - "val_hard_dice": mean_val_hard_dice, - "val_loss": mean_val_loss, - } - # save the best model based on validation dice score - if mean_val_soft_dice > self.best_val_dice: - self.best_val_dice = mean_val_soft_dice - self.best_val_epoch = self.current_epoch - - # save the best model based on validation CSA loss - if mean_val_loss < self.best_val_loss: - self.best_val_loss = mean_val_loss - self.best_val_epoch = self.current_epoch - - logger.info( - f"\nCurrent epoch: {self.current_epoch}" - f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" - f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" - f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" - f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" - f"\n----------------------------------------------------") - - - # log on to wandb - self.log_dict(wandb_logs) - - # # plot the validation images - # fig = plot_slices(image=self.val_step_outputs[0]["val_image"], - # gt=self.val_step_outputs[0]["val_gt"], - # pred=self.val_step_outputs[0]["val_pred"],) - # wandb.log({"validation images": wandb.Image(fig)}) - - # free up memory - self.val_step_outputs.clear() - wandb_logs.clear() - # plt.close(fig) - - # return {"log": wandb_logs} - - # -------------------------------- - # TESTING - # -------------------------------- - def test_step(self, batch, batch_idx): - - test_input = batch["image"] - # print(batch["label_meta_dict"]["filename_or_obj"][0]) - batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, - sw_batch_size=4, predictor=self.forward, overlap=0.5) - - if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: - # we only need the output with the highest resolution - batch["pred"] = batch["pred"][0] - - # normalize the logits - batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) - - post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] - - # make sure that the shapes of prediction and GT label are the same - # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") - assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape - - pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() - - # save the prediction and label - if self.cfg["save_test_preds"]: - - subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") - logger.info(f"Saving subject: {subject_name}") - - # image saver class - save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) - pred_saver = SaveImage( - output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", - separate_folder=False, print_log=False, resample=True) - # save the prediction - pred_saver(pred) - - # label_saver = SaveImage( - # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", - # separate_folder=False, print_log=False, resample=True) - # # save the label - # label_saver(label) - - - # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics - # calculate soft and hard dice here (for quick overview), other metrics can be computed from - # the saved predictions using ANIMA - # 1. Dice Score - test_soft_dice = self.soft_dice_metric(pred, label) - - # binarizing the predictions - pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() - label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() - - # 1.1 Hard Dice Score - test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) - - metrics_dict = { - "test_hard_dice": test_hard_dice, - "test_soft_dice": test_soft_dice, - } - self.test_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_test_epoch_end(self): - - avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ - np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() - avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ - np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() - - logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") - logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") - - self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test - self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test - - # free up memory - self.test_step_outputs.clear() - - -# -------------------------------- -# MAIN -# -------------------------------- -def main(args): - - # load config file - with open(args.config, "r") as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - # Setting the seed - pl.seed_everything(config["seed"], workers=True) - - # define root path for finding datalists - dataset_root = config["dataset"]["root_dir"] - - # define optimizer - if config["opt"]["name"] == "adam": - optimizer_class = torch.optim.Adam - - if config["model"]["nnunet"]["enable_deep_supervision"]: - logger.info(f"Using nnUNet model WITH deep supervision ...") - else: - logger.info(f"Using nnUNet model WITHOUT deep supervision ...") - - logger.info("Defining plans for nnUNet model ...") - # ========================================================================================= - # Define plans json taken from nnUNet_preprocessed folder - # ========================================================================================= - nnunet_plans = { - "UNet_class_name": "PlainConvUNet", - "UNet_base_num_features": config["model"]["nnunet"]["base_num_features"], - "n_conv_per_stage_encoder": config["model"]["nnunet"]["n_conv_per_stage_encoder"], - "n_conv_per_stage_decoder": config["model"]["nnunet"]["n_conv_per_stage_decoder"], - "pool_op_kernel_sizes": config["model"]["nnunet"]["pool_op_kernel_sizes"], - "conv_kernel_sizes": [ - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3] - ], - "unet_max_num_features": config["model"]["nnunet"]["max_num_features"], - } - - # define model - net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, - deep_supervision=config["model"]["nnunet"]["enable_deep_supervision"]) - # variable for saving patch size in the experiment id (same as crop_pad_size) - patch_size = f"{config['preprocessing']['crop_pad_size'][0]}x" \ - f"{config['preprocessing']['crop_pad_size'][1]}x" \ - f"{config['preprocessing']['crop_pad_size'][2]}" - # save experiment id - save_exp_id = f"{args.model}_seed={config['seed']}_" \ - f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ - f"nf={config['model']['nnunet']['base_num_features']}_" \ - f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ - f"bs={config['opt']['batch_size']}_{patch_size}" \ - - - timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format - save_exp_id = f"{save_exp_id}_{timestamp}" - - # save output to a log file - logger.add(os.path.join(config["directories"]["models_dir"], f"{save_exp_id}", "logs.txt"), rotation="10 MB", level="INFO") - - # save config file to the output folder - with open(os.path.join(config["directories"]["models_dir"], f"{save_exp_id}", "config.yaml"), "w") as f: - yaml.dump(config, f) - - # define loss function - loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above - # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") - logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") - - # define callbacks - early_stopping = pl.callbacks.EarlyStopping( - monitor="val_loss", min_delta=0.00, - patience=config["opt"]["early_stopping_patience"], - verbose=False, mode="min") - - lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') - - # training from scratch - # to save the best model on validation - save_path = os.path.join(config["directories"]["models_dir"], f"{save_exp_id}") - if not os.path.exists(save_path): - os.makedirs(save_path, exist_ok=True) - - # to save the results/model predictions - results_path = os.path.join(config["directories"]["results_dir"], f"{save_exp_id}") - if not os.path.exists(results_path): - os.makedirs(results_path, exist_ok=True) - - # i.e. train by loading weights from scratch - pl_model = Model(config, data_root=dataset_root, - optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id=save_exp_id, results_path=results_path) - - # saving the best model based on validation loss - logger.info(f"Saving best model to {save_path}!") - checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=True, save_weights_only=False) - - # # saving the best model based on soft validation dice score - # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( - # dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', - # save_top_k=1, mode="max", save_last=False, save_weights_only=True) - - logger.info(f"Starting training from scratch ...") - # wandb logger - exp_logger = pl.loggers.WandbLogger( - name=save_exp_id, - save_dir="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", - group=config["dataset"]["name"], - log_model=True, # save best model using checkpoint callback - project='contrast-agnostic', - entity='naga-karthik', - config=config) - - # Saving training script to wandb - wandb.save("main.py") - wandb.save("transforms.py") - - # initialise Lightning's trainer. - trainer = pl.Trainer( - devices=1, accelerator="gpu", - logger=exp_logger, - callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], - check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], - max_epochs=config["opt"]["max_epochs"], - precision=32, - # deterministic=True, - enable_progress_bar=False) - # profiler="simple",) # to profile the training time taken for each step - - # Train! - trainer.fit(pl_model) - logger.info(f" Training Done!") - - # Test! - trainer.test(pl_model) - logger.info(f"TESTING DONE!") - - # closing the current wandb instance so that a new one is created for the next fold - wandb.finish() - - # TODO: Figure out saving test metrics to a file - with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: - print('\n-------------- Test Metrics ----------------', file=f) - print(f"{args.model}_seed={config['seed']}_" \ - f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ - f"nf={config['model']['nnunet']['base_num_features']}_" \ - f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ - f"bs={config['opt']['batch_size']}_{patch_size}" \ - f"_{timestamp}", file=f) - - print('\n-------------- Test Hard Dice Scores ----------------', file=f) - print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) - - print('\n-------------- Test Soft Dice Scores ----------------', file=f) - print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) - - print('-------------------------------------------------------', file=f) - - -if __name__ == "__main__": - args = get_args() - main(args) \ No newline at end of file diff --git a/monai/nnunet/models.py b/monai/nnunet/models.py deleted file mode 100644 index 55f8a91..0000000 --- a/monai/nnunet/models.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -# ---------------------------- Imports for nnUNet's Model ----------------------------- -from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet -from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 - - -# ====================================================================================================== -# Utils for nnUNet's Model -# ==================================================================================================== -class InitWeights_He(object): - def __init__(self, neg_slope=1e-2): - self.neg_slope = neg_slope - - def __call__(self, module): - if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): - module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) - if module.bias is not None: - module.bias = nn.init.constant_(module.bias, 0) - - -# ====================================================================================================== -# Define the network based on plans json -# ==================================================================================================== -def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): - """ - Adapted from nnUNet's source code: - https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 - - """ - num_stages = len(plans["conv_kernel_sizes"]) - - dim = len(plans["conv_kernel_sizes"][0]) - conv_op = convert_dim_to_conv_op(dim) - - segmentation_network_class_name = plans["UNet_class_name"] - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accomodate that.' - network_class = mapping[segmentation_network_class_name] - - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], - 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] - } - - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, - plans["unet_max_num_features"]) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=plans["conv_kernel_sizes"], - strides=plans["pool_op_kernel_sizes"], - num_classes=num_classes, - deep_supervision=deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] - ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - - return model \ No newline at end of file diff --git a/monai/nnunet/transforms.py b/monai/nnunet/transforms.py deleted file mode 100644 index 7f89cc1..0000000 --- a/monai/nnunet/transforms.py +++ /dev/null @@ -1,59 +0,0 @@ - -import numpy as np -from monai.transforms import (Compose, CropForegroundd, LoadImaged, RandFlipd, - Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, - DivisiblePadd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, - RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, - RandSimulateLowResolutiond, ResizeWithPadOrCropd) - - -def train_transforms(crop_size, lbl_key="label"): - - monai_transforms = [ - # pre-processing - LoadImaged(keys=["image", lbl_key]), - EnsureChannelFirstd(keys=["image", lbl_key]), - # NOTE: spine interpolation with order=2 is spline, order=1 is linear - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), - ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - # data-augmentation - RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.9, - rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians - scale_range=(-0.2, 0.2), - translate_range=(-0.1, 0.1)), - Rand3DElasticd(keys=["image", lbl_key], prob=0.5, - sigma_range=(3.5, 5.5), - magnitude_range=(25., 35.)), - RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), - RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma - RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), - RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), - RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), - RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform - RandFlipd(keys=["image", lbl_key], prob=0.3,), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - ] - - return Compose(monai_transforms) - -def inference_transforms(crop_size, lbl_key="label"): - return Compose([ - LoadImaged(keys=["image", lbl_key], image_only=False), - EnsureChannelFirstd(keys=["image", lbl_key]), - # CropForegroundd(keys=["image", lbl_key], source_key="image"), - Orientationd(keys=["image", lbl_key], axcodes="RPI"), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), - ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - DivisiblePadd(keys=["image", lbl_key], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - ]) - -def val_transforms(crop_size, lbl_key="label"): - return Compose([ - LoadImaged(keys=["image", lbl_key], image_only=False), - EnsureChannelFirstd(keys=["image", lbl_key]), - # CropForegroundd(keys=["image", lbl_key], source_key="image"), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), - ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - ]) \ No newline at end of file From 56479550e6a58df73d9a5946281dbf963e321cd3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 14 Mar 2024 18:44:39 -0400 Subject: [PATCH 020/108] working monai script based on Jan's code: but no dice score improvement --- monai/nnunet/config_fake.yml | 20 +- monai/nnunet/train_monai_UNETR.py | 316 +++++++++++++++++++++--------- 2 files changed, 239 insertions(+), 97 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index 9cd1d07..3ec2549 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -1,13 +1,14 @@ # Description: Configuration file for the UNETR model # Path to the data json file -data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json -#data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json +#data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json +data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution pixdim : [1.0, 1.0, 1.0] # Spatial size of the input data -spatial_size : [32, 176, 176] +spatial_size : [16, 176, 176] +batch_size : 4 # UNETR model parameters feature_size : 8 @@ -18,10 +19,19 @@ num_heads : 12 # Optimizer parameters lr : 0.0001 weight_decay: 0.00001 +early_stopping_patience : 10 # Training parameters max_iterations : 25000 -eval_num : 5 +eval_num : 10 # Model saving -best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth \ No newline at end of file +best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth + +# log saving +log_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ + +# WANDB +experiment_name : monai_unet_canproco + +seed : 42 \ No newline at end of file diff --git a/monai/nnunet/train_monai_UNETR.py b/monai/nnunet/train_monai_UNETR.py index c8d955c..ed40c29 100644 --- a/monai/nnunet/train_monai_UNETR.py +++ b/monai/nnunet/train_monai_UNETR.py @@ -22,6 +22,14 @@ from tqdm import tqdm import matplotlib.pyplot as plt import yaml +import numpy as np +import wandb +import time +from loguru import logger + +from monai.networks.layers import Norm + +os.environ["PYTORCH_USE_CUDA_DSA"] = "1" #Transforms import @@ -40,7 +48,8 @@ RandAdjustContrastd, AsDiscreted, RandHistogramShiftd, - ResizeWithPadOrCropd + ResizeWithPadOrCropd, + EnsureTyped ) # Dataset import @@ -49,7 +58,7 @@ # model import import torch -from monai.networks.nets import UNETR +from monai.networks.nets import UNETR, UNet from monai.losses import DiceLoss # For training and validation @@ -70,60 +79,80 @@ def get_parser(): return parser -def validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step): - model.eval() - with torch.no_grad(): - for batch in epoch_iterator_val: - val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) - val_outputs = sliding_window_inference(val_inputs, config["spatial_size"], 4, model) - val_labels_list = decollate_batch(val_labels) - val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] - val_outputs_list = decollate_batch(val_outputs) - val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] - dice_metric(y_pred=val_output_convert, y=val_labels_convert) - epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0)) # noqa: B038 - mean_dice_val = dice_metric.aggregate().item() - dice_metric.reset() - return mean_dice_val - -def train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric): - model.train() - epoch_loss = 0 - step = 0 - epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) - for step, batch in enumerate(epoch_iterator): - step += 1 - x, y = (batch["image"].cuda(), batch["label"].cuda()) - logit_map = model(x) - loss = loss_function(logit_map, y) - loss.backward() - epoch_loss += loss.item() - optimizer.step() - optimizer.zero_grad() - epoch_iterator.set_description( # noqa: B038 - "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, config["max_iterations"], loss) - ) - if (global_step % config["eval_num"] == 0 and global_step != 0) or global_step == config["max_iterations"]: - epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) - dice_val = validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - metric_values.append(dice_val) - if dice_val > dice_val_best: - dice_val_best = dice_val - global_step_best = global_step - torch.save(model.state_dict(), config["best_model_path"]) - print( - "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val) - ) - else: - print( - "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format( - dice_val_best, dice_val - ) - ) - global_step += 1 - return global_step, dice_val_best, global_step_best +# def validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step): +# model.eval() +# with torch.no_grad(): +# for batch in epoch_iterator_val: +# val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) +# val_outputs = model(val_inputs) +# dice_metric(y_pred=val_outputs, y=val_labels) +# # val_outputs = sliding_window_inference(val_inputs, config["spatial_size"], 1, model) +# # val_labels_list = decollate_batch(val_labels) +# # val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] +# # val_outputs_list = decollate_batch(val_outputs) +# # val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] +# # dice_metric(y_pred=val_output_convert, y=val_labels_convert) +# epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0)) # noqa: B038 +# # for batch in epoch_iterator_val: +# # val_inputs, val_labels = ( +# # batch["image"].cuda(), +# # batch["label"].cuda(), +# # ) +# # # TODO: parametrize this +# # roi_size = config["spatial_size"] +# # sw_batch_size = 4 +# # val_outputs = sliding_window_inference( +# # val_inputs, roi_size, sw_batch_size, model) + +# # val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] +# # val_labels = [post_label(i) for i in decollate_batch(val_labels)] +# # # compute metric for current iteration +# # dice_metric(y_pred=val_outputs, y=val_labels) + +# mean_dice_val = dice_metric.aggregate().item() +# print("Mean dice val: ", mean_dice_val) +# dice_metric.reset() + +# return mean_dice_val + +# def train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric): +# model.train() +# epoch_loss = 0 +# step = 0 +# epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) +# for step, batch in enumerate(epoch_iterator): +# step += 1 +# x, y = (batch["image"].cuda(), batch["label"].cuda()) +# logit_map = model(x) +# loss = loss_function(logit_map, y) +# loss.backward() +# epoch_loss += loss.item() +# optimizer.step() +# optimizer.zero_grad() +# epoch_iterator.set_description( # noqa: B038 +# "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, config["max_iterations"], loss) +# ) +# if (global_step % config["eval_num"] == 0 and global_step != 0) or global_step == config["max_iterations"]: +# epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) +# dice_val = validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step) +# epoch_loss /= step +# epoch_loss_values.append(epoch_loss) +# metric_values.append(dice_val) +# if dice_val > dice_val_best: +# dice_val_best = dice_val +# global_step_best = global_step +# torch.save(model.state_dict(), config["best_model_path"]) +# print( +# "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val) +# ) +# else: +# print( +# "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format( +# dice_val_best, dice_val +# ) +# ) +# global_step += 1 +# return global_step, dice_val_best, global_step_best def main(): @@ -194,6 +223,7 @@ def main(): nonzero=False, channel_wise=False ), + EnsureTyped(keys=["image", "label"]), AsDiscreted( keys=["label"], num_classes=2, @@ -228,6 +258,7 @@ def main(): nonzero=False, channel_wise=False ), + EnsureTyped(keys=["image", "label"]), AsDiscreted( keys=["label"], num_classes=2, @@ -251,23 +282,21 @@ def main(): os.makedirs(output_dir) # We load the train and validation data - print("Loading the training and validation data...") - # train_files = load_decathlon_datalist(data, True, "train") + logger.info("Loading the training and validation data...") train_ds = CacheDataset( data=train_list, transform=train_transforms, - cache_rate=0.25, - num_workers=4 + cache_rate=0.1, + num_workers=2 ) - train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) - # val_files = load_decathlon_datalist(data, True, "validation") + train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True) val_ds = CacheDataset( data=val_list, transform=val_transforms, - cache_rate=0.25, - num_workers=4, + cache_rate=0.1, + num_workers=0, ) - val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) # plot 3 image and save them fig, axes = plt.subplots(1, 3, figsize=(15, 5)) @@ -279,43 +308,146 @@ def main(): plt.savefig(os.path.join(output_dir, "image.png")) - print("Preparing the UNETR model...") + + print("Preparing the UNET model...") # we define the device to use - device = torch.device("cuda") + device = torch.device("cuda:0") - model = UNETR( + model = UNet( + spatial_dims=3, in_channels=1, out_channels=1, - img_size=config["spatial_size"], - feature_size=config["feature_size"], - hidden_size=config["hidden_size"], - mlp_dim=config["mlp_dim"], - num_heads=config["num_heads"], - proj_type="perceptron", - norm_name="instance", - res_block=True, - dropout_rate=0.0, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + kernel_size=3, + up_kernel_size=3, + num_res_units=0, + act='PRELU', + norm=Norm.BATCH, + dropout=0.0, + bias=True, + adn_ordering='NDA', ).to(device) - loss_function = DiceLoss(softmax=True) - torch.backends.cudnn.benchmark = True + loss_function = DiceLoss(to_onehot_y=True, softmax=True) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) + dice_metric = DiceMetric(include_background=False, reduction="mean") + torch.backends.cudnn.benchmark = True + + # initialize wandb + wandb.init(project=config["experiment_name"], config=config) + + # 🐝 Log gen gradients of the models to wandb + wandb.watch(model, log_freq=100) + + # 🐝 Add training script as an artifact + artifact_script = wandb.Artifact(name='training', type='file') + artifact_script.add_file(local_path=os.path.abspath(__file__), name=os.path.basename(__file__)) + wandb.log_artifact(artifact_script) - # We then train the model - post_label = AsDiscrete(to_onehot=1) - post_pred = AsDiscrete(argmax=True, to_onehot=1) - dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) - global_step = 0 - dice_val_best = 0.0 - global_step_best = 0 epoch_loss_values = [] - metric_values = [] - while global_step < config["max_iterations"]: - global_step, dice_val_best, global_step_best = train(model, config, global_step, train_loader, dice_val_best, global_step_best, - loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric) - model.load_state_dict(torch.load(config["best_model_path"])) + step_loss_values = [] + val_loss_values = [] + best_val_loss = 1000.0 + + for epoch in range(config["max_iterations"]): + logger.info("-" * 10) + logger.info(f"epoch {epoch + 1}/{config['max_iterations']}") + model.train() + epoch_loss = 0 + epoch_cl_loss = 0 + epoch_recon_loss = 0 + step = 0 + + for batch_data in train_loader: + step += 1 + start_time = time.time() + + inputs, gt_input = ( + batch_data["image"].to(device), + batch_data["label"].to(device), + ) + + optimizer.zero_grad() + output = model(inputs) - print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}") + loss = loss_function(output, gt_input) + loss.detach().cpu() + + loss.backward() + optimizer.step() + epoch_loss += loss.item() + step_loss_values.append(loss.item()) + + + end_time = time.time() + logger.info( + f"{step}/{len(train_list) // train_loader.batch_size}, " + f"train_loss: {loss.item():.4f}, " + f"time taken: {end_time-start_time}s" + ) + + wandb.log({"Training/loss": loss.item()}) + + epoch_loss /= step + + epoch_loss_values.append(epoch_loss) + + # 🐝 log train_loss averaged over epoch to wandb + wandb.log({"Training/loss_epoch": epoch_loss}) + + + + logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + if epoch % config["eval_num"] == 0: + logger.info("Entering Validation for epoch: {}".format(epoch + 1)) + total_val_loss = 0 + val_step = 0 + model.eval() + for val_batch in val_loader: + val_step += 1 + start_time = time.time() + inputs, gt_input = ( + val_batch["image"].to(device), + val_batch["label"].to(device), + ) + outputs = model(inputs) + val_loss = loss_function(outputs, gt_input) + total_val_loss += val_loss.item() + end_time = time.time() + + total_val_loss /= val_step + val_loss_values.append(total_val_loss) + + wandb.log({"Validation loss": total_val_loss}) + + logger.info(f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") + + if total_val_loss < best_val_loss: + logger.info(f"Saving new model based on validation loss {total_val_loss:.4f}") + best_val_loss = total_val_loss + checkpoint = {"epoch": config["max_iterations"], "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} + torch.save(checkpoint, config["best_model_path"]) + + print("Done") + + + # # We then train the model + # post_label = AsDiscrete(to_onehot=2) + # post_pred = AsDiscrete(argmax=True, to_onehot=2) + # dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) + # global_step = 0 + # dice_val_best = 0.0 + # global_step_best = 0 + # epoch_loss_values = [] + # metric_values = [] + # while global_step < config["max_iterations"]: + # global_step, dice_val_best, global_step_best = train(model, config, global_step, train_loader, dice_val_best, global_step_best, + # loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric) + # model.load_state_dict(torch.load(config["best_model_path"])) + + # print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}") if __name__ == "__main__": From a4d7958168c9ddc93dcb32a6c8f4b57a7423ca45 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 14 Mar 2024 18:45:28 -0400 Subject: [PATCH 021/108] pytorch lightning script based on Naga's work --- monai/nnunet/losses.py | 85 +++ monai/nnunet/train_monai_unet_lightning.py | 588 +++++++++++++++++++++ monai/nnunet/utils.py | 20 + 3 files changed, 693 insertions(+) create mode 100644 monai/nnunet/losses.py create mode 100644 monai/nnunet/train_monai_unet_lightning.py create mode 100644 monai/nnunet/utils.py diff --git a/monai/nnunet/losses.py b/monai/nnunet/losses.py new file mode 100644 index 0000000..fb4ddfd --- /dev/null +++ b/monai/nnunet/losses.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy +import numpy as np + + +class AdapWingLoss(nn.Module): + """ + Adaptive Wing loss used for heatmap regression + Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/losses.py#L341 + + .. seealso:: + Wang, Xinyao, Liefeng Bo, and Li Fuxin. "Adaptive wing loss for robust face alignment via heatmap regression." + Proceedings of the IEEE International Conference on Computer Vision. 2019. + + Args: + theta (float): Threshold to switch between the linear and non-linear parts of the piece-wise loss function. + alpha (float): Used to adapt the behaviour of the loss function at y=0 and y=1 and make loss smooth at 0 (background). + It needs to be slightly above 2 to maintain ideal properties. + omega (float): Multiplicative factor for non linear part of the loss. + epsilon (float): factor to avoid gradient explosion. It must not be too small + NOTE: Larger omega and smaller epsilon values will increase the influence on small errors and vice versa + """ + + def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): + self.theta = theta + self.alpha = alpha + self.omega = omega + self.epsilon = epsilon + self.reduction = reduction + super(AdapWingLoss, self).__init__() + + def forward(self, input, target): + eps = self.epsilon + batch_size = target.size()[0] + + # Adaptive Wing loss. Section 4.2 of the paper. + # Compute adaptive factor + A = self.omega * (1 / (1 + torch.pow(self.theta / eps, + self.alpha - target))) * \ + (self.alpha - target) * torch.pow(self.theta / eps, + self.alpha - target - 1) * (1 / eps) + + # Constant term to link linear and non linear part + C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) + + diff_hm = torch.abs(target - input) + AWingLoss = A * diff_hm - C + idx = diff_hm < self.theta + # NOTE: this is a memory-efficient version than the one in ivadomed losses.py + # where idx is True, compute the non-linear part of the loss, otherwise keep the linear part + # the non-linear parts ensures small errors (as given by idx) have a larger influence to refine the predictions at the boundaries + # the linear part makes the loss function behave more like the MSE loss, which has a linear influence + # (i.e. small errors where y=0 --> small influence --> small gradients) + AWingLoss = torch.where(idx, self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target)), AWingLoss) + + + # Mask for weighting the loss function. Section 4.3 of the paper. + mask = torch.zeros_like(target) + kernel = scipy.ndimage.generate_binary_structure(2, 2) + # For 3D segmentation tasks + if len(input.shape) == 5: + kernel = scipy.ndimage.generate_binary_structure(3, 2) + + for i in range(batch_size): + img_list = list() + img_list.append(np.round(target[i].cpu().numpy() * 255)) + img_merge = np.concatenate(img_list) + img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) + # NOTE: why 51? the paper thresholds the dilated GT heatmap at 0.2. So, 51/255 = 0.2 + img_dilate[img_dilate < 51] = 1 # 0*omega+1 + img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 + img_dilate = np.array(img_dilate, dtype=int) + + mask[i] = torch.tensor(img_dilate) + + AWingLoss *= mask + + sum_loss = torch.sum(AWingLoss) + if self.reduction == "sum": + return sum_loss + elif self.reduction == "mean": + all_pixel = torch.sum(mask) + return sum_loss / all_pixel \ No newline at end of file diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py new file mode 100644 index 0000000..6f4e038 --- /dev/null +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -0,0 +1,588 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +from monai.metrics import DiceMetric + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +from losses import AdapWingLoss + +from utils import dice_score, check_empty_patch +from monai.networks.nets import UNet + +from monai.networks.layers import Norm + + +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped + ) + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RSP"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1), + ), + ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=0.2, + ), + RandAdjustContrastd( + keys=["image"], + prob=0.2, + gamma=(0.5, 4.5), + invert_image=True, + ), + NormalizeIntensityd( + keys=["image", "label"], + nonzero=False, + channel_wise=False + ), + EnsureTyped(keys=["image", "label"]), + AsDiscreted( + keys=["label"], + num_classes=2, + threshold_values=True, + logit_thresh=0.2, + ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RSP"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1), + ), + ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + NormalizeIntensityd( + keys=["image", "label"], + nonzero=False, + channel_wise=False + ), + EnsureTyped(keys=["image", "label"]), + AsDiscreted( + keys=["label"], + num_classes=2, + threshold_values=True, + logit_thresh=0.2, + ) + ] + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=4) + + # define test transforms + transforms_test = val_transforms + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=4, + pin_memory=True, persistent_workers=True) + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, + persistent_workers=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.cfg["lr"], weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # calculate training loss + loss = self.loss_function(output, labels) + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + # "train_image": inputs[0].detach().cpu().squeeze(), + # "train_gt": labels[0].detach().cpu().squeeze(), + # "train_pred": output[0].detach().cpu().squeeze() + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + } + self.log_dict(wandb_logs) + + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"], + # debug=args.debug) + # wandb.log({"training images": wandb.Image(fig)}) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + # "val_image": inputs[0].detach().cpu().squeeze(), + # "val_gt": labels[0].detach().cpu().squeeze(), + # "val_pred": post_outputs[0].detach().cpu().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + } + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" + f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + wandb.init() + + logger.info("Defining plans for nnUNet model ...") + + + # define model + net = UNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + kernel_size=3, + up_kernel_size=3, + num_res_units=0, + act='PRELU', + norm=Norm.BATCH, + dropout=0.0, + bias=True, + adn_ordering='NDA', + ) + logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") + + + # define loss function + loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=config["best_model_path"]) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=False) + + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir="/home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results", + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + project='ms-lesion-agnostic', + entity='pierre-louis-benveniste', + config=config) + + # Saving training script to wandb + wandb.save("main.py") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision="bf16-mixed", + # deterministic=True, + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py new file mode 100644 index 0000000..fdebfd2 --- /dev/null +++ b/monai/nnunet/utils.py @@ -0,0 +1,20 @@ +import numpy as np +import matplotlib.pyplot as plt +from torch.optim.lr_scheduler import _LRScheduler +import torch + +def dice_score(prediction, groundtruth): + smooth = 1. + numer = (prediction * groundtruth).sum() + denor = (prediction + groundtruth).sum() + # loss = (2 * numer + self.smooth) / (denor + self.smooth) + dice = (2 * numer + smooth) / (denor + smooth) + return dice + +# Check if any label image patch is empty in the batch +def check_empty_patch(labels): + for i, label in enumerate(labels): + if torch.sum(label) == 0.0: + # print(f"Empty label patch found at index {i}. Skipping training step ...") + return None + return labels # If no empty patch is found, return the labels \ No newline at end of file From 4e768e4d05f90879ed2e32859f3231cfcc394aec Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 14 Mar 2024 18:46:36 -0400 Subject: [PATCH 022/108] modified requirements for pytorch lightning --- monai/nnunet/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/nnunet/requirements.txt b/monai/nnunet/requirements.txt index e61d982..cb8f988 100644 --- a/monai/nnunet/requirements.txt +++ b/monai/nnunet/requirements.txt @@ -3,3 +3,4 @@ monai[all] torch torchvision matplotlib +pytorch_lightning \ No newline at end of file From 186c6025cd528dc5e025fa362b466c0d03581e23 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 18 Mar 2024 11:31:47 -0400 Subject: [PATCH 023/108] added multiply by -1 transform --- monai/nnunet/train_monai_unet_lightning.py | 34 +++++++++++++++------- monai/nnunet/utils.py | 7 ++++- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 6f4e038..8f8773c 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import matplotlib.pyplot as plt from monai.metrics import DiceMetric +from monai.losses import DiceLoss # Added this to solve problem with too many files open ## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -19,7 +20,7 @@ from losses import AdapWingLoss -from utils import dice_score, check_empty_patch +from utils import dice_score, check_empty_patch, multiply_by_negative_one from monai.networks.nets import UNet from monai.networks.layers import Norm @@ -41,7 +42,8 @@ AsDiscreted, RandHistogramShiftd, ResizeWithPadOrCropd, - EnsureTyped + EnsureTyped, + RandLambdad, ) from monai.utils import set_determinism @@ -167,6 +169,12 @@ def prepare_data(self): gamma=(0.5, 4.5), invert_image=True, ), + # we add the multiplication of the image by -1 + RandLambdad( + keys='image', + func=multiply_by_negative_one, + prob=0.5 + ), NormalizeIntensityd( keys=["image", "label"], nonzero=False, @@ -225,8 +233,8 @@ def prepare_data(self): test_files = load_decathlon_datalist(dataset, True, "test") train_cache_rate = 0.5 - self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=4) - self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=4) + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=8) # define test transforms transforms_test = val_transforms @@ -247,11 +255,11 @@ def prepare_data(self): # DATA LOADERS # -------------------------------- def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=4, + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True) def test_dataloader(self): @@ -398,6 +406,9 @@ def on_validation_epoch_end(self): "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, } + + self.log_dict(wandb_logs) + # save the best model based on validation dice score if mean_val_soft_dice > self.best_val_dice: self.best_val_dice = mean_val_soft_dice @@ -412,8 +423,8 @@ def on_validation_epoch_end(self): f"\nCurrent epoch: {self.current_epoch}" f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" - f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" - f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") @@ -526,11 +537,12 @@ def main(): # define loss function - loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") - logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") - + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using DiceLoss ...") # define callbacks early_stopping = pl.callbacks.EarlyStopping( monitor="val_loss", min_delta=0.00, diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py index fdebfd2..32b1d29 100644 --- a/monai/nnunet/utils.py +++ b/monai/nnunet/utils.py @@ -17,4 +17,9 @@ def check_empty_patch(labels): if torch.sum(label) == 0.0: # print(f"Empty label patch found at index {i}. Skipping training step ...") return None - return labels # If no empty patch is found, return the labels \ No newline at end of file + return labels # If no empty patch is found, return the labels + +# Function to multiply by -1 +def multiply_by_negative_one(x): + print(f"Multiplyings by -1") + return x * -1 \ No newline at end of file From adaf04fe0ed6cb64eba5bb18404d5bb51fc8a86f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 18 Mar 2024 11:33:23 -0400 Subject: [PATCH 024/108] parameters changed for config file for first inference run --- monai/nnunet/config_fake.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index 3ec2549..bb27017 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -1,14 +1,14 @@ # Description: Configuration file for the UNETR model # Path to the data json file -#data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution pixdim : [1.0, 1.0, 1.0] # Spatial size of the input data -spatial_size : [16, 176, 176] -batch_size : 4 +spatial_size : [32, 208, 208] +batch_size : 16 # UNETR model parameters feature_size : 8 From dab7dccfa7980448a30f08194e33c13a4e6002af Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 28 Mar 2024 14:23:33 -0400 Subject: [PATCH 025/108] added SoftDiceLoss --- monai/nnunet/losses.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/monai/nnunet/losses.py b/monai/nnunet/losses.py index fb4ddfd..7449032 100644 --- a/monai/nnunet/losses.py +++ b/monai/nnunet/losses.py @@ -5,6 +5,32 @@ import numpy as np +class SoftDiceLoss(nn.Module): + ''' + soft-dice loss, useful in binary segmentation + taken from: https://github.com/CoinCheung/pytorch-loss/blob/master/soft_dice_loss.py + ''' + def __init__(self, p=1, smooth=1e-5): + super(SoftDiceLoss, self).__init__() + self.p = p + self.smooth = smooth + + def forward(self, logits, labels): + ''' + inputs: + preds: logits - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + preds = logits # F.relu(logits) / F.relu(logits).max() if bool(F.relu(logits).max()) else F.relu(logits) + + numer = (preds * labels).sum() + denor = (preds.pow(self.p) + labels.pow(self.p)).sum() + # loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + loss = - (2 * numer + self.smooth) / (denor + self.smooth) + return loss + class AdapWingLoss(nn.Module): """ Adaptive Wing loss used for heatmap regression From 80494e55146cf0935c91ac9ea60a57a3be19007a Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 28 Mar 2024 14:24:27 -0400 Subject: [PATCH 026/108] removed print from inverse function in utils --- monai/nnunet/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py index 32b1d29..7d8cbc6 100644 --- a/monai/nnunet/utils.py +++ b/monai/nnunet/utils.py @@ -21,5 +21,4 @@ def check_empty_patch(labels): # Function to multiply by -1 def multiply_by_negative_one(x): - print(f"Multiplyings by -1") return x * -1 \ No newline at end of file From 4727dbc63735b8347ccf69524d7b6517cae21f40 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 28 Mar 2024 16:56:46 -0400 Subject: [PATCH 027/108] changed resolution to 0.6 isotropic --- monai/nnunet/config_fake.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index bb27017..7c14831 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -5,7 +5,7 @@ data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution -pixdim : [1.0, 1.0, 1.0] +pixdim : [0.6, 0.6, 0.6] # Spatial size of the input data spatial_size : [32, 208, 208] batch_size : 16 From 5992732e654564c95dafa52d88be43ed0fe7d9d7 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 28 Mar 2024 16:57:13 -0400 Subject: [PATCH 028/108] added plot images function for wandb --- monai/nnunet/utils.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py index 7d8cbc6..7fb9cbc 100644 --- a/monai/nnunet/utils.py +++ b/monai/nnunet/utils.py @@ -21,4 +21,37 @@ def check_empty_patch(labels): # Function to multiply by -1 def multiply_by_negative_one(x): - return x * -1 \ No newline at end of file + return x * -1 + +def plot_slices(image, gt, pred, debug=False): + """ + Plot the image, ground truth and prediction of the mid-sagittal axial slice + The orientaion is assumed to RPI + """ + + # bring everything to numpy + image = image.numpy() + gt = gt.numpy() + pred = pred.numpy() + + + mid_sagittal = image.shape[2]//2 + # plot X slices before and after the mid-sagittal slice in a grid + fig, axs = plt.subplots(3, 6, figsize=(10, 6)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + for i in range(6): + axs[0, i].imshow(image[:, :, mid_sagittal-3+i].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[:, :, mid_sagittal-3+i].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[:, :, mid_sagittal-3+i].T); axs[2, i].axis('off') + + # fig, axs = plt.subplots(1, 3, figsize=(10, 8)) + # fig.suptitle('Original Image --> Ground Truth --> Prediction') + # slice = image.shape[2]//2 + + # axs[0].imshow(image[:, :, slice].T, cmap='gray'); axs[0].axis('off') + # axs[1].imshow(gt[:, :, slice].T); axs[1].axis('off') + # axs[2].imshow(pred[:, :, slice].T); axs[2].axis('off') + + plt.tight_layout() + fig.show() + return fig \ No newline at end of file From 9aefb6d2e331133c951a124c920f8d553a9e665c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 28 Mar 2024 16:58:22 -0400 Subject: [PATCH 029/108] changed loss function and added image printing --- monai/nnunet/train_monai_unet_lightning.py | 34 ++++++++++++---------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 8f8773c..0725f91 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -18,9 +18,9 @@ import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') -from losses import AdapWingLoss +from losses import AdapWingLoss, SoftDiceLoss -from utils import dice_score, check_empty_patch, multiply_by_negative_one +from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices from monai.networks.nets import UNet from monai.networks.layers import Norm @@ -305,9 +305,9 @@ def training_step(self, batch, batch_idx): "loss": loss.cpu(), "train_soft_dice": train_soft_dice.detach().cpu(), "train_number": len(inputs), - # "train_image": inputs[0].detach().cpu().squeeze(), - # "train_gt": labels[0].detach().cpu().squeeze(), - # "train_pred": output[0].detach().cpu().squeeze() + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze() } self.train_step_outputs.append(metrics_dict) @@ -334,12 +334,12 @@ def on_train_epoch_end(self): } self.log_dict(wandb_logs) - # # plot the training images - # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], - # gt=self.train_step_outputs[0]["train_gt"], - # pred=self.train_step_outputs[0]["train_pred"], - # debug=args.debug) - # wandb.log({"training images": wandb.Image(fig)}) + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"], + ) + wandb.log({"training images": wandb.Image(fig)}) # free up memory self.train_step_outputs.clear() @@ -380,9 +380,9 @@ def validation_step(self, batch, batch_idx): "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), - # "val_image": inputs[0].detach().cpu().squeeze(), - # "val_gt": labels[0].detach().cpu().squeeze(), - # "val_pred": post_outputs[0].detach().cpu().squeeze(), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), } self.val_step_outputs.append(metrics_dict) @@ -518,12 +518,13 @@ def main(): # define model + # TODO: make the model deeper net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), + strides=(2, 2, 2, 2, 2, 2), kernel_size=3, up_kernel_size=3, num_res_units=0, @@ -538,7 +539,8 @@ def main(): # define loss function #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) + # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) + loss_func = SoftDiceLoss(smooth=1e-5) # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") From 855e26fd02a1882555af9ce6140d7ae896d8ee35 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 12:03:33 -0400 Subject: [PATCH 030/108] changed some parameters for training --- monai/nnunet/train_monai_unet_lightning.py | 239 +++++++++++---------- 1 file changed, 129 insertions(+), 110 deletions(-) diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 0725f91..53b8014 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -125,105 +125,105 @@ def prepare_data(self): # define training and validation transforms train_transforms = Compose( - [ - LoadImaged(keys=["image", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "label"]), - Orientationd(keys=["image", "label"], axcodes="RSP"), - Spacingd( - keys=["image", "label"], - pixdim=self.cfg["pixdim"], - mode=(2, 1), - ), - ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - # Flips the image : left becomes right - RandFlipd( - keys=["image", "label"], - spatial_axis=[0], - prob=0.2, - ), - # Flips the image : supperior becomes inferior - RandFlipd( - keys=["image", "label"], - spatial_axis=[1], - prob=0.2, - ), - # Flips the image : anterior becomes posterior - RandFlipd( - keys=["image", "label"], - spatial_axis=[2], - prob=0.2, - ), - RandAdjustContrastd( - keys=["image"], - prob=0.2, - gamma=(0.5, 4.5), - invert_image=True, - ), - # we add the multiplication of the image by -1 - RandLambdad( - keys='image', - func=multiply_by_negative_one, - prob=0.5 + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1), ), - NormalizeIntensityd( - keys=["image", "label"], - nonzero=False, - channel_wise=False - ), - EnsureTyped(keys=["image", "label"]), - AsDiscreted( - keys=["label"], - num_classes=2, - threshold_values=True, - logit_thresh=0.2, - ) - ] + ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=0.2, + ), + RandAdjustContrastd( + keys=["image"], + prob=0.2, + gamma=(0.5, 4.5), + invert_image=True, + ), + # we add the multiplication of the image by -1 + RandLambdad( + keys='image', + func=multiply_by_negative_one, + prob=0.5 + ), + NormalizeIntensityd( + keys=["image", "label"], + nonzero=False, + channel_wise=False + ), + EnsureTyped(keys=["image", "label"]), + AsDiscreted( + keys=["label"], + num_classes=2, + threshold_values=True, + logit_thresh=0.2, + ) + ] ) val_transforms = Compose( - [ - LoadImaged(keys=["image", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "label"]), - Orientationd(keys=["image", "label"], axcodes="RSP"), - Spacingd( - keys=["image", "label"], - pixdim=self.cfg["pixdim"], - mode=(2, 1), - ), - ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - NormalizeIntensityd( - keys=["image", "label"], - nonzero=False, - channel_wise=False - ), - EnsureTyped(keys=["image", "label"]), - AsDiscreted( - keys=["label"], - num_classes=2, - threshold_values=True, - logit_thresh=0.2, - ) - ] - ) + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1), + ), + ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + NormalizeIntensityd( + keys=["image", "label"], + nonzero=False, + channel_wise=False + ), + EnsureTyped(keys=["image", "label"]), + AsDiscreted( + keys=["label"], + num_classes=2, + threshold_values=True, + logit_thresh=0.2, + ) + ] + ) # load the dataset dataset = self.cfg["data"] @@ -282,10 +282,10 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # check if any label image patch is empty in the batch - if check_empty_patch(labels) is None: - # print(f"Empty label patch found. Skipping training step ...") - return None + # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # # print(f"Empty label patch found. Skipping training step ...") + # return None output = self.forward(inputs) # logits # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") @@ -344,7 +344,7 @@ def on_train_epoch_end(self): # free up memory self.train_step_outputs.clear() wandb_logs.clear() - # plt.close(fig) + plt.close(fig) # -------------------------------- @@ -380,9 +380,12 @@ def validation_step(self, batch, batch_idx): "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), - "val_image": inputs[0].detach().cpu().squeeze(), - "val_gt": labels[0].detach().cpu().squeeze(), - "val_pred": post_outputs[0].detach().cpu().squeeze(), + "val_image_0": inputs[0].detach().cpu().squeeze(), + "val_gt_0": labels[0].detach().cpu().squeeze(), + "val_pred_0": post_outputs[0].detach().cpu().squeeze(), + # "val_image_1": inputs[1].detach().cpu().squeeze(), + # "val_gt_1": labels[1].detach().cpu().squeeze(), + # "val_pred_1": post_outputs[1].detach().cpu().squeeze(), } self.val_step_outputs.append(metrics_dict) @@ -426,7 +429,21 @@ def on_validation_epoch_end(self): f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") - + + # log on to wandb + self.log_dict(wandb_logs) + + # plot the validation images + fig0 = plot_slices(image=self.val_step_outputs[0]["val_image_0"], + gt=self.val_step_outputs[0]["val_gt_0"], + pred=self.val_step_outputs[0]["val_pred_0"],) + wandb.log({"validation images": wandb.Image(fig0)}) + plt.close(fig0) + # fig1 = plot_slices(image=self.val_step_outputs[0]["val_image_1"], + # gt=self.val_step_outputs[0]["val_gt_1"], + # pred=self.val_step_outputs[0]["val_pred_1"],) + # wandb.log({"validation images 1": wandb.Image(fig1)}) + # plt.close(fig1) # free up memory self.val_step_outputs.clear() @@ -512,7 +529,9 @@ def main(): # define optimizer optimizer_class = torch.optim.Adam - wandb.init() + wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) + + wandb.name = "test123" logger.info("Defining plans for nnUNet model ...") @@ -523,8 +542,8 @@ def main(): spatial_dims=3, in_channels=1, out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2, 2, 2), + channels=config['unet_channels'], + strides=config['unet_strides'], kernel_size=3, up_kernel_size=3, num_res_units=0, @@ -544,7 +563,7 @@ def main(): # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") - logger.info(f"Using DiceLoss ...") + logger.info(f"Using SoftDiceLoss ...") # define callbacks early_stopping = pl.callbacks.EarlyStopping( monitor="val_loss", min_delta=0.00, @@ -595,7 +614,7 @@ def main(): logger.info(f" Training Done!") # Closing wandb log - wandb.finish() + #wandb.finish() if __name__ == "__main__": From e1ed6e2756d1a760330406492d9371efdbe6aebe Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 12:04:20 -0400 Subject: [PATCH 031/108] added the image plot function --- monai/nnunet/utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py index 7fb9cbc..9fe0ed8 100644 --- a/monai/nnunet/utils.py +++ b/monai/nnunet/utils.py @@ -29,20 +29,21 @@ def plot_slices(image, gt, pred, debug=False): The orientaion is assumed to RPI """ - # bring everything to numpy - image = image.numpy() - gt = gt.numpy() - pred = pred.numpy() + # bring everything to numpy + ## added the .float() because of issue : TypeError: Got unsupported ScalarType BFloat16 + image = image.float().numpy() + gt = gt.float().numpy() + pred = pred.float().numpy() - mid_sagittal = image.shape[2]//2 + mid_sagittal = image.shape[0]//2 # plot X slices before and after the mid-sagittal slice in a grid fig, axs = plt.subplots(3, 6, figsize=(10, 6)) fig.suptitle('Original Image --> Ground Truth --> Prediction') for i in range(6): - axs[0, i].imshow(image[:, :, mid_sagittal-3+i].T, cmap='gray'); axs[0, i].axis('off') - axs[1, i].imshow(gt[:, :, mid_sagittal-3+i].T); axs[1, i].axis('off') - axs[2, i].imshow(pred[:, :, mid_sagittal-3+i].T); axs[2, i].axis('off') + axs[0, i].imshow(image[mid_sagittal-3+i,:,:].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[mid_sagittal-3+i,:,:].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[mid_sagittal-3+i,:,:].T); axs[2, i].axis('off') # fig, axs = plt.subplots(1, 3, figsize=(10, 8)) # fig.suptitle('Original Image --> Ground Truth --> Prediction') From cffa941947c45678cd041e1a637fad8bf96670e6 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 12:04:55 -0400 Subject: [PATCH 032/108] changed model parameters for training --- monai/nnunet/config_fake.yml | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index 7c14831..1df1718 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -5,10 +5,11 @@ data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution +# pixdim : [1.0, 1.0, 1.0] pixdim : [0.6, 0.6, 0.6] # Spatial size of the input data -spatial_size : [32, 208, 208] -batch_size : 16 +spatial_size : [32, 32, 128] # RL, AP, IS +batch_size : 8 # UNETR model parameters feature_size : 8 @@ -19,11 +20,11 @@ num_heads : 12 # Optimizer parameters lr : 0.0001 weight_decay: 0.00001 -early_stopping_patience : 10 +early_stopping_patience : 100 # Training parameters -max_iterations : 25000 -eval_num : 10 +max_iterations : 3000 +eval_num : 5 # Model saving best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth @@ -34,4 +35,8 @@ log_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # WANDB experiment_name : monai_unet_canproco -seed : 42 \ No newline at end of file +seed : 42 + +# UNET model parameters +unet_channels : [16, 32, 64, 128, 256, 512] +unet_strides : [2, 2, 2, 2, 2, 2, 2] \ No newline at end of file From fc37ca061b4e98478de388668e5dbf5b2483e05f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 16:49:00 -0400 Subject: [PATCH 033/108] code reviewed with no prob but output still problematic --- monai/nnunet/config_fake.yml | 7 +- monai/nnunet/train_monai_unet_lightning.py | 141 +++++++++++---------- monai/nnunet/utils.py | 3 +- 3 files changed, 80 insertions(+), 71 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index 1df1718..f26840f 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -1,7 +1,8 @@ # Description: Configuration file for the UNETR model # Path to the data json file # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json -data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json +data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution @@ -24,7 +25,7 @@ early_stopping_patience : 100 # Training parameters max_iterations : 3000 -eval_num : 5 +eval_num : 2 # Model saving best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth @@ -38,5 +39,5 @@ experiment_name : monai_unet_canproco seed : 42 # UNET model parameters -unet_channels : [16, 32, 64, 128, 256, 512] +unet_channels : [32, 64, 128, 256, 512, 1024] unet_strides : [2, 2, 2, 2, 2, 2, 2] \ No newline at end of file diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 53b8014..7da5cfa 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -21,7 +21,7 @@ from losses import AdapWingLoss, SoftDiceLoss from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices -from monai.networks.nets import UNet +from monai.networks.nets import UNet, BasicUNet from monai.networks.layers import Norm @@ -51,6 +51,12 @@ from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + def get_parser(): """ @@ -163,30 +169,30 @@ def prepare_data(self): spatial_axis=[2], prob=0.2, ), - RandAdjustContrastd( - keys=["image"], - prob=0.2, - gamma=(0.5, 4.5), - invert_image=True, - ), + # RandAdjustContrastd( + # keys=["image"], + # prob=0.2, + # gamma=(0.5, 4.5), + # invert_image=True, + # ), # we add the multiplication of the image by -1 - RandLambdad( - keys='image', - func=multiply_by_negative_one, - prob=0.5 - ), + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), NormalizeIntensityd( keys=["image", "label"], nonzero=False, channel_wise=False ), EnsureTyped(keys=["image", "label"]), - AsDiscreted( - keys=["label"], - num_classes=2, - threshold_values=True, - logit_thresh=0.2, - ) + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) ] ) val_transforms = Compose( @@ -200,28 +206,28 @@ def prepare_data(self): mode=(2, 1), ), ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # ), NormalizeIntensityd( keys=["image", "label"], nonzero=False, channel_wise=False ), EnsureTyped(keys=["image", "label"]), - AsDiscreted( - keys=["label"], - num_classes=2, - threshold_values=True, - logit_thresh=0.2, - ) + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) ] ) @@ -270,7 +276,7 @@ def test_dataloader(self): # OPTIMIZATION # -------------------------------- def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.cfg["lr"], weight_decay=self.cfg["weight_decay"]) + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) return [optimizer], [scheduler] @@ -283,9 +289,9 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] # # check if any label image patch is empty in the batch - # if check_empty_patch(labels) is None: - # # print(f"Empty label patch found. Skipping training step ...") - # return None + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None output = self.forward(inputs) # logits # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") @@ -340,11 +346,12 @@ def on_train_epoch_end(self): pred=self.train_step_outputs[0]["train_pred"], ) wandb.log({"training images": wandb.Image(fig)}) + plt.close(fig) # free up memory self.train_step_outputs.clear() wandb_logs.clear() - plt.close(fig) + # -------------------------------- @@ -406,7 +413,7 @@ def on_validation_epoch_end(self): wandb_logs = { "val_soft_dice": mean_val_soft_dice, - "val_hard_dice": mean_val_hard_dice, + #"val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, } @@ -425,7 +432,7 @@ def on_validation_epoch_end(self): logger.info( f"\nCurrent epoch: {self.current_epoch}" f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" - f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") @@ -439,11 +446,7 @@ def on_validation_epoch_end(self): pred=self.val_step_outputs[0]["val_pred_0"],) wandb.log({"validation images": wandb.Image(fig0)}) plt.close(fig0) - # fig1 = plot_slices(image=self.val_step_outputs[0]["val_image_1"], - # gt=self.val_step_outputs[0]["val_gt_1"], - # pred=self.val_step_outputs[0]["val_pred_1"],) - # wandb.log({"validation images 1": wandb.Image(fig1)}) - # plt.close(fig1) + # free up memory self.val_step_outputs.clear() @@ -531,28 +534,34 @@ def main(): wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) - wandb.name = "test123" - logger.info("Defining plans for nnUNet model ...") # define model # TODO: make the model deeper - net = UNet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=config['unet_channels'], - strides=config['unet_strides'], - kernel_size=3, - up_kernel_size=3, - num_res_units=0, - act='PRELU', - norm=Norm.BATCH, - dropout=0.0, - bias=True, - adn_ordering='NDA', - ) + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + net=UNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=(32, 64, 128, 256, 512), + strides=(2, 2, 2, 2), + ) + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") @@ -580,7 +589,7 @@ def main(): # saving the best model based on validation loss checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=True, save_weights_only=False) + save_top_k=1, mode="min", save_last=True, save_weights_only=True) logger.info(f"Starting training from scratch ...") @@ -604,7 +613,7 @@ def main(): callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], check_val_every_n_epoch=config["eval_num"], max_epochs=config["max_iterations"], - precision="bf16-mixed", + precision=32, # deterministic=True, enable_progress_bar=True) # profiler="simple",) # to profile the training time taken for each step @@ -614,7 +623,7 @@ def main(): logger.info(f" Training Done!") # Closing wandb log - #wandb.finish() + wandb.finish() if __name__ == "__main__": diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py index 9fe0ed8..f8e4d07 100644 --- a/monai/nnunet/utils.py +++ b/monai/nnunet/utils.py @@ -3,8 +3,7 @@ from torch.optim.lr_scheduler import _LRScheduler import torch -def dice_score(prediction, groundtruth): - smooth = 1. +def dice_score(prediction, groundtruth, smooth=1.): numer = (prediction * groundtruth).sum() denor = (prediction + groundtruth).sum() # loss = (2 * numer + self.smooth) / (denor + self.smooth) From 9f1effdeb2f32e13bd20623d3a82ea06daf097ef Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 18:17:19 -0400 Subject: [PATCH 034/108] added lines to save images before training --- monai/nnunet/train_monai_unet_lightning.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 7da5cfa..20885e8 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -3,6 +3,8 @@ from datetime import datetime from loguru import logger import yaml +import nibabel as nib +from datetime import datetime import numpy as np import wandb @@ -48,6 +50,7 @@ from monai.utils import set_determinism from monai.inferers import sliding_window_inference +import time from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) @@ -288,6 +291,23 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] + # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + # print(f"Time: {time_0}") + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + # # check if any label image patch is empty in the batch if check_empty_patch(labels) is None: # print(f"Empty label patch found. Skipping training step ...") From 01c0912bf3c0fd108277f6aa140e6859710fce1b Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 19:17:55 -0400 Subject: [PATCH 035/108] correction: removed intensity normalisation for labels --- monai/nnunet/train_monai_unet_lightning.py | 37 ++++++++++++++-------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 20885e8..887a19c 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -13,7 +13,7 @@ import torch.nn.functional as F import matplotlib.pyplot as plt from monai.metrics import DiceMetric -from monai.losses import DiceLoss +from monai.losses import DiceLoss, DiceCELoss # Added this to solve problem with too many files open ## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -23,7 +23,7 @@ from losses import AdapWingLoss, SoftDiceLoss from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices -from monai.networks.nets import UNet, BasicUNet +from monai.networks.nets import UNet, BasicUNet, AttentionUnet from monai.networks.layers import Norm @@ -185,7 +185,7 @@ def prepare_data(self): # prob=0.5 # ), NormalizeIntensityd( - keys=["image", "label"], + keys=["image"], nonzero=False, channel_wise=False ), @@ -220,7 +220,7 @@ def prepare_data(self): # image_threshold=0, # ), NormalizeIntensityd( - keys=["image", "label"], + keys=["image"], nonzero=False, channel_wise=False ), @@ -315,12 +315,12 @@ def training_step(self, batch, batch_idx): output = self.forward(inputs) # logits # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - - # calculate training loss - loss = self.loss_function(output, labels) # get probabilities from logits output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) # calculate train dice # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here @@ -385,12 +385,13 @@ def validation_step(self, batch, batch_idx): outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", sw_batch_size=4, predictor=self.forward, overlap=0.5,) - # calculate validation loss - loss = self.loss_function(outputs, labels) - # get probabilities from logits outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] @@ -574,11 +575,18 @@ def main(): # bias=True, # adn_ordering='NDA', # ) - net=UNet( + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(16, 32, 64, 128, 256), + # strides=(2, 2, 2, 2), + # ) + net = AttentionUnet( spatial_dims=3, in_channels=1, out_channels=1, - channels=(32, 64, 128, 256, 512), + channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), ) # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) @@ -587,8 +595,9 @@ def main(): # define loss function #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) - loss_func = SoftDiceLoss(smooth=1e-5) + #loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") From 6aa82252958ce3f819f662eed7eeb1f48797b7c5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 1 Apr 2024 19:22:26 -0400 Subject: [PATCH 036/108] fixed filename to save --- monai/nnunet/train_monai_unet_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 887a19c..25a06c5 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -633,7 +633,7 @@ def main(): config=config) # Saving training script to wandb - wandb.save("main.py") + wandb.save("train_monai_unet_lightning.py") # initialise Lightning's trainer. trainer = pl.Trainer( From a89b5d705cbe8a17f24a5c41c85246a429a113c3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 2 Apr 2024 17:48:07 -0400 Subject: [PATCH 037/108] updated to add some data aug but then removed :/ --- monai/nnunet/config_fake.yml | 2 +- monai/nnunet/train_monai_unet_lightning.py | 58 ++++++---- monai/nnunet/utils.py | 118 ++++++++++++++++++++- 3 files changed, 157 insertions(+), 21 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index f26840f..06c7cd9 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -7,7 +7,7 @@ data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution # pixdim : [1.0, 1.0, 1.0] -pixdim : [0.6, 0.6, 0.6] +pixdim : [0.7, 0.7, 0.7] # Spatial size of the input data spatial_size : [32, 32, 128] # RL, AP, IS batch_size : 8 diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 25a06c5..48c5e87 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -22,7 +22,7 @@ from losses import AdapWingLoss, SoftDiceLoss -from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices +from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans from monai.networks.nets import UNet, BasicUNet, AttentionUnet from monai.networks.layers import Norm @@ -46,6 +46,8 @@ ResizeWithPadOrCropd, EnsureTyped, RandLambdad, + CropForegroundd, + RandGaussianNoised, ) from monai.utils import set_determinism @@ -143,6 +145,7 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 1), ), + # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), RandCropByPosNegLabeld( keys=["image", "label"], @@ -182,13 +185,22 @@ def prepare_data(self): # RandLambdad( # keys='image', # func=multiply_by_negative_one, - # prob=0.5 + # prob=0.2 # ), NormalizeIntensityd( keys=["image"], nonzero=False, channel_wise=False ), + # RandGaussianNoised( + # keys=["image"], + # prob=0.2, + # ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), EnsureTyped(keys=["image", "label"]), # AsDiscreted( # keys=["label"], @@ -208,6 +220,7 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 1), ), + # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), # RandCropByPosNegLabeld( # keys=["image", "label"], @@ -242,8 +255,8 @@ def prepare_data(self): test_files = load_decathlon_datalist(dataset, True, "test") train_cache_rate = 0.5 - self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) - self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=8) + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=16) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=16) # define test transforms transforms_test = val_transforms @@ -264,11 +277,11 @@ def prepare_data(self): # DATA LOADERS # -------------------------------- def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True) def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, persistent_workers=True) def test_dataloader(self): @@ -291,13 +304,12 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # print(inputs.shape, labels.shape) + # # print(inputs.shape, labels.shape) # input_0 = inputs[0].detach().cpu().squeeze() - # print(input_0.shape) + # # print(input_0.shape) # label_0 = labels[0].detach().cpu().squeeze() # time_0 = datetime.now() - # print(f"Time: {time_0}") # # save input 0 in a nifti file # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) @@ -576,26 +588,32 @@ def main(): # adn_ordering='NDA', # ) # net=UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(16, 32, 64, 128, 256), - # strides=(2, 2, 2, 2), - # ) + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128, 256), + # strides=(2, 2, 2 ), + + # # dropout=0.1 + # ) net = AttentionUnet( spatial_dims=3, in_channels=1, out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), + channels=(64, 128, 256, 512, 1024, 2048), + strides=(2, 2, 2, 2, 2), + dropout=0.0 ) # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") # define loss function #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - #loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) + # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) # loss_func = SoftDiceLoss(smooth=1e-5) # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above @@ -633,7 +651,9 @@ def main(): config=config) # Saving training script to wandb - wandb.save("train_monai_unet_lightning.py") + wandb.save("ms-lesion-agnostic/monai/nnunet/config_fake.yml") + wandb.save("ms-lesion-agnostic/monai/nnunet/train_monai_unet_lightning.py") + # initialise Lightning's trainer. trainer = pl.Trainer( diff --git a/monai/nnunet/utils.py b/monai/nnunet/utils.py index f8e4d07..dd65b1d 100644 --- a/monai/nnunet/utils.py +++ b/monai/nnunet/utils.py @@ -3,6 +3,13 @@ from torch.optim.lr_scheduler import _LRScheduler import torch +import torch.nn as nn +import torch.nn.functional as F + +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + def dice_score(prediction, groundtruth, smooth=1.): numer = (prediction * groundtruth).sum() denor = (prediction + groundtruth).sum() @@ -54,4 +61,113 @@ def plot_slices(image, gt, pred, debug=False): plt.tight_layout() fig.show() - return fig \ No newline at end of file + return fig + +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": 32, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [1, 2, 2], + [1, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [1, 3, 3], + [1, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} + + +# ====================================================================================================== +# Utils for nnUNet's Model +# ==================================================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +# ====================================================================================================== +# Define the network based on plans json +# ==================================================================================================== +def create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision: bool = False): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model \ No newline at end of file From 6342be07e99b0d0da98cf2af176e716b4d2d528d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 3 Apr 2024 18:23:03 -0400 Subject: [PATCH 038/108] created file for unet training with multiple input channels --- ...monai_unet_lightning_multichannel_input.py | 687 ++++++++++++++++++ 1 file changed, 687 insertions(+) create mode 100644 monai/nnunet/train_monai_unet_lightning_multichannel_input.py diff --git a/monai/nnunet/train_monai_unet_lightning_multichannel_input.py b/monai/nnunet/train_monai_unet_lightning_multichannel_input.py new file mode 100644 index 0000000..0cb9a12 --- /dev/null +++ b/monai/nnunet/train_monai_unet_lightning_multichannel_input.py @@ -0,0 +1,687 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +from losses import AdapWingLoss, SoftDiceLoss + +from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans, print_data_types +from monai.networks.nets import UNet, BasicUNet, AttentionUnet + +from monai.networks.layers import Norm + + +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + ConcatItemsd + ) + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +import time +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "sc", "label"]), + Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "sc", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1, 1), + ), + # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), + ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "sc", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "sc", "label"], + spatial_axis=[0], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "sc", "label"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image","sc", "label"], + spatial_axis=[2], + prob=0.2, + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=0.2, + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.2 + # ), + + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # RandGaussianNoised( + # keys=["image"], + # prob=0.2, + # ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # Concatenates the image and the sc + ConcatItemsd(keys=["image", "sc"], name="inputs"), + EnsureTyped(keys=["inputs", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "sc", "label"]), + Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "sc", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1, 1), + ), + # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), + ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # ), + # Concatenates the image and the sc + ConcatItemsd(keys=["image", "sc"], name="inputs"), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["inputs"], + nonzero=False, + channel_wise=False + ), + EnsureTyped(keys=["inputs", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + ] + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=16) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=16) + + # define test transforms + transforms_test = val_transforms + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=16, + pin_memory=True, persistent_workers=True) + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + persistent_workers=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["inputs"], batch["label"] + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + + # # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze() + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + } + self.log_dict(wandb_logs) + + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"], + # ) + # wandb.log({"training images": wandb.Image(fig)}) + # plt.close(fig) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["inputs"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image_0": inputs[0].detach().cpu().squeeze(), + "val_gt_0": labels[0].detach().cpu().squeeze(), + "val_pred_0": post_outputs[0].detach().cpu().squeeze(), + # "val_image_1": inputs[1].detach().cpu().squeeze(), + # "val_gt_1": labels[1].detach().cpu().squeeze(), + # "val_pred_1": post_outputs[1].detach().cpu().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + #"val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # # plot the validation images + # fig0 = plot_slices(image=self.val_step_outputs[0]["val_image_0"], + # gt=self.val_step_outputs[0]["val_gt_0"], + # pred=self.val_step_outputs[0]["val_pred_0"],) + # wandb.log({"validation images": wandb.Image(fig0)}) + # plt.close(fig0) + + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["inputs"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) + + logger.info("Defining plans for nnUNet model ...") + + + # define model + # TODO: make the model deeper + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128, 256), + # strides=(2, 2, 2 ), + + # # dropout=0.1 + # ) + net = AttentionUnet( + spatial_dims=3, + in_channels=2, + out_channels=1, + channels=(32, 64, 128, 256, 512, 1024), + strides=(2, 2, 2, 2, 2), + dropout=0.1, + ) + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") + + + # define loss function + #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using SoftDiceLoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=config["best_model_path"]) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir="/home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results", + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + project='ms-lesion-agnostic', + entity='pierre-louis-benveniste', + config=config) + + # Saving training script to wandb + wandb.save("ms-lesion-agnostic/monai/nnunet/config_fake.yml") + wandb.save("ms-lesion-agnostic/monai/nnunet/train_monai_unet_lightning_multichannel.py") + + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # deterministic=True, + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file From b858fbc54af40d06b6781128cd387985c1a6f8d1 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 3 Apr 2024 18:23:44 -0400 Subject: [PATCH 039/108] created file for unet training with multiple output channels --- ...onai_unet_lightning_multichannel_output.py | 725 ++++++++++++++++++ 1 file changed, 725 insertions(+) create mode 100644 monai/nnunet/train_monai_unet_lightning_multichannel_output.py diff --git a/monai/nnunet/train_monai_unet_lightning_multichannel_output.py b/monai/nnunet/train_monai_unet_lightning_multichannel_output.py new file mode 100644 index 0000000..f3232b8 --- /dev/null +++ b/monai/nnunet/train_monai_unet_lightning_multichannel_output.py @@ -0,0 +1,725 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +from losses import AdapWingLoss, SoftDiceLoss + +from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans, print_data_types +from monai.networks.nets import UNet, BasicUNet, AttentionUnet + +from monai.networks.layers import Norm + + +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + ConcatItemsd + ) + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +import time +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "sc", "label"]), + Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "sc", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1, 1), + ), + # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), + ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), + RandCropByPosNegLabeld( + keys=["image", "sc", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "sc", "label"], + spatial_axis=[0], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "sc", "label"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image","sc", "label"], + spatial_axis=[2], + prob=0.2, + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=0.2, + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.2 + # ), + + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # RandGaussianNoised( + # keys=["image"], + # prob=0.2, + # ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # Concatenates the image and the sc + ConcatItemsd(keys=["sc", "label"], name="outputs"), + EnsureTyped(keys=["image", "outputs"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "sc", "label"]), + Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "sc", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 1, 1), + ), + # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), + ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # ), + + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # Concatenates the image and the sc + ConcatItemsd(keys=["sc", "label"], name="outputs"), + EnsureTyped(keys=["image", "outputs"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + ] + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=16) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=16) + + # define test transforms + transforms_test = val_transforms + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=16, + pin_memory=True, persistent_workers=True) + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + persistent_workers=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["outputs"] + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + + # # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train loss for the sc and the lesion + loss_sc = self.loss_function(output[:, 0, ...], labels[:, 0, ...]) + loss_lesion = self.loss_function(output[:, 1, ...], labels[:, 1, ...]) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + # calculate the dice for the sc and the lesion + train_soft_dice_sc = self.soft_dice_metric(output[:, 0, ...], labels[:, 0, ...]) + train_soft_dice_lesion = self.soft_dice_metric(output[:, 1, ...], labels[:, 1, ...]) + + metrics_dict = { + "loss": loss.cpu(), + "loss_sc": loss_sc.cpu(), + "loss_lesion": loss_lesion.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_soft_dice_sc": train_soft_dice_sc.detach().cpu(), + "train_soft_dice_lesion": train_soft_dice_lesion.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt_sc": labels[0][0].detach().cpu().squeeze(), + "train_gt_lesion": labels[0][1].detach().cpu().squeeze(), + "train_pred_sc": output[0][0].detach().cpu().squeeze(), + "train_pred_lesion": output[0][1].detach().cpu().squeeze(), + + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + train_loss_sc, train_loss_lesion = 0, 0 + train_soft_dice_sc, train_soft_dice_lesion = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + train_loss_sc += output["loss_sc"].item() + train_loss_lesion += output["loss_lesion"].item() + train_soft_dice_sc += output["train_soft_dice_sc"].item() + train_soft_dice_lesion += output["train_soft_dice_lesion"].item() + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + mean_train_loss_sc = (train_loss_sc / num_items) + mean_train_loss_lesion = (train_loss_lesion / num_items) + mean_train_soft_dice_sc = (train_soft_dice_sc / num_items) + mean_train_soft_dice_lesion = (train_soft_dice_lesion / num_items) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + "train_loss_sc": mean_train_loss_sc, + "train_loss_lesion": mean_train_loss_lesion, + "train_soft_dice_sc": mean_train_soft_dice_sc, + "train_soft_dice_lesion": mean_train_soft_dice_lesion + } + self.log_dict(wandb_logs) + + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt_lesion"], + pred=self.train_step_outputs[0]["train_pred_lesion"], + ) + wandb.log({"training images lesion": wandb.Image(fig)}) + plt.close(fig) + + # plot the training images + fig2 = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt_sc"], + pred=self.train_step_outputs[0]["train_pred_sc"], + ) + wandb.log({"training images sc": wandb.Image(fig2)}) + plt.close(fig2) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["outputs"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image_0": inputs[0].detach().cpu().squeeze(), + "val_gt_0": labels[0].detach().cpu().squeeze(), + "val_pred_0": post_outputs[0].detach().cpu().squeeze(), + # "val_image_1": inputs[1].detach().cpu().squeeze(), + # "val_gt_1": labels[1].detach().cpu().squeeze(), + # "val_pred_1": post_outputs[1].detach().cpu().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + #"val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # # plot the validation images + # fig0 = plot_slices(image=self.val_step_outputs[0]["val_image_0"], + # gt=self.val_step_outputs[0]["val_gt_0"], + # pred=self.val_step_outputs[0]["val_pred_0"],) + # wandb.log({"validation images": wandb.Image(fig0)}) + # plt.close(fig0) + + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["inputs"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) + + logger.info("Defining plans for nnUNet model ...") + + + # define model + # TODO: make the model deeper + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128, 256), + # strides=(2, 2, 2 ), + + # # dropout=0.1 + # ) + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(32, 64, 128, 256, 512, 1024), + strides=(2, 2, 2, 2, 2), + dropout=0.1, + ) + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") + + + # define loss function + #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using SoftDiceLoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=config["best_model_path"]) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir="/home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results", + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + project='ms-lesion-agnostic', + entity='pierre-louis-benveniste', + config=config) + + # Saving training script to wandb + wandb.save("ms-lesion-agnostic/monai/nnunet/config_fake.yml") + wandb.save("ms-lesion-agnostic/monai/nnunet/train_monai_unet_lightning_regionBased.py") + + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # deterministic=True, + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file From afb7b4f4ab06f664c1a78a6739e934824590a4ce Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 3 Apr 2024 18:49:53 -0400 Subject: [PATCH 040/108] training script cleaned for ms lesion seg --- monai/nnunet/config_fake.yml | 15 +- monai/nnunet/train_monai_unet_lightning.py | 161 ++++++++++++--------- 2 files changed, 105 insertions(+), 71 deletions(-) diff --git a/monai/nnunet/config_fake.yml b/monai/nnunet/config_fake.yml index 06c7cd9..2e8ac4a 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/nnunet/config_fake.yml @@ -1,7 +1,8 @@ # Description: Configuration file for the UNETR model # Path to the data json file # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json -data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json +data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ @@ -9,8 +10,8 @@ output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # pixdim : [1.0, 1.0, 1.0] pixdim : [0.7, 0.7, 0.7] # Spatial size of the input data -spatial_size : [32, 32, 128] # RL, AP, IS -batch_size : 8 +spatial_size : [64, 128, 128] # RL, AP, IS +batch_size : 2 # UNETR model parameters feature_size : 8 @@ -19,7 +20,7 @@ mlp_dim : 3072 num_heads : 12 # Optimizer parameters -lr : 0.0001 +lr : 0.001 weight_decay: 0.00001 early_stopping_patience : 100 @@ -40,4 +41,8 @@ seed : 42 # UNET model parameters unet_channels : [32, 64, 128, 256, 512, 1024] -unet_strides : [2, 2, 2, 2, 2, 2, 2] \ No newline at end of file +unet_strides : [2, 2, 2, 2, 2, 2, 2] + +#Attention Unet +channels : [32, 64, 128, 256, 512] +strides : [2, 2, 2, 2, 2] \ No newline at end of file diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/nnunet/train_monai_unet_lightning.py index 48c5e87..3dbebb9 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/nnunet/train_monai_unet_lightning.py @@ -5,15 +5,13 @@ import yaml import nibabel as nib from datetime import datetime - import numpy as np import wandb import torch import pytorch_lightning as pl import torch.nn.functional as F import matplotlib.pyplot as plt -from monai.metrics import DiceMetric -from monai.losses import DiceLoss, DiceCELoss +import time # Added this to solve problem with too many files open ## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -24,10 +22,9 @@ from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans from monai.networks.nets import UNet, BasicUNet, AttentionUnet - +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss from monai.networks.layers import Norm - - from monai.transforms import ( EnsureChannelFirstd, Compose, @@ -47,14 +44,15 @@ EnsureTyped, RandLambdad, CropForegroundd, - RandGaussianNoised, - ) - + RandGaussianNoised, + LabelToContourd, + Invertd, + SaveImage, + EnsureType +) from monai.utils import set_determinism from monai.inferers import sliding_window_inference -import time from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) # Added this because of following warning received: ## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` @@ -78,7 +76,6 @@ def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_i super().__init__() self.cfg = config self.save_hyperparameters(ignore=['net', 'loss_function']) - self.root = data_root self.net = net self.lr = config["lr"] @@ -145,18 +142,28 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 1), ), - # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), - RandCropByPosNegLabeld( + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=100 + # ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( keys=["image", "label"], - label_key="label", spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, ), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # ), # Flips the image : left becomes right RandFlipd( keys=["image", "label"], @@ -175,33 +182,38 @@ def prepare_data(self): spatial_axis=[2], prob=0.2, ), - # RandAdjustContrastd( - # keys=["image"], - # prob=0.2, - # gamma=(0.5, 4.5), - # invert_image=True, - # ), - # we add the multiplication of the image by -1 - # RandLambdad( - # keys='image', - # func=multiply_by_negative_one, - # prob=0.2 - # ), + # # RandAdjustContrastd( + # # keys=["image"], + # # prob=0.2, + # # gamma=(0.5, 4.5), + # # invert_image=True, + # # ), + # # we add the multiplication of the image by -1 + # # RandLambdad( + # # keys='image', + # # func=multiply_by_negative_one, + # # prob=0.2 + # # ), + # Normalize the intensity of the image NormalizeIntensityd( keys=["image"], nonzero=False, channel_wise=False ), - # RandGaussianNoised( - # keys=["image"], - # prob=0.2, - # ), - # RandShiftIntensityd( + # LabelToContourd( # keys=["image"], - # offsets=0.1, - # prob=0.2, + # kernel_type='Laplace', # ), - EnsureTyped(keys=["image", "label"]), + # # RandGaussianNoised( + # # keys=["image"], + # # prob=0.2, + # # ), + # # RandShiftIntensityd( + # # keys=["image"], + # # offsets=0.1, + # # prob=0.2, + # # ), + # EnsureTyped(keys=["image", "label"]), # AsDiscreted( # keys=["label"], # num_classes=2, @@ -221,7 +233,10 @@ def prepare_data(self): mode=(2, 1), ), # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=self.cfg["spatial_size"],), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), # RandCropByPosNegLabeld( # keys=["image", "label"], # label_key="label", @@ -232,12 +247,17 @@ def prepare_data(self): # image_key="image", # image_threshold=0, # ), + # This normalizes the intensity of the image NormalizeIntensityd( keys=["image"], nonzero=False, channel_wise=False ), - EnsureTyped(keys=["image", "label"]), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + # EnsureTyped(keys=["image", "label"]), # AsDiscreted( # keys=["label"], # num_classes=2, @@ -304,6 +324,9 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] + # The following was done to debug : + # I was checking the image and the label to see if they were empty or not + # # print(inputs.shape, labels.shape) # input_0 = inputs[0].detach().cpu().squeeze() # # print(input_0.shape) @@ -322,7 +345,7 @@ def training_step(self, batch, batch_idx): # # check if any label image patch is empty in the batch if check_empty_patch(labels) is None: - # print(f"Empty label patch found. Skipping training step ...") + print(f"Empty label patch found. Skipping training step ...") return None output = self.forward(inputs) # logits @@ -420,12 +443,9 @@ def validation_step(self, batch, batch_idx): "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), - "val_image_0": inputs[0].detach().cpu().squeeze(), - "val_gt_0": labels[0].detach().cpu().squeeze(), - "val_pred_0": post_outputs[0].detach().cpu().squeeze(), - # "val_image_1": inputs[1].detach().cpu().squeeze(), - # "val_gt_1": labels[1].detach().cpu().squeeze(), - # "val_pred_1": post_outputs[1].detach().cpu().squeeze(), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), } self.val_step_outputs.append(metrics_dict) @@ -446,7 +466,7 @@ def on_validation_epoch_end(self): wandb_logs = { "val_soft_dice": mean_val_soft_dice, - #"val_hard_dice": mean_val_hard_dice, + # "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, } @@ -473,13 +493,19 @@ def on_validation_epoch_end(self): # log on to wandb self.log_dict(wandb_logs) - # plot the validation images - fig0 = plot_slices(image=self.val_step_outputs[0]["val_image_0"], - gt=self.val_step_outputs[0]["val_gt_0"], - pred=self.val_step_outputs[0]["val_pred_0"],) - wandb.log({"validation images": wandb.Image(fig0)}) - plt.close(fig0) + # plot 1 validation image + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) + wandb.log({"validation image 1": wandb.Image(fig)}) + plt.close(fig) + # plot another validation image + fig0 = plot_slices(image=self.val_step_outputs[1]["val_image"], + gt=self.val_step_outputs[1]["val_gt"], + pred=self.val_step_outputs[1]["val_pred"],) + wandb.log({"validation image 2": wandb.Image(fig0)}) + plt.close(fig0) # free up memory self.val_step_outputs.clear() @@ -565,13 +591,13 @@ def main(): # define optimizer optimizer_class = torch.optim.Adam - wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) + wandb.init(project=f'monai-ms-lesion-seg', config=config) - logger.info("Defining plans for nnUNet model ...") + logger.info("Building the model ...") # define model - # TODO: make the model deeper + # net = UNet( # spatial_dims=3, # in_channels=1, @@ -587,6 +613,7 @@ def main(): # bias=True, # adn_ordering='NDA', # ) + # net=UNet( # spatial_dims=3, # in_channels=1, @@ -596,14 +623,16 @@ def main(): # # dropout=0.1 # ) + net = AttentionUnet( spatial_dims=3, in_channels=1, out_channels=1, - channels=(64, 128, 256, 512, 1024, 2048), - strides=(2, 2, 2, 2, 2), - dropout=0.0 + channels=(32, 64, 128), + strides=(2, 2, 2,), + dropout=0.1, ) + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) # net = create_nnunet_from_plans() @@ -613,13 +642,13 @@ def main(): # define loss function #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) - loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) + # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) # loss_func = SoftDiceLoss(smooth=1e-5) # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") - logger.info(f"Using SoftDiceLoss ...") + logger.info(f"Using DiceCELoss ...") # define callbacks early_stopping = pl.callbacks.EarlyStopping( monitor="val_loss", min_delta=0.00, From 9d6af9b8285beed66289543e9d477b5307008eb3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 4 Apr 2024 12:44:07 -0400 Subject: [PATCH 041/108] moved all files in monai and removed nnunet folder --- monai/{nnunet => }/1_create_msd_data.py | 0 monai/{nnunet => }/config_fake.yml | 5 +- monai/{nnunet => }/losses.py | 0 monai/{nnunet => }/requirements.txt | 0 monai/{nnunet => }/train_monai_UNETR.py | 0 .../train_monai_unet_lightning.py | 90 +++++++++---------- ...monai_unet_lightning_multichannel_input.py | 0 ...onai_unet_lightning_multichannel_output.py | 0 monai/{nnunet => }/utils.py | 0 9 files changed, 48 insertions(+), 47 deletions(-) rename monai/{nnunet => }/1_create_msd_data.py (100%) rename monai/{nnunet => }/config_fake.yml (90%) rename monai/{nnunet => }/losses.py (100%) rename monai/{nnunet => }/requirements.txt (100%) rename monai/{nnunet => }/train_monai_UNETR.py (100%) rename monai/{nnunet => }/train_monai_unet_lightning.py (94%) rename monai/{nnunet => }/train_monai_unet_lightning_multichannel_input.py (100%) rename monai/{nnunet => }/train_monai_unet_lightning_multichannel_output.py (100%) rename monai/{nnunet => }/utils.py (100%) diff --git a/monai/nnunet/1_create_msd_data.py b/monai/1_create_msd_data.py similarity index 100% rename from monai/nnunet/1_create_msd_data.py rename to monai/1_create_msd_data.py diff --git a/monai/nnunet/config_fake.yml b/monai/config_fake.yml similarity index 90% rename from monai/nnunet/config_fake.yml rename to monai/config_fake.yml index 2e8ac4a..bf666fe 100644 --- a/monai/nnunet/config_fake.yml +++ b/monai/config_fake.yml @@ -1,7 +1,8 @@ # Description: Configuration file for the UNETR model # Path to the data json file # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json -data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json +data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_10_each.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json # Path to the output directory @@ -11,7 +12,7 @@ output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ pixdim : [0.7, 0.7, 0.7] # Spatial size of the input data spatial_size : [64, 128, 128] # RL, AP, IS -batch_size : 2 +batch_size : 4 # UNETR model parameters feature_size : 8 diff --git a/monai/nnunet/losses.py b/monai/losses.py similarity index 100% rename from monai/nnunet/losses.py rename to monai/losses.py diff --git a/monai/nnunet/requirements.txt b/monai/requirements.txt similarity index 100% rename from monai/nnunet/requirements.txt rename to monai/requirements.txt diff --git a/monai/nnunet/train_monai_UNETR.py b/monai/train_monai_UNETR.py similarity index 100% rename from monai/nnunet/train_monai_UNETR.py rename to monai/train_monai_UNETR.py diff --git a/monai/nnunet/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py similarity index 94% rename from monai/nnunet/train_monai_unet_lightning.py rename to monai/train_monai_unet_lightning.py index 3dbebb9..1f330c0 100644 --- a/monai/nnunet/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -154,16 +154,16 @@ def prepare_data(self): keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=1, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # ), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), # Flips the image : left becomes right RandFlipd( keys=["image", "label"], @@ -182,18 +182,18 @@ def prepare_data(self): spatial_axis=[2], prob=0.2, ), - # # RandAdjustContrastd( - # # keys=["image"], - # # prob=0.2, - # # gamma=(0.5, 4.5), - # # invert_image=True, - # # ), + # RandAdjustContrastd( + # keys=["image"], + # prob=0.2, + # gamma=(0.5, 4.5), + # invert_image=True, + # ), # # we add the multiplication of the image by -1 - # # RandLambdad( - # # keys='image', - # # func=multiply_by_negative_one, - # # prob=0.2 - # # ), + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.2 + # ), # Normalize the intensity of the image NormalizeIntensityd( keys=["image"], @@ -204,15 +204,15 @@ def prepare_data(self): # keys=["image"], # kernel_type='Laplace', # ), - # # RandGaussianNoised( - # # keys=["image"], - # # prob=0.2, - # # ), - # # RandShiftIntensityd( - # # keys=["image"], - # # offsets=0.1, - # # prob=0.2, - # # ), + # RandGaussianNoised( + # keys=["image"], + # prob=0.2, + # ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), # EnsureTyped(keys=["image", "label"]), # AsDiscreted( # keys=["label"], @@ -237,16 +237,16 @@ def prepare_data(self): keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=1, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # ), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), # This normalizes the intensity of the image NormalizeIntensityd( keys=["image"], @@ -275,8 +275,8 @@ def prepare_data(self): test_files = load_decathlon_datalist(dataset, True, "test") train_cache_rate = 0.5 - self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=16) - self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=16) + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=8) # define test transforms transforms_test = val_transforms @@ -297,11 +297,11 @@ def prepare_data(self): # DATA LOADERS # -------------------------------- def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=16, + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True) def test_dataloader(self): @@ -623,7 +623,7 @@ def main(): # # dropout=0.1 # ) - + net = AttentionUnet( spatial_dims=3, in_channels=1, diff --git a/monai/nnunet/train_monai_unet_lightning_multichannel_input.py b/monai/train_monai_unet_lightning_multichannel_input.py similarity index 100% rename from monai/nnunet/train_monai_unet_lightning_multichannel_input.py rename to monai/train_monai_unet_lightning_multichannel_input.py diff --git a/monai/nnunet/train_monai_unet_lightning_multichannel_output.py b/monai/train_monai_unet_lightning_multichannel_output.py similarity index 100% rename from monai/nnunet/train_monai_unet_lightning_multichannel_output.py rename to monai/train_monai_unet_lightning_multichannel_output.py diff --git a/monai/nnunet/utils.py b/monai/utils.py similarity index 100% rename from monai/nnunet/utils.py rename to monai/utils.py From 97b495f635964e0f45058df3391ff0b592593e99 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 4 Apr 2024 13:21:35 -0400 Subject: [PATCH 042/108] renamed config_fake.yml to config.yml --- monai/{config_fake.yml => config.yml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename monai/{config_fake.yml => config.yml} (100%) diff --git a/monai/config_fake.yml b/monai/config.yml similarity index 100% rename from monai/config_fake.yml rename to monai/config.yml From 1df0305071467a8df0d3b35757b55dbe8336faba Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 4 Apr 2024 13:22:57 -0400 Subject: [PATCH 043/108] removed useless previous training script train_monai_UNETR.py --- monai/train_monai_UNETR.py | 461 ------------------------------------- 1 file changed, 461 deletions(-) delete mode 100644 monai/train_monai_UNETR.py diff --git a/monai/train_monai_UNETR.py b/monai/train_monai_UNETR.py deleted file mode 100644 index ed40c29..0000000 --- a/monai/train_monai_UNETR.py +++ /dev/null @@ -1,461 +0,0 @@ -""" -This script is used to train a UNETR model. - -It takes as input the config file (a JSON file) - -This script is inspired from : https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb - -Args: - -c: path to the config file - -Example: - python train_monai_nnunet.py -c /path/to/nnunet/config.json - -Pierre-Louis Benveniste -""" - -import argparse -import json -import os -import sys -import monai -from tqdm import tqdm -import matplotlib.pyplot as plt -import yaml -import numpy as np -import wandb -import time -from loguru import logger - -from monai.networks.layers import Norm - -os.environ["PYTORCH_USE_CUDA_DSA"] = "1" - - -#Transforms import -from monai.transforms import ( - EnsureChannelFirstd, - Compose, - LoadImaged, - Orientationd, - RandFlipd, - RandShiftIntensityd, - Spacingd, - RandRotate90d, - NormalizeIntensityd, - RandCropByPosNegLabeld, - BatchInverseTransform, - RandAdjustContrastd, - AsDiscreted, - RandHistogramShiftd, - ResizeWithPadOrCropd, - EnsureTyped - ) - -# Dataset import -from monai.data import DataLoader, CacheDataset, load_decathlon_datalist, Dataset - - -# model import -import torch -from monai.networks.nets import UNETR, UNet -from monai.losses import DiceLoss - -# For training and validation -from monai.data import decollate_batch -from monai.inferers import sliding_window_inference -from monai.metrics import DiceMetric -from monai.transforms import AsDiscrete - - - - -def get_parser(): - """ - This function returns the parser for the command line arguments. - """ - parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") - parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) - return parser - - -# def validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step): -# model.eval() -# with torch.no_grad(): -# for batch in epoch_iterator_val: -# val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) -# val_outputs = model(val_inputs) -# dice_metric(y_pred=val_outputs, y=val_labels) -# # val_outputs = sliding_window_inference(val_inputs, config["spatial_size"], 1, model) -# # val_labels_list = decollate_batch(val_labels) -# # val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] -# # val_outputs_list = decollate_batch(val_outputs) -# # val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] -# # dice_metric(y_pred=val_output_convert, y=val_labels_convert) -# epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0)) # noqa: B038 -# # for batch in epoch_iterator_val: -# # val_inputs, val_labels = ( -# # batch["image"].cuda(), -# # batch["label"].cuda(), -# # ) -# # # TODO: parametrize this -# # roi_size = config["spatial_size"] -# # sw_batch_size = 4 -# # val_outputs = sliding_window_inference( -# # val_inputs, roi_size, sw_batch_size, model) - -# # val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] -# # val_labels = [post_label(i) for i in decollate_batch(val_labels)] -# # # compute metric for current iteration -# # dice_metric(y_pred=val_outputs, y=val_labels) - -# mean_dice_val = dice_metric.aggregate().item() -# print("Mean dice val: ", mean_dice_val) -# dice_metric.reset() - -# return mean_dice_val - -# def train(model, config, global_step, train_loader, dice_val_best, global_step_best, loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric): -# model.train() -# epoch_loss = 0 -# step = 0 -# epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) -# for step, batch in enumerate(epoch_iterator): -# step += 1 -# x, y = (batch["image"].cuda(), batch["label"].cuda()) -# logit_map = model(x) -# loss = loss_function(logit_map, y) -# loss.backward() -# epoch_loss += loss.item() -# optimizer.step() -# optimizer.zero_grad() -# epoch_iterator.set_description( # noqa: B038 -# "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, config["max_iterations"], loss) -# ) -# if (global_step % config["eval_num"] == 0 and global_step != 0) or global_step == config["max_iterations"]: -# epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) -# dice_val = validation(model, epoch_iterator_val, config, post_label, post_pred, dice_metric, global_step) -# epoch_loss /= step -# epoch_loss_values.append(epoch_loss) -# metric_values.append(dice_val) -# if dice_val > dice_val_best: -# dice_val_best = dice_val -# global_step_best = global_step -# torch.save(model.state_dict(), config["best_model_path"]) -# print( -# "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val) -# ) -# else: -# print( -# "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format( -# dice_val_best, dice_val -# ) -# ) -# global_step += 1 -# return global_step, dice_val_best, global_step_best - - -def main(): - """ - Main function of the script. - """ - - # We get the parser and parse the arguments - parser = get_parser() - args = parser.parse_args() - - # We load the config file (a yml file) - # load config file - with open(args.config, "r") as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - - ##### ------------------ - # Monai should be installed with pip install monai[all] (to get all readers) - # We define the trasnformations for training and validation - train_transforms = Compose( - [ - LoadImaged(keys=["image", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "label"]), - Orientationd(keys=["image", "label"], axcodes="RSP"), - Spacingd( - keys=["image", "label"], - pixdim=config["pixdim"], - mode=(2, 1), - ), - ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=config["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=config["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - # Flips the image : left becomes right - RandFlipd( - keys=["image", "label"], - spatial_axis=[0], - prob=0.2, - ), - # Flips the image : supperior becomes inferior - RandFlipd( - keys=["image", "label"], - spatial_axis=[1], - prob=0.2, - ), - # Flips the image : anterior becomes posterior - RandFlipd( - keys=["image", "label"], - spatial_axis=[2], - prob=0.2, - ), - RandAdjustContrastd( - keys=["image"], - prob=0.2, - gamma=(0.5, 4.5), - invert_image=True, - ), - NormalizeIntensityd( - keys=["image", "label"], - nonzero=False, - channel_wise=False - ), - EnsureTyped(keys=["image", "label"]), - AsDiscreted( - keys=["label"], - num_classes=2, - threshold_values=True, - logit_thresh=0.2, - ) - ] - ) - val_transforms = Compose( - [ - LoadImaged(keys=["image", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "label"]), - Orientationd(keys=["image", "label"], axcodes="RSP"), - Spacingd( - keys=["image", "label"], - pixdim=config["pixdim"], - mode=(2, 1), - ), - ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=config["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=config["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - NormalizeIntensityd( - keys=["image", "label"], - nonzero=False, - channel_wise=False - ), - EnsureTyped(keys=["image", "label"]), - AsDiscreted( - keys=["label"], - num_classes=2, - threshold_values=True, - logit_thresh=0.2, - ) - ] - ) - - # Path to data split (JSON file) - data_split_json_path = config["data"] - # We load the data lists - with open(data_split_json_path, "r") as f: - data = json.load(f) - train_list = data["train"] - val_list = data["validation"] - - # Path to the output directory - output_dir = config["output_dir"] - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # We load the train and validation data - logger.info("Loading the training and validation data...") - train_ds = CacheDataset( - data=train_list, - transform=train_transforms, - cache_rate=0.1, - num_workers=2 - ) - train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True) - val_ds = CacheDataset( - data=val_list, - transform=val_transforms, - cache_rate=0.1, - num_workers=0, - ) - val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) - - # plot 3 image and save them - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - for i, ax in enumerate(axes): - img = train_ds[i][0]['image'] - ax.imshow(img[0, 7, :, :], cmap="gray") - ax.set_title(f"Image {i+1}") - ax.axis('on') - plt.savefig(os.path.join(output_dir, "image.png")) - - - - print("Preparing the UNET model...") - # we define the device to use - device = torch.device("cuda:0") - - model = UNet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - kernel_size=3, - up_kernel_size=3, - num_res_units=0, - act='PRELU', - norm=Norm.BATCH, - dropout=0.0, - bias=True, - adn_ordering='NDA', - ).to(device) - - loss_function = DiceLoss(to_onehot_y=True, softmax=True) - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) - dice_metric = DiceMetric(include_background=False, reduction="mean") - torch.backends.cudnn.benchmark = True - - # initialize wandb - wandb.init(project=config["experiment_name"], config=config) - - # 🐝 Log gen gradients of the models to wandb - wandb.watch(model, log_freq=100) - - # 🐝 Add training script as an artifact - artifact_script = wandb.Artifact(name='training', type='file') - artifact_script.add_file(local_path=os.path.abspath(__file__), name=os.path.basename(__file__)) - wandb.log_artifact(artifact_script) - - epoch_loss_values = [] - step_loss_values = [] - val_loss_values = [] - best_val_loss = 1000.0 - - for epoch in range(config["max_iterations"]): - logger.info("-" * 10) - logger.info(f"epoch {epoch + 1}/{config['max_iterations']}") - model.train() - epoch_loss = 0 - epoch_cl_loss = 0 - epoch_recon_loss = 0 - step = 0 - - for batch_data in train_loader: - step += 1 - start_time = time.time() - - inputs, gt_input = ( - batch_data["image"].to(device), - batch_data["label"].to(device), - ) - - optimizer.zero_grad() - output = model(inputs) - - loss = loss_function(output, gt_input) - loss.detach().cpu() - - loss.backward() - optimizer.step() - epoch_loss += loss.item() - step_loss_values.append(loss.item()) - - - end_time = time.time() - logger.info( - f"{step}/{len(train_list) // train_loader.batch_size}, " - f"train_loss: {loss.item():.4f}, " - f"time taken: {end_time-start_time}s" - ) - - wandb.log({"Training/loss": loss.item()}) - - epoch_loss /= step - - epoch_loss_values.append(epoch_loss) - - # 🐝 log train_loss averaged over epoch to wandb - wandb.log({"Training/loss_epoch": epoch_loss}) - - - - logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") - - if epoch % config["eval_num"] == 0: - logger.info("Entering Validation for epoch: {}".format(epoch + 1)) - total_val_loss = 0 - val_step = 0 - model.eval() - for val_batch in val_loader: - val_step += 1 - start_time = time.time() - inputs, gt_input = ( - val_batch["image"].to(device), - val_batch["label"].to(device), - ) - outputs = model(inputs) - val_loss = loss_function(outputs, gt_input) - total_val_loss += val_loss.item() - end_time = time.time() - - total_val_loss /= val_step - val_loss_values.append(total_val_loss) - - wandb.log({"Validation loss": total_val_loss}) - - logger.info(f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") - - if total_val_loss < best_val_loss: - logger.info(f"Saving new model based on validation loss {total_val_loss:.4f}") - best_val_loss = total_val_loss - checkpoint = {"epoch": config["max_iterations"], "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} - torch.save(checkpoint, config["best_model_path"]) - - print("Done") - - - # # We then train the model - # post_label = AsDiscrete(to_onehot=2) - # post_pred = AsDiscrete(argmax=True, to_onehot=2) - # dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) - # global_step = 0 - # dice_val_best = 0.0 - # global_step_best = 0 - # epoch_loss_values = [] - # metric_values = [] - # while global_step < config["max_iterations"]: - # global_step, dice_val_best, global_step_best = train(model, config, global_step, train_loader, dice_val_best, global_step_best, - # loss_function, optimizer, epoch_loss_values, metric_values, val_loader, post_label, post_pred, dice_metric) - # model.load_state_dict(torch.load(config["best_model_path"])) - - # print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}") - - -if __name__ == "__main__": - main() - - - - - - - From 27eb1854006d94cbece2cdc28b3d8a6d056ed114 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 4 Apr 2024 22:31:03 -0400 Subject: [PATCH 044/108] script for training unet with finetuning data-aug parameters --- monai/train_monai_unet_lightning.py | 80 ++++++++++++++--------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 1f330c0..bfb2800 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -154,16 +154,16 @@ def prepare_data(self): keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # ), # Flips the image : left becomes right RandFlipd( keys=["image", "label"], @@ -237,16 +237,16 @@ def prepare_data(self): keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # ), # This normalizes the intensity of the image NormalizeIntensityd( keys=["image"], @@ -591,7 +591,7 @@ def main(): # define optimizer optimizer_class = torch.optim.Adam - wandb.init(project=f'monai-ms-lesion-seg', config=config) + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config) logger.info("Building the model ...") @@ -614,24 +614,24 @@ def main(): # adn_ordering='NDA', # ) - # net=UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(32, 64, 128, 256), - # strides=(2, 2, 2 ), + net=UNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=(32, 64, 128, 256), + strides=(2, 2, 2 ), - # # dropout=0.1 - # ) - - net = AttentionUnet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=(32, 64, 128), - strides=(2, 2, 2,), - dropout=0.1, - ) + # dropout=0.1 + ) + + # net = AttentionUnet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2,), + # dropout=0.1, + # ) # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) @@ -680,8 +680,8 @@ def main(): config=config) # Saving training script to wandb - wandb.save("ms-lesion-agnostic/monai/nnunet/config_fake.yml") - wandb.save("ms-lesion-agnostic/monai/nnunet/train_monai_unet_lightning.py") + wandb.save("ms-lesion-agnostic/monai/config.yml") + wandb.save("ms-lesion-agnostic/monai/train_monai_unet_lightning.py") # initialise Lightning's trainer. From 4d1996fe21085be26493535277f159226754ac26 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 4 Apr 2024 23:21:40 -0400 Subject: [PATCH 045/108] modified training script with new data augmentation strategies --- monai/config.yml | 3 +- monai/train_monai_unet_lightning.py | 43 +++++++++++++++++++---------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/monai/config.yml b/monai/config.yml index bf666fe..55763f3 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -10,8 +10,9 @@ output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution # pixdim : [1.0, 1.0, 1.0] pixdim : [0.7, 0.7, 0.7] +# pixdim : [0.5, 0.5, 0.5] # Spatial size of the input data -spatial_size : [64, 128, 128] # RL, AP, IS +spatial_size : [32, 128, 128] # RL, AP, IS batch_size : 4 # UNETR model parameters diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index bfb2800..91457e9 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -48,7 +48,8 @@ LabelToContourd, Invertd, SaveImage, - EnsureType + EnsureType, + Rand3DElasticd ) from monai.utils import set_determinism from monai.inferers import sliding_window_inference @@ -147,13 +148,9 @@ def prepare_data(self): # CropForegroundd( # keys=["image", "label"], # source_key="label", - # margin=100 + # margin=150 # ), - # This resizes the image and the label to the spatial size defined in the config - ResizeWithPadOrCropd( - keys=["image", "label"], - spatial_size=self.cfg["spatial_size"], - ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) # RandCropByPosNegLabeld( # keys=["image", "label"], # label_key="label", @@ -163,7 +160,13 @@ def prepare_data(self): # num_samples=4, # image_key="image", # image_threshold=0, + # allow_smaller=True, # ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), # Flips the image : left becomes right RandFlipd( keys=["image", "label"], @@ -182,6 +185,14 @@ def prepare_data(self): spatial_axis=[2], prob=0.2, ), + # # Random elastic deformation + # Rand3DElasticd( + # keys=["image", "label"], + # sigma_range=(5, 7), + # magnitude_range=(50, 150), + # prob=0.2, + # mode=['bilinear', 'nearest'], + # ), # RandAdjustContrastd( # keys=["image"], # prob=0.2, @@ -232,11 +243,10 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 1), ), - # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd( - keys=["image", "label"], - spatial_size=self.cfg["spatial_size"], - ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), # RandCropByPosNegLabeld( # keys=["image", "label"], # label_key="label", @@ -246,7 +256,13 @@ def prepare_data(self): # num_samples=4, # image_key="image", # image_threshold=0, + # allow_smaller=True, # ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # This normalizes the intensity of the image NormalizeIntensityd( keys=["image"], @@ -619,8 +635,7 @@ def main(): in_channels=1, out_channels=1, channels=(32, 64, 128, 256), - strides=(2, 2, 2 ), - + strides=(2, 2, 2, ), # dropout=0.1 ) From a28d17637bd26d4e00f37caa55b192b5fd70f65e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 5 Apr 2024 17:43:50 -0400 Subject: [PATCH 046/108] fixed typos in script and arranged in functions --- monai/1_create_msd_data.py | 281 +++++++++++++++++++++---------------- 1 file changed, 157 insertions(+), 124 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index 451f67b..df3f595 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -11,7 +11,7 @@ --seed: Seed for reproducibility Example: - python create_msd_data.py ... + python create_msd_data.py -pd /path/dataset -po /path/output TO DO: * @@ -29,126 +29,159 @@ from datetime import date from pathlib import Path -# root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" - -parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') - -parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') -parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') -parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") -args = parser.parse_args() - - -root = args.path_data -seed = args.seed - -# Get all subjects -canproco_path = Path(os.path.join(root, "canproco")) -basel_path = Path(os.path.join(root, "basel-mp2rage")) -bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms")) -sct_testing_path = Path(os.path.join(root, "sct-testing-large")) - -subjects_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) -subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) -subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) -subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) - -subjects = subjects_canproco + subjects_basel + subjects_sct + subjects_bavaria -logger.info(f"Total number of subjects in the root directory: {len(subjects)}") - -# create one json file with 60-20-20 train-val-test split -train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 -train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) -# Use the training split to further split into training and validation splits -train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), - random_state=args.seed, ) -# sort the subjects -train_subjects = sorted(train_subjects) -val_subjects = sorted(val_subjects) -test_subjects = sorted(test_subjects) - -logger.info(f"Number of training subjects: {len(train_subjects)}") -logger.info(f"Number of validation subjects: {len(val_subjects)}") -logger.info(f"Number of testing subjects: {len(test_subjects)}") - -# dump train/val/test splits into a yaml file -with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: - yaml.dump({'train': train_subjects, 'val': val_subjects, 'test': test_subjects}, file, indent=2, sort_keys=True) - -# keys to be defined in the dataset_0.json -params = {} -params["description"] = "ms-lesion-agnostic" -params["labels"] = { - "0": "background", - "1": "ms-lesion-seg" - } -params["license"] = "plb" -params["modality"] = { - "0": "MRI" - } -params["name"] = "ms-lesion-agnostic" -params["numTest"] = len(test_subjects) -params["numTraining"] = len(train_subjects) -params["numValidation"] = len(val_subjects) -params["seed"] = args.seed -params["reference"] = "NeuroPoly" -params["tensorImageSize"] = "3D" - -train_subjects_dict = {"train": train_subjects} -val_subjects_dict = {"validation": val_subjects} -test_subjects_dict = {"test": test_subjects} -all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] - -for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): - - for name, subs_list in subjects_dict.items(): - - temp_list = [] - for subject_no, subject in enumerate(subs_list): - - temp_data_canproco = {} - temp_data_basel = {} - temp_data_sct = {} - temp_data_bavaria = {} - - # Canproco - if 'canproco' in str(subject): - temp_data_canproco["label"] = str(subject) - temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): - temp_list.append(temp_data_canproco) - - # Basel - elif 'basel-mp2rage' in str(subject): - relative_path = subject.relative_to(basel_path).parent - temp_data_basel["image"] = str(subject) - temp_data_basel["label"] = str(basel_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz')) - if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): - temp_list.append(temp_data_basel) - - # sct-testing-large - elif 'sct-testing-large' in str(subject): - temp_data_sct["label"] = str(subject) - temp_data_sct["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): - temp_list.append(temp_data_sct) - - - # Bavaria-quebec - elif 'bavaria-quebec-spine-ms' in str(subject): - relative_path = subject.relative_to(bavaria_path).parent - temp_data_bavaria["image"] = str(subject) - temp_data_bavaria["label"] = str(bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz')) - if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): - temp_list.append(temp_data_bavaria) - - params[name] = temp_list - logger.info(f"Number of images in {name} set: {len(temp_list)}") - -final_json = json.dumps(params, indent=4, sort_keys=True) -if not os.path.exists(args.path_out): - os.makedirs(args.path_out, exist_ok=True) - -jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") -jsonFile.write(final_json) -jsonFile.close() + +def get_parser(): + """ + Get parser for script create_msd_data.py + + Input: + None + + Returns: + parser : argparse object + """ + + parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') + + parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') + parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') + parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") + + return parser + +def main(): + """ + This is the main function of the script. + + Input: + None + + Returns: + None + """ + # Get the arguments + parser = get_parser() + args = parser.parse_args() + + root = args.path_data + seed = args.seed + + # Get all subjects + canproco_path = Path(os.path.join(root, "canproco")) + basel_path = Path(os.path.join(root, "basel-mp2rage")) + bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms")) + sct_testing_path = Path(os.path.join(root, "sct-testing-large")) + + subjects_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) + subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) + subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) + subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) + + subjects = subjects_canproco + subjects_basel + subjects_sct + subjects_bavaria + # logger.info(f"Total number of subjects in the root directory: {len(subjects)}") + + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + # sort the subjects + train_subjects = sorted(train_subjects) + val_subjects = sorted(val_subjects) + test_subjects = sorted(test_subjects) + + # logger.info(f"Number of training subjects: {len(train_subjects)}") + # logger.info(f"Number of validation subjects: {len(val_subjects)}") + # logger.info(f"Number of testing subjects: {len(test_subjects)}") + + # dump train/val/test splits into a yaml file + with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: + yaml.dump({'train': train_subjects, 'val': val_subjects, 'test': test_subjects}, file, indent=2, sort_keys=True) + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "ms-lesion-agnostic" + params["labels"] = { + "0": "background", + "1": "ms-lesion-seg" + } + params["license"] = "plb" + params["modality"] = { + "0": "MRI" + } + params["name"] = "ms-lesion-agnostic" + params["seed"] = args.seed + params["reference"] = "NeuroPoly" + params["tensorImageSize"] = "3D" + + train_subjects_dict = {"train": train_subjects} + val_subjects_dict = {"validation": val_subjects} + test_subjects_dict = {"test": test_subjects} + all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] + + # iterate through the train/val/test splits and add those which have both image and label + for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): + + for name, subs_list in subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(subs_list): + + temp_data_canproco = {} + temp_data_basel = {} + temp_data_sct = {} + temp_data_bavaria = {} + + # Canproco + if 'canproco' in str(subject): + temp_data_canproco["label"] = str(subject) + temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): + temp_list.append(temp_data_canproco) + + # Basel + elif 'basel-mp2rage' in str(subject): + relative_path = subject.relative_to(basel_path).parent + temp_data_basel["image"] = str(subject) + temp_data_basel["label"] = str(basel_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz')) + if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + temp_list.append(temp_data_basel) + + # sct-testing-large + elif 'sct-testing-large' in str(subject): + temp_data_sct["label"] = str(subject) + temp_data_sct["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): + temp_list.append(temp_data_sct) + + + # Bavaria-quebec + elif 'bavaria-quebec-spine-ms' in str(subject): + relative_path = subject.relative_to(bavaria_path).parent + temp_data_bavaria["image"] = str(subject) + temp_data_bavaria["label"] = str(bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz')) + if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + temp_list.append(temp_data_bavaria) + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + params["numTest"] = len(params["test"]) + params["numTraining"] = len(params["train"]) + params["numValidation"] = len(params["validation"]) + # Print total number of images + logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file From a9c293f983d784ff8b3a638ee667fb694552548d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 5 Apr 2024 17:44:14 -0400 Subject: [PATCH 047/108] fixed parameters for model training on entire dataset --- monai/train_monai_unet_lightning.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 91457e9..47c1dd8 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -185,14 +185,14 @@ def prepare_data(self): spatial_axis=[2], prob=0.2, ), - # # Random elastic deformation - # Rand3DElasticd( - # keys=["image", "label"], - # sigma_range=(5, 7), - # magnitude_range=(50, 150), - # prob=0.2, - # mode=['bilinear', 'nearest'], - # ), + # Random elastic deformation + Rand3DElasticd( + keys=["image", "label"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=0.2, + mode=['bilinear', 'nearest'], + ), # RandAdjustContrastd( # keys=["image"], # prob=0.2, @@ -215,10 +215,10 @@ def prepare_data(self): # keys=["image"], # kernel_type='Laplace', # ), - # RandGaussianNoised( - # keys=["image"], - # prob=0.2, - # ), + RandGaussianNoised( + keys=["image"], + prob=0.2, + ), # RandShiftIntensityd( # keys=["image"], # offsets=0.1, @@ -634,8 +634,8 @@ def main(): spatial_dims=3, in_channels=1, out_channels=1, - channels=(32, 64, 128, 256), - strides=(2, 2, 2, ), + channels=(32, 64, 128, 256, 512), + strides=(2, 2, 2, 2, ), # dropout=0.1 ) From 29b8b694d92cbe7c370a21b0746da095bafbb408 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 5 Apr 2024 18:29:48 -0400 Subject: [PATCH 048/108] added function to cound lesions and get total volume --- monai/1_create_msd_data.py | 57 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index df3f595..dd02f84 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -28,6 +28,9 @@ from sklearn.model_selection import train_test_split from datetime import date from pathlib import Path +import nibabel as nib +import numpy as np +import skimage def get_parser(): @@ -45,10 +48,38 @@ def get_parser(): parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') + parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") return parser + +def count_lesion(label_file): + """ + This function takes a label file and counts the number of lesions in it. + + Input: + label_file : str : Path to the label file + + Returns: + count : int : Number of lesions in the label file + total_volume : float : Total volume of lesions in the label file + """ + + label = nib.load(label_file) + label_data = label.get_fdata() + + # get the total volume of the lesions + total_volume = np.sum(label_data) + resolution = label.header.get_zooms() + total_volume = total_volume * np.prod(resolution) + + # get the number of lesions + _, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True) + + return total_volume, nb_lesions + + def main(): """ This is the main function of the script. @@ -138,6 +169,11 @@ def main(): temp_data_canproco["label"] = str(subject) temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) + temp_data_canproco["total_lesion_volume"] = total_lesion_volume + temp_data_canproco["nb_lesions"] = nb_lesions + if args.lesion_only and nb_lesions == 0: + continue temp_list.append(temp_data_canproco) # Basel @@ -146,6 +182,11 @@ def main(): temp_data_basel["image"] = str(subject) temp_data_basel["label"] = str(basel_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz')) if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) + temp_data_basel["total_lesion_volume"] = total_lesion_volume + temp_data_basel["nb_lesions"] = nb_lesions + if args.lesion_only and nb_lesions == 0: + continue temp_list.append(temp_data_basel) # sct-testing-large @@ -153,6 +194,11 @@ def main(): temp_data_sct["label"] = str(subject) temp_data_sct["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) + temp_data_sct["total_lesion_volume"] = total_lesion_volume + temp_data_sct["nb_lesions"] = nb_lesions + if args.lesion_only and nb_lesions == 0: + continue temp_list.append(temp_data_sct) @@ -162,6 +208,11 @@ def main(): temp_data_bavaria["image"] = str(subject) temp_data_bavaria["label"] = str(bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz')) if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) + temp_data_bavaria["total_lesion_volume"] = total_lesion_volume + temp_data_bavaria["nb_lesions"] = nb_lesions + if args.lesion_only and nb_lesions == 0: + continue temp_list.append(temp_data_bavaria) params[name] = temp_list @@ -175,8 +226,10 @@ def main(): final_json = json.dumps(params, indent=4, sort_keys=True) if not os.path.exists(args.path_out): os.makedirs(args.path_out, exist_ok=True) - - jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") + if args.lesion_only: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w") + else: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") jsonFile.write(final_json) jsonFile.close() From fc1831845469c906cd0dc24ee539739100892d44 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 5 Apr 2024 18:30:13 -0400 Subject: [PATCH 049/108] changed batch-size to 8 --- monai/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/config.yml b/monai/config.yml index 55763f3..286b91c 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -13,7 +13,7 @@ pixdim : [0.7, 0.7, 0.7] # pixdim : [0.5, 0.5, 0.5] # Spatial size of the input data spatial_size : [32, 128, 128] # RL, AP, IS -batch_size : 4 +batch_size : 8 # UNETR model parameters feature_size : 8 From b1563f82ffe2a44238221c70393d36160ff942cf Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 5 Apr 2024 18:31:06 -0400 Subject: [PATCH 050/108] changed to attentionUnet --- monai/train_monai_unet_lightning.py | 34 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 47c1dd8..52e8ac9 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -630,23 +630,23 @@ def main(): # adn_ordering='NDA', # ) - net=UNet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=(32, 64, 128, 256, 512), - strides=(2, 2, 2, 2, ), - # dropout=0.1 - ) - - # net = AttentionUnet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(32, 64, 128), - # strides=(2, 2, 2,), - # dropout=0.1, - # ) + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128, 256, 512), + # strides=(2, 2, 2, 2, ), + # # dropout=0.1 + # ) + + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=(32, 64, 128, 256, 512), + strides=(2, 2, 2, 2,), + dropout=0.1, + ) # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) From 8e8ee701545e6bf893a4f619b6b8048ff474d54d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 5 Apr 2024 18:33:43 -0400 Subject: [PATCH 051/108] added lesion only dataset on entirety of images --- monai/config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/config.yml b/monai/config.yml index 286b91c..9ba503a 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -2,9 +2,10 @@ # Path to the data json file # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json -data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_10_each.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_10_each.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json +data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json # Path to the output directory output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ # Resampling resolution From d9997a65a38b6c36a09d6aa8b75a1d4b64ea3a1c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 10 Apr 2024 10:05:18 -0400 Subject: [PATCH 052/108] added crop foreground for model training --- monai/train_monai_unet_lightning.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 52e8ac9..d434d78 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -145,11 +145,11 @@ def prepare_data(self): ), # # This crops the image around areas where the mask is non-zero # # (the margin is added because otherwise the image would be just the size of the lesion) - # CropForegroundd( - # keys=["image", "label"], - # source_key="label", - # margin=150 - # ), + CropForegroundd( + keys=["image", "label"], + source_key="label", + margin=200 + ), # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) # RandCropByPosNegLabeld( # keys=["image", "label"], From 6f35a4a46107a8251a235a7432fd631a9135dfdd Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 10 Apr 2024 13:37:30 -0400 Subject: [PATCH 053/108] added more data augmentation --- monai/train_monai_unet_lightning.py | 66 +++++++++++++++++++---------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index d434d78..b2fff01 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -49,7 +49,10 @@ Invertd, SaveImage, EnsureType, - Rand3DElasticd + Rand3DElasticd, + RandSimulateLowResolutiond, + RandBiasFieldd, + RandAffined ) from monai.utils import set_determinism from monai.inferers import sliding_window_inference @@ -143,25 +146,31 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 1), ), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), # # This crops the image around areas where the mask is non-zero # # (the margin is added because otherwise the image would be just the size of the lesion) - CropForegroundd( - keys=["image", "label"], - source_key="label", - margin=200 - ), - # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) - # RandCropByPosNegLabeld( + # CropForegroundd( # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=1, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # allow_smaller=True, + # source_key="label", + # margin=200 # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=4, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), # This resizes the image and the label to the spatial size defined in the config ResizeWithPadOrCropd( keys=["image", "label"], @@ -193,6 +202,13 @@ def prepare_data(self): prob=0.2, mode=['bilinear', 'nearest'], ), + # Random affine transform of the image + RandAffined( + keys=["image", "label"], + prob=0.2, + mode=('bilinear', 'nearest'), + padding_mode='zeros', + ), # RandAdjustContrastd( # keys=["image"], # prob=0.2, @@ -205,12 +221,6 @@ def prepare_data(self): # func=multiply_by_negative_one, # prob=0.2 # ), - # Normalize the intensity of the image - NormalizeIntensityd( - keys=["image"], - nonzero=False, - channel_wise=False - ), # LabelToContourd( # keys=["image"], # kernel_type='Laplace', @@ -218,7 +228,17 @@ def prepare_data(self): RandGaussianNoised( keys=["image"], prob=0.2, - ), + ), + # Random simulation of low resolution + RandSimulateLowResolutiond( + keys=["image"], + zoom_range=(0.8, 1.5), + prob=0.2), + # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing + RandBiasFieldd(keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=0.1), # RandShiftIntensityd( # keys=["image"], # offsets=0.1, From a3df6684f690ec777d7aac3e476905f55f44f86f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 10 Apr 2024 15:53:39 -0400 Subject: [PATCH 054/108] added precision and recall metric --- monai/utils.py | 252 +++++++++++++++++++++++++++++-------------------- 1 file changed, 148 insertions(+), 104 deletions(-) diff --git a/monai/utils.py b/monai/utils.py index dd65b1d..04b3224 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -10,6 +10,9 @@ from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +import skimage + + def dice_score(prediction, groundtruth, smooth=1.): numer = (prediction * groundtruth).sum() denor = (prediction + groundtruth).sum() @@ -63,111 +66,152 @@ def plot_slices(image, gt, pred, debug=False): fig.show() return fig -nnunet_plans = { - "UNet_class_name": "PlainConvUNet", - "UNet_base_num_features": 32, - "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2, 2], - "n_conv_per_stage_decoder": [2, 2, 2, 2, 2, 2], - "pool_op_kernel_sizes": [ - [1, 1, 1], - [1, 2, 2], - [1, 2, 2], - [2, 2, 2], - [2, 2, 2], - [1, 2, 2], - [1, 2, 2] - ], - "conv_kernel_sizes": [ - [1, 3, 3], - [1, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3] - ], - "unet_max_num_features": 320, -} - - -# ====================================================================================================== -# Utils for nnUNet's Model -# ==================================================================================================== -class InitWeights_He(object): - def __init__(self, neg_slope=1e-2): - self.neg_slope = neg_slope - - def __call__(self, module): - if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): - module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) - if module.bias is not None: - module.bias = nn.init.constant_(module.bias, 0) - - -# ====================================================================================================== -# Define the network based on plans json -# ==================================================================================================== -def create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision: bool = False): - """ - Adapted from nnUNet's source code: - https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 +def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.2): """ - num_stages = len(plans["conv_kernel_sizes"]) - - dim = len(plans["conv_kernel_sizes"][0]) - conv_op = convert_dim_to_conv_op(dim) - - segmentation_network_class_name = plans["UNet_class_name"] - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accomodate that.' - network_class = mapping[segmentation_network_class_name] - - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], - 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] - } + This function computes the lesion-wise precision and recall. + + Args: + prediction: predicted segmentation mask + groundtruth: ground truth segmentation mask + iou_threshold: threshold for intersection over union (IoU) for a lesion to be considered as true positive + Returns: + precision: lesion-wise precision + recall: lesion-wise recall + """ + # Compute connected components in the predicted and ground truth segmentation masks + pred_labels = skimage.measure.label(prediction, connectivity=2) + gt_labels = skimage.measure.label(groundtruth, connectivity=2) + + # Compute the intersection over union (IoU) between each pair of connected components + iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) + for i in range(np.max(pred_labels)): + for j in range(np.max(gt_labels)): + # Compute the intersection + intersection = np.sum((pred_labels == i + 1) * (gt_labels == j + 1)) + # Compute the union + union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection + # Compute the IoU + iou_matrix[i, j] = intersection / union + + # Compute lesion-wise precision and recall + true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) + false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) + false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) + precision = true_positives / (true_positives + false_positives) + recall = true_positives / (true_positives + false_negatives) + + return precision, recall + + +# ############################################################################################################ +# # NNUNet's Model +# ############################################################################################################ +# nnunet_plans = { +# "UNet_class_name": "PlainConvUNet", +# "UNet_base_num_features": 32, +# "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2, 2], +# "n_conv_per_stage_decoder": [2, 2, 2, 2, 2, 2], +# "pool_op_kernel_sizes": [ +# [1, 1, 1], +# [1, 2, 2], +# [1, 2, 2], +# [2, 2, 2], +# [2, 2, 2], +# [1, 2, 2], +# [1, 2, 2] +# ], +# "conv_kernel_sizes": [ +# [1, 3, 3], +# [1, 3, 3], +# [3, 3, 3], +# [3, 3, 3], +# [3, 3, 3], +# [3, 3, 3], +# [3, 3, 3] +# ], +# "unet_max_num_features": 320, +# } + + +# # ====================================================================================================== +# # Utils for nnUNet's Model +# # ==================================================================================================== +# class InitWeights_He(object): +# def __init__(self, neg_slope=1e-2): +# self.neg_slope = neg_slope + +# def __call__(self, module): +# if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): +# module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) +# if module.bias is not None: +# module.bias = nn.init.constant_(module.bias, 0) + + +# # ====================================================================================================== +# # Define the network based on plans json +# # ==================================================================================================== +# def create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision: bool = False): +# """ +# Adapted from nnUNet's source code: +# https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + +# """ +# num_stages = len(plans["conv_kernel_sizes"]) + +# dim = len(plans["conv_kernel_sizes"][0]) +# conv_op = convert_dim_to_conv_op(dim) + +# segmentation_network_class_name = plans["UNet_class_name"] +# mapping = { +# 'PlainConvUNet': PlainConvUNet, +# 'ResidualEncoderUNet': ResidualEncoderUNet +# } +# kwargs = { +# 'PlainConvUNet': { +# 'conv_bias': True, +# 'norm_op': get_matching_instancenorm(conv_op), +# 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, +# 'dropout_op': None, 'dropout_op_kwargs': None, +# 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, +# }, +# 'ResidualEncoderUNet': { +# 'conv_bias': True, +# 'norm_op': get_matching_instancenorm(conv_op), +# 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, +# 'dropout_op': None, 'dropout_op_kwargs': None, +# 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, +# } +# } +# assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ +# 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ +# 'into either this ' \ +# 'function (get_network_from_plans) or ' \ +# 'the init of your nnUNetModule to accomodate that.' +# network_class = mapping[segmentation_network_class_name] + +# conv_or_blocks_per_stage = { +# 'n_conv_per_stage' +# if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], +# 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] +# } - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, - plans["unet_max_num_features"]) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=plans["conv_kernel_sizes"], - strides=plans["pool_op_kernel_sizes"], - num_classes=num_classes, - deep_supervision=deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] - ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) +# # network class name!! +# model = network_class( +# input_channels=num_input_channels, +# n_stages=num_stages, +# features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, +# plans["unet_max_num_features"]) for i in range(num_stages)], +# conv_op=conv_op, +# kernel_sizes=plans["conv_kernel_sizes"], +# strides=plans["pool_op_kernel_sizes"], +# num_classes=num_classes, +# deep_supervision=deep_supervision, +# **conv_or_blocks_per_stage, +# **kwargs[segmentation_network_class_name] +# ) +# model.apply(InitWeights_He(1e-2)) +# if network_class == ResidualEncoderUNet: +# model.apply(init_last_bn_before_add_to_0) - return model \ No newline at end of file +# return model \ No newline at end of file From 1c3c557847c2ff5629871652126bd832d10d21d7 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 10 Apr 2024 23:19:01 -0400 Subject: [PATCH 055/108] added precision and recall metric : function must be reviewed (not sure it works) --- monai/train_monai_unet_lightning.py | 94 +++++++++++++++++++---------- monai/utils.py | 72 +++++++++++++++------- 2 files changed, 112 insertions(+), 54 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index b2fff01..1081c36 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -20,7 +20,7 @@ from losses import AdapWingLoss, SoftDiceLoss -from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans +from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, lesion_wise_precision_recall from monai.networks.nets import UNet, BasicUNet, AttentionUnet from monai.metrics import DiceMetric from monai.losses import DiceLoss, DiceCELoss @@ -105,6 +105,7 @@ def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_i # define evaluation metric self.soft_dice_metric = dice_score + self.lesion_wise_precision_recall = lesion_wise_precision_recall # temp lists for storing outputs from training, validation, and testing self.train_step_outputs = [] @@ -233,12 +234,15 @@ def prepare_data(self): RandSimulateLowResolutiond( keys=["image"], zoom_range=(0.8, 1.5), - prob=0.2), + prob=0.2 + ), # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing - RandBiasFieldd(keys=["image"], - coeff_range=(0.0, 0.5), - degree=3, - prob=0.1), + RandBiasFieldd( + keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=0.2 + ), # RandShiftIntensityd( # keys=["image"], # offsets=0.1, @@ -263,6 +267,12 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 1), ), + # This normalizes the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), # CropForegroundd( # keys=["image", "label"], # source_key="label", @@ -282,13 +292,6 @@ def prepare_data(self): keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - - # This normalizes the intensity of the image - NormalizeIntensityd( - keys=["image"], - nonzero=False, - channel_wise=False - ), # LabelToContourd( # keys=["image"], # kernel_type='Laplace', @@ -398,13 +401,19 @@ def training_step(self, batch, batch_idx): # So, take this dice score with a lot of salt train_soft_dice = self.soft_dice_metric(output, labels) + # Compute precision and recall + train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + print("sucess") + metrics_dict = { "loss": loss.cpu(), "train_soft_dice": train_soft_dice.detach().cpu(), "train_number": len(inputs), "train_image": inputs[0].detach().cpu().squeeze(), "train_gt": labels[0].detach().cpu().squeeze(), - "train_pred": output[0].detach().cpu().squeeze() + "train_pred": output[0].detach().cpu().squeeze(), + "train_precision": train_precision.detach().cpu(), + "train_recall": train_recall.detach().cpu(), } self.train_step_outputs.append(metrics_dict) @@ -417,18 +426,26 @@ def on_train_epoch_end(self): return None else: train_loss, train_soft_dice = 0, 0 + precision_score, recall_score = 0, 0 num_items = len(self.train_step_outputs) for output in self.train_step_outputs: train_loss += output["loss"].item() train_soft_dice += output["train_soft_dice"].item() + precision_score = output["train_precision"] + recall_score = output["train_recall"] mean_train_loss = (train_loss / num_items) mean_train_soft_dice = (train_soft_dice / num_items) + mean_precision_score = np.mean(precision_score.detach().numpy()) + mean_recall_score = np.mean(recall_score.detach().numpy()) wandb_logs = { "train_soft_dice": mean_train_soft_dice, "train_loss": mean_train_loss, + "train_precision": mean_precision_score, + "train_recall": mean_recall_score, } + self.log_dict(wandb_logs) # plot the training images @@ -471,6 +488,10 @@ def validation_step(self, batch, batch_idx): hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + # compute precision and recall + val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + print("sucess val") + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, # using .detach() to avoid storing the whole computation graph # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 @@ -482,6 +503,8 @@ def validation_step(self, batch, batch_idx): "val_image": inputs[0].detach().cpu().squeeze(), "val_gt": labels[0].detach().cpu().squeeze(), "val_pred": post_outputs[0].detach().cpu().squeeze(), + "val_precision": val_precision.detach().cpu(), + "val_recall": val_recall.detach().cpu(), } self.val_step_outputs.append(metrics_dict) @@ -490,20 +513,27 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + val_precision, val_recall = 0, 0 for output in self.val_step_outputs: val_loss += output["val_loss"].sum().item() val_soft_dice += output["val_soft_dice"].sum().item() val_hard_dice += output["val_hard_dice"].sum().item() num_items += output["val_number"] + val_precision += output["val_precision"].sum().item() + val_recall += output["val_recall"].sum().item() mean_val_loss = (val_loss / num_items) mean_val_soft_dice = (val_soft_dice / num_items) mean_val_hard_dice = (val_hard_dice / num_items) + mean_val_precision = (val_precision / num_items) + mean_val_recall = (val_recall / num_items) wandb_logs = { "val_soft_dice": mean_val_soft_dice, # "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, + "val_precision": mean_val_precision, + "val_recall": mean_val_recall, } self.log_dict(wandb_logs) @@ -650,23 +680,23 @@ def main(): # adn_ordering='NDA', # ) - # net=UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(32, 64, 128, 256, 512), - # strides=(2, 2, 2, 2, ), - # # dropout=0.1 - # ) - - net = AttentionUnet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=(32, 64, 128, 256, 512), - strides=(2, 2, 2, 2,), - dropout=0.1, - ) + net=UNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=(32, 64, 128), + strides=(2, 2, 2, ), + # dropout=0.1 + ) + + # net = AttentionUnet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128, 256, 512), + # strides=(2, 2, 2, 2,), + # dropout=0.1, + # ) # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) @@ -727,7 +757,7 @@ def main(): check_val_every_n_epoch=config["eval_num"], max_epochs=config["max_iterations"], precision=32, - # deterministic=True, + # precision='bf16-mixed', enable_progress_bar=True) # profiler="simple",) # to profile the training time taken for each step diff --git a/monai/utils.py b/monai/utils.py index 04b3224..b00452b 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -79,28 +79,56 @@ def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.2): precision: lesion-wise precision recall: lesion-wise recall """ - # Compute connected components in the predicted and ground truth segmentation masks - pred_labels = skimage.measure.label(prediction, connectivity=2) - gt_labels = skimage.measure.label(groundtruth, connectivity=2) - - # Compute the intersection over union (IoU) between each pair of connected components - iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) - for i in range(np.max(pred_labels)): - for j in range(np.max(gt_labels)): - # Compute the intersection - intersection = np.sum((pred_labels == i + 1) * (gt_labels == j + 1)) - # Compute the union - union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection - # Compute the IoU - iou_matrix[i, j] = intersection / union - - # Compute lesion-wise precision and recall - true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) - false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) - false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) - precision = true_positives / (true_positives + false_positives) - recall = true_positives / (true_positives + false_negatives) - + prediction_cpu = prediction#.detach().numpy() + groundtruth_cpu = groundtruth#.detach().numpy() + + precision = [] + recall = [] + print(prediction_cpu.shape) + for i in range(prediction_cpu.shape[0]): + # Compute connected components in the predicted and ground truth segmentation masks + if len(prediction_cpu.shape) == 4: + print("iteration") + pred_labels = skimage.measure.label(prediction_cpu[0], connectivity=2) + gt_labels = skimage.measure.label(groundtruth_cpu[0], connectivity=2) + print('c', pred_labels.shape) + print('d', gt_labels.shape) + if len(prediction_cpu.shape) == 5: + pred_labels = skimage.measure.label(prediction_cpu[i][0], connectivity=2) + gt_labels = skimage.measure.label(groundtruth_cpu[i][0], connectivity=2) + print('e', pred_labels.shape) + print('f', gt_labels.shape) + + # If there are no connected components in the predicted or ground truth segmentation masks we return 0 and continue + if np.max(pred_labels)==0 or np.max(gt_labels)==0: + precision+= [0] + recall+= [0] + continue + + # Compute the intersection over union (IoU) between each pair of connected components + iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) + for i in range(np.max(pred_labels)): + for j in range(np.max(gt_labels)): + # Compute the intersection + intersection = np.sum((pred_labels == i + 1) * (gt_labels == j + 1)) + # Compute the union + union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection + # Compute the IoU + iou_matrix[i, j] = intersection / union + + # Compute lesion-wise precision and recall + true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) + false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) + false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) + precision += [true_positives / (true_positives + false_positives)] + recall+= [true_positives / (true_positives + false_negatives)] + + # Put it back in cuda + precision = torch.tensor(precision).cuda() + recall = torch.tensor(recall).cuda() + + print("precision", precision) + print("recall", recall) return precision, recall From 9d80931160199fb168b588bac5df2e5a0fe718d1 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 16 Apr 2024 13:35:36 -0400 Subject: [PATCH 056/108] modified version of precision/recall metric --- monai/utils.py | 51 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/monai/utils.py b/monai/utils.py index b00452b..c102263 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -67,7 +67,7 @@ def plot_slices(image, gt, pred, debug=False): return fig -def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.2): +def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): """ This function computes the lesion-wise precision and recall. @@ -84,29 +84,38 @@ def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.2): precision = [] recall = [] - print(prediction_cpu.shape) + # print(prediction_cpu.shape) for i in range(prediction_cpu.shape[0]): # Compute connected components in the predicted and ground truth segmentation masks if len(prediction_cpu.shape) == 4: - print("iteration") - pred_labels = skimage.measure.label(prediction_cpu[0], connectivity=2) - gt_labels = skimage.measure.label(groundtruth_cpu[0], connectivity=2) - print('c', pred_labels.shape) - print('d', gt_labels.shape) + # print("iteration") + # binarize the prediction and ground truth + prediction_cpu[0] = prediction_cpu[0] > 0.2 + groundtruth_cpu[0] = groundtruth_cpu[0] > 0.2 + # compute connected components + pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[0], connectivity=2, return_num=True) + gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[0], connectivity=2, return_num=True) + # print('c', pred_labels.shape) + # print('d', gt_labels.shape) if len(prediction_cpu.shape) == 5: - pred_labels = skimage.measure.label(prediction_cpu[i][0], connectivity=2) - gt_labels = skimage.measure.label(groundtruth_cpu[i][0], connectivity=2) - print('e', pred_labels.shape) - print('f', gt_labels.shape) + # binarize the prediction and ground truth + prediction_cpu[i][0] = prediction_cpu[i][0] > 0.2 + groundtruth_cpu[i][0] = groundtruth_cpu[i][0] > 0.2 + # compute connected components + pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[i][0], connectivity=2, return_num=True) + gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[i][0], connectivity=2, return_num=True) + # print('e', pred_labels.shape) + # print('f', gt_labels.shape) # If there are no connected components in the predicted or ground truth segmentation masks we return 0 and continue - if np.max(pred_labels)==0 or np.max(gt_labels)==0: + if num_components_gt==0 or num_components_pred==0: precision+= [0] recall+= [0] continue # Compute the intersection over union (IoU) between each pair of connected components iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) + intersection_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) for i in range(np.max(pred_labels)): for j in range(np.max(gt_labels)): # Compute the intersection @@ -115,14 +124,26 @@ def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.2): union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection # Compute the IoU iou_matrix[i, j] = intersection / union + # if iou_matrix[i, j] > 0: + # print("iou_matrix", iou_matrix[i, j]) + # Compute the intersection + intersection_matrix[i, j] = intersection + # # Compute lesion-wise precision and recall + # true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) + # false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) + # false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) + # precision += [true_positives / (true_positives + false_positives)] + # recall+= [true_positives / (true_positives + false_negatives)] + # Compute lesion-wise precision and recall - true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) - false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) - false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) + true_positives = np.sum(np.max(intersection_matrix, axis=1) > iou_threshold) + false_positives = np.sum(np.max(intersection_matrix, axis=0) <= iou_threshold) + false_negatives = np.sum(np.max(intersection_matrix, axis=1) <= iou_threshold) precision += [true_positives / (true_positives + false_positives)] recall+= [true_positives / (true_positives + false_negatives)] + # Put it back in cuda precision = torch.tensor(precision).cuda() recall = torch.tensor(recall).cuda() From 60a63e6d58a5285146416d4352848d03c7338e84 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 16 Apr 2024 13:36:42 -0400 Subject: [PATCH 057/108] modified config file --- monai/config.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/config.yml b/monai/config.yml index 9ba503a..e156248 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -13,8 +13,8 @@ output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ pixdim : [0.7, 0.7, 0.7] # pixdim : [0.5, 0.5, 0.5] # Spatial size of the input data -spatial_size : [32, 128, 128] # RL, AP, IS -batch_size : 8 +spatial_size : [64, 128, 128] # RL, AP, IS +batch_size : 4 # smaller batch size lead to better generalization https://arxiv.org/abs/1609.04836 but longer to train # UNETR model parameters feature_size : 8 @@ -28,7 +28,7 @@ weight_decay: 0.00001 early_stopping_patience : 100 # Training parameters -max_iterations : 3000 +max_iterations : 1000 eval_num : 2 # Model saving From 318c117a76c7e9b8004ad8c202bfd0c022268a80 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 16 Apr 2024 13:37:21 -0400 Subject: [PATCH 058/108] modified code to include swinUNETR model --- monai/train_monai_unet_lightning.py | 94 ++++++++++++++++------------- 1 file changed, 53 insertions(+), 41 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 1081c36..f3f952f 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -21,7 +21,7 @@ from losses import AdapWingLoss, SoftDiceLoss from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, lesion_wise_precision_recall -from monai.networks.nets import UNet, BasicUNet, AttentionUnet +from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR from monai.metrics import DiceMetric from monai.losses import DiceLoss, DiceCELoss from monai.networks.layers import Norm @@ -402,8 +402,8 @@ def training_step(self, batch, batch_idx): train_soft_dice = self.soft_dice_metric(output, labels) # Compute precision and recall - train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) - print("sucess") + # train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + # print("sucess") metrics_dict = { "loss": loss.cpu(), @@ -412,8 +412,8 @@ def training_step(self, batch, batch_idx): "train_image": inputs[0].detach().cpu().squeeze(), "train_gt": labels[0].detach().cpu().squeeze(), "train_pred": output[0].detach().cpu().squeeze(), - "train_precision": train_precision.detach().cpu(), - "train_recall": train_recall.detach().cpu(), + # "train_precision": train_precision.detach().cpu(), + # "train_recall": train_recall.detach().cpu(), } self.train_step_outputs.append(metrics_dict) @@ -426,24 +426,24 @@ def on_train_epoch_end(self): return None else: train_loss, train_soft_dice = 0, 0 - precision_score, recall_score = 0, 0 + # precision_score, recall_score = 0, 0 num_items = len(self.train_step_outputs) for output in self.train_step_outputs: train_loss += output["loss"].item() train_soft_dice += output["train_soft_dice"].item() - precision_score = output["train_precision"] - recall_score = output["train_recall"] + # precision_score = output["train_precision"] + # recall_score = output["train_recall"] mean_train_loss = (train_loss / num_items) mean_train_soft_dice = (train_soft_dice / num_items) - mean_precision_score = np.mean(precision_score.detach().numpy()) - mean_recall_score = np.mean(recall_score.detach().numpy()) + # mean_precision_score = np.mean(precision_score.detach().numpy()) + # mean_recall_score = np.mean(recall_score.detach().numpy()) wandb_logs = { "train_soft_dice": mean_train_soft_dice, "train_loss": mean_train_loss, - "train_precision": mean_precision_score, - "train_recall": mean_recall_score, + # "train_precision": mean_precision_score, + # "train_recall": mean_recall_score, } self.log_dict(wandb_logs) @@ -489,8 +489,8 @@ def validation_step(self, batch, batch_idx): val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) # compute precision and recall - val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) - print("sucess val") + # val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + # print("sucess val") # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, # using .detach() to avoid storing the whole computation graph @@ -503,8 +503,8 @@ def validation_step(self, batch, batch_idx): "val_image": inputs[0].detach().cpu().squeeze(), "val_gt": labels[0].detach().cpu().squeeze(), "val_pred": post_outputs[0].detach().cpu().squeeze(), - "val_precision": val_precision.detach().cpu(), - "val_recall": val_recall.detach().cpu(), + # "val_precision": val_precision.detach().cpu(), + # "val_recall": val_recall.detach().cpu(), } self.val_step_outputs.append(metrics_dict) @@ -513,27 +513,27 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 - val_precision, val_recall = 0, 0 + # val_precision, val_recall = 0, 0 for output in self.val_step_outputs: val_loss += output["val_loss"].sum().item() val_soft_dice += output["val_soft_dice"].sum().item() val_hard_dice += output["val_hard_dice"].sum().item() num_items += output["val_number"] - val_precision += output["val_precision"].sum().item() - val_recall += output["val_recall"].sum().item() + # val_precision += output["val_precision"].sum().item() + # val_recall += output["val_recall"].sum().item() mean_val_loss = (val_loss / num_items) mean_val_soft_dice = (val_soft_dice / num_items) mean_val_hard_dice = (val_hard_dice / num_items) - mean_val_precision = (val_precision / num_items) - mean_val_recall = (val_recall / num_items) + # mean_val_precision = (val_precision / num_items) + # mean_val_recall = (val_recall / num_items) wandb_logs = { "val_soft_dice": mean_val_soft_dice, # "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, - "val_precision": mean_val_precision, - "val_recall": mean_val_recall, + # "val_precision": mean_val_precision, + # "val_recall": mean_val_recall, } self.log_dict(wandb_logs) @@ -680,23 +680,35 @@ def main(): # adn_ordering='NDA', # ) - net=UNet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=(32, 64, 128), - strides=(2, 2, 2, ), - # dropout=0.1 - ) - - # net = AttentionUnet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(32, 64, 128, 256, 512), - # strides=(2, 2, 2, 2,), - # dropout=0.1, - # ) + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2, ), + # # dropout=0.1 + # ) + + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=(32, 64, 128, 256), + strides=(2, 2, 2,), + dropout=0.1, + ) + + # net = SwinUNETR( + # img_size=config["spatial_size"], + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # feature_size=48, + # use_checkpoint=True, + # ) + + # net.use_multiprocessing = False + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) @@ -706,7 +718,7 @@ def main(): # define loss function - #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) # loss_func = SoftDiceLoss(smooth=1e-5) From 188ee484100bf5c471cc0e3f2b87ad56d5c03acd Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 16 Apr 2024 17:27:43 -0400 Subject: [PATCH 059/108] fixed wandb and config file for cleaner pipeline --- monai/config.yml | 29 +++++++----------- monai/train_monai_unet_lightning.py | 46 ++++++++++++++--------------- 2 files changed, 34 insertions(+), 41 deletions(-) diff --git a/monai/config.yml b/monai/config.yml index e156248..9b4b22a 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -1,4 +1,5 @@ # Description: Configuration file for the UNETR model + # Path to the data json file # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json @@ -6,21 +7,18 @@ # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json -# Path to the output directory -output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ + # Resampling resolution # pixdim : [1.0, 1.0, 1.0] pixdim : [0.7, 0.7, 0.7] # pixdim : [0.5, 0.5, 0.5] + # Spatial size of the input data spatial_size : [64, 128, 128] # RL, AP, IS batch_size : 4 # smaller batch size lead to better generalization https://arxiv.org/abs/1609.04836 but longer to train -# UNETR model parameters -feature_size : 8 -hidden_size : 768 -mlp_dim : 3072 -num_heads : 12 +# Augmentation parameters +DA_probability : 0.2 # Optimizer parameters lr : 0.001 @@ -31,21 +29,16 @@ early_stopping_patience : 100 max_iterations : 1000 eval_num : 2 -# Model saving -best_model_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/best_metric_model.pth - -# log saving -log_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ - -# WANDB -experiment_name : monai_unet_canproco +# Outputs +output_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ +# Seed seed : 42 # UNET model parameters unet_channels : [32, 64, 128, 256, 512, 1024] unet_strides : [2, 2, 2, 2, 2, 2, 2] -#Attention Unet -channels : [32, 64, 128, 256, 512] -strides : [2, 2, 2, 2, 2] \ No newline at end of file +# AttentionUnet +attention_unet_channels : [32, 64, 128, 256, 512] +attention_unet_strides : [2, 2, 2, 2, 2] \ No newline at end of file diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index f3f952f..cd38c74 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -181,38 +181,38 @@ def prepare_data(self): RandFlipd( keys=["image", "label"], spatial_axis=[0], - prob=0.2, + prob=self.cfg["DA_probability"], ), # Flips the image : supperior becomes inferior RandFlipd( keys=["image", "label"], spatial_axis=[1], - prob=0.2, + prob=self.cfg["DA_probability"], ), # Flips the image : anterior becomes posterior RandFlipd( keys=["image", "label"], spatial_axis=[2], - prob=0.2, + prob=self.cfg["DA_probability"], ), # Random elastic deformation Rand3DElasticd( keys=["image", "label"], sigma_range=(5, 7), magnitude_range=(50, 150), - prob=0.2, + prob=self.cfg["DA_probability"], mode=['bilinear', 'nearest'], ), # Random affine transform of the image RandAffined( keys=["image", "label"], - prob=0.2, + prob=self.cfg["DA_probability"], mode=('bilinear', 'nearest'), padding_mode='zeros', ), # RandAdjustContrastd( # keys=["image"], - # prob=0.2, + # prob=self.cfg["DA_probability"], # gamma=(0.5, 4.5), # invert_image=True, # ), @@ -220,7 +220,7 @@ def prepare_data(self): # RandLambdad( # keys='image', # func=multiply_by_negative_one, - # prob=0.2 + # prob=self.cfg["DA_probability"] # ), # LabelToContourd( # keys=["image"], @@ -228,20 +228,20 @@ def prepare_data(self): # ), RandGaussianNoised( keys=["image"], - prob=0.2, + prob=self.cfg["DA_probability"], ), # Random simulation of low resolution RandSimulateLowResolutiond( keys=["image"], zoom_range=(0.8, 1.5), - prob=0.2 + prob=self.cfg["DA_probability"] ), # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing RandBiasFieldd( keys=["image"], coeff_range=(0.0, 0.5), degree=3, - prob=0.2 + prob=self.cfg["DA_probability"] ), # RandShiftIntensityd( # keys=["image"], @@ -657,11 +657,13 @@ def main(): # define optimizer optimizer_class = torch.optim.Adam - wandb.init(project=f'monai-ms-lesion-seg-unet', config=config) + output_path = os.path.join(config["output_path"], str(datetime.now().date()) +"_" +str(datetime.now().time())) + os.makedirs(output_path, exist_ok=True) + + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config, save_code=True, dir=output_path) logger.info("Building the model ...") - # define model # net = UNet( @@ -693,8 +695,8 @@ def main(): spatial_dims=3, in_channels=1, out_channels=1, - channels=(32, 64, 128, 256), - strides=(2, 2, 2,), + channels=config["attention_unet_channels"], + strides=config["attention_unet_strides"], dropout=0.1, ) @@ -714,8 +716,8 @@ def main(): # net = create_nnunet_from_plans() - logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") + logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") # define loss function # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") @@ -734,14 +736,16 @@ def main(): lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + best_model_path = os.path.join(output_path, "best_model.pth") + # i.e. train by loading weights from scratch pl_model = Model(config, data_root=dataset_root, optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id="test", results_path=config["best_model_path"]) + exp_id="test", results_path=best_model_path) # saving the best model based on validation loss checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', + dirpath= best_model_path, filename='best_model', monitor='val_loss', save_top_k=1, mode="min", save_last=True, save_weights_only=True) @@ -749,17 +753,13 @@ def main(): # wandb logger exp_logger = pl.loggers.WandbLogger( name="test", - save_dir="/home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results", + save_dir=output_path, group="test-on-canproco", log_model=True, # save best model using checkpoint callback - project='ms-lesion-agnostic', - entity='pierre-louis-benveniste', config=config) # Saving training script to wandb - wandb.save("ms-lesion-agnostic/monai/config.yml") - wandb.save("ms-lesion-agnostic/monai/train_monai_unet_lightning.py") - + wandb.save(args.config) # initialise Lightning's trainer. trainer = pl.Trainer( From 01bce84c7bf4d26b9f859d9c468a88f6d4a29cfa Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 17 Apr 2024 13:05:59 -0400 Subject: [PATCH 060/108] new script to test the dataset --- monai/test_model.py | 184 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 monai/test_model.py diff --git a/monai/test_model.py b/monai/test_model.py new file mode 100644 index 0000000..f1bf5a8 --- /dev/null +++ b/monai/test_model.py @@ -0,0 +1,184 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data_split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Dict of dice score + dice_scores = {} + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 1), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=cfg["spatial_size"], + ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Create the data loader + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=4) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=1) + + # Load the model + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=cfg["attention_unet_channels"], + strides=cfg["attention_unet_strides"], + dropout=0.1, + ) + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Run inference + with torch.no_grad(): + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # compute the dice score + dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + pred = post_test_out[0]['pred'].cpu() + + # Threshold the prediction + pred[pred < 0.5] = 0 + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[file_name] = dice + + test_input.detach() + + + # Save the dice scores + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file From d96f1e3ad761196ac26e9066521344d5b3f98901 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 17 Apr 2024 13:06:14 -0400 Subject: [PATCH 061/108] config file for testing the dataset --- monai/config_test.yml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 monai/config_test.yml diff --git a/monai/config_test.yml b/monai/config_test.yml new file mode 100644 index 0000000..242988b --- /dev/null +++ b/monai/config_test.yml @@ -0,0 +1,7 @@ +dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +pixdim : [0.7, 0.7, 0.7] +spatial_size : [64, 128, 128] +attention_unet_channels : [32, 64, 128, 256, 512] +attention_unet_strides : [2, 2, 2, 2, 2] +path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-16_15:30:27.095475/best_model.pth/best_model.ckpt +output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-16_15:30:27.095475/ \ No newline at end of file From 761d32f2cefd4ff802cb842f323fef5a3783673d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 17 Apr 2024 13:06:30 -0400 Subject: [PATCH 062/108] added cupy install for inference --- monai/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/requirements.txt b/monai/requirements.txt index cb8f988..7618f5c 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -3,4 +3,5 @@ monai[all] torch torchvision matplotlib -pytorch_lightning \ No newline at end of file +pytorch_lightning +cupy-cuda117==10.6.0 \ No newline at end of file From 11abb7399e44834a812fa52c7614d369d44582b5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 17 Apr 2024 13:50:08 -0400 Subject: [PATCH 063/108] added script for plotting the performance (dice metric) on the data split --- monai/plot_performance.py | 65 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 monai/plot_performance.py diff --git a/monai/plot_performance.py b/monai/plot_performance.py new file mode 100644 index 0000000..250b8f2 --- /dev/null +++ b/monai/plot_performance.py @@ -0,0 +1,65 @@ +"""" +This script is used to plot the performance of the model on the test set, validation and train set. +It saves a plot of dice scores per contrat in the output folder + +""" +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import argparse + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Plot the performance of the model") + parser.add_argument("--pred-dir-path", help="Path to the directory containing the dice_score.txt file", required=True) + return parser + + +def main(): + """ + This function is used to plot the performance of the model on the test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Path to the dice_scores + path_to_outputs = args.pred_dir_path + dice_score_file = path_to_outputs + '/dice_scores.txt' + + # Open dice results (they are txt files) + test_dice_results = {} + with open(dice_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + test_dice_results[key] = float(value) + + # convert to a df with name and dice score + test_dice_results = pd.DataFrame(list(test_dice_results.items()), columns=['name', 'dice_score']) + + # Add the contrats column + test_dice_results['contrast'] = test_dice_results['name'].apply(lambda x: x.split('_')[-1]) + + # plot a violin plot per contrast + plt.figure(figsize=(20, 10)) + sns.violinplot(x='contrast', y='dice_score', data=test_dice_results) + plt.title('Dice scores per contrast') + plt.show() + + # Save the plot + plt.savefig(path_to_outputs + '/dice_scores.png') + print(f"Saved the dice_scores plot in {path_to_outputs}") + + +if __name__ == "__main__": + main() \ No newline at end of file From e4d6088ecf4f6cf9a3f40583c09010bad910c22d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 17 Apr 2024 13:50:41 -0400 Subject: [PATCH 064/108] correct typo in parser --- monai/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/test_model.py b/monai/test_model.py index f1bf5a8..e34ae2b 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -33,7 +33,7 @@ def get_parser(): """ parser = argparse.ArgumentParser(description="Test the model on the test set") parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) - parser.add_argument("--data_split", help="Data split to use (train, validation, test)", required=True, type=str) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) return parser From aa42e9420b11f7fc23a2c64f4e7f8b59b4b88b1f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 17 Apr 2024 16:21:20 -0400 Subject: [PATCH 065/108] fixed typo on basel and bavaria data import --- monai/1_create_msd_data.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index dd02f84..d0cc8c7 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -4,14 +4,12 @@ Arguments: -pd, --path-data: Path to the data set directory - -pj, --path-joblib: Path to joblib file from ivadomed containing the dataset splits. -po, --path-out: Path to the output directory where dataset json is saved - --contrast: Contrast to use for training - --label-type: Type of labels to use for training + --lesion-only: Use only masks which contain some lesions --seed: Seed for reproducibility Example: - python create_msd_data.py -pd /path/dataset -po /path/output + python create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 TO DO: * @@ -180,7 +178,7 @@ def main(): elif 'basel-mp2rage' in str(subject): relative_path = subject.relative_to(basel_path).parent temp_data_basel["image"] = str(subject) - temp_data_basel["label"] = str(basel_path / 'derivatives' / 'labels' / relative_path / str(subject).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz')) + temp_data_basel["label"] = str(basel_path) + '/derivatives/labels/' + str(relative_path) +'/'+ str(subject.name).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz') if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) temp_data_basel["total_lesion_volume"] = total_lesion_volume @@ -206,7 +204,7 @@ def main(): elif 'bavaria-quebec-spine-ms' in str(subject): relative_path = subject.relative_to(bavaria_path).parent temp_data_bavaria["image"] = str(subject) - temp_data_bavaria["label"] = str(bavaria_path / 'derivatives' / 'labels' / relative_path / subject.name.replace('T2w.nii.gz', 'T2w_lesion-manual.nii.gz')) + temp_data_bavaria["label"] = str(bavaria_path) + '/derivatives/labels/' + str(relative_path) + '/' +str(subject.name).replace('T2w.nii.gz', 'lesions-manual_T2w.nii.gz') if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) temp_data_bavaria["total_lesion_volume"] = total_lesion_volume From 55b0add3da74e4e54d685f5f0602f5f6360ff4de Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 11:44:03 -0400 Subject: [PATCH 066/108] changes made for previous run (before ISMRM) --- monai/config.yml | 7 ++++--- monai/config_test.yml | 7 ++++--- monai/plot_performance.py | 11 ++++++++++- monai/train_monai_unet_lightning.py | 6 +++--- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/monai/config.yml b/monai/config.yml index 9b4b22a..0668f5f 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -7,6 +7,7 @@ # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json # Resampling resolution # pixdim : [1.0, 1.0, 1.0] @@ -21,12 +22,12 @@ batch_size : 4 # smaller batch size lead to better generalization https://arxiv. DA_probability : 0.2 # Optimizer parameters -lr : 0.001 +lr : 0.0001 weight_decay: 0.00001 -early_stopping_patience : 100 +early_stopping_patience : 50 # Training parameters -max_iterations : 1000 +max_iterations : 250 eval_num : 2 # Outputs diff --git a/monai/config_test.yml b/monai/config_test.yml index 242988b..9c497e7 100644 --- a/monai/config_test.yml +++ b/monai/config_test.yml @@ -1,7 +1,8 @@ -dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +# dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json +dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json pixdim : [0.7, 0.7, 0.7] spatial_size : [64, 128, 128] attention_unet_channels : [32, 64, 128, 256, 512] attention_unet_strides : [2, 2, 2, 2, 2] -path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-16_15:30:27.095475/best_model.pth/best_model.ckpt -output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-16_15:30:27.095475/ \ No newline at end of file +path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt +output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/ \ No newline at end of file diff --git a/monai/plot_performance.py b/monai/plot_performance.py index 250b8f2..eefa4af 100644 --- a/monai/plot_performance.py +++ b/monai/plot_performance.py @@ -50,9 +50,18 @@ def main(): # Add the contrats column test_dice_results['contrast'] = test_dice_results['name'].apply(lambda x: x.split('_')[-1]) - # plot a violin plot per contrast + # Count the number of samples per contrast + contrast_counts = test_dice_results['contrast'].value_counts() + + # In the df replace the contrats by the number of samples of the contarsts( for example, T2 becomes T2 (n=10)) + test_dice_results['contrast'] = test_dice_results['contrast'].apply(lambda x: x + f' (n={contrast_counts[x]})') + + # plot a violin plot per contrast plt.figure(figsize=(20, 10)) + plt.grid(True) sns.violinplot(x='contrast', y='dice_score', data=test_dice_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) plt.title('Dice scores per contrast') plt.show() diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index cd38c74..606a825 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -145,7 +145,7 @@ def prepare_data(self): Spacingd( keys=["image", "label"], pixdim=self.cfg["pixdim"], - mode=(2, 1), + mode=(2, 0), ), # Normalize the intensity of the image NormalizeIntensityd( @@ -220,7 +220,7 @@ def prepare_data(self): # RandLambdad( # keys='image', # func=multiply_by_negative_one, - # prob=self.cfg["DA_probability"] + # prob=0.5 # ), # LabelToContourd( # keys=["image"], @@ -265,7 +265,7 @@ def prepare_data(self): Spacingd( keys=["image", "label"], pixdim=self.cfg["pixdim"], - mode=(2, 1), + mode=(2, 0), ), # This normalizes the intensity of the image NormalizeIntensityd( From 149fa3401b57dd4e32e9abe15794286ff825ae77 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 13:46:42 -0400 Subject: [PATCH 067/108] add function to not take files which are in canproco/exclude.yml --- monai/1_create_msd_data.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index d0cc8c7..0b277ee 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -9,7 +9,7 @@ --seed: Seed for reproducibility Example: - python create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 + python create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt TO DO: * @@ -46,6 +46,7 @@ def get_parser(): parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') + parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo') parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") @@ -106,6 +107,13 @@ def main(): subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) + # Path to the file containing the list of subjects to exclude from CanProCo + if args.canproco_exclude is not None: + with open(args.canproco_exclude, 'r') as file: + canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader) + # only keep the contrast psir and stir + canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] + subjects = subjects_canproco + subjects_basel + subjects_sct + subjects_bavaria # logger.info(f"Total number of subjects in the root directory: {len(subjects)}") @@ -164,6 +172,11 @@ def main(): # Canproco if 'canproco' in str(subject): + subject_id = subject.name.replace('_PSIR_lesion-manual.nii.gz', '') + subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') + if subject_id in canproco_exclude_list: + print(f"Excluding {subject_id}") + continue temp_data_canproco["label"] = str(subject) temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): From f28bfe475fe1167d794962045bb0bc34b71466d8 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 16:01:37 -0400 Subject: [PATCH 068/108] added lesion wide metrics --- monai/utils.py | 310 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 235 insertions(+), 75 deletions(-) diff --git a/monai/utils.py b/monai/utils.py index c102263..90348e0 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -10,7 +10,8 @@ from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 -import skimage +from scipy import ndimage + def dice_score(prediction, groundtruth, smooth=1.): @@ -67,6 +68,155 @@ def plot_slices(image, gt, pred, debug=False): return fig +def lesion_wise_tp_fp_fn(truth, prediction): + """ + Computes the true positives, false positives, and false negatives two masks. Masks are considered true positives + if at least one voxel overlaps between the truth and the prediction. + Adapted from: https://github.com/npnl/atlas2_grand_challenge/blob/main/isles/scoring.py#L341 + + Parameters + ---------- + truth : array-like, bool + 3D array. If not boolean, will be converted. + prediction : array-like, bool + 3D array with a shape matching 'truth'. If not boolean, will be converted. + empty_value : scalar, float + Optional. Value to which to default if there are no labels. Default: 1.0. + + Returns + ------- + tp (int): 3D connected-component from the ground-truth image that overlaps at least on one voxel with the prediction image. + fp (int): 3D connected-component from the prediction image that has no voxel overlapping with the ground-truth image. + fn (int): 3d connected-component from the ground-truth image that has no voxel overlapping with the prediction image. + + Notes + ----- + This function computes lesion-wise score by defining true positive lesions (tp), false positive lesions (fp) and + false negative lesions (fn) using 3D connected-component-analysis. + + tp: 3D connected-component from the ground-truth image that overlaps at least on one voxel with the prediction image. + fp: 3D connected-component from the prediction image that has no voxel overlapping with the ground-truth image. + fn: 3d connected-component from the ground-truth image that has no voxel overlapping with the prediction image. + """ + tp, fp, fn = 0, 0, 0 + + # For each true lesion, check if there is at least one overlapping voxel. This determines true positives and + # false negatives (unpredicted lesions) + labeled_ground_truth, num_lesions = ndimage.label(truth.astype(bool)) + for idx_lesion in range(1, num_lesions+1): + lesion = labeled_ground_truth == idx_lesion + lesion_pred_sum = lesion + prediction + if(np.max(lesion_pred_sum) > 1): + tp += 1 + else: + fn += 1 + + # For each predicted lesion, check if there is at least one overlapping voxel in the ground truth. + labaled_prediction, num_pred_lesions = ndimage.label(prediction.astype(bool)) + for idx_lesion in range(1, num_pred_lesions+1): + lesion = labaled_prediction == idx_lesion + lesion_pred_sum = lesion + truth + if(np.max(lesion_pred_sum) <= 1): # No overlap + fp += 1 + + return tp, fp, fn + + +def lesion_sensitivity(truth, prediction): + """ + Computes the lesion-wise sensitivity between two masks + Returns + ------- + sensitivity (float): Lesion-wise sensitivity as float. + Max score = 1 + Min score = 0 + If both images are empty (tp + fp + fn =0) = empty_value + """ + empty_value = 1.0 # Value to which to default if there are no labels. Default: 1.0. + + if np.sum(truth) == 0 and np.sum(prediction)==0: + # Both reference and prediction are empty --> model learned correctly + return 1.0 + # if the prediction is not empty and ref is empty, it's false positive + # if both are not empty, it's true positive + else: + + tp, _, fn = lesion_wise_tp_fp_fn(truth, prediction) + sensitivity = empty_value + + # Compute sensitivity + denom = tp + fn + if(denom != 0): + sensitivity = tp / denom + return sensitivity + + +def lesion_ppv(truth, prediction): + """ + Computes the lesion-wise positive predictive value (PPV) between two masks + Returns + ------- + ppv (float): Lesion-wise positive predictive value as float. + Max score = 1 + Min score = 0 + If both images are empty (tp + fp + fn =0) = empty_value + """ + if np.sum(truth) == 0 and np.sum(prediction)==0: + # Both reference and prediction are empty --> model learned correctly + return 1.0 + elif np.sum(truth) != 0 and np.sum(prediction)==0: + # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) + return 0.0 + # if the predction is not empty and ref is empty, it's false positive + # if both are not empty, it's true positive + else: + tp, fp, _ = lesion_wise_tp_fp_fn(truth, prediction) + # ppv = 1.0 + + # Compute ppv + denom = tp + fp + # denom should ideally not be zero inside this else as it should be caught by the empty checks above + if(denom != 0): + ppv = tp / denom + return ppv + + +def lesion_f1_score(truth, prediction): + """ + Computes the lesion-wise F1-score between two masks by defining true positive lesions (tp), false positive lesions (fp) + and false negative lesions (fn) using 3D connected-component-analysis. + + Masks are considered true positives if at least one voxel overlaps between the truth and the prediction. + + Returns + ------- + f1_score : float + Lesion-wise F1-score as float. + Max score = 1 + Min score = 0 + If both images are empty (tp + fp + fn =0) = empty_value + """ + empty_value = 1.0 # Value to which to default if there are no labels. Default: 1.0. + + if np.sum(truth) == 0 and np.sum(prediction)==0: + # Both reference and prediction are empty --> model learned correctly + return 1.0 + elif np.sum(truth) != 0 and np.sum(prediction)==0: + # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) + return 0.0 + # if the predction is not empty and ref is empty, it's false positive + # if both are not empty, it's true positive + else: + tp, fp, fn = lesion_wise_tp_fp_fn(truth, prediction) + f1_score = empty_value + + # Compute f1_score + denom = tp + (fp + fn)/2 + if(denom != 0): + f1_score = tp / denom + return f1_score + + def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): """ This function computes the lesion-wise precision and recall. @@ -75,82 +225,92 @@ def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): prediction: predicted segmentation mask groundtruth: ground truth segmentation mask iou_threshold: threshold for intersection over union (IoU) for a lesion to be considered as true positive - Returns: - precision: lesion-wise precision - recall: lesion-wise recall - """ - prediction_cpu = prediction#.detach().numpy() - groundtruth_cpu = groundtruth#.detach().numpy() - - precision = [] - recall = [] - # print(prediction_cpu.shape) - for i in range(prediction_cpu.shape[0]): - # Compute connected components in the predicted and ground truth segmentation masks - if len(prediction_cpu.shape) == 4: - # print("iteration") - # binarize the prediction and ground truth - prediction_cpu[0] = prediction_cpu[0] > 0.2 - groundtruth_cpu[0] = groundtruth_cpu[0] > 0.2 - # compute connected components - pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[0], connectivity=2, return_num=True) - gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[0], connectivity=2, return_num=True) - # print('c', pred_labels.shape) - # print('d', gt_labels.shape) - if len(prediction_cpu.shape) == 5: - # binarize the prediction and ground truth - prediction_cpu[i][0] = prediction_cpu[i][0] > 0.2 - groundtruth_cpu[i][0] = groundtruth_cpu[i][0] > 0.2 - # compute connected components - pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[i][0], connectivity=2, return_num=True) - gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[i][0], connectivity=2, return_num=True) - # print('e', pred_labels.shape) - # print('f', gt_labels.shape) + + +# def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): +# """ +# This function computes the lesion-wise precision and recall. + +# Args: +# prediction: predicted segmentation mask +# groundtruth: ground truth segmentation mask +# iou_threshold: threshold for intersection over union (IoU) for a lesion to be considered as true positive +# Returns: +# precision: lesion-wise precision +# recall: lesion-wise recall +# """ +# prediction_cpu = prediction#.detach().numpy() +# groundtruth_cpu = groundtruth#.detach().numpy() + +# precision = [] +# recall = [] +# # print(prediction_cpu.shape) +# for i in range(prediction_cpu.shape[0]): +# # Compute connected components in the predicted and ground truth segmentation masks +# if len(prediction_cpu.shape) == 4: +# # print("iteration") +# # binarize the prediction and ground truth +# prediction_cpu[0] = prediction_cpu[0] > 0.2 +# groundtruth_cpu[0] = groundtruth_cpu[0] > 0.2 +# # compute connected components +# pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[0], connectivity=2, return_num=True) +# gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[0], connectivity=2, return_num=True) +# # print('c', pred_labels.shape) +# # print('d', gt_labels.shape) +# if len(prediction_cpu.shape) == 5: +# # binarize the prediction and ground truth +# prediction_cpu[i][0] = prediction_cpu[i][0] > 0.2 +# groundtruth_cpu[i][0] = groundtruth_cpu[i][0] > 0.2 +# # compute connected components +# pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[i][0], connectivity=2, return_num=True) +# gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[i][0], connectivity=2, return_num=True) +# # print('e', pred_labels.shape) +# # print('f', gt_labels.shape) - # If there are no connected components in the predicted or ground truth segmentation masks we return 0 and continue - if num_components_gt==0 or num_components_pred==0: - precision+= [0] - recall+= [0] - continue - - # Compute the intersection over union (IoU) between each pair of connected components - iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) - intersection_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) - for i in range(np.max(pred_labels)): - for j in range(np.max(gt_labels)): - # Compute the intersection - intersection = np.sum((pred_labels == i + 1) * (gt_labels == j + 1)) - # Compute the union - union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection - # Compute the IoU - iou_matrix[i, j] = intersection / union - # if iou_matrix[i, j] > 0: - # print("iou_matrix", iou_matrix[i, j]) - # Compute the intersection - intersection_matrix[i, j] = intersection +# # If there are no connected components in the predicted or ground truth segmentation masks we return 0 and continue +# if num_components_gt==0 or num_components_pred==0: +# precision+= [0] +# recall+= [0] +# continue + +# # Compute the intersection over union (IoU) between each pair of connected components +# iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) +# intersection_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) +# for i in range(np.max(pred_labels)): +# for j in range(np.max(gt_labels)): +# # Compute the intersection +# intersection = np.sum((pred_labels == i + 1) * (gt_labels == j + 1)) +# # Compute the union +# union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection +# # Compute the IoU +# iou_matrix[i, j] = intersection / union +# # if iou_matrix[i, j] > 0: +# # print("iou_matrix", iou_matrix[i, j]) +# # Compute the intersection +# intersection_matrix[i, j] = intersection - # # Compute lesion-wise precision and recall - # true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) - # false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) - # false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) - # precision += [true_positives / (true_positives + false_positives)] - # recall+= [true_positives / (true_positives + false_negatives)] - - # Compute lesion-wise precision and recall - true_positives = np.sum(np.max(intersection_matrix, axis=1) > iou_threshold) - false_positives = np.sum(np.max(intersection_matrix, axis=0) <= iou_threshold) - false_negatives = np.sum(np.max(intersection_matrix, axis=1) <= iou_threshold) - precision += [true_positives / (true_positives + false_positives)] - recall+= [true_positives / (true_positives + false_negatives)] - - - # Put it back in cuda - precision = torch.tensor(precision).cuda() - recall = torch.tensor(recall).cuda() - - print("precision", precision) - print("recall", recall) - return precision, recall +# # # Compute lesion-wise precision and recall +# # true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) +# # false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) +# # false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) +# # precision += [true_positives / (true_positives + false_positives)] +# # recall+= [true_positives / (true_positives + false_negatives)] + +# # Compute lesion-wise precision and recall +# true_positives = np.sum(np.max(intersection_matrix, axis=1) > iou_threshold) +# false_positives = np.sum(np.max(intersection_matrix, axis=0) <= iou_threshold) +# false_negatives = np.sum(np.max(intersection_matrix, axis=1) <= iou_threshold) +# precision += [true_positives / (true_positives + false_positives)] +# recall+= [true_positives / (true_positives + false_negatives)] + + +# # Put it back in cuda +# precision = torch.tensor(precision).cuda() +# recall = torch.tensor(recall).cuda() + +# print("precision", precision) +# print("recall", recall) +# return precision, recall # ############################################################################################################ From 9ac26f3638e25256e6d6162fce3550b47f720d61 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 16:01:58 -0400 Subject: [PATCH 069/108] changed for bavaria dataset new format --- monai/1_create_msd_data.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index 0b277ee..445645c 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -9,7 +9,7 @@ --seed: Seed for reproducibility Example: - python create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt + python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt TO DO: * @@ -105,7 +105,7 @@ def main(): subjects_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) - subjects_bavaria = list(bavaria_path.rglob('*T2w.nii.gz')) + subjects_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) # Path to the file containing the list of subjects to exclude from CanProCo if args.canproco_exclude is not None: @@ -175,7 +175,6 @@ def main(): subject_id = subject.name.replace('_PSIR_lesion-manual.nii.gz', '') subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') if subject_id in canproco_exclude_list: - print(f"Excluding {subject_id}") continue temp_data_canproco["label"] = str(subject) temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') @@ -215,9 +214,8 @@ def main(): # Bavaria-quebec elif 'bavaria-quebec-spine-ms' in str(subject): - relative_path = subject.relative_to(bavaria_path).parent - temp_data_bavaria["image"] = str(subject) - temp_data_bavaria["label"] = str(bavaria_path) + '/derivatives/labels/' + str(relative_path) + '/' +str(subject.name).replace('T2w.nii.gz', 'lesions-manual_T2w.nii.gz') + temp_data_bavaria["label"] = str(subject) + temp_data_bavaria["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) temp_data_bavaria["total_lesion_volume"] = total_lesion_volume From dadc8acb259beafbbf40a94be0f9c6f458f9ed94 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 17:14:32 -0400 Subject: [PATCH 070/108] added function to remove small objects in utils --- monai/utils.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/monai/utils.py b/monai/utils.py index 90348e0..630b6ca 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -217,14 +217,31 @@ def lesion_f1_score(truth, prediction): return f1_score -def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): +def remove_small_lesions(lesion_seg, resolution, min_volume=50): """ - This function computes the lesion-wise precision and recall. + Remove lesions which are smaller than a given volume threshold. Args: - prediction: predicted segmentation mask - groundtruth: ground truth segmentation mask - iou_threshold: threshold for intersection over union (IoU) for a lesion to be considered as true positive + predictions (ndarray or nibabel object): Input segmentation. Image could be 2D or 3D. + resolution (list): Resolution of the image (Example: [1, 1, 1]) in mm + min_volume (float): Minimum volume of the lesion to be kept. in mm3 (Default is 5 voxels in canproco = 5*0.7*0.7*3=7.35 ) + + Returns: + ndarray or nibabel (same object as the input). + """ + # Find number of closed objects using skimage "label" + labeled_obj, num_obj = ndimage.label(np.copy(lesion_seg)) + # Compute the volume of each object + obj_volume = np.zeros(num_obj) + for i in range(num_obj): + obj_volume[i] = np.sum(labeled_obj == i+1)*np.prod(resolution) + # Remove objects with volume less than min_volume + lesion_seg = np.copy(lesion_seg) + for i in range(num_obj): + if obj_volume[i] < min_volume: + lesion_seg[labeled_obj == i+1] = 0 + labeled_obj, num_obj = ndimage.label(lesion_seg) + return lesion_seg # def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): From 3965b9b06fefed449dbc689c1dc3d06baf40a437 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 17:29:33 -0400 Subject: [PATCH 071/108] added remove small objects for train, val and inference --- monai/train_monai_unet_lightning.py | 57 +++++++++++++++++++---------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 606a825..4f87721 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -20,7 +20,7 @@ from losses import AdapWingLoss, SoftDiceLoss -from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, lesion_wise_precision_recall +from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR from monai.metrics import DiceMetric from monai.losses import DiceLoss, DiceCELoss @@ -105,7 +105,7 @@ def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_i # define evaluation metric self.soft_dice_metric = dice_score - self.lesion_wise_precision_recall = lesion_wise_precision_recall + # self.lesion_wise_precision_recall = lesion_wise_precision_recall # temp lists for storing outputs from training, validation, and testing self.train_step_outputs = [] @@ -160,18 +160,18 @@ def prepare_data(self): # source_key="label", # margin=200 # ), - # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=0, - num_samples=4, - image_key="image", - image_threshold=0, - allow_smaller=True, - ), + # # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=0, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), # This resizes the image and the label to the spatial size defined in the config ResizeWithPadOrCropd( keys=["image", "label"], @@ -254,7 +254,13 @@ def prepare_data(self): # num_classes=2, # threshold_values=True, # logit_thresh=0.2, - # ) + # ), + # Remove small lesions in the label + RandLambdad( + keys='label', + func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + prob=1.0 + ) ] ) val_transforms = Compose( @@ -303,6 +309,12 @@ def prepare_data(self): # threshold_values=True, # logit_thresh=0.2, # ) + # Remove small lesions in the label + RandLambdad( + keys='label', + func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + prob=1.0 + ) ] ) @@ -328,6 +340,12 @@ def prepare_data(self): orig_keys=["image", "label"], meta_keys=["pred_meta_dict", "label_meta_dict"], nearest_interp=False, to_tensor=True), + # Remove small lesions in the label + RandLambdad( + keys='label', + func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + prob=1.0 + ) ]) self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) @@ -381,11 +399,10 @@ def training_step(self, batch, batch_idx): # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") - - # # check if any label image patch is empty in the batch - if check_empty_patch(labels) is None: - print(f"Empty label patch found. Skipping training step ...") - return None + # # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + # return None output = self.forward(inputs) # logits # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") From 71faa5219c8b661a73c71a85d1f6e59d74dae44b Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 4 Jun 2024 17:32:54 -0400 Subject: [PATCH 072/108] changed the min volume threshold --- monai/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils.py b/monai/utils.py index 630b6ca..0b7a841 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -217,7 +217,7 @@ def lesion_f1_score(truth, prediction): return f1_score -def remove_small_lesions(lesion_seg, resolution, min_volume=50): +def remove_small_lesions(lesion_seg, resolution, min_volume=7.5): """ Remove lesions which are smaller than a given volume threshold. From baaa98073429d12da3f1853da0cf1c60cfb5f5c0 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 26 Jun 2024 10:50:03 -0400 Subject: [PATCH 073/108] changed msd dataset creation for nih and updated bavaria unstiched data --- monai/1_create_msd_data.py | 124 +++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index 445645c..6b2cee8 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -7,6 +7,7 @@ -po, --path-out: Path to the output directory where dataset json is saved --lesion-only: Use only masks which contain some lesions --seed: Seed for reproducibility + --canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo Example: python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt @@ -97,15 +98,17 @@ def main(): seed = args.seed # Get all subjects - canproco_path = Path(os.path.join(root, "canproco")) basel_path = Path(os.path.join(root, "basel-mp2rage")) - bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms")) + bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched")) + canproco_path = Path(os.path.join(root, "canproco")) + nih_path = Path(os.path.join(root, "nih-ms-mp2rage")) sct_testing_path = Path(os.path.join(root, "sct-testing-large")) - subjects_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) - subjects_basel = list(basel_path.rglob('*UNIT1.nii.gz')) - subjects_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) - subjects_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) + derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) + derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) + derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) + derivatives_nih = list(nih_path.rglob('*_label-lesion_seg.nii.gz')) + derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) # Path to the file containing the list of subjects to exclude from CanProCo if args.canproco_exclude is not None: @@ -114,19 +117,19 @@ def main(): # only keep the contrast psir and stir canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] - subjects = subjects_canproco + subjects_basel + subjects_sct + subjects_bavaria - # logger.info(f"Total number of subjects in the root directory: {len(subjects)}") + derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct + logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}") # create one json file with 60-20-20 train-val-test split train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 - train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed) # Use the training split to further split into training and validation splits - train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio), random_state=args.seed, ) # sort the subjects - train_subjects = sorted(train_subjects) - val_subjects = sorted(val_subjects) - test_subjects = sorted(test_subjects) + train_derivatives = sorted(train_derivatives) + val_derivatives = sorted(val_derivatives) + test_derivatives = sorted(test_derivatives) # logger.info(f"Number of training subjects: {len(train_subjects)}") # logger.info(f"Number of validation subjects: {len(val_subjects)}") @@ -134,7 +137,7 @@ def main(): # dump train/val/test splits into a yaml file with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: - yaml.dump({'train': train_subjects, 'val': val_subjects, 'test': test_subjects}, file, indent=2, sort_keys=True) + yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) # keys to be defined in the dataset_0.json params = {} @@ -152,32 +155,59 @@ def main(): params["reference"] = "NeuroPoly" params["tensorImageSize"] = "3D" - train_subjects_dict = {"train": train_subjects} - val_subjects_dict = {"validation": val_subjects} - test_subjects_dict = {"test": test_subjects} - all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] + train_derivatives_dict = {"train": train_derivatives} + val_derivatives_dict = {"validation": val_derivatives} + test_derivatives_dict = {"test": test_derivatives} + all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict] # iterate through the train/val/test splits and add those which have both image and label - for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): + for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"): - for name, subs_list in subjects_dict.items(): + for name, derivs_list in derivatives_dict.items(): temp_list = [] - for subject_no, subject in enumerate(subs_list): + for subject_no, derivative in enumerate(derivs_list): - temp_data_canproco = {} + temp_data_basel = {} - temp_data_sct = {} temp_data_bavaria = {} + temp_data_canproco = {} + temp_data_nih = {} + temp_data_sct = {} + + # Basel + if 'basel-mp2rage' in str(derivative): + relative_path = derivative.relative_to(basel_path).parent + temp_data_basel["label"] = str(derivative) + temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) + temp_data_basel["total_lesion_volume"] = total_lesion_volume + temp_data_basel["nb_lesions"] = nb_lesions + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_basel) + # Bavaria-quebec + elif 'bavaria-quebec-spine-ms' in str(derivative): + temp_data_bavaria["label"] = str(derivative) + temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) + temp_data_bavaria["total_lesion_volume"] = total_lesion_volume + temp_data_bavaria["nb_lesions"] = nb_lesions + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_bavaria) + # Canproco - if 'canproco' in str(subject): - subject_id = subject.name.replace('_PSIR_lesion-manual.nii.gz', '') + elif 'canproco' in str(derivative): + subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '') subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') if subject_id in canproco_exclude_list: continue - temp_data_canproco["label"] = str(subject) - temp_data_canproco["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + temp_data_canproco["label"] = str(derivative) + temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) temp_data_canproco["total_lesion_volume"] = total_lesion_volume @@ -185,24 +215,23 @@ def main(): if args.lesion_only and nb_lesions == 0: continue temp_list.append(temp_data_canproco) - - # Basel - elif 'basel-mp2rage' in str(subject): - relative_path = subject.relative_to(basel_path).parent - temp_data_basel["image"] = str(subject) - temp_data_basel["label"] = str(basel_path) + '/derivatives/labels/' + str(relative_path) +'/'+ str(subject.name).replace('UNIT1.nii.gz', 'UNIT1_desc-rater3_label-lesion_seg.nii.gz') - if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) - temp_data_basel["total_lesion_volume"] = total_lesion_volume - temp_data_basel["nb_lesions"] = nb_lesions + + # nih-ms-mp2rage + elif 'nih-ms-mp2rage' in str(derivative): + temp_data_nih["label"] = str(derivative) + temp_data_nih["image"] = str(derivative).replace('_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) + temp_data_nih["total_lesion_volume"] = total_lesion_volume + temp_data_nih["nb_lesions"] = nb_lesions if args.lesion_only and nb_lesions == 0: continue - temp_list.append(temp_data_basel) + temp_list.append(temp_data_nih) # sct-testing-large - elif 'sct-testing-large' in str(subject): - temp_data_sct["label"] = str(subject) - temp_data_sct["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + elif 'sct-testing-large' in str(derivative): + temp_data_sct["label"] = str(derivative) + temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) temp_data_sct["total_lesion_volume"] = total_lesion_volume @@ -211,19 +240,6 @@ def main(): continue temp_list.append(temp_data_sct) - - # Bavaria-quebec - elif 'bavaria-quebec-spine-ms' in str(subject): - temp_data_bavaria["label"] = str(subject) - temp_data_bavaria["image"] = str(subject).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) - temp_data_bavaria["total_lesion_volume"] = total_lesion_volume - temp_data_bavaria["nb_lesions"] = nb_lesions - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_bavaria) - params[name] = temp_list logger.info(f"Number of images in {name} set: {len(temp_list)}") params["numTest"] = len(params["test"]) From 500de22e13252d29a600bdea5dfb58c01f5518a6 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 26 Jun 2024 15:08:45 -0400 Subject: [PATCH 074/108] updated requirements and added loguru --- monai/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/requirements.txt b/monai/requirements.txt index 7618f5c..ef7795d 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -4,4 +4,5 @@ torch torchvision matplotlib pytorch_lightning -cupy-cuda117==10.6.0 \ No newline at end of file +cupy-cuda117==10.6.0 +loguru \ No newline at end of file From 8096913dda82b491eda8bf32c791b7c5fba850a4 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 26 Jun 2024 16:54:44 -0400 Subject: [PATCH 075/108] corrected lesion mask name for nih --- monai/1_create_msd_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index 6b2cee8..204b6fe 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -107,7 +107,7 @@ def main(): derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) - derivatives_nih = list(nih_path.rglob('*_label-lesion_seg.nii.gz')) + derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz')) derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) # Path to the file containing the list of subjects to exclude from CanProCo @@ -219,7 +219,7 @@ def main(): # nih-ms-mp2rage elif 'nih-ms-mp2rage' in str(derivative): temp_data_nih["label"] = str(derivative) - temp_data_nih["image"] = str(derivative).replace('_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) temp_data_nih["total_lesion_volume"] = total_lesion_volume From 3c7e488b3c966bad98a5b9a67ffaa9614864f393 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 26 Jun 2024 16:55:16 -0400 Subject: [PATCH 076/108] corrected requirements --- monai/requirements.txt | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/requirements.txt b/monai/requirements.txt index ef7795d..8a9c2f5 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -1,8 +1,11 @@ -tqdm -monai[all] -torch -torchvision -matplotlib -pytorch_lightning +numpy==1.24.3 +tqdm==4.65.0 +torch==2.0.1 +torchvision==0.15.2 +monai[all]==1.3.0 +matplotlib==3.8.2 +pytorch-lightning==2.2.1 cupy-cuda117==10.6.0 -loguru \ No newline at end of file +loguru==0.7.2 +wandb==0.15.12 +dynamic-network-architectures==0.2 \ No newline at end of file From 7cea4c83abac03de8d5974e8edd756ba071504b4 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 15 Jul 2024 10:52:36 -0400 Subject: [PATCH 077/108] config file for training on ETS server --- monai/config.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/config.yml b/monai/config.yml index 0668f5f..36b7f3e 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -6,8 +6,9 @@ # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_10_each.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json -data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json +data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json # Resampling resolution # pixdim : [1.0, 1.0, 1.0] @@ -31,7 +32,8 @@ max_iterations : 250 eval_num : 2 # Outputs -output_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ +# output_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ +output_path : /home/plbenveniste/net/ms-lesion-agnostic/results/ # Seed seed : 42 From 5fa1ac1a2799d844d03c502e7b70a8855334ffb4 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 16 Jul 2024 10:17:30 -0400 Subject: [PATCH 078/108] added nnUNet data augmentation --- monai/train_monai_unet_lightning.py | 263 +++++++++++++++------------- 1 file changed, 142 insertions(+), 121 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index 4f87721..bda010d 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -52,7 +52,11 @@ Rand3DElasticd, RandSimulateLowResolutiond, RandBiasFieldd, - RandAffined + RandAffined, + RandRotated, + RandZoomd, + RandGaussianSmoothd, + RandScaleIntensityd ) from monai.utils import set_determinism from monai.inferers import sliding_window_inference @@ -138,10 +142,11 @@ def prepare_data(self): # define training and validation transforms train_transforms = Compose( - [ + [ LoadImaged(keys=["image", "label"], reader="NibabelReader"), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RPI"), + # This changes the spacing of the image Spacingd( keys=["image", "label"], pixdim=self.cfg["pixdim"], @@ -153,47 +158,21 @@ def prepare_data(self): nonzero=False, channel_wise=False ), - # # This crops the image around areas where the mask is non-zero - # # (the margin is added because otherwise the image would be just the size of the lesion) - # CropForegroundd( - # keys=["image", "label"], - # source_key="label", - # margin=200 - # ), - # # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=0, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # allow_smaller=True, - # ), # This resizes the image and the label to the spatial size defined in the config ResizeWithPadOrCropd( keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - # Flips the image : left becomes right - RandFlipd( - keys=["image", "label"], - spatial_axis=[0], - prob=self.cfg["DA_probability"], - ), - # Flips the image : supperior becomes inferior - RandFlipd( - keys=["image", "label"], - spatial_axis=[1], - prob=self.cfg["DA_probability"], - ), - # Flips the image : anterior becomes posterior - RandFlipd( + # Spatial transforms + # Random rotation of the image + RandRotated( keys=["image", "label"], - spatial_axis=[2], + range_x=np.pi, + range_y=np.pi, + range_z=np.pi, prob=self.cfg["DA_probability"], + keep_size=True, + mode=('bilinear', 'nearest'), ), # Random elastic deformation Rand3DElasticd( @@ -203,6 +182,15 @@ def prepare_data(self): prob=self.cfg["DA_probability"], mode=['bilinear', 'nearest'], ), + # Changes the spacing of the image + RandZoomd( + keys=["image", "label"], + prob=self.cfg["DA_probability"], + min_zoom=0.75, + max_zoom=1.25, + mode=('bilinear', 'nearest'), + keep_size=True, + ), # Random affine transform of the image RandAffined( keys=["image", "label"], @@ -210,31 +198,49 @@ def prepare_data(self): mode=('bilinear', 'nearest'), padding_mode='zeros', ), - # RandAdjustContrastd( - # keys=["image"], - # prob=self.cfg["DA_probability"], - # gamma=(0.5, 4.5), - # invert_image=True, - # ), - # # we add the multiplication of the image by -1 - # RandLambdad( - # keys='image', - # func=multiply_by_negative_one, - # prob=0.5 - # ), - # LabelToContourd( - # keys=["image"], - # kernel_type='Laplace', - # ), + # Intensity transforms + # Random Gaussian noise is added to the image RandGaussianNoised( keys=["image"], prob=self.cfg["DA_probability"], + mean=0.0, + std=0.1, + ), + # Gaussian blur with RandGaussianSmoothd + RandGaussianSmoothd( + keys=["image"], + prob=self.cfg["DA_probability"], + sigma_x=(0.5, 1.), + sigma_y=(0.5, 1.), + sigma_z=(0.5, 1.), + ), + # Brightness transform: with RandScaleIntensityd + RandScaleIntensityd( + keys=["image"], + prob=self.cfg["DA_probability"], + factors=0.25, + ), + # Contrast transform: with RandAdjustContrastd + RandAdjustContrastd( + keys=["image"], + prob=self.cfg["DA_probability"], + invert_image=True, + retain_stats=True, ), - # Random simulation of low resolution + # Contrast transform: with RandAdjustContrastd + RandAdjustContrastd( + keys=["image"], + prob=self.cfg["DA_probability"], + invert_image=False, + retain_stats=True, + ), + # Simulate low resolution with RandSimulateLowResolutiond RandSimulateLowResolutiond( keys=["image"], - zoom_range=(0.8, 1.5), - prob=self.cfg["DA_probability"] + prob=self.cfg["DA_probability"], + downsample_mode='nearest', + upsample_mode='trilinear', + zoom_range=(0.5, 1.0), ), # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing RandBiasFieldd( @@ -243,24 +249,58 @@ def prepare_data(self): degree=3, prob=self.cfg["DA_probability"] ), - # RandShiftIntensityd( + # Binary thresholding of the label + AsDiscreted( + keys=["label"], + threshold=0.5, + ), + + + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=0, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), + # Multiplication of image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # Takes the laplacian of the image + # LabelToContourd( # keys=["image"], - # offsets=0.1, - # prob=0.2, + # kernel_type='Laplace', # ), + # EnsureTyped(keys=["image", "label"]), - # AsDiscreted( - # keys=["label"], - # num_classes=2, - # threshold_values=True, - # logit_thresh=0.2, - # ), - # Remove small lesions in the label - RandLambdad( - keys='label', - func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), - prob=1.0 - ) + + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), ] ) val_transforms = Compose( @@ -273,48 +313,27 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 0), ), - # This normalizes the intensity of the image NormalizeIntensityd( keys=["image"], nonzero=False, channel_wise=False ), - # CropForegroundd( - # keys=["image", "label"], - # source_key="label", - # margin=150), - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=1, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # allow_smaller=True, - # ), ResizeWithPadOrCropd( keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - # LabelToContourd( - # keys=["image"], - # kernel_type='Laplace', + # Binary thresholding of the label + AsDiscreted( + keys=["label"], + threshold=0.5, + ), + # # This normalizes the intensity of the image + # NormalizeIntensityd( + # keys=["image"], + # nonzero=False, + # channel_wise=False # ), # EnsureTyped(keys=["image", "label"]), - # AsDiscreted( - # keys=["label"], - # num_classes=2, - # threshold_values=True, - # logit_thresh=0.2, - # ) - # Remove small lesions in the label - RandLambdad( - keys='label', - func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), - prob=1.0 - ) ] ) @@ -326,28 +345,30 @@ def prepare_data(self): test_files = load_decathlon_datalist(dataset, True, "test") train_cache_rate = 0.5 + val_cache_rate = 0.25 self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) - self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=val_cache_rate, num_workers=8) # define test transforms transforms_test = val_transforms - # define post-processing transforms for testing; taken (with explanations) from - # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - self.test_post_pred = Compose([ - EnsureTyped(keys=["pred", "label"]), - Invertd(keys=["pred", "label"], transform=transforms_test, - orig_keys=["image", "label"], - meta_keys=["pred_meta_dict", "label_meta_dict"], - nearest_interp=False, to_tensor=True), - # Remove small lesions in the label - RandLambdad( - keys='label', - func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), - prob=1.0 - ) - ]) - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + # Hidden because we don't use it + # # define post-processing transforms for testing; taken (with explanations) from + # # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + # self.test_post_pred = Compose([ + # EnsureTyped(keys=["pred", "label"]), + # Invertd(keys=["pred", "label"], transform=transforms_test, + # orig_keys=["image", "label"], + # meta_keys=["pred_meta_dict", "label_meta_dict"], + # nearest_interp=False, to_tensor=True), + # # # Remove small lesions in the label + # # RandLambdad( + # # keys='label', + # # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # # prob=1.0 + # # ) + # ]) + # self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) # -------------------------------- @@ -358,7 +379,7 @@ def train_dataloader(self): pin_memory=True, persistent_workers=True) def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True) def test_dataloader(self): @@ -705,7 +726,7 @@ def main(): # out_channels=1, # channels=(32, 64, 128), # strides=(2, 2, 2, ), - # # dropout=0.1 + # dropout=0.1 # ) net = AttentionUnet( @@ -715,7 +736,7 @@ def main(): channels=config["attention_unet_channels"], strides=config["attention_unet_strides"], dropout=0.1, - ) + ) # net = SwinUNETR( # img_size=config["spatial_size"], From 88b661942c1ca28e9630266f2185eaed4c36705f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 22 Jul 2024 17:16:57 -0400 Subject: [PATCH 079/108] added contrast, site and orientation in msd dataset --- monai/1_create_msd_data.py | 50 +++ monai/utils/image.py | 685 +++++++++++++++++++++++++++++++++++++ 2 files changed, 735 insertions(+) create mode 100644 monai/utils/image.py diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index 204b6fe..7f311e4 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -30,6 +30,7 @@ import nibabel as nib import numpy as np import skimage +from utils.image import Image def get_parser(): @@ -80,6 +81,40 @@ def count_lesion(label_file): return total_volume, nb_lesions +def get_orientation(image_path): + """ + This function takes an image file as input and returns its orientation. + + Input: + image_path : str : Path to the image file + + Returns: + orientation : str : Orientation of the image + """ + img = Image(str(image_path)) + pixdim = img.dim[4:7] + # if all pixdim are the same than, the image orientation is isotropic (a small threshold is used) + if np.allclose(pixdim, pixdim[0], atol=1e-3): + orientation = 'iso' + print("orientation", orientation) + return orientation + # Get arg of 2 lowest pixdim + arg = np.argsort(pixdim)[:2] + # Get corresponding orientation letters + orientation = ''.join([img.orientation[i] for i in arg]) + # print("orientation", orientation) + # if A-P and L-R : orientation is axial + if orientation in ['AL', 'LA', 'AR', 'RA', 'PL', 'LP', 'PR', 'RP']: + orientation = 'ax' + # elif A-P and I-S: orientation is sagittal + elif orientation in ['AI', 'IA', 'AS', 'SA', 'PI', 'IP', 'PS', 'SP']: + orientation = 'sag' + # Finaly for coronal: I-S and L-R + else: + orientation = 'cor' + return orientation + + def main(): """ This is the main function of the script. @@ -184,6 +219,9 @@ def main(): total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) temp_data_basel["total_lesion_volume"] = total_lesion_volume temp_data_basel["nb_lesions"] = nb_lesions + temp_data_basel["site"]='basel' + temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1] + temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) if args.lesion_only and nb_lesions == 0: continue temp_list.append(temp_data_basel) @@ -196,6 +234,9 @@ def main(): total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) temp_data_bavaria["total_lesion_volume"] = total_lesion_volume temp_data_bavaria["nb_lesions"] = nb_lesions + temp_data_bavaria["site"]='bavaria-quebec' + temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"]) if args.lesion_only and nb_lesions == 0: continue temp_list.append(temp_data_bavaria) @@ -212,6 +253,9 @@ def main(): total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) temp_data_canproco["total_lesion_volume"] = total_lesion_volume temp_data_canproco["nb_lesions"] = nb_lesions + temp_data_canproco["site"]='canproco' + temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"]) if args.lesion_only and nb_lesions == 0: continue temp_list.append(temp_data_canproco) @@ -224,6 +268,9 @@ def main(): total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) temp_data_nih["total_lesion_volume"] = total_lesion_volume temp_data_nih["nb_lesions"] = nb_lesions + temp_data_nih["site"]='nih' + temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"]) if args.lesion_only and nb_lesions == 0: continue temp_list.append(temp_data_nih) @@ -236,6 +283,9 @@ def main(): total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) temp_data_sct["total_lesion_volume"] = total_lesion_volume temp_data_sct["nb_lesions"] = nb_lesions + temp_data_sct["site"]='sct-testing-large' + temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"]) if args.lesion_only and nb_lesions == 0: continue temp_list.append(temp_data_sct) diff --git a/monai/utils/image.py b/monai/utils/image.py new file mode 100644 index 0000000..03e670c --- /dev/null +++ b/monai/utils/image.py @@ -0,0 +1,685 @@ +import os +import numpy as np +import nibabel as nib +import logging +from copy import deepcopy + +logger = logging.getLogger(__name__) + +class Image(object): + """ + Compact version of SCT's Image Class (https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/image.py#L245) + Create an object that behaves similarly to nibabel's image object. Useful additions include: dims, change_orientation and getNonZeroCoordinates. + """ + + def __init__(self, param=None, hdr=None, orientation=None, absolutepath=None, dim=None): + """ + :param param: string indicating a path to a image file or an `Image` object. + """ + + # initialization of all parameters + self.affine = None + self.data = None + self._path = None + self.ext = "" + + if absolutepath is not None: + self._path = os.path.abspath(absolutepath) + + # Case 1: load an image from file + if isinstance(param, str): + self.loadFromPath(param) + # Case 2: create a copy of an existing `Image` object + elif isinstance(param, type(self)): + self.copy(param) + # Case 3: create a blank image from a list of dimensions + elif isinstance(param, list): + self.data = np.zeros(param) + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + # Case 4: create an image from an existing data array + elif isinstance(param, (np.ndarray, np.generic)): + self.data = param + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + else: + raise TypeError('Image constructor takes at least one argument.') + + # Fix any mismatch between the array's datatype and the header datatype + self.fix_header_dtype() + + @property + def dim(self): + return get_dimension(self) + + @property + def orientation(self): + return get_orientation(self) + + @property + def absolutepath(self): + """ + Storage path (either actual or potential) + + Notes: + + - As several tools perform chdir() it's very important to have absolute paths + - When set, if relative: + + - If it already existed, it becomes a new basename in the old dirname + - Else, it becomes absolute (shortcut) + + Usually not directly touched (use `Image.save`), but in some cases it's + the best way to set it. + """ + return self._path + + @absolutepath.setter + def absolutepath(self, value): + if value is None: + self._path = None + return + elif not os.path.isabs(value) and self._path is not None: + value = os.path.join(os.path.dirname(self._path), value) + elif not os.path.isabs(value): + value = os.path.abspath(value) + self._path = value + + @property + def header(self): + return self.hdr + + @header.setter + def header(self, value): + self.hdr = value + + def __deepcopy__(self, memo): + return type(self)(deepcopy(self.data, memo), deepcopy(self.hdr, memo), deepcopy(self.orientation, memo), deepcopy(self.absolutepath, memo), deepcopy(self.dim, memo)) + + def copy(self, image=None): + if image is not None: + self.affine = deepcopy(image.affine) + self.data = deepcopy(image.data) + self.hdr = deepcopy(image.hdr) + self._path = deepcopy(image._path) + else: + return deepcopy(self) + + def loadFromPath(self, path): + """ + This function load an image from an absolute path using nibabel library + + :param path: path of the file from which the image will be loaded + :return: + """ + + self.absolutepath = os.path.abspath(path) + im_file = nib.load(self.absolutepath, mmap=True) + self.affine = im_file.affine.copy() + self.data = np.asanyarray(im_file.dataobj) + self.hdr = im_file.header.copy() + if path != self.absolutepath: + logger.debug("Loaded %s (%s) orientation %s shape %s", path, self.absolutepath, self.orientation, self.data.shape) + else: + logger.debug("Loaded %s orientation %s shape %s", path, self.orientation, self.data.shape) + + def change_orientation(self, orientation, inverse=False): + """ + Change orientation on image (in-place). + + :param orientation: orientation string (SCT "from" convention) + + :param inverse: if you think backwards, use this to specify that you actually\ + want to transform *from* the specified orientation, not *to*\ + it. + + """ + change_orientation(self, orientation, self, inverse=inverse) + return self + + def getNonZeroCoordinates(self, sorting=None, reverse_coord=False): + """ + This function return all the non-zero coordinates that the image contains. + Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value' + If reverse_coord is True, coordinate are sorted from larger to smaller. + + Removed Coordinate object + """ + n_dim = 1 + if self.dim[3] == 1: + n_dim = 3 + else: + n_dim = 4 + if self.dim[2] == 1: + n_dim = 2 + + if n_dim == 3: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]] for i in range(0, len(X))] + elif n_dim == 2: + try: + X, Y = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i]]] for i in range(0, len(X))] + except ValueError: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i], 0]] for i in range(0, len(X))] + + if sorting is not None: + if reverse_coord not in [True, False]: + raise ValueError('reverse_coord parameter must be a boolean') + + if sorting == 'x': + list_coordinates = sorted(list_coordinates, key=lambda el: el[0], reverse=reverse_coord) + elif sorting == 'y': + list_coordinates = sorted(list_coordinates, key=lambda el: el[1], reverse=reverse_coord) + elif sorting == 'z': + list_coordinates = sorted(list_coordinates, key=lambda el: el[2], reverse=reverse_coord) + elif sorting == 'value': + list_coordinates = sorted(list_coordinates, key=lambda el: el[3], reverse=reverse_coord) + else: + raise ValueError("sorting parameter must be either 'x', 'y', 'z' or 'value'") + + return list_coordinates + + def change_type(self, dtype): + """ + Change data type on image. + + Note: the image path is voided. + """ + change_type(self, dtype, self) + return self + + def fix_header_dtype(self): + """ + Change the header dtype to the match the datatype of the array. + """ + # Using bool for nibabel headers is unsupported, so use uint8 instead: + # `nibabel.spatialimages.HeaderDataError: data dtype "bool" not supported` + dtype_data = self.data.dtype + if dtype_data == bool: + dtype_data = np.uint8 + + dtype_header = self.hdr.get_data_dtype() + if dtype_header != dtype_data: + logger.warning(f"Image header specifies datatype '{dtype_header}', but array is of type " + f"'{dtype_data}'. Header metadata will be overwritten to use '{dtype_data}'.") + self.hdr.set_data_dtype(dtype_data) + + def save(self, path=None, dtype=None, verbose=1, mutable=False): + """ + Write an image in a nifti file + + :param path: Where to save the data, if None it will be taken from the\ + absolutepath member.\ + If path is a directory, will save to a file under this directory\ + with the basename from the absolutepath member. + + :param dtype: if not set, the image is saved in the same type as input data\ + if 'minimize', image storage space is minimized\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + + :param mutable: whether to update members with newly created path or dtype + """ + if mutable: # do all modifications in-place + # Case 1: `path` not specified + if path is None: + if self.absolutepath: # Fallback to the original filepath + path = self.absolutepath + else: + raise ValueError("Don't know where to save the image (no absolutepath or path parameter)") + # Case 2: `path` points to an existing directory + elif os.path.isdir(path): + if self.absolutepath: # Use the original filename, but save to the directory specified by `path` + path = os.path.join(os.path.abspath(path), os.path.basename(self.absolutepath)) + else: + raise ValueError("Don't know where to save the image (path parameter is dir, but absolutepath is " + "missing)") + # Case 3: `path` points to a file (or a *nonexistent* directory) so use its value as-is + # (We're okay with letting nonexistent directories slip through, because it's difficult to distinguish + # between nonexistent directories and nonexistent files. Plus, `nibabel` will catch any further errors.) + else: + pass + + if os.path.isfile(path) and verbose: + logger.warning("File %s already exists. Will overwrite it.", path) + if os.path.isabs(path): + logger.debug("Saving image to %s orientation %s shape %s", + path, self.orientation, self.data.shape) + else: + logger.debug("Saving image to %s (%s) orientation %s shape %s", + path, os.path.abspath(path), self.orientation, self.data.shape) + + # Now that `path` has been set and log messages have been written, we can assign it to the image itself + self.absolutepath = os.path.abspath(path) + + if dtype is not None: + self.change_type(dtype) + + if self.hdr is not None: + self.hdr.set_data_shape(self.data.shape) + self.fix_header_dtype() + + # nb. that copy() is important because if it were a memory map, save() would corrupt it + dataobj = self.data.copy() + affine = None + header = self.hdr.copy() if self.hdr is not None else None + nib.save(nib.nifti1.Nifti1Image(dataobj, affine, header), self.absolutepath) + if not os.path.isfile(self.absolutepath): + raise RuntimeError(f"Couldn't save image to {self.absolutepath}") + else: + # if we're not operating in-place, then make any required modifications on a throw-away copy + self.copy().save(path, dtype, verbose, mutable=True) + return self + + +class SlicerOneAxis(object): + """ + Image slicer to use when you don't care about the 2D slice orientation, + and don't want to specify them. + The slicer will just iterate through the right axis that corresponds to + its specification. + + Can help getting ranges and slice indices. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + + def __init__(self, im, axis="IS"): + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + axis_labels = "LRPAIS" + if len(axis) != 2: + raise ValueError() + if axis[0] not in axis_labels: + raise ValueError() + if axis[1] not in axis_labels: + raise ValueError() + if axis[0] != opposite_character[axis[1]]: + raise ValueError() + + for idx_axis in range(2): + dim_nr = im.orientation.find(axis[idx_axis]) + if dim_nr != -1: + break + if dim_nr == -1: + raise ValueError() + + # SCT convention + from_dir = im.orientation[dim_nr] + self.direction = +1 if axis[0] == from_dir else -1 + self.nb_slices = im.dim[dim_nr] + self.im = im + self.axis = axis + self._slice = lambda idx: tuple([(idx if x in axis else slice(None)) for x in im.orientation]) + + def __len__(self): + return self.nb_slices + + def __getitem__(self, idx): + """ + + :return: an image slice, at slicing index idx + :param idx: slicing index (according to the slicing direction) + """ + if isinstance(idx, slice): + raise NotImplementedError() + + if idx >= self.nb_slices: + raise IndexError("I just have {} slices!".format(self.nb_slices)) + + if self.direction == -1: + idx = self.nb_slices - 1 - idx + + return self.im.data[self._slice(idx)] + +def get_dimension(im_file, verbose=1): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + Get dimension from Image or nibabel object. Manages 2D, 3D or 4D images. + + :param: im_file: Image or nibabel object + :return: nx, ny, nz, nt, px, py, pz, pt + """ + if not isinstance(im_file, (nib.nifti1.Nifti1Image, Image)): + raise TypeError("The provided image file is neither a nibabel.nifti1.Nifti1Image instance nor an Image instance") + # initializating ndims [nx, ny, nz, nt] and pdims [px, py, pz, pt] + ndims = [1, 1, 1, 1] + pdims = [1, 1, 1, 1] + data_shape = im_file.header.get_data_shape() + zooms = im_file.header.get_zooms() + for i in range(min(len(data_shape), 4)): + ndims[i] = data_shape[i] + pdims[i] = zooms[i] + return *ndims, *pdims + + +def change_orientation(im_src, orientation, im_dst=None, inverse=False): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src: source image + :param orientation: orientation string (SCT "from" convention) + :param im_dst: destination image (can be the source image for in-place + operation, can be unset to generate one) + :param inverse: if you think backwards, use this to specify that you actually + want to transform *from* the specified orientation, not *to* it. + :return: an image with changed orientation + + .. note:: + - the resulting image has no path member set + - if the source image is < 3D, it is reshaped to 3D and the destination is 3D + """ + + if len(im_src.data.shape) < 3: + pass # Will reshape to 3D + elif len(im_src.data.shape) == 3: + pass # OK, standard 3D volume + elif len(im_src.data.shape) == 4: + pass # OK, standard 4D volume + elif len(im_src.data.shape) == 5 and im_src.header.get_intent()[0] == "vector": + pass # OK, physical displacement field + else: + raise NotImplementedError("Don't know how to change orientation for this image") + + im_src_orientation = im_src.orientation + im_dst_orientation = orientation + if inverse: + im_src_orientation, im_dst_orientation = im_dst_orientation, im_src_orientation + + perm, inversion = _get_permutations(im_src_orientation, im_dst_orientation) + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + im_src_data = im_src.data + if len(im_src_data.shape) < 3: + im_src_data = im_src_data.reshape(tuple(list(im_src_data.shape) + ([1] * (3 - len(im_src_data.shape))))) + + # Update data by performing inversions and swaps + + # axes inversion (flip) + data = im_src_data[::inversion[0], ::inversion[1], ::inversion[2]] + + # axes manipulations (transpose) + if perm == [1, 0, 2]: + data = np.swapaxes(data, 0, 1) + elif perm == [2, 1, 0]: + data = np.swapaxes(data, 0, 2) + elif perm == [0, 2, 1]: + data = np.swapaxes(data, 1, 2) + elif perm == [2, 0, 1]: + data = np.swapaxes(data, 0, 2) # transform [2, 0, 1] to [1, 0, 2] + data = np.swapaxes(data, 0, 1) # transform [1, 0, 2] to [0, 1, 2] + elif perm == [1, 2, 0]: + data = np.swapaxes(data, 0, 2) # transform [1, 2, 0] to [0, 2, 1] + data = np.swapaxes(data, 1, 2) # transform [0, 2, 1] to [0, 1, 2] + elif perm == [0, 1, 2]: + # do nothing + pass + else: + raise NotImplementedError() + + # Update header + + im_src_aff = im_src.hdr.get_best_affine() + aff = nib.orientations.inv_ornt_aff( + np.array((perm, inversion)).T, + im_src_data.shape) + im_dst_aff = np.matmul(im_src_aff, aff) + + im_dst.header.set_qform(im_dst_aff) + im_dst.header.set_sform(im_dst_aff) + im_dst.header.set_data_shape(data.shape) + im_dst.data = data + + return im_dst + + +def _get_permutations(im_src_orientation, im_dst_orientation): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src_orientation str: Orientation of source image. Example: 'RPI' + :param im_dest_orientation str: Orientation of destination image. Example: 'SAL' + :return: list of axes permutations and list of inversions to achieve an orientation change + """ + + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + + perm = [0, 1, 2] + inversion = [1, 1, 1] + for i, character in enumerate(im_src_orientation): + try: + perm[i] = im_dst_orientation.index(character) + except ValueError: + perm[i] = im_dst_orientation.index(opposite_character[character]) + inversion[i] = -1 + + return perm, inversion + + +def get_orientation(im): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im: an Image + :return: reference space string (ie. what's in Image.orientation) + """ + res = "".join(nib.orientations.aff2axcodes(im.hdr.get_best_affine())) + return orientation_string_nib2sct(res) + + +def orientation_string_nib2sct(s): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :return: SCT reference space code from nibabel one + """ + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + return "".join([opposite_character[x] for x in s]) + + +def change_type(im_src, dtype, im_dst=None): + """ + Change the voxel type of the image + + :param dtype: if not set, the image is saved in standard type\ + if 'minimize', image space is minimize\ + if 'minimize_int', image space is minimize and values are approximated to integers\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + :return: + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + if dtype is None: + return im_dst + + # get min/max from input image + min_in = np.nanmin(im_src.data) + max_in = np.nanmax(im_src.data) + + # find optimum type for the input image + if dtype in ('minimize', 'minimize_int'): + # warning: does not take intensity resolution into account, neither complex voxels + + # check if voxel values are real or integer + isInteger = True + if dtype == 'minimize': + for vox in im_src.data.flatten(): + if int(vox) != vox: + isInteger = False + break + + if isInteger: + if min_in >= 0: # unsigned + if max_in <= np.iinfo(np.uint8).max: + dtype = np.uint8 + elif max_in <= np.iinfo(np.uint16): + dtype = np.uint16 + elif max_in <= np.iinfo(np.uint32).max: + dtype = np.uint32 + elif max_in <= np.iinfo(np.uint64).max: + dtype = np.uint64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + if max_in <= np.iinfo(np.int8).max and min_in >= np.iinfo(np.int8).min: + dtype = np.int8 + elif max_in <= np.iinfo(np.int16).max and min_in >= np.iinfo(np.int16).min: + dtype = np.int16 + elif max_in <= np.iinfo(np.int32).max and min_in >= np.iinfo(np.int32).min: + dtype = np.int32 + elif max_in <= np.iinfo(np.int64).max and min_in >= np.iinfo(np.int64).min: + dtype = np.int64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + # if max_in <= np.finfo(np.float16).max and min_in >= np.finfo(np.float16).min: + # type = 'np.float16' # not supported by nibabel + if max_in <= np.finfo(np.float32).max and min_in >= np.finfo(np.float32).min: + dtype = np.float32 + elif max_in <= np.finfo(np.float64).max and min_in >= np.finfo(np.float64).min: + dtype = np.float64 + + dtype = to_dtype(dtype) + else: + dtype = to_dtype(dtype) + + # if output type is int, check if it needs intensity rescaling + if "int" in dtype.name: + # get min/max from output type + min_out = np.iinfo(dtype).min + max_out = np.iinfo(dtype).max + # before rescaling, check if there would be an intensity overflow + + if (min_in < min_out) or (max_in > max_out): + # This condition is important for binary images since we do not want to scale them + logger.warning(f"To avoid intensity overflow due to convertion to +{dtype.name}+, intensity will be rescaled to the maximum quantization scale") + # rescale intensity + data_rescaled = im_src.data * (max_out - min_out) / (max_in - min_in) + im_dst.data = data_rescaled - (data_rescaled.min() - min_out) + + # change type of data in both numpy array and nifti header + im_dst.data = getattr(np, dtype.name)(im_dst.data) + im_dst.hdr.set_data_dtype(dtype) + return im_dst + + +def to_dtype(dtype): + """ + Take a dtypeification and return an np.dtype + + :param dtype: dtypeification (string or np.dtype or None are supported for now) + :return: dtype or None + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + # TODO add more or filter on things supported by nibabel + + if dtype is None: + return None + if isinstance(dtype, type): + if isinstance(dtype(0).dtype, np.dtype): + return dtype(0).dtype + if isinstance(dtype, np.dtype): + return dtype + if isinstance(dtype, str): + return np.dtype(dtype) + + raise TypeError("data type {}: {} not understood".format(dtype.__class__, dtype)) + + +def zeros_like(img, dtype=None): + """ + + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, filled with zeros + + Similar to numpy.zeros_like(), the goal of the function is to show the developer's + intent and avoid doing a copy, which is slower than initialization with a constant. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + zimg = Image(np.zeros_like(img.data), hdr=img.hdr.copy()) + if dtype is not None: + zimg.change_type(dtype) + return zimg + + +def empty_like(img, dtype=None): + """ + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, whose data is uninitialized + + Similar to numpy.empty_like(), the goal of the function is to show the developer's + intent and avoid touching the allocated memory, because it will be written to + afterwards. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + dst = change_type(img, dtype) + return dst + + +def find_zmin_zmax(im, threshold=0.1): + """ + Find the min (and max) z-slice index below which (and above which) slices only have voxels below a given threshold. + + :param im: Image object + :param threshold: threshold to apply before looking for zmin/zmax, typically corresponding to noise level. + :return: [zmin, zmax] + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + slicer = SlicerOneAxis(im, axis="IS") + + # Make sure image is not empty + if not np.any(slicer): + logger.error('Input image is empty') + + # Iterate from bottom to top until we find data + for zmin in range(0, len(slicer)): + if np.any(slicer[zmin] > threshold): + break + + # Conversely from top to bottom + for zmax in range(len(slicer) - 1, zmin, -1): + if np.any(slicer[zmax] > threshold): + break + + return zmin, zmax \ No newline at end of file From f387067f32826bdbd037b944ffb97792ff7b7874 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 22 Jul 2024 17:31:56 -0400 Subject: [PATCH 080/108] improved computation of orientation of image --- monai/1_create_msd_data.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index 7f311e4..df027e0 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -92,26 +92,22 @@ def get_orientation(image_path): orientation : str : Orientation of the image """ img = Image(str(image_path)) + img.change_orientation('RPI') + # Get pixdim pixdim = img.dim[4:7] - # if all pixdim are the same than, the image orientation is isotropic (a small threshold is used) + # If all are the same, the image is isotropic if np.allclose(pixdim, pixdim[0], atol=1e-3): orientation = 'iso' - print("orientation", orientation) return orientation - # Get arg of 2 lowest pixdim - arg = np.argsort(pixdim)[:2] - # Get corresponding orientation letters - orientation = ''.join([img.orientation[i] for i in arg]) - # print("orientation", orientation) - # if A-P and L-R : orientation is axial - if orientation in ['AL', 'LA', 'AR', 'RA', 'PL', 'LP', 'PR', 'RP']: - orientation = 'ax' - # elif A-P and I-S: orientation is sagittal - elif orientation in ['AI', 'IA', 'AS', 'SA', 'PI', 'IP', 'PS', 'SP']: + # Elif, the lowest arg is 0 then the orientation is sagittal + elif np.argmax(pixdim) == 0: orientation = 'sag' - # Finaly for coronal: I-S and L-R - else: + # Elif, the lowest arg is 1 then the orientation is coronal + elif np.argmax(pixdim) == 1: orientation = 'cor' + # Else the orientation is axial + else: + orientation = 'ax' return orientation From ec256aba458cc711ef073563a7798d0c7a390b59 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 22 Jul 2024 17:42:19 -0400 Subject: [PATCH 081/108] added __init__.py file for import possibility --- monai/utils/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 monai/utils/__init__.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py new file mode 100644 index 0000000..e69de29 From 554b5228581556d211b372393c92437718b21015 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 22 Jul 2024 17:43:14 -0400 Subject: [PATCH 082/108] removed unused files --- ...monai_unet_lightning_multichannel_input.py | 687 ----------------- ...onai_unet_lightning_multichannel_output.py | 725 ------------------ 2 files changed, 1412 deletions(-) delete mode 100644 monai/train_monai_unet_lightning_multichannel_input.py delete mode 100644 monai/train_monai_unet_lightning_multichannel_output.py diff --git a/monai/train_monai_unet_lightning_multichannel_input.py b/monai/train_monai_unet_lightning_multichannel_input.py deleted file mode 100644 index 0cb9a12..0000000 --- a/monai/train_monai_unet_lightning_multichannel_input.py +++ /dev/null @@ -1,687 +0,0 @@ -import os -import argparse -from datetime import datetime -from loguru import logger -import yaml -import nibabel as nib -from datetime import datetime - -import numpy as np -import wandb -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -import matplotlib.pyplot as plt -from monai.metrics import DiceMetric -from monai.losses import DiceLoss, DiceCELoss - -# Added this to solve problem with too many files open -## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 -import torch.multiprocessing -torch.multiprocessing.set_sharing_strategy('file_system') - -from losses import AdapWingLoss, SoftDiceLoss - -from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans, print_data_types -from monai.networks.nets import UNet, BasicUNet, AttentionUnet - -from monai.networks.layers import Norm - - -from monai.transforms import ( - EnsureChannelFirstd, - Compose, - LoadImaged, - Orientationd, - RandFlipd, - RandShiftIntensityd, - Spacingd, - RandRotate90d, - NormalizeIntensityd, - RandCropByPosNegLabeld, - BatchInverseTransform, - RandAdjustContrastd, - AsDiscreted, - RandHistogramShiftd, - ResizeWithPadOrCropd, - EnsureTyped, - RandLambdad, - CropForegroundd, - RandGaussianNoised, - ConcatItemsd - ) - -from monai.utils import set_determinism -from monai.inferers import sliding_window_inference -import time -from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) - -# Added this because of following warning received: -## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` -## which will trade-off precision for performance. For more details, -## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision -# torch.set_float32_matmul_precision('medium' | 'high') - - -def get_parser(): - """ - This function returns the parser for the command line arguments. - """ - parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") - parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) - return parser - - -# create a "model"-agnostic class with PL to use different models -class Model(pl.LightningModule): - def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): - super().__init__() - self.cfg = config - self.save_hyperparameters(ignore=['net', 'loss_function']) - - self.root = data_root - self.net = net - self.lr = config["lr"] - self.loss_function = loss_function - self.optimizer_class = optimizer_class - self.save_exp_id = exp_id - self.results_path = results_path - - self.best_val_dice, self.best_val_epoch = 0, 0 - self.best_val_loss = float("inf") - - # define cropping and padding dimensions - # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference - # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches - # which could be sub-optimal. - # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that - # only the SC is in context. - self.spacing = config["spatial_size"] - self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] - - # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) - self.val_post_pred = self.val_post_label = Compose([EnsureType()]) - - # define evaluation metric - self.soft_dice_metric = dice_score - - # temp lists for storing outputs from training, validation, and testing - self.train_step_outputs = [] - self.val_step_outputs = [] - self.test_step_outputs = [] - - - # -------------------------------- - # FORWARD PASS - # -------------------------------- - def forward(self, x): - - out = self.net(x) - # # NOTE: MONAI's models only output the logits, not the output after the final activation function - # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the - # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) - # # as the final block applied to the input, which is just a convolutional layer with no activation function - # # Hence, we are used Normalized ReLU to normalize the logits to the final output - # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) - - return out # returns logits - - - # -------------------------------- - # DATA PREPARATION - # -------------------------------- - def prepare_data(self): - # set deterministic training for reproducibility - set_determinism(seed=self.cfg["seed"]) - - # define training and validation transforms - train_transforms = Compose( - [ - LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "sc", "label"]), - Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), - Spacingd( - keys=["image", "sc", "label"], - pixdim=self.cfg["pixdim"], - mode=(2, 1, 1), - ), - # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "sc", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - # Flips the image : left becomes right - RandFlipd( - keys=["image", "sc", "label"], - spatial_axis=[0], - prob=0.2, - ), - # Flips the image : supperior becomes inferior - RandFlipd( - keys=["image", "sc", "label"], - spatial_axis=[1], - prob=0.2, - ), - # Flips the image : anterior becomes posterior - RandFlipd( - keys=["image","sc", "label"], - spatial_axis=[2], - prob=0.2, - ), - # RandAdjustContrastd( - # keys=["image"], - # prob=0.2, - # gamma=(0.5, 4.5), - # invert_image=True, - # ), - # we add the multiplication of the image by -1 - # RandLambdad( - # keys='image', - # func=multiply_by_negative_one, - # prob=0.2 - # ), - - # Normalize the intensity of the image - NormalizeIntensityd( - keys=["image"], - nonzero=False, - channel_wise=False - ), - # RandGaussianNoised( - # keys=["image"], - # prob=0.2, - # ), - # RandShiftIntensityd( - # keys=["image"], - # offsets=0.1, - # prob=0.2, - # ), - # Concatenates the image and the sc - ConcatItemsd(keys=["image", "sc"], name="inputs"), - EnsureTyped(keys=["inputs", "label"]), - # AsDiscreted( - # keys=["label"], - # num_classes=2, - # threshold_values=True, - # logit_thresh=0.2, - # ) - ] - ) - val_transforms = Compose( - [ - LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "sc", "label"]), - Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), - Spacingd( - keys=["image", "sc", "label"], - pixdim=self.cfg["pixdim"], - mode=(2, 1, 1), - ), - # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=1, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # ), - # Concatenates the image and the sc - ConcatItemsd(keys=["image", "sc"], name="inputs"), - # Normalize the intensity of the image - NormalizeIntensityd( - keys=["inputs"], - nonzero=False, - channel_wise=False - ), - EnsureTyped(keys=["inputs", "label"]), - # AsDiscreted( - # keys=["label"], - # num_classes=2, - # threshold_values=True, - # logit_thresh=0.2, - # ) - ] - ) - - # load the dataset - dataset = self.cfg["data"] - logger.info(f"Loading dataset: {dataset}") - train_files = load_decathlon_datalist(dataset, True, "train") - val_files = load_decathlon_datalist(dataset, True, "validation") - test_files = load_decathlon_datalist(dataset, True, "test") - - train_cache_rate = 0.5 - self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=16) - self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=16) - - # define test transforms - transforms_test = val_transforms - - # define post-processing transforms for testing; taken (with explanations) from - # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - self.test_post_pred = Compose([ - EnsureTyped(keys=["pred", "label"]), - Invertd(keys=["pred", "label"], transform=transforms_test, - orig_keys=["image", "label"], - meta_keys=["pred_meta_dict", "label_meta_dict"], - nearest_interp=False, to_tensor=True), - ]) - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) - - - # -------------------------------- - # DATA LOADERS - # -------------------------------- - def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=16, - pin_memory=True, persistent_workers=True) - - def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, - persistent_workers=True) - - def test_dataloader(self): - return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) - - - # -------------------------------- - # OPTIMIZATION - # -------------------------------- - def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) - return [optimizer], [scheduler] - - - # -------------------------------- - # TRAINING - # -------------------------------- - def training_step(self, batch, batch_idx): - - inputs, labels = batch["inputs"], batch["label"] - - # # print(inputs.shape, labels.shape) - # input_0 = inputs[0].detach().cpu().squeeze() - # # print(input_0.shape) - # label_0 = labels[0].detach().cpu().squeeze() - - # time_0 = datetime.now() - - # # save input 0 in a nifti file - # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) - # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") - - # # save label in a nifti file - # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) - # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") - - - # # check if any label image patch is empty in the batch - if check_empty_patch(labels) is None: - # print(f"Empty label patch found. Skipping training step ...") - return None - - output = self.forward(inputs) # logits - # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - - # get probabilities from logits - output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) - - # calculate training loss - loss = self.loss_function(output, labels) - - # calculate train dice - # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here - # So, take this dice score with a lot of salt - train_soft_dice = self.soft_dice_metric(output, labels) - - metrics_dict = { - "loss": loss.cpu(), - "train_soft_dice": train_soft_dice.detach().cpu(), - "train_number": len(inputs), - "train_image": inputs[0].detach().cpu().squeeze(), - "train_gt": labels[0].detach().cpu().squeeze(), - "train_pred": output[0].detach().cpu().squeeze() - } - self.train_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_train_epoch_end(self): - - if self.train_step_outputs == []: - # means the training step was skipped because of empty input patch - return None - else: - train_loss, train_soft_dice = 0, 0 - num_items = len(self.train_step_outputs) - for output in self.train_step_outputs: - train_loss += output["loss"].item() - train_soft_dice += output["train_soft_dice"].item() - - mean_train_loss = (train_loss / num_items) - mean_train_soft_dice = (train_soft_dice / num_items) - - wandb_logs = { - "train_soft_dice": mean_train_soft_dice, - "train_loss": mean_train_loss, - } - self.log_dict(wandb_logs) - - # # plot the training images - # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], - # gt=self.train_step_outputs[0]["train_gt"], - # pred=self.train_step_outputs[0]["train_pred"], - # ) - # wandb.log({"training images": wandb.Image(fig)}) - # plt.close(fig) - - # free up memory - self.train_step_outputs.clear() - wandb_logs.clear() - - - - # -------------------------------- - # VALIDATION - # -------------------------------- - def validation_step(self, batch, batch_idx): - - inputs, labels = batch["inputs"], batch["label"] - - # NOTE: this calculates the loss on the entire image after sliding window - outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", - sw_batch_size=4, predictor=self.forward, overlap=0.5,) - - # get probabilities from logits - outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) - - # calculate validation loss - loss = self.loss_function(outputs, labels) - - - # post-process for calculating the evaluation metric - post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] - post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] - val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) - - hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() - val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) - - # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, - # using .detach() to avoid storing the whole computation graph - # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 - metrics_dict = { - "val_loss": loss.detach().cpu(), - "val_soft_dice": val_soft_dice.detach().cpu(), - "val_hard_dice": val_hard_dice.detach().cpu(), - "val_number": len(post_outputs), - "val_image_0": inputs[0].detach().cpu().squeeze(), - "val_gt_0": labels[0].detach().cpu().squeeze(), - "val_pred_0": post_outputs[0].detach().cpu().squeeze(), - # "val_image_1": inputs[1].detach().cpu().squeeze(), - # "val_gt_1": labels[1].detach().cpu().squeeze(), - # "val_pred_1": post_outputs[1].detach().cpu().squeeze(), - } - self.val_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_validation_epoch_end(self): - - val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 - for output in self.val_step_outputs: - val_loss += output["val_loss"].sum().item() - val_soft_dice += output["val_soft_dice"].sum().item() - val_hard_dice += output["val_hard_dice"].sum().item() - num_items += output["val_number"] - - mean_val_loss = (val_loss / num_items) - mean_val_soft_dice = (val_soft_dice / num_items) - mean_val_hard_dice = (val_hard_dice / num_items) - - wandb_logs = { - "val_soft_dice": mean_val_soft_dice, - #"val_hard_dice": mean_val_hard_dice, - "val_loss": mean_val_loss, - } - - self.log_dict(wandb_logs) - - # save the best model based on validation dice score - if mean_val_soft_dice > self.best_val_dice: - self.best_val_dice = mean_val_soft_dice - self.best_val_epoch = self.current_epoch - - # save the best model based on validation loss - if mean_val_loss < self.best_val_loss: - self.best_val_loss = mean_val_loss - self.best_val_epoch = self.current_epoch - - logger.info( - f"\nCurrent epoch: {self.current_epoch}" - f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" - # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" - f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" - f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" - f"\n----------------------------------------------------") - - # log on to wandb - self.log_dict(wandb_logs) - - # # plot the validation images - # fig0 = plot_slices(image=self.val_step_outputs[0]["val_image_0"], - # gt=self.val_step_outputs[0]["val_gt_0"], - # pred=self.val_step_outputs[0]["val_pred_0"],) - # wandb.log({"validation images": wandb.Image(fig0)}) - # plt.close(fig0) - - - # free up memory - self.val_step_outputs.clear() - wandb_logs.clear() - - - # -------------------------------- - # TESTING - # -------------------------------- - def test_step(self, batch, batch_idx): - - test_input = batch["inputs"] - # print(batch["label_meta_dict"]["filename_or_obj"][0]) - batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, - sw_batch_size=4, predictor=self.forward, overlap=0.5) - - # normalize the logits - batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) - - post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] - - # make sure that the shapes of prediction and GT label are the same - # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") - assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape - - pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() - - # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics - # calculate soft and hard dice here (for quick overview), other metrics can be computed from - # the saved predictions using ANIMA - # 1. Dice Score - test_soft_dice = self.soft_dice_metric(pred, label) - - # binarizing the predictions - pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() - label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() - - # 1.1 Hard Dice Score - test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) - - metrics_dict = { - "test_hard_dice": test_hard_dice, - "test_soft_dice": test_soft_dice, - } - self.test_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_test_epoch_end(self): - - avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ - np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() - avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ - np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() - - logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") - logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") - - self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test - self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test - - # free up memory - self.test_step_outputs.clear() - -# -------------------------------- -# MAIN -# -------------------------------- -def main(): - # get the parser - parser = get_parser() - args= parser.parse_args() - - # load config file - with open(args.config, "r") as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - # Setting the seed - pl.seed_everything(config["seed"], workers=True) - - # define root path for finding datalists - dataset_root = config["data"] - - # define optimizer - optimizer_class = torch.optim.Adam - - wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) - - logger.info("Defining plans for nnUNet model ...") - - - # define model - # TODO: make the model deeper - # net = UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=config['unet_channels'], - # strides=config['unet_strides'], - # kernel_size=3, - # up_kernel_size=3, - # num_res_units=0, - # act='PRELU', - # norm=Norm.INSTANCE, - # dropout=0.0, - # bias=True, - # adn_ordering='NDA', - # ) - # net=UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(32, 64, 128, 256), - # strides=(2, 2, 2 ), - - # # dropout=0.1 - # ) - net = AttentionUnet( - spatial_dims=3, - in_channels=2, - out_channels=1, - channels=(32, 64, 128, 256, 512, 1024), - strides=(2, 2, 2, 2, 2), - dropout=0.1, - ) - # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) - - # net = create_nnunet_from_plans() - - logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") - - - # define loss function - #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) - loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) - # loss_func = SoftDiceLoss(smooth=1e-5) - # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above - # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") - #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") - logger.info(f"Using SoftDiceLoss ...") - # define callbacks - early_stopping = pl.callbacks.EarlyStopping( - monitor="val_loss", min_delta=0.00, - patience=config["early_stopping_patience"], - verbose=False, mode="min") - - lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') - - # i.e. train by loading weights from scratch - pl_model = Model(config, data_root=dataset_root, - optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id="test", results_path=config["best_model_path"]) - - # saving the best model based on validation loss - checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=True, save_weights_only=True) - - - logger.info(f"Starting training from scratch ...") - # wandb logger - exp_logger = pl.loggers.WandbLogger( - name="test", - save_dir="/home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results", - group="test-on-canproco", - log_model=True, # save best model using checkpoint callback - project='ms-lesion-agnostic', - entity='pierre-louis-benveniste', - config=config) - - # Saving training script to wandb - wandb.save("ms-lesion-agnostic/monai/nnunet/config_fake.yml") - wandb.save("ms-lesion-agnostic/monai/nnunet/train_monai_unet_lightning_multichannel.py") - - - # initialise Lightning's trainer. - trainer = pl.Trainer( - devices=1, accelerator="gpu", - logger=exp_logger, - callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], - check_val_every_n_epoch=config["eval_num"], - max_epochs=config["max_iterations"], - precision=32, - # deterministic=True, - enable_progress_bar=True) - # profiler="simple",) # to profile the training time taken for each step - - # Train! - trainer.fit(pl_model) - logger.info(f" Training Done!") - - # Closing wandb log - wandb.finish() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/monai/train_monai_unet_lightning_multichannel_output.py b/monai/train_monai_unet_lightning_multichannel_output.py deleted file mode 100644 index f3232b8..0000000 --- a/monai/train_monai_unet_lightning_multichannel_output.py +++ /dev/null @@ -1,725 +0,0 @@ -import os -import argparse -from datetime import datetime -from loguru import logger -import yaml -import nibabel as nib -from datetime import datetime - -import numpy as np -import wandb -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -import matplotlib.pyplot as plt -from monai.metrics import DiceMetric -from monai.losses import DiceLoss, DiceCELoss - -# Added this to solve problem with too many files open -## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 -import torch.multiprocessing -torch.multiprocessing.set_sharing_strategy('file_system') - -from losses import AdapWingLoss, SoftDiceLoss - -from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, create_nnunet_from_plans, print_data_types -from monai.networks.nets import UNet, BasicUNet, AttentionUnet - -from monai.networks.layers import Norm - - -from monai.transforms import ( - EnsureChannelFirstd, - Compose, - LoadImaged, - Orientationd, - RandFlipd, - RandShiftIntensityd, - Spacingd, - RandRotate90d, - NormalizeIntensityd, - RandCropByPosNegLabeld, - BatchInverseTransform, - RandAdjustContrastd, - AsDiscreted, - RandHistogramShiftd, - ResizeWithPadOrCropd, - EnsureTyped, - RandLambdad, - CropForegroundd, - RandGaussianNoised, - ConcatItemsd - ) - -from monai.utils import set_determinism -from monai.inferers import sliding_window_inference -import time -from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) - -# Added this because of following warning received: -## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` -## which will trade-off precision for performance. For more details, -## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision -# torch.set_float32_matmul_precision('medium' | 'high') - - -def get_parser(): - """ - This function returns the parser for the command line arguments. - """ - parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") - parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) - return parser - - -# create a "model"-agnostic class with PL to use different models -class Model(pl.LightningModule): - def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): - super().__init__() - self.cfg = config - self.save_hyperparameters(ignore=['net', 'loss_function']) - - self.root = data_root - self.net = net - self.lr = config["lr"] - self.loss_function = loss_function - self.optimizer_class = optimizer_class - self.save_exp_id = exp_id - self.results_path = results_path - - self.best_val_dice, self.best_val_epoch = 0, 0 - self.best_val_loss = float("inf") - - # define cropping and padding dimensions - # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference - # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches - # which could be sub-optimal. - # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that - # only the SC is in context. - self.spacing = config["spatial_size"] - self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] - - # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) - self.val_post_pred = self.val_post_label = Compose([EnsureType()]) - - # define evaluation metric - self.soft_dice_metric = dice_score - - # temp lists for storing outputs from training, validation, and testing - self.train_step_outputs = [] - self.val_step_outputs = [] - self.test_step_outputs = [] - - - # -------------------------------- - # FORWARD PASS - # -------------------------------- - def forward(self, x): - - out = self.net(x) - # # NOTE: MONAI's models only output the logits, not the output after the final activation function - # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the - # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) - # # as the final block applied to the input, which is just a convolutional layer with no activation function - # # Hence, we are used Normalized ReLU to normalize the logits to the final output - # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) - - return out # returns logits - - - # -------------------------------- - # DATA PREPARATION - # -------------------------------- - def prepare_data(self): - # set deterministic training for reproducibility - set_determinism(seed=self.cfg["seed"]) - - # define training and validation transforms - train_transforms = Compose( - [ - LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "sc", "label"]), - Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), - Spacingd( - keys=["image", "sc", "label"], - pixdim=self.cfg["pixdim"], - mode=(2, 1, 1), - ), - # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), - RandCropByPosNegLabeld( - keys=["image", "sc", "label"], - label_key="label", - spatial_size=self.cfg["spatial_size"], - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - # Flips the image : left becomes right - RandFlipd( - keys=["image", "sc", "label"], - spatial_axis=[0], - prob=0.2, - ), - # Flips the image : supperior becomes inferior - RandFlipd( - keys=["image", "sc", "label"], - spatial_axis=[1], - prob=0.2, - ), - # Flips the image : anterior becomes posterior - RandFlipd( - keys=["image","sc", "label"], - spatial_axis=[2], - prob=0.2, - ), - # RandAdjustContrastd( - # keys=["image"], - # prob=0.2, - # gamma=(0.5, 4.5), - # invert_image=True, - # ), - # we add the multiplication of the image by -1 - # RandLambdad( - # keys='image', - # func=multiply_by_negative_one, - # prob=0.2 - # ), - - # Normalize the intensity of the image - NormalizeIntensityd( - keys=["image"], - nonzero=False, - channel_wise=False - ), - # RandGaussianNoised( - # keys=["image"], - # prob=0.2, - # ), - # RandShiftIntensityd( - # keys=["image"], - # offsets=0.1, - # prob=0.2, - # ), - # Concatenates the image and the sc - ConcatItemsd(keys=["sc", "label"], name="outputs"), - EnsureTyped(keys=["image", "outputs"]), - # AsDiscreted( - # keys=["label"], - # num_classes=2, - # threshold_values=True, - # logit_thresh=0.2, - # ) - ] - ) - val_transforms = Compose( - [ - LoadImaged(keys=["image", "sc", "label"], reader="NibabelReader"), - EnsureChannelFirstd(keys=["image", "sc", "label"]), - Orientationd(keys=["image", "sc", "label"], axcodes="RPI"), - Spacingd( - keys=["image", "sc", "label"], - pixdim=self.cfg["pixdim"], - mode=(2, 1, 1), - ), - # CropForegroundd(keys=["image", "label"], source_key="label", margin=100), - ResizeWithPadOrCropd(keys=["image", "sc", "label"], spatial_size=self.cfg["spatial_size"],), - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=1, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # ), - - # Normalize the intensity of the image - NormalizeIntensityd( - keys=["image"], - nonzero=False, - channel_wise=False - ), - # Concatenates the image and the sc - ConcatItemsd(keys=["sc", "label"], name="outputs"), - EnsureTyped(keys=["image", "outputs"]), - # AsDiscreted( - # keys=["label"], - # num_classes=2, - # threshold_values=True, - # logit_thresh=0.2, - # ) - ] - ) - - # load the dataset - dataset = self.cfg["data"] - logger.info(f"Loading dataset: {dataset}") - train_files = load_decathlon_datalist(dataset, True, "train") - val_files = load_decathlon_datalist(dataset, True, "validation") - test_files = load_decathlon_datalist(dataset, True, "test") - - train_cache_rate = 0.5 - self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=16) - self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.25, num_workers=16) - - # define test transforms - transforms_test = val_transforms - - # define post-processing transforms for testing; taken (with explanations) from - # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - self.test_post_pred = Compose([ - EnsureTyped(keys=["pred", "label"]), - Invertd(keys=["pred", "label"], transform=transforms_test, - orig_keys=["image", "label"], - meta_keys=["pred_meta_dict", "label_meta_dict"], - nearest_interp=False, to_tensor=True), - ]) - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) - - - # -------------------------------- - # DATA LOADERS - # -------------------------------- - def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=16, - pin_memory=True, persistent_workers=True) - - def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, - persistent_workers=True) - - def test_dataloader(self): - return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) - - - # -------------------------------- - # OPTIMIZATION - # -------------------------------- - def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) - return [optimizer], [scheduler] - - - # -------------------------------- - # TRAINING - # -------------------------------- - def training_step(self, batch, batch_idx): - - inputs, labels = batch["image"], batch["outputs"] - - # # print(inputs.shape, labels.shape) - # input_0 = inputs[0].detach().cpu().squeeze() - # # print(input_0.shape) - # label_0 = labels[0].detach().cpu().squeeze() - - # time_0 = datetime.now() - - # # save input 0 in a nifti file - # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) - # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") - - # # save label in a nifti file - # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) - # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") - - - # # check if any label image patch is empty in the batch - if check_empty_patch(labels) is None: - # print(f"Empty label patch found. Skipping training step ...") - return None - - output = self.forward(inputs) # logits - # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - - # get probabilities from logits - output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) - - # calculate training loss - loss = self.loss_function(output, labels) - - # calculate train loss for the sc and the lesion - loss_sc = self.loss_function(output[:, 0, ...], labels[:, 0, ...]) - loss_lesion = self.loss_function(output[:, 1, ...], labels[:, 1, ...]) - - # calculate train dice - # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here - # So, take this dice score with a lot of salt - train_soft_dice = self.soft_dice_metric(output, labels) - - # calculate the dice for the sc and the lesion - train_soft_dice_sc = self.soft_dice_metric(output[:, 0, ...], labels[:, 0, ...]) - train_soft_dice_lesion = self.soft_dice_metric(output[:, 1, ...], labels[:, 1, ...]) - - metrics_dict = { - "loss": loss.cpu(), - "loss_sc": loss_sc.cpu(), - "loss_lesion": loss_lesion.cpu(), - "train_soft_dice": train_soft_dice.detach().cpu(), - "train_soft_dice_sc": train_soft_dice_sc.detach().cpu(), - "train_soft_dice_lesion": train_soft_dice_lesion.detach().cpu(), - "train_number": len(inputs), - "train_image": inputs[0].detach().cpu().squeeze(), - "train_gt_sc": labels[0][0].detach().cpu().squeeze(), - "train_gt_lesion": labels[0][1].detach().cpu().squeeze(), - "train_pred_sc": output[0][0].detach().cpu().squeeze(), - "train_pred_lesion": output[0][1].detach().cpu().squeeze(), - - } - self.train_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_train_epoch_end(self): - - if self.train_step_outputs == []: - # means the training step was skipped because of empty input patch - return None - else: - train_loss, train_soft_dice = 0, 0 - train_loss_sc, train_loss_lesion = 0, 0 - train_soft_dice_sc, train_soft_dice_lesion = 0, 0 - num_items = len(self.train_step_outputs) - for output in self.train_step_outputs: - train_loss += output["loss"].item() - train_soft_dice += output["train_soft_dice"].item() - train_loss_sc += output["loss_sc"].item() - train_loss_lesion += output["loss_lesion"].item() - train_soft_dice_sc += output["train_soft_dice_sc"].item() - train_soft_dice_lesion += output["train_soft_dice_lesion"].item() - - mean_train_loss = (train_loss / num_items) - mean_train_soft_dice = (train_soft_dice / num_items) - mean_train_loss_sc = (train_loss_sc / num_items) - mean_train_loss_lesion = (train_loss_lesion / num_items) - mean_train_soft_dice_sc = (train_soft_dice_sc / num_items) - mean_train_soft_dice_lesion = (train_soft_dice_lesion / num_items) - - wandb_logs = { - "train_soft_dice": mean_train_soft_dice, - "train_loss": mean_train_loss, - "train_loss_sc": mean_train_loss_sc, - "train_loss_lesion": mean_train_loss_lesion, - "train_soft_dice_sc": mean_train_soft_dice_sc, - "train_soft_dice_lesion": mean_train_soft_dice_lesion - } - self.log_dict(wandb_logs) - - # plot the training images - fig = plot_slices(image=self.train_step_outputs[0]["train_image"], - gt=self.train_step_outputs[0]["train_gt_lesion"], - pred=self.train_step_outputs[0]["train_pred_lesion"], - ) - wandb.log({"training images lesion": wandb.Image(fig)}) - plt.close(fig) - - # plot the training images - fig2 = plot_slices(image=self.train_step_outputs[0]["train_image"], - gt=self.train_step_outputs[0]["train_gt_sc"], - pred=self.train_step_outputs[0]["train_pred_sc"], - ) - wandb.log({"training images sc": wandb.Image(fig2)}) - plt.close(fig2) - - # free up memory - self.train_step_outputs.clear() - wandb_logs.clear() - - - - # -------------------------------- - # VALIDATION - # -------------------------------- - def validation_step(self, batch, batch_idx): - - inputs, labels = batch["image"], batch["outputs"] - - # NOTE: this calculates the loss on the entire image after sliding window - outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", - sw_batch_size=4, predictor=self.forward, overlap=0.5,) - - # get probabilities from logits - outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) - - # calculate validation loss - loss = self.loss_function(outputs, labels) - - - # post-process for calculating the evaluation metric - post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] - post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] - val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) - - hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() - val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) - - # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, - # using .detach() to avoid storing the whole computation graph - # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 - metrics_dict = { - "val_loss": loss.detach().cpu(), - "val_soft_dice": val_soft_dice.detach().cpu(), - "val_hard_dice": val_hard_dice.detach().cpu(), - "val_number": len(post_outputs), - "val_image_0": inputs[0].detach().cpu().squeeze(), - "val_gt_0": labels[0].detach().cpu().squeeze(), - "val_pred_0": post_outputs[0].detach().cpu().squeeze(), - # "val_image_1": inputs[1].detach().cpu().squeeze(), - # "val_gt_1": labels[1].detach().cpu().squeeze(), - # "val_pred_1": post_outputs[1].detach().cpu().squeeze(), - } - self.val_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_validation_epoch_end(self): - - val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 - for output in self.val_step_outputs: - val_loss += output["val_loss"].sum().item() - val_soft_dice += output["val_soft_dice"].sum().item() - val_hard_dice += output["val_hard_dice"].sum().item() - num_items += output["val_number"] - - mean_val_loss = (val_loss / num_items) - mean_val_soft_dice = (val_soft_dice / num_items) - mean_val_hard_dice = (val_hard_dice / num_items) - - wandb_logs = { - "val_soft_dice": mean_val_soft_dice, - #"val_hard_dice": mean_val_hard_dice, - "val_loss": mean_val_loss, - } - - self.log_dict(wandb_logs) - - # save the best model based on validation dice score - if mean_val_soft_dice > self.best_val_dice: - self.best_val_dice = mean_val_soft_dice - self.best_val_epoch = self.current_epoch - - # save the best model based on validation loss - if mean_val_loss < self.best_val_loss: - self.best_val_loss = mean_val_loss - self.best_val_epoch = self.current_epoch - - logger.info( - f"\nCurrent epoch: {self.current_epoch}" - f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" - # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" - f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" - f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" - f"\n----------------------------------------------------") - - # log on to wandb - self.log_dict(wandb_logs) - - # # plot the validation images - # fig0 = plot_slices(image=self.val_step_outputs[0]["val_image_0"], - # gt=self.val_step_outputs[0]["val_gt_0"], - # pred=self.val_step_outputs[0]["val_pred_0"],) - # wandb.log({"validation images": wandb.Image(fig0)}) - # plt.close(fig0) - - - # free up memory - self.val_step_outputs.clear() - wandb_logs.clear() - - - # -------------------------------- - # TESTING - # -------------------------------- - def test_step(self, batch, batch_idx): - - test_input = batch["inputs"] - # print(batch["label_meta_dict"]["filename_or_obj"][0]) - batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, - sw_batch_size=4, predictor=self.forward, overlap=0.5) - - # normalize the logits - batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) - - post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] - - # make sure that the shapes of prediction and GT label are the same - # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") - assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape - - pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() - - # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics - # calculate soft and hard dice here (for quick overview), other metrics can be computed from - # the saved predictions using ANIMA - # 1. Dice Score - test_soft_dice = self.soft_dice_metric(pred, label) - - # binarizing the predictions - pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() - label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() - - # 1.1 Hard Dice Score - test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) - - metrics_dict = { - "test_hard_dice": test_hard_dice, - "test_soft_dice": test_soft_dice, - } - self.test_step_outputs.append(metrics_dict) - - return metrics_dict - - def on_test_epoch_end(self): - - avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ - np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() - avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ - np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() - - logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") - logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") - - self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test - self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test - - # free up memory - self.test_step_outputs.clear() - -# -------------------------------- -# MAIN -# -------------------------------- -def main(): - # get the parser - parser = get_parser() - args= parser.parse_args() - - # load config file - with open(args.config, "r") as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - # Setting the seed - pl.seed_everything(config["seed"], workers=True) - - # define root path for finding datalists - dataset_root = config["data"] - - # define optimizer - optimizer_class = torch.optim.Adam - - wandb.init(project=f'monai-unet-ms-lesion-seg-canproco', config=config) - - logger.info("Defining plans for nnUNet model ...") - - - # define model - # TODO: make the model deeper - # net = UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=config['unet_channels'], - # strides=config['unet_strides'], - # kernel_size=3, - # up_kernel_size=3, - # num_res_units=0, - # act='PRELU', - # norm=Norm.INSTANCE, - # dropout=0.0, - # bias=True, - # adn_ordering='NDA', - # ) - # net=UNet( - # spatial_dims=3, - # in_channels=1, - # out_channels=1, - # channels=(32, 64, 128, 256), - # strides=(2, 2, 2 ), - - # # dropout=0.1 - # ) - net = AttentionUnet( - spatial_dims=3, - in_channels=1, - out_channels=2, - channels=(32, 64, 128, 256, 512, 1024), - strides=(2, 2, 2, 2, 2), - dropout=0.1, - ) - # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) - - # net = create_nnunet_from_plans() - - logger.add(os.path.join(config["log_path"], str(datetime.now()) + 'log.txt'), rotation="10 MB", level="INFO") - - - # define loss function - #loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - # loss_func = DiceLoss(sigmoid=True, smooth_dr=1e-4) - loss_func = DiceCELoss(sigmoid=True, smooth_dr=1e-4) - # loss_func = SoftDiceLoss(smooth=1e-5) - # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above - # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") - #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") - logger.info(f"Using SoftDiceLoss ...") - # define callbacks - early_stopping = pl.callbacks.EarlyStopping( - monitor="val_loss", min_delta=0.00, - patience=config["early_stopping_patience"], - verbose=False, mode="min") - - lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') - - # i.e. train by loading weights from scratch - pl_model = Model(config, data_root=dataset_root, - optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id="test", results_path=config["best_model_path"]) - - # saving the best model based on validation loss - checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=config["best_model_path"], filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=True, save_weights_only=True) - - - logger.info(f"Starting training from scratch ...") - # wandb logger - exp_logger = pl.loggers.WandbLogger( - name="test", - save_dir="/home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results", - group="test-on-canproco", - log_model=True, # save best model using checkpoint callback - project='ms-lesion-agnostic', - entity='pierre-louis-benveniste', - config=config) - - # Saving training script to wandb - wandb.save("ms-lesion-agnostic/monai/nnunet/config_fake.yml") - wandb.save("ms-lesion-agnostic/monai/nnunet/train_monai_unet_lightning_regionBased.py") - - - # initialise Lightning's trainer. - trainer = pl.Trainer( - devices=1, accelerator="gpu", - logger=exp_logger, - callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], - check_val_every_n_epoch=config["eval_num"], - max_epochs=config["max_iterations"], - precision=32, - # deterministic=True, - enable_progress_bar=True) - # profiler="simple",) # to profile the training time taken for each step - - # Train! - trainer.fit(pl_model) - logger.info(f" Training Done!") - - # Closing wandb log - wandb.finish() - - -if __name__ == "__main__": - main() \ No newline at end of file From 1fdbdd35860df4e8e2932c3fb01da6b5827080d5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Mon, 22 Jul 2024 17:45:55 -0400 Subject: [PATCH 083/108] moved files to utils folder --- monai/{ => utils}/losses.py | 0 monai/{ => utils}/utils.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename monai/{ => utils}/losses.py (100%) rename monai/{ => utils}/utils.py (100%) diff --git a/monai/losses.py b/monai/utils/losses.py similarity index 100% rename from monai/losses.py rename to monai/utils/losses.py diff --git a/monai/utils.py b/monai/utils/utils.py similarity index 100% rename from monai/utils.py rename to monai/utils/utils.py From 8292ff8cea997d506b4f644a660aeab49c73bc68 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 23 Jul 2024 14:22:59 -0400 Subject: [PATCH 084/108] updated parameters for model testing --- monai/config_test.yml | 13 ++++++++++--- monai/test_model.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/monai/config_test.yml b/monai/config_test.yml index 9c497e7..75a1a1d 100644 --- a/monai/config_test.yml +++ b/monai/config_test.yml @@ -1,8 +1,15 @@ # dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json -dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +# dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json + pixdim : [0.7, 0.7, 0.7] spatial_size : [64, 128, 128] attention_unet_channels : [32, 64, 128, 256, 512] attention_unet_strides : [2, 2, 2, 2, 2] -path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt -output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/ \ No newline at end of file + +# path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt +path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/best_model.pth/best_model.ckpt + +# output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/ +output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/ + diff --git a/monai/test_model.py b/monai/test_model.py index e34ae2b..3c9864a 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -20,7 +20,7 @@ import torch from monai.inferers import sliding_window_inference import torch.nn.functional as F -from utils import dice_score +from utils.utils import dice_score import argparse import yaml import torch.multiprocessing @@ -78,7 +78,7 @@ def main(): Spacingd( keys=["image", "label"], pixdim=cfg["pixdim"], - mode=(2, 1), + mode=(2, 0), ), NormalizeIntensityd( keys=["image"], From 7e361bac61157314b13a007ac7ca5c631bf543d5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 23 Jul 2024 15:53:42 -0400 Subject: [PATCH 085/108] updated inference script and evaluation plots scripts --- monai/plot_performance.py | 69 ++++++++++++++++++++++++++++++++++++--- monai/requirements.txt | 3 +- monai/test_model.py | 2 +- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/monai/plot_performance.py b/monai/plot_performance.py index eefa4af..28dd871 100644 --- a/monai/plot_performance.py +++ b/monai/plot_performance.py @@ -8,6 +8,7 @@ import seaborn as sns import pandas as pd import argparse +import json def get_parser(): @@ -16,6 +17,8 @@ def get_parser(): """ parser = argparse.ArgumentParser(description="Plot the performance of the model") parser.add_argument("--pred-dir-path", help="Path to the directory containing the dice_score.txt file", required=True) + parser.add_argument("--data-json-path", help="Path to the json file containing the data split", required=True) + parser.add_argument("--split", help="Data split to use (train, validation, test)", required=True, type=str) return parser @@ -47,28 +50,84 @@ def main(): # convert to a df with name and dice score test_dice_results = pd.DataFrame(list(test_dice_results.items()), columns=['name', 'dice_score']) - # Add the contrats column - test_dice_results['contrast'] = test_dice_results['name'].apply(lambda x: x.split('_')[-1]) + # Create an empty column for the contrast, the site and the resolution + test_dice_results['contrast'] = None + test_dice_results['site'] = None + test_dice_results['resolution'] = None + + # Load the data json file + data_json_path = args.data_json_path + with open(data_json_path, 'r') as f: + jsondata = json.load(f) + + # Iterate over the test files + for file in test_dice_results['name']: + # We find the corresponding file in the json file + for data in jsondata[args.split]: + if data["image"] == file: + # Add the contrat, the site and the resolution to the df + test_dice_results.loc[test_dice_results['name'] == file, 'contrast'] = data['contrast'] + test_dice_results.loc[test_dice_results['name'] == file, 'site'] = data['site'] + test_dice_results.loc[test_dice_results['name'] == file, 'orientation'] = data['orientation'] # Count the number of samples per contrast contrast_counts = test_dice_results['contrast'].value_counts() # In the df replace the contrats by the number of samples of the contarsts( for example, T2 becomes T2 (n=10)) - test_dice_results['contrast'] = test_dice_results['contrast'].apply(lambda x: x + f' (n={contrast_counts[x]})') + test_dice_results['contrast_count'] = test_dice_results['contrast'].apply(lambda x: x + f' (n={contrast_counts[x]})') + + # Same for the site + site_counts = test_dice_results['site'].value_counts() + test_dice_results['site_count'] = test_dice_results['site'].apply(lambda x: x + f' (n={site_counts[x]})') + + # Same for the resolution + resolution_counts = test_dice_results['orientation'].value_counts() + test_dice_results['orientation_count'] = test_dice_results['orientation'].apply(lambda x: x + f' (n={resolution_counts[x]})') # plot a violin plot per contrast plt.figure(figsize=(20, 10)) plt.grid(True) - sns.violinplot(x='contrast', y='dice_score', data=test_dice_results) + sns.violinplot(x='contrast_count', y='dice_score', data=test_dice_results) # y ranges from -0.2 to 1.2 plt.ylim(-0.2, 1.2) plt.title('Dice scores per contrast') plt.show() # Save the plot - plt.savefig(path_to_outputs + '/dice_scores.png') + plt.savefig(path_to_outputs + '/dice_scores_contrast.png') print(f"Saved the dice_scores plot in {path_to_outputs}") + # plot a violin plot per site + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='site_count', y='dice_score', data=test_dice_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Dice scores per site') + plt.show() + + # Save the plot + plt.savefig(path_to_outputs + '/dice_scores_site.png') + print(f"Saved the dice_scores per site plot in {path_to_outputs}") + + # plot a violin plot per resolution + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='orientation_count', y='dice_score', data=test_dice_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Dice scores per orientation') + plt.show() + + # Save the plot + plt.savefig(path_to_outputs + '/dice_scores_orientation.png') + print(f"Saved the dice_scores per orientation plot in {path_to_outputs}") + + # Save the test_dice_results dataframe + test_dice_results.to_csv(path_to_outputs + '/dice_results.csv', index=False) + + return None + if __name__ == "__main__": main() \ No newline at end of file diff --git a/monai/requirements.txt b/monai/requirements.txt index 8a9c2f5..7b726e9 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -8,4 +8,5 @@ pytorch-lightning==2.2.1 cupy-cuda117==10.6.0 loguru==0.7.2 wandb==0.15.12 -dynamic-network-architectures==0.2 \ No newline at end of file +dynamic-network-architectures==0.2 +seaborn==0.13.2 \ No newline at end of file diff --git a/monai/test_model.py b/monai/test_model.py index 3c9864a..057badc 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -169,7 +169,7 @@ def main(): pred_saver(pred) # Save the dice score - dice_scores[file_name] = dice + dice_scores[test_files[i]["image"]] = dice test_input.detach() From bf0726209276e831e8e720b934a03d55161020a6 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 23 Jul 2024 15:55:50 -0400 Subject: [PATCH 086/108] added removal of .nii.gz for UINT1 contrast --- monai/1_create_msd_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py index df027e0..ddce188 100644 --- a/monai/1_create_msd_data.py +++ b/monai/1_create_msd_data.py @@ -216,7 +216,7 @@ def main(): temp_data_basel["total_lesion_volume"] = total_lesion_volume temp_data_basel["nb_lesions"] = nb_lesions temp_data_basel["site"]='basel' - temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1] + temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) if args.lesion_only and nb_lesions == 0: continue From 9f5587a43a26b97d95305a9ab343e7d6f1690db3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 1 Aug 2024 15:52:59 -0400 Subject: [PATCH 087/108] changed workers to 0 for test_model --- monai/test_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/test_model.py b/monai/test_model.py index 057badc..17cd80b 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -103,8 +103,8 @@ def main(): ]) # Create the data loader - test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=4) - test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=1) + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) # Load the model net = AttentionUnet( From c23d63f27d9510a7ae94d1fa2711e09182e44eba Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 1 Aug 2024 15:53:14 -0400 Subject: [PATCH 088/108] added more info in output --- monai/plot_performance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/plot_performance.py b/monai/plot_performance.py index 28dd871..7b586c5 100644 --- a/monai/plot_performance.py +++ b/monai/plot_performance.py @@ -69,6 +69,8 @@ def main(): test_dice_results.loc[test_dice_results['name'] == file, 'contrast'] = data['contrast'] test_dice_results.loc[test_dice_results['name'] == file, 'site'] = data['site'] test_dice_results.loc[test_dice_results['name'] == file, 'orientation'] = data['orientation'] + test_dice_results.loc[test_dice_results['name'] == file, 'nb_lesions'] = data['nb_lesions'] + test_dice_results.loc[test_dice_results['name'] == file, 'total_lesion_volume'] = data['total_lesion_volume'] # Count the number of samples per contrast contrast_counts = test_dice_results['contrast'].value_counts() From 034d97ec459e8568881ab57d263aa3c6c964cb55 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 1 Aug 2024 15:53:29 -0400 Subject: [PATCH 089/108] created file for cropping aroung head --- monai/1_create_msd_data_head_cropped.py | 395 ++++++++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 monai/1_create_msd_data_head_cropped.py diff --git a/monai/1_create_msd_data_head_cropped.py b/monai/1_create_msd_data_head_cropped.py new file mode 100644 index 0000000..d53cd49 --- /dev/null +++ b/monai/1_create_msd_data_head_cropped.py @@ -0,0 +1,395 @@ +""" +This file creates the MSD-style JSON datalist to train an nnunet model using monai. +The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large. + +Arguments: + -pd, --path-data: Path to the data set directory + -po, --path-out: Path to the output directory where dataset json is saved + --lesion-only: Use only masks which contain some lesions + --seed: Seed for reproducibility + --canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo + +Example: + python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt + +TO DO: + * + +Pierre-Louis Benveniste +""" + +import os +import json +from tqdm import tqdm +import yaml +import argparse +from loguru import logger +from sklearn.model_selection import train_test_split +from datetime import date +from pathlib import Path +import nibabel as nib +import numpy as np +import skimage +from utils.image import Image + + +def get_parser(): + """ + Get parser for script create_msd_data.py + + Input: + None + + Returns: + parser : argparse object + """ + + parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') + + parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') + parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') + parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo') + parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') + parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") + + return parser + + +def count_lesion(label_file): + """ + This function takes a label file and counts the number of lesions in it. + + Input: + label_file : str : Path to the label file + + Returns: + count : int : Number of lesions in the label file + total_volume : float : Total volume of lesions in the label file + """ + + label = nib.load(label_file) + label_data = label.get_fdata() + + # get the total volume of the lesions + total_volume = np.sum(label_data) + resolution = label.header.get_zooms() + total_volume = total_volume * np.prod(resolution) + + # get the number of lesions + _, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True) + + return total_volume, nb_lesions + + +def get_orientation(image_path): + """ + This function takes an image file as input and returns its orientation. + + Input: + image_path : str : Path to the image file + + Returns: + orientation : str : Orientation of the image + """ + img = Image(str(image_path)) + img.change_orientation('RPI') + # Get pixdim + pixdim = img.dim[4:7] + # If all are the same, the image is isotropic + if np.allclose(pixdim, pixdim[0], atol=1e-3): + orientation = 'iso' + return orientation + # Elif, the lowest arg is 0 then the orientation is sagittal + elif np.argmax(pixdim) == 0: + orientation = 'sag' + # Elif, the lowest arg is 1 then the orientation is coronal + elif np.argmax(pixdim) == 1: + orientation = 'cor' + # Else the orientation is axial + else: + orientation = 'ax' + return orientation + + +def cropping_saving(image_path, label_path, cropped_head_data_folder): + """ + This function does the following action successively: + - copy image and label to the output folder for cropped head data + - segments the spinal cord on the image + - crops the image and label to the remove the superior part of the head (what is above the seg of the spinal cord) + - save the cropped image and label in the output folder + + Input: + image_path : str : Path to the image file + label_path : str : Path to the label file + cropped_head_data_folder : str : Path to the output folder + + Returns: + image_cropped : str : Path to the cropped image + seg_cropped : str : Path to the cropped label + """ + + # Copy image and label to the output folder for cropped head data + image_cropped = os.path.join(cropped_head_data_folder, image_path.split('/')[-1]) + seg_cropped = os.path.join(cropped_head_data_folder, label_path.split('/')[-1]) + img = Image(image_path) + img.change_orientation('RPI') + img.save(image_cropped) + seg = Image(label_path) + seg.change_orientation('RPI') + seg.save(seg_cropped) + + # Segment the spinal cord on the image + ## Create a temporary folder + temp_folder = os.path.join(cropped_head_data_folder, "temp") + os.makedirs(temp_folder, exist_ok=True) + ## Segment the spinal cord + os.system(f"sct_deepseg -i {image_cropped} -o {os.path.join(temp_folder, 'seg.nii.gz')} -task seg_sc_contrast_agnostic -thr 0.5") + ## Get the highest point of the spinal cord + spinal_cord_seg = Image(os.path.join(temp_folder, 'seg.nii.gz')) + spinal_cord_seg.change_orientation('RPI') + spinal_cord_seg_data = spinal_cord_seg.data + spinal_cord_superior = np.max(np.where(spinal_cord_seg_data == 1)[2]) + ## Remove the temporary folder + os.system(f"rm -rf {temp_folder}") + + # Crop the image and label to the remove the superior part of the head (what is above the seg of the spinal cord) + os.system(f"sct_crop_image -i {image_cropped} -o {image_cropped} -zmax {spinal_cord_superior}") + os.system(f"sct_crop_image -i {seg_cropped} -o {seg_cropped} -zmax {spinal_cord_superior}") + + return image_cropped, seg_cropped + + +def main(): + """ + This is the main function of the script. + + Input: + None + + Returns: + None + """ + # Get the arguments + parser = get_parser() + args = parser.parse_args() + + root = args.path_data + seed = args.seed + + # Get all subjects + basel_path = Path(os.path.join(root, "basel-mp2rage")) + bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched")) + canproco_path = Path(os.path.join(root, "canproco")) + nih_path = Path(os.path.join(root, "nih-ms-mp2rage")) + sct_testing_path = Path(os.path.join(root, "sct-testing-large")) + + derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) + derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) + derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) + derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz')) + derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) + + # Make the folder for the cropped images + cropped_head_data_folder = os.path.join(args.path_out, "cropped_head_data") + os.makedirs(args.path_out, exist_ok=True) + os.makedirs(cropped_head_data_folder, exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "basel-mp2rage"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "bavaria-quebec-spine-ms-unstitched"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "canproco"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "nih-ms-mp2rage"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "sct-testing-large"), exist_ok=True) + + # Path to the file containing the list of subjects to exclude from CanProCo + if args.canproco_exclude is not None: + with open(args.canproco_exclude, 'r') as file: + canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader) + # only keep the contrast psir and stir + canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] + + derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct + logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}") + + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + # sort the subjects + train_derivatives = sorted(train_derivatives) + val_derivatives = sorted(val_derivatives) + test_derivatives = sorted(test_derivatives) + + # logger.info(f"Number of training subjects: {len(train_subjects)}") + # logger.info(f"Number of validation subjects: {len(val_subjects)}") + # logger.info(f"Number of testing subjects: {len(test_subjects)}") + + # dump train/val/test splits into a yaml file + with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: + yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "ms-lesion-agnostic" + params["labels"] = { + "0": "background", + "1": "ms-lesion-seg" + } + params["license"] = "plb" + params["modality"] = { + "0": "MRI" + } + params["name"] = "ms-lesion-agnostic" + params["seed"] = args.seed + params["reference"] = "NeuroPoly" + params["tensorImageSize"] = "3D" + + train_derivatives_dict = {"train": train_derivatives} + val_derivatives_dict = {"validation": val_derivatives} + test_derivatives_dict = {"test": test_derivatives} + all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict] + + # iterate through the train/val/test splits and add those which have both image and label + for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"): + + for name, derivs_list in derivatives_dict.items(): + + temp_list = [] + for subject_no, derivative in enumerate(derivs_list): + + + temp_data_basel = {} + temp_data_bavaria = {} + temp_data_canproco = {} + temp_data_nih = {} + temp_data_sct = {} + + # Basel + if 'basel-mp2rage' in str(derivative): + relative_path = derivative.relative_to(basel_path).parent + temp_data_basel["label"] = str(derivative) + temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_basel["image"], temp_data_basel["label"], os.path.join(cropped_head_data_folder, "basel-mp2rage")) + temp_data_basel["label"] = seg + temp_data_basel["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) + temp_data_basel["total_lesion_volume"] = total_lesion_volume + temp_data_basel["nb_lesions"] = nb_lesions + temp_data_basel["site"]='basel' + temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_basel) + + # Bavaria-quebec + elif 'bavaria-quebec-spine-ms' in str(derivative): + temp_data_bavaria["label"] = str(derivative) + temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_bavaria["image"], temp_data_bavaria["label"], os.path.join(cropped_head_data_folder, "bavaria-quebec-spine-ms-unstitched")) + temp_data_bavaria["label"] = seg + temp_data_bavaria["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) + temp_data_bavaria["total_lesion_volume"] = total_lesion_volume + temp_data_bavaria["nb_lesions"] = nb_lesions + temp_data_bavaria["site"]='bavaria-quebec' + temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_bavaria) + + # Canproco + elif 'canproco' in str(derivative): + subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '') + subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') + if subject_id in canproco_exclude_list: + continue + temp_data_canproco["label"] = str(derivative) + temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_canproco["image"], temp_data_canproco["label"], os.path.join(cropped_head_data_folder, "canproco")) + temp_data_canproco["label"] = seg + temp_data_canproco["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) + temp_data_canproco["total_lesion_volume"] = total_lesion_volume + temp_data_canproco["nb_lesions"] = nb_lesions + temp_data_canproco["site"]='canproco' + temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_canproco) + + # nih-ms-mp2rage + elif 'nih-ms-mp2rage' in str(derivative): + temp_data_nih["label"] = str(derivative) + temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_nih["image"], temp_data_nih["label"], os.path.join(cropped_head_data_folder, "nih-ms-mp2rage")) + temp_data_nih["label"] = seg + temp_data_nih["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) + temp_data_nih["total_lesion_volume"] = total_lesion_volume + temp_data_nih["nb_lesions"] = nb_lesions + temp_data_nih["site"]='nih' + temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_nih) + + # sct-testing-large + elif 'sct-testing-large' in str(derivative): + temp_data_sct["label"] = str(derivative) + temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_sct["image"], temp_data_sct["label"], os.path.join(cropped_head_data_folder, "sct-testing-large")) + temp_data_sct["label"] = seg + temp_data_sct["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) + temp_data_sct["total_lesion_volume"] = total_lesion_volume + temp_data_sct["nb_lesions"] = nb_lesions + temp_data_sct["site"]='sct-testing-large' + temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_sct) + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + params["numTest"] = len(params["test"]) + params["numTraining"] = len(params["train"]) + params["numValidation"] = len(params["validation"]) + # Print total number of images + logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + if args.lesion_only: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w") + else: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file From 00ec30c36c79cadc313b3f1c5eacc88765e1737c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 9 Aug 2024 16:56:50 -0400 Subject: [PATCH 090/108] updated training script to sota model training script (set workers to 0) --- monai/train_monai_unet_lightning.py | 258 +++++++++++++--------------- 1 file changed, 123 insertions(+), 135 deletions(-) diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py index bda010d..4685086 100644 --- a/monai/train_monai_unet_lightning.py +++ b/monai/train_monai_unet_lightning.py @@ -12,15 +12,16 @@ import torch.nn.functional as F import matplotlib.pyplot as plt import time +import torch.multiprocessing # Added this to solve problem with too many files open ## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 -import torch.multiprocessing +## Linke to other issue: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/59 torch.multiprocessing.set_sharing_strategy('file_system') -from losses import AdapWingLoss, SoftDiceLoss +from utils.losses import AdapWingLoss, SoftDiceLoss -from utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions +from utils.utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR from monai.metrics import DiceMetric from monai.losses import DiceLoss, DiceCELoss @@ -142,11 +143,10 @@ def prepare_data(self): # define training and validation transforms train_transforms = Compose( - [ + [ LoadImaged(keys=["image", "label"], reader="NibabelReader"), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RPI"), - # This changes the spacing of the image Spacingd( keys=["image", "label"], pixdim=self.cfg["pixdim"], @@ -158,21 +158,47 @@ def prepare_data(self): nonzero=False, channel_wise=False ), + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=4, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), # This resizes the image and the label to the spatial size defined in the config ResizeWithPadOrCropd( keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - # Spatial transforms - # Random rotation of the image - RandRotated( + # Flips the image : left becomes right + RandFlipd( keys=["image", "label"], - range_x=np.pi, - range_y=np.pi, - range_z=np.pi, + spatial_axis=[0], + prob=self.cfg["DA_probability"], + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=self.cfg["DA_probability"], + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], prob=self.cfg["DA_probability"], - keep_size=True, - mode=('bilinear', 'nearest'), ), # Random elastic deformation Rand3DElasticd( @@ -182,15 +208,6 @@ def prepare_data(self): prob=self.cfg["DA_probability"], mode=['bilinear', 'nearest'], ), - # Changes the spacing of the image - RandZoomd( - keys=["image", "label"], - prob=self.cfg["DA_probability"], - min_zoom=0.75, - max_zoom=1.25, - mode=('bilinear', 'nearest'), - keep_size=True, - ), # Random affine transform of the image RandAffined( keys=["image", "label"], @@ -198,49 +215,31 @@ def prepare_data(self): mode=('bilinear', 'nearest'), padding_mode='zeros', ), - # Intensity transforms - # Random Gaussian noise is added to the image + # RandAdjustContrastd( + # keys=["image"], + # prob=self.cfg["DA_probability"], + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), RandGaussianNoised( keys=["image"], prob=self.cfg["DA_probability"], - mean=0.0, - std=0.1, - ), - # Gaussian blur with RandGaussianSmoothd - RandGaussianSmoothd( - keys=["image"], - prob=self.cfg["DA_probability"], - sigma_x=(0.5, 1.), - sigma_y=(0.5, 1.), - sigma_z=(0.5, 1.), - ), - # Brightness transform: with RandScaleIntensityd - RandScaleIntensityd( - keys=["image"], - prob=self.cfg["DA_probability"], - factors=0.25, - ), - # Contrast transform: with RandAdjustContrastd - RandAdjustContrastd( - keys=["image"], - prob=self.cfg["DA_probability"], - invert_image=True, - retain_stats=True, - ), - # Contrast transform: with RandAdjustContrastd - RandAdjustContrastd( - keys=["image"], - prob=self.cfg["DA_probability"], - invert_image=False, - retain_stats=True, ), - # Simulate low resolution with RandSimulateLowResolutiond + # Random simulation of low resolution RandSimulateLowResolutiond( keys=["image"], - prob=self.cfg["DA_probability"], - downsample_mode='nearest', - upsample_mode='trilinear', - zoom_range=(0.5, 1.0), + zoom_range=(0.8, 1.5), + prob=self.cfg["DA_probability"] ), # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing RandBiasFieldd( @@ -249,58 +248,24 @@ def prepare_data(self): degree=3, prob=self.cfg["DA_probability"] ), - # Binary thresholding of the label - AsDiscreted( - keys=["label"], - threshold=0.5, - ), - - - # # This crops the image around areas where the mask is non-zero - # # (the margin is added because otherwise the image would be just the size of the lesion) - # CropForegroundd( - # keys=["image", "label"], - # source_key="label", - # margin=200 - # ), - # # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) - # RandCropByPosNegLabeld( - # keys=["image", "label"], - # label_key="label", - # spatial_size=self.cfg["spatial_size"], - # pos=1, - # neg=0, - # num_samples=4, - # image_key="image", - # image_threshold=0, - # allow_smaller=True, - # ), - # Multiplication of image by -1 - # RandLambdad( - # keys='image', - # func=multiply_by_negative_one, - # prob=0.5 - # ), - # Takes the laplacian of the image - # LabelToContourd( + # RandShiftIntensityd( # keys=["image"], - # kernel_type='Laplace', + # offsets=0.1, + # prob=0.2, # ), - # EnsureTyped(keys=["image", "label"]), - + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ), # # Remove small lesions in the label # RandLambdad( # keys='label', # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), # prob=1.0 # ) - - # This resizes the image and the label to the spatial size defined in the config - ResizeWithPadOrCropd( - keys=["image", "label"], - spatial_size=self.cfg["spatial_size"], - ), ] ) val_transforms = Compose( @@ -313,28 +278,50 @@ def prepare_data(self): pixdim=self.cfg["pixdim"], mode=(2, 0), ), + # This normalizes the intensity of the image NormalizeIntensityd( keys=["image"], nonzero=False, channel_wise=False ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), ResizeWithPadOrCropd( keys=["image", "label"], spatial_size=self.cfg["spatial_size"], ), - # Binary thresholding of the label - AsDiscreted( - keys=["label"], - threshold=0.5, - ), - # # This normalizes the intensity of the image - # NormalizeIntensityd( - # keys=["image"], - # nonzero=False, - # channel_wise=False + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', # ), # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) ] + ) # load the dataset @@ -353,22 +340,22 @@ def prepare_data(self): transforms_test = val_transforms # Hidden because we don't use it - # # define post-processing transforms for testing; taken (with explanations) from - # # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - # self.test_post_pred = Compose([ - # EnsureTyped(keys=["pred", "label"]), - # Invertd(keys=["pred", "label"], transform=transforms_test, - # orig_keys=["image", "label"], - # meta_keys=["pred_meta_dict", "label_meta_dict"], - # nearest_interp=False, to_tensor=True), - # # # Remove small lesions in the label - # # RandLambdad( - # # keys='label', - # # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), - # # prob=1.0 - # # ) - # ]) - # self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) # -------------------------------- @@ -378,14 +365,16 @@ def train_dataloader(self): return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) + def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, - persistent_workers=True) + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, + persistent_workers=False) + def test_dataloader(self): return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) - + # -------------------------------- # OPTIMIZATION # -------------------------------- @@ -457,6 +446,7 @@ def training_step(self, batch, batch_idx): return metrics_dict + def on_train_epoch_end(self): if self.train_step_outputs == []: @@ -517,7 +507,6 @@ def validation_step(self, batch, batch_idx): # calculate validation loss loss = self.loss_function(outputs, labels) - # post-process for calculating the evaluation metric post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] @@ -658,6 +647,7 @@ def test_step(self, batch, batch_idx): return metrics_dict + def on_test_epoch_end(self): avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ @@ -674,6 +664,7 @@ def on_test_epoch_end(self): # free up memory self.test_step_outputs.clear() + # -------------------------------- # MAIN # -------------------------------- @@ -748,13 +739,11 @@ def main(): # ) # net.use_multiprocessing = False - # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) # net = create_nnunet_from_plans() - logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") # define loss function @@ -786,7 +775,6 @@ def main(): dirpath= best_model_path, filename='best_model', monitor='val_loss', save_top_k=1, mode="min", save_last=True, save_weights_only=True) - logger.info(f"Starting training from scratch ...") # wandb logger exp_logger = pl.loggers.WandbLogger( From e038c67259323e40a8e7879217106ce5135d462c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 9 Aug 2024 19:09:31 -0400 Subject: [PATCH 091/108] changed location of saving of yaml file to save with the same date as json file --- monai/1_create_msd_data_head_cropped.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/1_create_msd_data_head_cropped.py b/monai/1_create_msd_data_head_cropped.py index d53cd49..cf28f9e 100644 --- a/monai/1_create_msd_data_head_cropped.py +++ b/monai/1_create_msd_data_head_cropped.py @@ -225,10 +225,6 @@ def main(): # logger.info(f"Number of validation subjects: {len(val_subjects)}") # logger.info(f"Number of testing subjects: {len(test_subjects)}") - # dump train/val/test splits into a yaml file - with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: - yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) - # keys to be defined in the dataset_0.json params = {} params["description"] = "ms-lesion-agnostic" @@ -388,6 +384,10 @@ def main(): jsonFile.write(final_json) jsonFile.close() + # dump train/val/test splits into a yaml file + with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: + yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) + return None From 391193c5c8e26c8dd3afc7d637c072457b25d047 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Fri, 30 Aug 2024 15:11:01 -0400 Subject: [PATCH 092/108] init mednext training script --- monai/train_monai_mednext_lightning.py | 811 +++++++++++++++++++++++++ 1 file changed, 811 insertions(+) create mode 100644 monai/train_monai_mednext_lightning.py diff --git a/monai/train_monai_mednext_lightning.py b/monai/train_monai_mednext_lightning.py new file mode 100644 index 0000000..4685086 --- /dev/null +++ b/monai/train_monai_mednext_lightning.py @@ -0,0 +1,811 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +import time +import torch.multiprocessing + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +## Linke to other issue: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/59 +torch.multiprocessing.set_sharing_strategy('file_system') + +from utils.losses import AdapWingLoss, SoftDiceLoss + +from utils.utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions +from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss +from monai.networks.layers import Norm +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + LabelToContourd, + Invertd, + SaveImage, + EnsureType, + Rand3DElasticd, + RandSimulateLowResolutiond, + RandBiasFieldd, + RandAffined, + RandRotated, + RandZoomd, + RandGaussianSmoothd, + RandScaleIntensityd +) +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + # self.lesion_wise_precision_recall = lesion_wise_precision_recall + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=4, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=self.cfg["DA_probability"], + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=self.cfg["DA_probability"], + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=self.cfg["DA_probability"], + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image", "label"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=self.cfg["DA_probability"], + mode=['bilinear', 'nearest'], + ), + # Random affine transform of the image + RandAffined( + keys=["image", "label"], + prob=self.cfg["DA_probability"], + mode=('bilinear', 'nearest'), + padding_mode='zeros', + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=self.cfg["DA_probability"], + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + RandGaussianNoised( + keys=["image"], + prob=self.cfg["DA_probability"], + ), + # Random simulation of low resolution + RandSimulateLowResolutiond( + keys=["image"], + zoom_range=(0.8, 1.5), + prob=self.cfg["DA_probability"] + ), + # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing + RandBiasFieldd( + keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=self.cfg["DA_probability"] + ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # This normalizes the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + val_cache_rate = 0.25 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=val_cache_rate, num_workers=8) + + # define test transforms + transforms_test = val_transforms + + # Hidden because we don't use it + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, + pin_memory=True, persistent_workers=True) + + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, + persistent_workers=False) + + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # The following was done to debug : + # I was checking the image and the label to see if they were empty or not + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + # # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + # return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + # Compute precision and recall + # train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + # print("sucess") + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze(), + # "train_precision": train_precision.detach().cpu(), + # "train_recall": train_recall.detach().cpu(), + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + # precision_score, recall_score = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + # precision_score = output["train_precision"] + # recall_score = output["train_recall"] + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + # mean_precision_score = np.mean(precision_score.detach().numpy()) + # mean_recall_score = np.mean(recall_score.detach().numpy()) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + # "train_precision": mean_precision_score, + # "train_recall": mean_recall_score, + } + + self.log_dict(wandb_logs) + + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"], + ) + wandb.log({"training images": wandb.Image(fig)}) + plt.close(fig) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # compute precision and recall + # val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + # print("sucess val") + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), + # "val_precision": val_precision.detach().cpu(), + # "val_recall": val_recall.detach().cpu(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + # val_precision, val_recall = 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + # val_precision += output["val_precision"].sum().item() + # val_recall += output["val_recall"].sum().item() + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + # mean_val_precision = (val_precision / num_items) + # mean_val_recall = (val_recall / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + # "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + # "val_precision": mean_val_precision, + # "val_recall": mean_val_recall, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # plot 1 validation image + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) + wandb.log({"validation image 1": wandb.Image(fig)}) + plt.close(fig) + + # plot another validation image + fig0 = plot_slices(image=self.val_step_outputs[1]["val_image"], + gt=self.val_step_outputs[1]["val_gt"], + pred=self.val_step_outputs[1]["val_pred"],) + wandb.log({"validation image 2": wandb.Image(fig0)}) + plt.close(fig0) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + output_path = os.path.join(config["output_path"], str(datetime.now().date()) +"_" +str(datetime.now().time())) + os.makedirs(output_path, exist_ok=True) + + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config, save_code=True, dir=output_path) + + logger.info("Building the model ...") + + # define model + + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2, ), + # dropout=0.1 + # ) + + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=config["attention_unet_channels"], + strides=config["attention_unet_strides"], + dropout=0.1, + ) + + # net = SwinUNETR( + # img_size=config["spatial_size"], + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # feature_size=48, + # use_checkpoint=True, + # ) + + # net.use_multiprocessing = False + + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") + + # define loss function + # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using DiceCELoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + best_model_path = os.path.join(output_path, "best_model.pth") + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=best_model_path) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath= best_model_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir=output_path, + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + config=config) + + # Saving training script to wandb + wandb.save(args.config) + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # precision='bf16-mixed', + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file From 25e7874a1120950e3c8904dcfab5d770e041ff79 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 14:34:11 -0400 Subject: [PATCH 093/108] added library for diffusion model --- monai/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/requirements.txt b/monai/requirements.txt index 7b726e9..c14734f 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -9,4 +9,5 @@ cupy-cuda117==10.6.0 loguru==0.7.2 wandb==0.15.12 dynamic-network-architectures==0.2 -seaborn==0.13.2 \ No newline at end of file +seaborn==0.13.2 +monai-generative==0.2.3 \ No newline at end of file From b7ee7201d045dfedc534a5de3ff09643db213872 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 14:34:44 -0400 Subject: [PATCH 094/108] first draft (non-functionning) of diffusion model training script --- monai/train_monai_diffusion_lightning.py | 824 +++++++++++++++++++++++ 1 file changed, 824 insertions(+) create mode 100644 monai/train_monai_diffusion_lightning.py diff --git a/monai/train_monai_diffusion_lightning.py b/monai/train_monai_diffusion_lightning.py new file mode 100644 index 0000000..b23862b --- /dev/null +++ b/monai/train_monai_diffusion_lightning.py @@ -0,0 +1,824 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +import time +import torch.multiprocessing + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +## Linke to other issue: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/59 +torch.multiprocessing.set_sharing_strategy('file_system') + +from utils.losses import AdapWingLoss, SoftDiceLoss + +from utils.utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions +from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR +from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss +from monai.networks.layers import Norm +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + LabelToContourd, + Invertd, + SaveImage, + EnsureType, + Rand3DElasticd, + RandSimulateLowResolutiond, + RandBiasFieldd, + RandAffined, + RandRotated, + RandZoomd, + RandGaussianSmoothd, + RandScaleIntensityd +) +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + # self.lesion_wise_precision_recall = lesion_wise_precision_recall + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=4, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=self.cfg["DA_probability"], + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=self.cfg["DA_probability"], + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=self.cfg["DA_probability"], + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image", "label"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=self.cfg["DA_probability"], + mode=['bilinear', 'nearest'], + ), + # Random affine transform of the image + RandAffined( + keys=["image", "label"], + prob=self.cfg["DA_probability"], + mode=('bilinear', 'nearest'), + padding_mode='zeros', + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=self.cfg["DA_probability"], + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + RandGaussianNoised( + keys=["image"], + prob=self.cfg["DA_probability"], + ), + # Random simulation of low resolution + RandSimulateLowResolutiond( + keys=["image"], + zoom_range=(0.8, 1.5), + prob=self.cfg["DA_probability"] + ), + # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing + RandBiasFieldd( + keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=self.cfg["DA_probability"] + ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # This normalizes the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + val_cache_rate = 0.25 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=val_cache_rate, num_workers=8) + + # define test transforms + transforms_test = val_transforms + + # Hidden because we don't use it + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, + pin_memory=True, persistent_workers=True) + + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, + persistent_workers=False) + + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # The following was done to debug : + # I was checking the image and the label to see if they were empty or not + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + # # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + # return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + # Compute precision and recall + # train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + # print("sucess") + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze(), + # "train_precision": train_precision.detach().cpu(), + # "train_recall": train_recall.detach().cpu(), + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + # precision_score, recall_score = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + # precision_score = output["train_precision"] + # recall_score = output["train_recall"] + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + # mean_precision_score = np.mean(precision_score.detach().numpy()) + # mean_recall_score = np.mean(recall_score.detach().numpy()) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + # "train_precision": mean_precision_score, + # "train_recall": mean_recall_score, + } + + self.log_dict(wandb_logs) + + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"], + ) + wandb.log({"training images": wandb.Image(fig)}) + plt.close(fig) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # compute precision and recall + # val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + # print("sucess val") + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), + # "val_precision": val_precision.detach().cpu(), + # "val_recall": val_recall.detach().cpu(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + # val_precision, val_recall = 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + # val_precision += output["val_precision"].sum().item() + # val_recall += output["val_recall"].sum().item() + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + # mean_val_precision = (val_precision / num_items) + # mean_val_recall = (val_recall / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + # "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + # "val_precision": mean_val_precision, + # "val_recall": mean_val_recall, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # plot 1 validation image + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) + wandb.log({"validation image 1": wandb.Image(fig)}) + plt.close(fig) + + # plot another validation image + fig0 = plot_slices(image=self.val_step_outputs[1]["val_image"], + gt=self.val_step_outputs[1]["val_gt"], + pred=self.val_step_outputs[1]["val_pred"],) + wandb.log({"validation image 2": wandb.Image(fig0)}) + plt.close(fig0) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + output_path = os.path.join(config["output_path"], str(datetime.now().date()) +"_" +str(datetime.now().time())) + os.makedirs(output_path, exist_ok=True) + + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config, save_code=True, dir=output_path) + + logger.info("Building the model ...") + + # define model + + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2, ), + # dropout=0.1 + # ) + + # net = AttentionUnet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config["attention_unet_channels"], + # strides=config["attention_unet_strides"], + # dropout=0.1, + # ) + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_channels=(64, 128, 256, 256), + attention_levels=(False, False, True, True), + num_res_blocks=(2, 2, 2, 2), + num_head_channels=32, + with_conditioning=False, + norm_eps= 1e-6, + dropout_cattn=0.1, + ) + + # net = SwinUNETR( + # img_size=config["spatial_size"], + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # feature_size=48, + # use_checkpoint=True, + # ) + + # net.use_multiprocessing = False + + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") + + # define loss function + # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using DiceCELoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + best_model_path = os.path.join(output_path, "best_model.pth") + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=best_model_path) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath= best_model_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir=output_path, + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + config=config) + + # Saving training script to wandb + wandb.save(args.config) + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # precision='bf16-mixed', + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file From 6f807f9682dfb69bbc61e13f5bc5212bc1d341d2 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 14:35:29 -0400 Subject: [PATCH 095/108] created script to train a mednext model --- monai/train_monai_mednext_lightning.py | 32 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/monai/train_monai_mednext_lightning.py b/monai/train_monai_mednext_lightning.py index 4685086..48bb0de 100644 --- a/monai/train_monai_mednext_lightning.py +++ b/monai/train_monai_mednext_lightning.py @@ -69,6 +69,8 @@ ## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision # torch.set_float32_matmul_precision('medium' | 'high') +# Adding Mednext model +from nnunet_mednext import MedNeXt def get_parser(): """ @@ -172,7 +174,7 @@ def prepare_data(self): spatial_size=self.cfg["spatial_size"], pos=1, neg=0, - num_samples=4, + num_samples=2, image_key="image", image_threshold=0, allow_smaller=True, @@ -720,14 +722,14 @@ def main(): # dropout=0.1 # ) - net = AttentionUnet( - spatial_dims=3, - in_channels=1, - out_channels=1, - channels=config["attention_unet_channels"], - strides=config["attention_unet_strides"], - dropout=0.1, - ) + # net = AttentionUnet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config["attention_unet_channels"], + # strides=config["attention_unet_strides"], + # dropout=0.1, + # ) # net = SwinUNETR( # img_size=config["spatial_size"], @@ -738,6 +740,18 @@ def main(): # use_checkpoint=True, # ) + net = MedNeXt( + in_channels=1, + n_channels=32, + n_classes=1, + exp_r=2, + kernel_size=3, + do_res=True, + do_res_up_down=True, + checkpoint_style="outside_block", + block_counts=[2,2,2,2,1,1,1,1,1] + ) + # net.use_multiprocessing = False # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) From 74b6eed39fca129e3d3d350250bc11833302744d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 14:40:39 -0400 Subject: [PATCH 096/108] removed cropping of image before inference --- monai/test_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/test_model.py b/monai/test_model.py index 17cd80b..ec96b37 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -85,10 +85,10 @@ def main(): nonzero=False, channel_wise=False ), - ResizeWithPadOrCropd( - keys=["image", "label"], - spatial_size=cfg["spatial_size"], - ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), ] ) @@ -159,7 +159,7 @@ def main(): # Get file name file_name = test_files[i]["image"].split("/")[-1].split(".")[0] - print(f"Saving {file_name}") + print(f"Saving {file_name}: dice score = {dice}") # Save the prediction pred_saver = SaveImage( From c497c4e222510df12b484a041452c8cab1917656 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 15:14:38 -0400 Subject: [PATCH 097/108] added new config files --- monai/config.yml | 3 +++ monai/config_test.yml | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/config.yml b/monai/config.yml index 36b7f3e..5e953b2 100644 --- a/monai/config.yml +++ b/monai/config.yml @@ -9,6 +9,8 @@ # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json # data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json +# data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-08-13_seed42_lesionOnly.json +# data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/fake.json # Resampling resolution # pixdim : [1.0, 1.0, 1.0] @@ -34,6 +36,7 @@ eval_num : 2 # Outputs # output_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ output_path : /home/plbenveniste/net/ms-lesion-agnostic/results/ +# output_path : /home/plbenveniste/net/ms-lesion-agnostic/results_cropped_head/ # Seed seed : 42 diff --git a/monai/config_test.yml b/monai/config_test.yml index 75a1a1d..21ed127 100644 --- a/monai/config_test.yml +++ b/monai/config_test.yml @@ -8,8 +8,9 @@ attention_unet_channels : [32, 64, 128, 256, 512] attention_unet_strides : [2, 2, 2, 2, 2] # path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt -path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/best_model.pth/best_model.ckpt +# path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp/best_model.pth/best_model.ckpt +path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/thresholding_optimisation/best_model.pth/best_model.ckpt # output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/ -output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/ - +# output_dir : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp +output_dir : /home/plbenveniste/net/ms-lesion-agnostic/thresholding_optimisation From 825d816983d7159bbac97cc7589f9bbf4bf956cc Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 15:15:57 -0400 Subject: [PATCH 098/108] added script to perform inference and compute the dice score with varying threshold --- monai/test_model_optThresh.py | 308 ++++++++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 monai/test_model_optThresh.py diff --git a/monai/test_model_optThresh.py b/monai/test_model_optThresh.py new file mode 100644 index 0000000..41bacc8 --- /dev/null +++ b/monai/test_model_optThresh.py @@ -0,0 +1,308 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Dict of dice score + dice_scores = {} + dice_scores_0_01 = {} + dice_scores_0_02 = {} + dice_scores_0_05 = {} + dice_scores_0_1 = {} + dice_scores_0_2 = {} + dice_scores_0_3 = {} + dice_scores_0_4 = {} + dice_scores_0_5 = {} + dice_scores_0_6 = {} + dice_scores_0_7 = {} + dice_scores_0_8 = {} + dice_scores_0_9 = {} + + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Create the data loader + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) + + # Load the model + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=cfg["attention_unet_channels"], + strides=cfg["attention_unet_strides"], + dropout=0.1, + ) + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Run inference + with torch.no_grad(): + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # compute the dice score + # dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + pred = post_test_out[0]['pred'].cpu() + + # Threshold the prediction and compute the dice score + pred_0 = pred.clone() + pred_0[pred_0 < 0.01] = 0 + pred_0[pred_0 >= 0.01] = 1 + dice = dice_score(pred_0, batch["label"].cpu()) + + pred_0_01 = pred.clone() + pred_0_01[pred_0_01 < 0.01] = 0 + pred_0_01[pred_0_01 >= 0.01] = 1 + dice_0_01 = dice_score(pred_0_01, batch["label"].cpu()) + + pred_0_02 = pred.clone() + pred_0_02[pred_0_02 < 0.02] = 0 + pred_0_02[pred_0_02 >= 0.02] = 1 + dice_0_02 = dice_score(pred_0_02, batch["label"].cpu()) + + pred_0_05 = pred.clone() + pred_0_05[pred_0_05 < 0.05] = 0 + pred_0_05[pred_0_05 >= 0.05] = 1 + dice_0_05 = dice_score(pred_0_05, batch["label"].cpu()) + + pred_0_1 = pred.clone() + pred_0_1[pred_0_1 < 0.1] = 0 + pred_0_1[pred_0_1 >= 0.1] = 1 + dice_0_1 = dice_score(pred_0_1, batch["label"].cpu()) + + pred_0_2 = pred.clone() + pred_0_2[pred_0_2 < 0.2] = 0 + pred_0_2[pred_0_2 >= 0.2] = 1 + dice_0_2 = dice_score(pred_0_2, batch["label"].cpu()) + + pred_0_3 = pred.clone() + pred_0_3[pred_0_3 < 0.3] = 0 + pred_0_3[pred_0_3 >= 0.3] = 1 + dice_0_3 = dice_score(pred_0_3, batch["label"].cpu()) + + pred_0_4 = pred.clone() + pred_0_4[pred_0_4 < 0.4] = 0 + pred_0_4[pred_0_4 >= 0.4] = 1 + dice_0_4 = dice_score(pred_0_4, batch["label"].cpu()) + + pred_0_5 = pred.clone() + pred_0_5[pred_0_5 < 0.5] = 0 + pred_0_5[pred_0_5 >= 0.5] = 1 + dice_0_5 = dice_score(pred_0_5, batch["label"].cpu()) + + pred_0_6 = pred.clone() + pred_0_6[pred_0_6 < 0.6] = 0 + pred_0_6[pred_0_6 >= 0.6] = 1 + dice_0_6 = dice_score(pred_0_6, batch["label"].cpu()) + + pred_0_7 = pred.clone() + pred_0_7[pred_0_7 < 0.7] = 0 + pred_0_7[pred_0_7 >= 0.7] = 1 + dice_0_7 = dice_score(pred_0_7, batch["label"].cpu()) + + pred_0_8 = pred.clone() + pred_0_8[pred_0_8 < 0.8] = 0 + pred_0_8[pred_0_8 >= 0.8] = 1 + dice_0_8 = dice_score(pred_0_8, batch["label"].cpu()) + + pred_0_9 = pred.clone() + pred_0_9[pred_0_9 < 0.9] = 0 + pred_0_9[pred_0_9 >= 0.9] = 1 + dice_0_9 = dice_score(pred_0_9, batch["label"].cpu()) + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}: dice score = {dice}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[test_files[i]["image"]] = dice + dice_scores_0_01[test_files[i]["image"]] = dice_0_01 + dice_scores_0_02[test_files[i]["image"]] = dice_0_02 + dice_scores_0_05[test_files[i]["image"]] = dice_0_05 + dice_scores_0_1[test_files[i]["image"]] = dice_0_1 + dice_scores_0_2[test_files[i]["image"]] = dice_0_2 + dice_scores_0_3[test_files[i]["image"]] = dice_0_3 + dice_scores_0_4[test_files[i]["image"]] = dice_0_4 + dice_scores_0_5[test_files[i]["image"]] = dice_0_5 + dice_scores_0_6[test_files[i]["image"]] = dice_0_6 + dice_scores_0_7[test_files[i]["image"]] = dice_0_7 + dice_scores_0_8[test_files[i]["image"]] = dice_0_8 + dice_scores_0_9[test_files[i]["image"]] = dice_0_9 + + test_input.detach() + + + # Save the dice scores + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_01.txt"), "w") as f: + for key, value in dice_scores_0_01.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_02.txt"), "w") as f: + for key, value in dice_scores_0_02.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_05.txt"), "w") as f: + for key, value in dice_scores_0_05.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_1.txt"), "w") as f: + for key, value in dice_scores_0_1.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_2.txt"), "w") as f: + for key, value in dice_scores_0_2.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_3.txt"), "w") as f: + for key, value in dice_scores_0_3.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_4.txt"), "w") as f: + for key, value in dice_scores_0_4.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_5.txt"), "w") as f: + for key, value in dice_scores_0_5.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_6.txt"), "w") as f: + for key, value in dice_scores_0_6.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_7.txt"), "w") as f: + for key, value in dice_scores_0_7.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_8.txt"), "w") as f: + for key, value in dice_scores_0_8.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_9.txt"), "w") as f: + for key, value in dice_scores_0_9.items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file From ba6d90ee592b6a2b492c9e1a3c316b91f98f4f9f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 15:52:48 -0400 Subject: [PATCH 099/108] fixed .cpu problem and added more thresholds --- monai/test_model_optThresh.py | 78 ++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/monai/test_model_optThresh.py b/monai/test_model_optThresh.py index 41bacc8..a51f6ff 100644 --- a/monai/test_model_optThresh.py +++ b/monai/test_model_optThresh.py @@ -25,6 +25,7 @@ import yaml import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') +import numpy as np def get_parser(): @@ -77,6 +78,9 @@ def main(): dice_scores_0_7 = {} dice_scores_0_8 = {} dice_scores_0_9 = {} + dice_scores_0_95 = {} + dice_scores_0_98 = {} + dice_scores_0_99 = {} # Load the data @@ -160,78 +164,112 @@ def main(): batch["pred"] = F.relu(batch["pred"]) # compute the dice score - # dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) # post-process the prediction post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] pred = post_test_out[0]['pred'].cpu() + + pred_cpu = batch["pred"].cpu() + label_cpu = batch["label"].cpu() # Threshold the prediction and compute the dice score - pred_0 = pred.clone() + pred_0 = pred_cpu.clone() pred_0[pred_0 < 0.01] = 0 pred_0[pred_0 >= 0.01] = 1 dice = dice_score(pred_0, batch["label"].cpu()) + print(f"For thresh 0 dice score = {dice}") - pred_0_01 = pred.clone() + pred_0_01 = pred_cpu.clone() pred_0_01[pred_0_01 < 0.01] = 0 pred_0_01[pred_0_01 >= 0.01] = 1 dice_0_01 = dice_score(pred_0_01, batch["label"].cpu()) + print(f"For thresh 0.01 dice score = {dice_0_01}") - pred_0_02 = pred.clone() + pred_0_02 = pred_cpu.clone() pred_0_02[pred_0_02 < 0.02] = 0 pred_0_02[pred_0_02 >= 0.02] = 1 dice_0_02 = dice_score(pred_0_02, batch["label"].cpu()) + print(f"For thresh 0.02 dice score = {dice_0_02}") - pred_0_05 = pred.clone() + pred_0_05 = pred_cpu.clone() pred_0_05[pred_0_05 < 0.05] = 0 pred_0_05[pred_0_05 >= 0.05] = 1 dice_0_05 = dice_score(pred_0_05, batch["label"].cpu()) + print(f"For thresh 0.05 dice score = {dice_0_05}") - pred_0_1 = pred.clone() + pred_0_1 = pred_cpu.clone() pred_0_1[pred_0_1 < 0.1] = 0 pred_0_1[pred_0_1 >= 0.1] = 1 dice_0_1 = dice_score(pred_0_1, batch["label"].cpu()) + print(f"For thresh 0.1 dice score = {dice_0_1}") - pred_0_2 = pred.clone() + pred_0_2 = pred_cpu.clone() pred_0_2[pred_0_2 < 0.2] = 0 pred_0_2[pred_0_2 >= 0.2] = 1 dice_0_2 = dice_score(pred_0_2, batch["label"].cpu()) + print(f"For thresh 0.2 dice score = {dice_0_2}") - pred_0_3 = pred.clone() + pred_0_3 = pred_cpu.clone() pred_0_3[pred_0_3 < 0.3] = 0 pred_0_3[pred_0_3 >= 0.3] = 1 dice_0_3 = dice_score(pred_0_3, batch["label"].cpu()) + print(f"For thresh 0.3 dice score = {dice_0_3}") - pred_0_4 = pred.clone() + pred_0_4 = pred_cpu.clone() pred_0_4[pred_0_4 < 0.4] = 0 pred_0_4[pred_0_4 >= 0.4] = 1 dice_0_4 = dice_score(pred_0_4, batch["label"].cpu()) + print(f"For thresh 0.4 dice score = {dice_0_4}") - pred_0_5 = pred.clone() + pred_0_5 = pred_cpu.clone() pred_0_5[pred_0_5 < 0.5] = 0 pred_0_5[pred_0_5 >= 0.5] = 1 dice_0_5 = dice_score(pred_0_5, batch["label"].cpu()) + print(f"For thresh 0.5 dice score = {dice_0_5}") - pred_0_6 = pred.clone() + pred_0_6 = pred_cpu.clone() pred_0_6[pred_0_6 < 0.6] = 0 pred_0_6[pred_0_6 >= 0.6] = 1 dice_0_6 = dice_score(pred_0_6, batch["label"].cpu()) + print(f"For thresh 0.6 dice score = {dice_0_6}") - pred_0_7 = pred.clone() + pred_0_7 = pred_cpu.clone() pred_0_7[pred_0_7 < 0.7] = 0 pred_0_7[pred_0_7 >= 0.7] = 1 dice_0_7 = dice_score(pred_0_7, batch["label"].cpu()) + print(f"For thresh 0.7 dice score = {dice_0_7}") - pred_0_8 = pred.clone() + pred_0_8 = pred_cpu.clone() pred_0_8[pred_0_8 < 0.8] = 0 pred_0_8[pred_0_8 >= 0.8] = 1 dice_0_8 = dice_score(pred_0_8, batch["label"].cpu()) + print(f"For thresh 0.8 dice score = {dice_0_8}") - pred_0_9 = pred.clone() + pred_0_9 = pred_cpu.clone() pred_0_9[pred_0_9 < 0.9] = 0 pred_0_9[pred_0_9 >= 0.9] = 1 dice_0_9 = dice_score(pred_0_9, batch["label"].cpu()) + print(f"For thresh 0.9 dice score = {dice_0_9}") + + pred_0_95 = pred_cpu.clone() + pred_0_95[pred_0_95 < 0.95] = 0 + pred_0_95[pred_0_95 >= 0.95] = 1 + dice_0_95 = dice_score(pred_0_95, batch["label"].cpu()) + print(f"For thresh 0.95 dice score = {dice_0_95}") + + pred_0_98 = pred_cpu.clone() + pred_0_98[pred_0_98 < 0.98] = 0 + pred_0_98[pred_0_98 >= 0.98] = 1 + dice_0_98 = dice_score(pred_0_98, batch["label"].cpu()) + print(f"For thresh 0.98 dice score = {dice_0_98}") + + pred_0_99 = pred_cpu.clone() + pred_0_99[pred_0_99 < 0.99] = 0 + pred_0_99[pred_0_99 >= 0.99] = 1 + dice_0_99 = dice_score(pred_0_99, batch["label"].cpu()) + print(f"For thresh 0.99 dice score = {dice_0_99}") # Get file name file_name = test_files[i]["image"].split("/")[-1].split(".")[0] @@ -258,6 +296,9 @@ def main(): dice_scores_0_7[test_files[i]["image"]] = dice_0_7 dice_scores_0_8[test_files[i]["image"]] = dice_0_8 dice_scores_0_9[test_files[i]["image"]] = dice_0_9 + dice_scores_0_95[test_files[i]["image"]] = dice_0_95 + dice_scores_0_98[test_files[i]["image"]] = dice_0_98 + dice_scores_0_99[test_files[i]["image"]] = dice_0_99 test_input.detach() @@ -302,6 +343,15 @@ def main(): with open(os.path.join(output_dir, "dice_scores_0_9.txt"), "w") as f: for key, value in dice_scores_0_9.items(): f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_95.txt"), "w") as f: + for key, value in dice_scores_0_95.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_98.txt"), "w") as f: + for key, value in dice_scores_0_98.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_99.txt"), "w") as f: + for key, value in dice_scores_0_99.items(): + f.write(f"{key}: {value}\n") if __name__ == "__main__": From 9027f32dc63e5f9b6cf40f04ed51b1456fe21b19 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 4 Sep 2024 15:53:27 -0400 Subject: [PATCH 100/108] first draft of script for TTA --- monai/average_tta_performance.py | 84 ++++++++++++ monai/test_model_tta.py | 216 +++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 monai/average_tta_performance.py create mode 100644 monai/test_model_tta.py diff --git a/monai/average_tta_performance.py b/monai/average_tta_performance.py new file mode 100644 index 0000000..14e6498 --- /dev/null +++ b/monai/average_tta_performance.py @@ -0,0 +1,84 @@ +""" +This file is used to get all the dice_scores_X.txt files in a directory and average them. + +Input: + - Path to the directory containing the dice_scores_X.txt files + +Output: + None + +Example: + python average_tta_performance.py --pred-dir-path /path/to/dice_scores + +Author: Pierre-Louis Benveniste +""" + +import os +import argparse +import numpy as np +import pandas as pd +from pathlib import Path + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Average the performance of the model") + parser.add_argument("--pred-dir-path", help="Path to the directory containing the dice_scores_X.txt files", required=True) + return parser + + +def main(): + """ + This function is used to average the performance of the model on the test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Path to the dice_scores + path_to_outputs = args.pred_dir_path + + # Get all the dice_scores_X.txt files using rglob + dice_score_files = [str(file) for file in Path(path_to_outputs).rglob("dice_scores_*.txt")] + + # Dict to store the dice scores + dice_scores = {} + + # Loop over the dice_scores_X.txt files + for dice_score_file in dice_score_files: + # Open dice results (they are txt files) + with open(os.path.join(path_to_outputs, dice_score_file), 'r') as file: + for line in file: + key, value = line.strip().split(':') + if key in dice_scores: + dice_scores[key].append(float(value)) + else: + dice_scores[key] = [float(value)] + + # Average the dice scores ang get standard deviation + std = {} + for key in dice_scores: + std[key] = np.std(dice_scores[key]) + dice_scores[key] = np.mean(dice_scores[key]) + + # Save the averaged dice scores + with open(os.path.join(path_to_outputs, "dice_scores.txt"), 'w') as file: + for key in dice_scores: + file.write(f"{key}: {dice_scores[key]}\n") + + # Save the standard deviation + with open(os.path.join(path_to_outputs, "std.txt"), 'w') as file: + for key in std: + file.write(f"{key}: {std[key]}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/test_model_tta.py b/monai/test_model_tta.py new file mode 100644 index 0000000..07871ed --- /dev/null +++ b/monai/test_model_tta.py @@ -0,0 +1,216 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, + RandGaussianNoised, + RandFlipd, + Rand3DElasticd +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Num test time augmentations + n_tta = 10 + + # Dict of dice score + dice_scores = [{} for i in range(n_tta)] + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + RandGaussianNoised( + keys=["image"], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image"], + spatial_axis=[2], + prob=0.2, + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=0.2, + mode='bilinear', + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Load the model + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=cfg["attention_unet_channels"], + strides=cfg["attention_unet_strides"], + dropout=0.1, + ) + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Create the data loader + test_ds = [CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) for i in range(n_tta)] + + # Run inference + with torch.no_grad(): + for k in range(n_tta): + test_data_loader = DataLoader(test_ds[k], batch_size=1, shuffle=False, num_workers=0) + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # compute the dice score + dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + pred = post_test_out[0]['pred'].cpu() + + # Threshold the prediction + pred[pred < 0.5] = 0 + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=f"_{k}.nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[k][test_files[i]["image"]] = dice + + test_input.detach() + + + # Save the dice scores + for j in range(n_tta): + with open(os.path.join(output_dir, f"dice_scores_{j}.txt"), "w") as f: + for key, value in dice_scores[j].items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file From 11eec85e2876d4145a96f6efc91de1dcc37b2e22 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 5 Sep 2024 10:51:20 -0400 Subject: [PATCH 101/108] added code to plot the opt threshold output --- monai/plot_optThresh.py | 85 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 monai/plot_optThresh.py diff --git a/monai/plot_optThresh.py b/monai/plot_optThresh.py new file mode 100644 index 0000000..17d4908 --- /dev/null +++ b/monai/plot_optThresh.py @@ -0,0 +1,85 @@ +""" +This script plots the performance of the model based on the threshold applied to the predictions. + +Input: + --path-scores: Path to the directory containing the dice_scores_X.txt files + +Output: + None + +Example: + python plot_optThresh.py --path-scores /path/to/dice_scores + +Author: Pierre-Louis Benveniste +""" + +import os +import argparse +import numpy as np +from pathlib import Path +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Plot the optimal threshold") + parser.add_argument("--path-scores", help="Path to the directory containing the dice_scores_X.txt files", required=True) + return parser + + +def main(): + + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Path to the dice_scores + path_to_outputs = args.path_scores + + # Get all the dice_scores_X.txt files using rglob + dice_score_files = [str(file) for file in Path(path_to_outputs).rglob("dice_scores_*.txt")] + + # Create a list to store the dataframes + test_dice_results_list = [None] * len(dice_score_files) + + # For each file, get the threshold and the dice score + for i, dice_score_file in enumerate(dice_score_files): + test_dice_results = {} + with open(dice_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + test_dice_results[key] = float(value) + # convert to a df with name and dice score + test_dice_results_list[i] = pd.DataFrame(list(test_dice_results.items()), columns=['name', 'dice_score']) + # Create a column which stores the threshold + test_dice_results_list[i]['threshold'] = str(Path(dice_score_file).name).replace('dice_scores_', '').replace('.txt', '').replace('_', '.') + + # Concatenate all the dataframes + test_dice_results = pd.concat(test_dice_results_list) + + # Plot + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='threshold', y='dice_score', data=test_dice_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Dice scores per threshold') + plt.show() + + # Save the plot + plt.savefig(path_to_outputs + '/dice_scores_contrast.png') + print(f"Saved the dice_scores plot in {path_to_outputs}") + + # Print the average dice score per threshold + for thresh in test_dice_results['threshold'].unique(): + print(f"Threshold: {thresh} - Average dice score: {test_dice_results[test_dice_results['threshold'] == thresh]['dice_score'].mean()}") + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file From 38c1594c4ecb54dfdcabd70b590857e416f94d28 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 5 Sep 2024 11:15:32 -0400 Subject: [PATCH 102/108] fixed threshold to 0.5 --- monai/test_model.py | 14 +++++++++----- monai/test_model_tta.py | 18 +++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/monai/test_model.py b/monai/test_model.py index ec96b37..f958bda 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -146,16 +146,20 @@ def main(): else: batch["pred"] = F.relu(batch["pred"]) - # compute the dice score - dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + # Threshold the prediction with 0.5 based on this investigation: https://github.com/ivadomed/ms-lesion-agnostic/issues/32 + pred_cpu = batch["pred"].cpu() + pred_cpu[pred_cpu < 0.5] = 0 + pred_cpu[pred_cpu >= 0.5] = 1 + # Compute the dice score + dice = dice_score(pred_cpu, batch["label"].cpu) # post-process the prediction post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] - - pred = post_test_out[0]['pred'].cpu() - # Threshold the prediction + # Threshold the prediction with 0.5 before saving + pred = post_test_out[0]['pred'].cpu() pred[pred < 0.5] = 0 + pred[pred >= 0.5] = 1 # Get file name file_name = test_files[i]["image"].split("/")[-1].split(".")[0] diff --git a/monai/test_model_tta.py b/monai/test_model_tta.py index 07871ed..8c86186 100644 --- a/monai/test_model_tta.py +++ b/monai/test_model_tta.py @@ -177,24 +177,28 @@ def main(): else: batch["pred"] = F.relu(batch["pred"]) - # compute the dice score - dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + # Threshold the prediction with 0.5 based on this investigation: https://github.com/ivadomed/ms-lesion-agnostic/issues/32 + pred_cpu = batch["pred"].cpu() + pred_cpu[pred_cpu < 0.5] = 0 + pred_cpu[pred_cpu >= 0.5] = 1 + # Compute the dice score + dice = dice_score(pred_cpu, batch["label"].cpu) # post-process the prediction post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] - - pred = post_test_out[0]['pred'].cpu() - # Threshold the prediction + # Threshold the prediction with 0.5 before saving + pred = post_test_out[0]['pred'].cpu() pred[pred < 0.5] = 0 + pred[pred >= 0.5] = 1 # Get file name file_name = test_files[i]["image"].split("/")[-1].split(".")[0] - print(f"Saving {file_name}") + print(f"Saving {file_name}: dice score = {dice}") # Save the prediction pred_saver = SaveImage( - output_dir=output_dir , output_postfix="pred", output_ext=f"_{k}.nii.gz", + output_dir=output_dir , output_postfix=f"pred_{k}", output_ext=".nii.gz", separate_folder=False, print_log=False) # save the prediction pred_saver(pred) From e3b1effe985100c5a3869775548ba8e75fe5986c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 11 Sep 2024 10:57:14 -0400 Subject: [PATCH 103/108] fixed parenthesis when computing dice score --- monai/test_model.py | 2 +- monai/test_model_tta.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/test_model.py b/monai/test_model.py index f958bda..7f46a50 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -151,7 +151,7 @@ def main(): pred_cpu[pred_cpu < 0.5] = 0 pred_cpu[pred_cpu >= 0.5] = 1 # Compute the dice score - dice = dice_score(pred_cpu, batch["label"].cpu) + dice = dice_score(pred_cpu, batch["label"].cpu()) # post-process the prediction post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] diff --git a/monai/test_model_tta.py b/monai/test_model_tta.py index 8c86186..1337f13 100644 --- a/monai/test_model_tta.py +++ b/monai/test_model_tta.py @@ -182,7 +182,7 @@ def main(): pred_cpu[pred_cpu < 0.5] = 0 pred_cpu[pred_cpu >= 0.5] = 1 # Compute the dice score - dice = dice_score(pred_cpu, batch["label"].cpu) + dice = dice_score(pred_cpu, batch["label"].cpu()) # post-process the prediction post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] From a37c6a013a6fe6ae2cc6f5958081dc44e8f64c5d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 11 Sep 2024 10:57:37 -0400 Subject: [PATCH 104/108] added script to compute TTA with 2nd strategy --- monai/compute_performance_tta_sum.py | 130 +++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 monai/compute_performance_tta_sum.py diff --git a/monai/compute_performance_tta_sum.py b/monai/compute_performance_tta_sum.py new file mode 100644 index 0000000..a8702a5 --- /dev/null +++ b/monai/compute_performance_tta_sum.py @@ -0,0 +1,130 @@ +""" +This script is used to sum all the image predictions of the same subject, then threshold to 0.5 and then compute the dice score. + +Input: + --path-pred: Path to the directory containing the predictions + --path-json: Path to the json file containing the data split + --split: Data split to use (train, validation, test) + --output-dir: Output directory to save the dice scores + +Output: + None + +Example: + python compute_performance_tta_sum.py --path-pred /path/to/predictions --path-json /path/to/data.json --split test --output-dir /path/to/output + +Author: Pierre-Louis Benveniste +""" + +import os +import numpy as np +import argparse +from pathlib import Path +import json +import nibabel as nib +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--path-pred", type=str, required=True, help="Path to the directory containing the predictions") + parser.add_argument("--path-json", type=str, required=True, help="Path to the json file containing the data split") + parser.add_argument("--split", type=str, required=True, help="Data split to use (train, validation, test)") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory to save the dice scores") + return parser.parse_args() + + +def dice_score(prediction, groundtruth, smooth=1.): + numer = (prediction * groundtruth).sum() + denor = (prediction + groundtruth).sum() + # loss = (2 * numer + self.smooth) / (denor + self.smooth) + dice = (2 * numer + smooth) / (denor + smooth) + return dice + + +def main(): + + # Parse arguments + args = parse_args() + path_pred = args.path_pred + path_json = args.path_json + split = args.split + output_dir = args.output_dir + + # Create the output directory + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Get all the predictions (with rglob) + predictions = list(Path(path_pred).rglob("*.nii.gz")) + + # List of subjects + subjects = [pred.name for pred in predictions] + + n_tta = 10 + + for i in range(n_tta): + # Remove the _pred_0, _pred_1 ... _pred_9 at the end of the name + subjects = [sub.replace(f"_pred_{i}", "") for sub in subjects] + + # Open the conversion dictionary (its a json file) + with open(path_json, "r") as f: + conversion_dict = json.load(f) + conversion_dict = conversion_dict[split] + + # Dict of dice score + dice_scores = {} + + # Iterate over the subjects in the predictions + for subject in subjects: + print(f"Processing subject {subject}") + + # Get all predictions corresponding to the subject + subject_predictions = [str(pred) for pred in predictions if subject.replace(".nii.gz", "") in pred.name] + # print(subject_predictions) + + # Find the corresponding label from the conversion dict + + image_dict = [data for data in conversion_dict if subject in data["image"]] + label = image_dict[0]["label"] + image = image_dict[0]["image"] + + # We now sum all the predictions + summed_prediction = None + for pred in subject_predictions: + pred_data = nib.load(pred).get_fdata() + if summed_prediction is None: + summed_prediction = pred_data + else: + summed_prediction += pred_data + + # Threshold the summed prediction + summed_prediction[summed_prediction >= 0.5] = 1 + summed_prediction[summed_prediction < 0.5] = 0 + + # Load the label + label_data = nib.load(label).get_fdata() + + # Compute dice score + dice = dice_score(summed_prediction, label_data) + # print(f"Dice score for summed prediction: {dice}") + + # Compare the dice score with the individual predictions + for pred in subject_predictions: + pred_data = nib.load(pred).get_fdata() + dice_pred = dice_score(pred_data, label_data) + # print(f"Dice score for {pred}: {dice_pred}") + + # Save the dice score + dice_scores[image] = dice + + # Save the results + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + + return None + + +if __name__ == "__main__": + main() From 60506b423fbcbb0d9e95ae3194ea8f57e450c9de Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 11 Sep 2024 21:50:58 -0400 Subject: [PATCH 105/108] added script for mednext inference --- monai/test_model_mednext.py | 193 ++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 monai/test_model_mednext.py diff --git a/monai/test_model_mednext.py b/monai/test_model_mednext.py new file mode 100644 index 0000000..9c1e6fd --- /dev/null +++ b/monai/test_model_mednext.py @@ -0,0 +1,193 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +from nnunet_mednext import MedNeXt + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Dict of dice score + dice_scores = {} + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Create the data loader + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) + + # Load the model + net = MedNeXt( + in_channels=1, + n_channels=32, + n_classes=1, + exp_r=2, + kernel_size=3, + do_res=True, + do_res_up_down=True, + checkpoint_style="outside_block", + block_counts=[2,2,2,2,1,1,1,1,1] + ) + + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Run inference + with torch.no_grad(): + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # Threshold the prediction with 0.5 based on this investigation: https://github.com/ivadomed/ms-lesion-agnostic/issues/32 + pred_cpu = batch["pred"].cpu() + pred_cpu[pred_cpu < 0.5] = 0 + pred_cpu[pred_cpu >= 0.5] = 1 + # Compute the dice score + dice = dice_score(pred_cpu, batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + # Threshold the prediction with 0.5 before saving + pred = post_test_out[0]['pred'].cpu() + pred[pred < 0.5] = 0 + pred[pred >= 0.5] = 1 + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}: dice score = {dice}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[test_files[i]["image"]] = dice + + test_input.detach() + + + # Save the dice scores + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file From a79cb44e3ebfac26e2957b69fa450426d833c01b Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Wed, 25 Sep 2024 17:51:38 -0400 Subject: [PATCH 106/108] added computation of f1-score, ppv and sensitivity --- monai/plot_performance.py | 109 +++++++++++++++++++++++++++++++------- monai/test_model.py | 26 +++++++-- monai/utils/utils.py | 68 ++++++++++++------------ 3 files changed, 146 insertions(+), 57 deletions(-) diff --git a/monai/plot_performance.py b/monai/plot_performance.py index 7b586c5..2fc719d 100644 --- a/monai/plot_performance.py +++ b/monai/plot_performance.py @@ -86,47 +86,118 @@ def main(): resolution_counts = test_dice_results['orientation'].value_counts() test_dice_results['orientation_count'] = test_dice_results['orientation'].apply(lambda x: x + f' (n={resolution_counts[x]})') - # plot a violin plot per contrast + # then we add the ppv score to the df + ppv_score_file = path_to_outputs + '/ppv_scores.txt' + ppv_scores = {} + with open(ppv_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + ppv_scores[key] = float(value) + test_dice_results['ppv_score'] = test_dice_results['name'].apply(lambda x: ppv_scores[x]) + + # then we add the f1 score to the df + f1_score_file = path_to_outputs + '/f1_scores.txt' + f1_scores = {} + with open(f1_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + f1_scores[key] = float(value) + test_dice_results['f1_score'] = test_dice_results['name'].apply(lambda x: f1_scores[x]) + + # then we add the sensitivity score to the df + sensitivity_score_file = path_to_outputs + '/sensitivity_scores.txt' + sensitivity_scores = {} + with open(sensitivity_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + sensitivity_scores[key] = float(value) + test_dice_results['sensitivity_score'] = test_dice_results['name'].apply(lambda x: sensitivity_scores[x]) + + # We rename th df to metrics_results + metrics_results = test_dice_results + + # Sort the order of the lines by contrast (alphabetical order) + metrics_results = metrics_results.sort_values(by='contrast').reset_index(drop=True) + + # plot a violin plot per contrast for dice scores plt.figure(figsize=(20, 10)) plt.grid(True) - sns.violinplot(x='contrast_count', y='dice_score', data=test_dice_results) + sns.violinplot(x='contrast_count', y='dice_score', data=metrics_results) # y ranges from -0.2 to 1.2 plt.ylim(-0.2, 1.2) plt.title('Dice scores per contrast') plt.show() - - # Save the plot + # # Save the plot plt.savefig(path_to_outputs + '/dice_scores_contrast.png') - print(f"Saved the dice_scores plot in {path_to_outputs}") + print(f"Saved the dice plot in {path_to_outputs}") - # plot a violin plot per site + # plot a violin plot per contrast for ppv scores plt.figure(figsize=(20, 10)) plt.grid(True) - sns.violinplot(x='site_count', y='dice_score', data=test_dice_results) + sns.violinplot(x='contrast_count', y='ppv_score', data=metrics_results) # y ranges from -0.2 to 1.2 plt.ylim(-0.2, 1.2) - plt.title('Dice scores per site') + plt.title('PPV scores per contrast') plt.show() - # Save the plot - plt.savefig(path_to_outputs + '/dice_scores_site.png') - print(f"Saved the dice_scores per site plot in {path_to_outputs}") + # # Save the plot + plt.savefig(path_to_outputs + '/ppv_scores_contrast.png') + print(f"Saved the ppv plot in {path_to_outputs}") - # plot a violin plot per resolution + # plot a violin plot per contrast for f1 scores plt.figure(figsize=(20, 10)) plt.grid(True) - sns.violinplot(x='orientation_count', y='dice_score', data=test_dice_results) + sns.violinplot(x='contrast_count', y='f1_score', data=metrics_results) # y ranges from -0.2 to 1.2 plt.ylim(-0.2, 1.2) - plt.title('Dice scores per orientation') + plt.title('F1 scores per contrast') plt.show() - # Save the plot - plt.savefig(path_to_outputs + '/dice_scores_orientation.png') - print(f"Saved the dice_scores per orientation plot in {path_to_outputs}") + # # Save the plot + plt.savefig(path_to_outputs + '/f1_scores_contrast.png') + print(f"Saved the F1 plot in {path_to_outputs}") + + # plot a violin plot per contrast for f1 scores + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='contrast_count', y='sensitivity_score', data=metrics_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Sensitivity scores per contrast') + plt.show() - # Save the test_dice_results dataframe - test_dice_results.to_csv(path_to_outputs + '/dice_results.csv', index=False) + # # Save the plot + plt.savefig(path_to_outputs + '/sensitivity_scores_contrast.png') + print(f"Saved the sensitivity plot in {path_to_outputs}") + + # # plot a violin plot per site + # plt.figure(figsize=(20, 10)) + # plt.grid(True) + # sns.violinplot(x='site_count', y='dice_score', data=test_dice_results, order = ['bavaria-quebec (n=208)', 'sct-testing-large (n=233)', 'canproco (n=71)','nih (n=25)','basel (n=32)']) + # # y ranges from -0.2 to 1.2 + # plt.ylim(-0.2, 1.2) + # plt.title('Dice scores per site') + # plt.show() + + # # Save the plot + # plt.savefig(path_to_outputs + '/dice_scores_site.png') + # print(f"Saved the dice_scores per site plot in {path_to_outputs}") + + # # plot a violin plot per resolution + # plt.figure(figsize=(20, 10)) + # plt.grid(True) + # sns.violinplot(x='orientation_count', y='dice_score', data=test_dice_results, order = ['iso (n=58)', 'ax (n=343)', 'sag (n=168)']) + # # y ranges from -0.2 to 1.2 + # plt.ylim(-0.2, 1.2) + # plt.title('Dice scores per orientation') + # plt.show() + + # # Save the plot + # plt.savefig(path_to_outputs + '/dice_scores_orientation.png') + # print(f"Saved the dice_scores per orientation plot in {path_to_outputs}") + + # # Save the test_dice_results dataframe + # test_dice_results.to_csv(path_to_outputs + '/dice_results.csv', index=False) return None diff --git a/monai/test_model.py b/monai/test_model.py index 7f46a50..d0bb1d5 100644 --- a/monai/test_model.py +++ b/monai/test_model.py @@ -20,7 +20,7 @@ import torch from monai.inferers import sliding_window_inference import torch.nn.functional as F -from utils.utils import dice_score +from utils.utils import dice_score, lesion_f1_score, lesion_ppv, lesion_sensitivity import argparse import yaml import torch.multiprocessing @@ -63,8 +63,11 @@ def main(): if not os.path.exists(output_dir): os.makedirs(output_dir) - # Dict of dice score + # Dict of scores dice_scores = {} + ppv_scores = {} + sensitivity_scores = {} + f1_scores = {} # Load the data test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) @@ -152,6 +155,9 @@ def main(): pred_cpu[pred_cpu >= 0.5] = 1 # Compute the dice score dice = dice_score(pred_cpu, batch["label"].cpu()) + ppv = lesion_ppv(batch["label"].cpu(), pred_cpu) + sensitivity = lesion_sensitivity(batch["label"].cpu(), pred_cpu) + f1 = lesion_f1_score(batch["label"].cpu(), pred_cpu) # post-process the prediction post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] @@ -163,7 +169,7 @@ def main(): # Get file name file_name = test_files[i]["image"].split("/")[-1].split(".")[0] - print(f"Saving {file_name}: dice score = {dice}") + print(f"Saving {file_name}: dice score = {dice}, f1 = {f1}") # Save the prediction pred_saver = SaveImage( @@ -172,8 +178,11 @@ def main(): # save the prediction pred_saver(pred) - # Save the dice score + # Save the scores dice_scores[test_files[i]["image"]] = dice + ppv_scores[test_files[i]["image"]] = ppv + sensitivity_scores[test_files[i]["image"]] = sensitivity + f1_scores[test_files[i]["image"]] = f1 test_input.detach() @@ -182,6 +191,15 @@ def main(): with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: for key, value in dice_scores.items(): f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "ppv_scores.txt"), "w") as f: + for key, value in ppv_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "sensitivity_scores.txt"), "w") as f: + for key, value in sensitivity_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "f1_scores.txt"), "w") as f: + for key, value in f1_scores.items(): + f.write(f"{key}: {value}\n") if __name__ == "__main__": diff --git a/monai/utils/utils.py b/monai/utils/utils.py index 0b7a841..c462e70 100644 --- a/monai/utils/utils.py +++ b/monai/utils/utils.py @@ -122,34 +122,41 @@ def lesion_wise_tp_fp_fn(truth, prediction): return tp, fp, fn -def lesion_sensitivity(truth, prediction): +def lesion_f1_score(truth, prediction): """ - Computes the lesion-wise sensitivity between two masks + Computes the lesion-wise F1-score between two masks by defining true positive lesions (tp), false positive lesions (fp) + and false negative lesions (fn) using 3D connected-component-analysis. + + Masks are considered true positives if at least one voxel overlaps between the truth and the prediction. + Returns ------- - sensitivity (float): Lesion-wise sensitivity as float. + f1_score : float + Lesion-wise F1-score as float. Max score = 1 Min score = 0 If both images are empty (tp + fp + fn =0) = empty_value """ empty_value = 1.0 # Value to which to default if there are no labels. Default: 1.0. - if np.sum(truth) == 0 and np.sum(prediction)==0: + if not np.any(truth) and not np.any(prediction): # Both reference and prediction are empty --> model learned correctly return 1.0 - # if the prediction is not empty and ref is empty, it's false positive + elif np.any(truth) and not np.any(prediction): + # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) + return 0.0 + # if the predction is not empty and ref is empty, it's false positive # if both are not empty, it's true positive else: + tp, fp, fn = lesion_wise_tp_fp_fn(truth, prediction) + f1_score = empty_value - tp, _, fn = lesion_wise_tp_fp_fn(truth, prediction) - sensitivity = empty_value - - # Compute sensitivity - denom = tp + fn + # Compute f1_score + denom = tp + (fp + fn)/2 if(denom != 0): - sensitivity = tp / denom - return sensitivity - + f1_score = tp / denom + return f1_score + def lesion_ppv(truth, prediction): """ @@ -161,10 +168,10 @@ def lesion_ppv(truth, prediction): Min score = 0 If both images are empty (tp + fp + fn =0) = empty_value """ - if np.sum(truth) == 0 and np.sum(prediction)==0: + if not np.any(truth) and not np.any(prediction): # Both reference and prediction are empty --> model learned correctly return 1.0 - elif np.sum(truth) != 0 and np.sum(prediction)==0: + elif np.any(truth) and not np.any(prediction): # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) return 0.0 # if the predction is not empty and ref is empty, it's false positive @@ -179,42 +186,35 @@ def lesion_ppv(truth, prediction): if(denom != 0): ppv = tp / denom return ppv - -def lesion_f1_score(truth, prediction): - """ - Computes the lesion-wise F1-score between two masks by defining true positive lesions (tp), false positive lesions (fp) - and false negative lesions (fn) using 3D connected-component-analysis. - - Masks are considered true positives if at least one voxel overlaps between the truth and the prediction. +def lesion_sensitivity(truth, prediction): + """ + Computes the lesion-wise sensitivity between two masks Returns ------- - f1_score : float - Lesion-wise F1-score as float. + sensitivity (float): Lesion-wise sensitivity as float. Max score = 1 Min score = 0 If both images are empty (tp + fp + fn =0) = empty_value """ empty_value = 1.0 # Value to which to default if there are no labels. Default: 1.0. - if np.sum(truth) == 0 and np.sum(prediction)==0: + if not np.any(truth) and not np.any(prediction): # Both reference and prediction are empty --> model learned correctly return 1.0 - elif np.sum(truth) != 0 and np.sum(prediction)==0: - # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) - return 0.0 # if the predction is not empty and ref is empty, it's false positive # if both are not empty, it's true positive else: - tp, fp, fn = lesion_wise_tp_fp_fn(truth, prediction) - f1_score = empty_value - # Compute f1_score - denom = tp + (fp + fn)/2 + tp, _, fn = lesion_wise_tp_fp_fn(truth, prediction) + sensitivity = empty_value + + # Compute sensitivity + denom = tp + fn if(denom != 0): - f1_score = tp / denom - return f1_score + sensitivity = tp / denom + return sensitivity def remove_small_lesions(lesion_seg, resolution, min_volume=7.5): From 953de3f8fe4a02a7507ff641e19fa8aa99db6b2c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Thu, 31 Oct 2024 13:50:29 -0400 Subject: [PATCH 107/108] fixed utils command --- monai/config_test.yml | 9 +++++++-- monai/utils/utils.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/monai/config_test.yml b/monai/config_test.yml index 21ed127..4fc3c59 100644 --- a/monai/config_test.yml +++ b/monai/config_test.yml @@ -1,6 +1,9 @@ # dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json # dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json +# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-08-13_seed42_lesionOnly.json +# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_optThresh.json +# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/fake.json pixdim : [0.7, 0.7, 0.7] spatial_size : [64, 128, 128] @@ -9,8 +12,10 @@ attention_unet_strides : [2, 2, 2, 2, 2] # path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt # path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp/best_model.pth/best_model.ckpt -path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/thresholding_optimisation/best_model.pth/best_model.ckpt +path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/best_model.pth/best_model.ckpt +# path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-09-02_12:14:28.124188/best_model.pth/best_model.ckpt # output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/ # output_dir : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp -output_dir : /home/plbenveniste/net/ms-lesion-agnostic/thresholding_optimisation +output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/ +# output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-09-02_12:14:28.124188/ \ No newline at end of file diff --git a/monai/utils/utils.py b/monai/utils/utils.py index c462e70..1e116ae 100644 --- a/monai/utils/utils.py +++ b/monai/utils/utils.py @@ -178,7 +178,7 @@ def lesion_ppv(truth, prediction): # if both are not empty, it's true positive else: tp, fp, _ = lesion_wise_tp_fp_fn(truth, prediction) - # ppv = 1.0 + ppv = 1.0 # Compute ppv denom = tp + fp From 15d13b45b33eb3e355c04ad5012e9e7c3b97375e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Benveniste Date: Tue, 17 Dec 2024 10:31:11 -0500 Subject: [PATCH 108/108] removed dataset aggregation scripts --- monai/1_create_msd_data.py | 311 ------------------- monai/1_create_msd_data_head_cropped.py | 395 ------------------------ 2 files changed, 706 deletions(-) delete mode 100644 monai/1_create_msd_data.py delete mode 100644 monai/1_create_msd_data_head_cropped.py diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py deleted file mode 100644 index ddce188..0000000 --- a/monai/1_create_msd_data.py +++ /dev/null @@ -1,311 +0,0 @@ -""" -This file creates the MSD-style JSON datalist to train an nnunet model using monai. -The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large. - -Arguments: - -pd, --path-data: Path to the data set directory - -po, --path-out: Path to the output directory where dataset json is saved - --lesion-only: Use only masks which contain some lesions - --seed: Seed for reproducibility - --canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo - -Example: - python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt - -TO DO: - * - -Pierre-Louis Benveniste -""" - -import os -import json -from tqdm import tqdm -import yaml -import argparse -from loguru import logger -from sklearn.model_selection import train_test_split -from datetime import date -from pathlib import Path -import nibabel as nib -import numpy as np -import skimage -from utils.image import Image - - -def get_parser(): - """ - Get parser for script create_msd_data.py - - Input: - None - - Returns: - parser : argparse object - """ - - parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') - - parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') - parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') - parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo') - parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') - parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") - - return parser - - -def count_lesion(label_file): - """ - This function takes a label file and counts the number of lesions in it. - - Input: - label_file : str : Path to the label file - - Returns: - count : int : Number of lesions in the label file - total_volume : float : Total volume of lesions in the label file - """ - - label = nib.load(label_file) - label_data = label.get_fdata() - - # get the total volume of the lesions - total_volume = np.sum(label_data) - resolution = label.header.get_zooms() - total_volume = total_volume * np.prod(resolution) - - # get the number of lesions - _, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True) - - return total_volume, nb_lesions - - -def get_orientation(image_path): - """ - This function takes an image file as input and returns its orientation. - - Input: - image_path : str : Path to the image file - - Returns: - orientation : str : Orientation of the image - """ - img = Image(str(image_path)) - img.change_orientation('RPI') - # Get pixdim - pixdim = img.dim[4:7] - # If all are the same, the image is isotropic - if np.allclose(pixdim, pixdim[0], atol=1e-3): - orientation = 'iso' - return orientation - # Elif, the lowest arg is 0 then the orientation is sagittal - elif np.argmax(pixdim) == 0: - orientation = 'sag' - # Elif, the lowest arg is 1 then the orientation is coronal - elif np.argmax(pixdim) == 1: - orientation = 'cor' - # Else the orientation is axial - else: - orientation = 'ax' - return orientation - - -def main(): - """ - This is the main function of the script. - - Input: - None - - Returns: - None - """ - # Get the arguments - parser = get_parser() - args = parser.parse_args() - - root = args.path_data - seed = args.seed - - # Get all subjects - basel_path = Path(os.path.join(root, "basel-mp2rage")) - bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched")) - canproco_path = Path(os.path.join(root, "canproco")) - nih_path = Path(os.path.join(root, "nih-ms-mp2rage")) - sct_testing_path = Path(os.path.join(root, "sct-testing-large")) - - derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) - derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) - derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) - derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz')) - derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) - - # Path to the file containing the list of subjects to exclude from CanProCo - if args.canproco_exclude is not None: - with open(args.canproco_exclude, 'r') as file: - canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader) - # only keep the contrast psir and stir - canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] - - derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct - logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}") - - # create one json file with 60-20-20 train-val-test split - train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 - train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed) - # Use the training split to further split into training and validation splits - train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio), - random_state=args.seed, ) - # sort the subjects - train_derivatives = sorted(train_derivatives) - val_derivatives = sorted(val_derivatives) - test_derivatives = sorted(test_derivatives) - - # logger.info(f"Number of training subjects: {len(train_subjects)}") - # logger.info(f"Number of validation subjects: {len(val_subjects)}") - # logger.info(f"Number of testing subjects: {len(test_subjects)}") - - # dump train/val/test splits into a yaml file - with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: - yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) - - # keys to be defined in the dataset_0.json - params = {} - params["description"] = "ms-lesion-agnostic" - params["labels"] = { - "0": "background", - "1": "ms-lesion-seg" - } - params["license"] = "plb" - params["modality"] = { - "0": "MRI" - } - params["name"] = "ms-lesion-agnostic" - params["seed"] = args.seed - params["reference"] = "NeuroPoly" - params["tensorImageSize"] = "3D" - - train_derivatives_dict = {"train": train_derivatives} - val_derivatives_dict = {"validation": val_derivatives} - test_derivatives_dict = {"test": test_derivatives} - all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict] - - # iterate through the train/val/test splits and add those which have both image and label - for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"): - - for name, derivs_list in derivatives_dict.items(): - - temp_list = [] - for subject_no, derivative in enumerate(derivs_list): - - - temp_data_basel = {} - temp_data_bavaria = {} - temp_data_canproco = {} - temp_data_nih = {} - temp_data_sct = {} - - # Basel - if 'basel-mp2rage' in str(derivative): - relative_path = derivative.relative_to(basel_path).parent - temp_data_basel["label"] = str(derivative) - temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) - temp_data_basel["total_lesion_volume"] = total_lesion_volume - temp_data_basel["nb_lesions"] = nb_lesions - temp_data_basel["site"]='basel' - temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_basel) - - # Bavaria-quebec - elif 'bavaria-quebec-spine-ms' in str(derivative): - temp_data_bavaria["label"] = str(derivative) - temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) - temp_data_bavaria["total_lesion_volume"] = total_lesion_volume - temp_data_bavaria["nb_lesions"] = nb_lesions - temp_data_bavaria["site"]='bavaria-quebec' - temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_bavaria) - - # Canproco - elif 'canproco' in str(derivative): - subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '') - subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') - if subject_id in canproco_exclude_list: - continue - temp_data_canproco["label"] = str(derivative) - temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) - temp_data_canproco["total_lesion_volume"] = total_lesion_volume - temp_data_canproco["nb_lesions"] = nb_lesions - temp_data_canproco["site"]='canproco' - temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_canproco) - - # nih-ms-mp2rage - elif 'nih-ms-mp2rage' in str(derivative): - temp_data_nih["label"] = str(derivative) - temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) - temp_data_nih["total_lesion_volume"] = total_lesion_volume - temp_data_nih["nb_lesions"] = nb_lesions - temp_data_nih["site"]='nih' - temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_nih) - - # sct-testing-large - elif 'sct-testing-large' in str(derivative): - temp_data_sct["label"] = str(derivative) - temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): - total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) - temp_data_sct["total_lesion_volume"] = total_lesion_volume - temp_data_sct["nb_lesions"] = nb_lesions - temp_data_sct["site"]='sct-testing-large' - temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_sct) - - params[name] = temp_list - logger.info(f"Number of images in {name} set: {len(temp_list)}") - params["numTest"] = len(params["test"]) - params["numTraining"] = len(params["train"]) - params["numValidation"] = len(params["validation"]) - # Print total number of images - logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}") - - final_json = json.dumps(params, indent=4, sort_keys=True) - if not os.path.exists(args.path_out): - os.makedirs(args.path_out, exist_ok=True) - if args.lesion_only: - jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w") - else: - jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") - jsonFile.write(final_json) - jsonFile.close() - - return None - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/monai/1_create_msd_data_head_cropped.py b/monai/1_create_msd_data_head_cropped.py deleted file mode 100644 index cf28f9e..0000000 --- a/monai/1_create_msd_data_head_cropped.py +++ /dev/null @@ -1,395 +0,0 @@ -""" -This file creates the MSD-style JSON datalist to train an nnunet model using monai. -The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large. - -Arguments: - -pd, --path-data: Path to the data set directory - -po, --path-out: Path to the output directory where dataset json is saved - --lesion-only: Use only masks which contain some lesions - --seed: Seed for reproducibility - --canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo - -Example: - python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt - -TO DO: - * - -Pierre-Louis Benveniste -""" - -import os -import json -from tqdm import tqdm -import yaml -import argparse -from loguru import logger -from sklearn.model_selection import train_test_split -from datetime import date -from pathlib import Path -import nibabel as nib -import numpy as np -import skimage -from utils.image import Image - - -def get_parser(): - """ - Get parser for script create_msd_data.py - - Input: - None - - Returns: - parser : argparse object - """ - - parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') - - parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') - parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') - parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo') - parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') - parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") - - return parser - - -def count_lesion(label_file): - """ - This function takes a label file and counts the number of lesions in it. - - Input: - label_file : str : Path to the label file - - Returns: - count : int : Number of lesions in the label file - total_volume : float : Total volume of lesions in the label file - """ - - label = nib.load(label_file) - label_data = label.get_fdata() - - # get the total volume of the lesions - total_volume = np.sum(label_data) - resolution = label.header.get_zooms() - total_volume = total_volume * np.prod(resolution) - - # get the number of lesions - _, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True) - - return total_volume, nb_lesions - - -def get_orientation(image_path): - """ - This function takes an image file as input and returns its orientation. - - Input: - image_path : str : Path to the image file - - Returns: - orientation : str : Orientation of the image - """ - img = Image(str(image_path)) - img.change_orientation('RPI') - # Get pixdim - pixdim = img.dim[4:7] - # If all are the same, the image is isotropic - if np.allclose(pixdim, pixdim[0], atol=1e-3): - orientation = 'iso' - return orientation - # Elif, the lowest arg is 0 then the orientation is sagittal - elif np.argmax(pixdim) == 0: - orientation = 'sag' - # Elif, the lowest arg is 1 then the orientation is coronal - elif np.argmax(pixdim) == 1: - orientation = 'cor' - # Else the orientation is axial - else: - orientation = 'ax' - return orientation - - -def cropping_saving(image_path, label_path, cropped_head_data_folder): - """ - This function does the following action successively: - - copy image and label to the output folder for cropped head data - - segments the spinal cord on the image - - crops the image and label to the remove the superior part of the head (what is above the seg of the spinal cord) - - save the cropped image and label in the output folder - - Input: - image_path : str : Path to the image file - label_path : str : Path to the label file - cropped_head_data_folder : str : Path to the output folder - - Returns: - image_cropped : str : Path to the cropped image - seg_cropped : str : Path to the cropped label - """ - - # Copy image and label to the output folder for cropped head data - image_cropped = os.path.join(cropped_head_data_folder, image_path.split('/')[-1]) - seg_cropped = os.path.join(cropped_head_data_folder, label_path.split('/')[-1]) - img = Image(image_path) - img.change_orientation('RPI') - img.save(image_cropped) - seg = Image(label_path) - seg.change_orientation('RPI') - seg.save(seg_cropped) - - # Segment the spinal cord on the image - ## Create a temporary folder - temp_folder = os.path.join(cropped_head_data_folder, "temp") - os.makedirs(temp_folder, exist_ok=True) - ## Segment the spinal cord - os.system(f"sct_deepseg -i {image_cropped} -o {os.path.join(temp_folder, 'seg.nii.gz')} -task seg_sc_contrast_agnostic -thr 0.5") - ## Get the highest point of the spinal cord - spinal_cord_seg = Image(os.path.join(temp_folder, 'seg.nii.gz')) - spinal_cord_seg.change_orientation('RPI') - spinal_cord_seg_data = spinal_cord_seg.data - spinal_cord_superior = np.max(np.where(spinal_cord_seg_data == 1)[2]) - ## Remove the temporary folder - os.system(f"rm -rf {temp_folder}") - - # Crop the image and label to the remove the superior part of the head (what is above the seg of the spinal cord) - os.system(f"sct_crop_image -i {image_cropped} -o {image_cropped} -zmax {spinal_cord_superior}") - os.system(f"sct_crop_image -i {seg_cropped} -o {seg_cropped} -zmax {spinal_cord_superior}") - - return image_cropped, seg_cropped - - -def main(): - """ - This is the main function of the script. - - Input: - None - - Returns: - None - """ - # Get the arguments - parser = get_parser() - args = parser.parse_args() - - root = args.path_data - seed = args.seed - - # Get all subjects - basel_path = Path(os.path.join(root, "basel-mp2rage")) - bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched")) - canproco_path = Path(os.path.join(root, "canproco")) - nih_path = Path(os.path.join(root, "nih-ms-mp2rage")) - sct_testing_path = Path(os.path.join(root, "sct-testing-large")) - - derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) - derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) - derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) - derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz')) - derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) - - # Make the folder for the cropped images - cropped_head_data_folder = os.path.join(args.path_out, "cropped_head_data") - os.makedirs(args.path_out, exist_ok=True) - os.makedirs(cropped_head_data_folder, exist_ok=True) - os.makedirs(os.path.join(cropped_head_data_folder, "basel-mp2rage"), exist_ok=True) - os.makedirs(os.path.join(cropped_head_data_folder, "bavaria-quebec-spine-ms-unstitched"), exist_ok=True) - os.makedirs(os.path.join(cropped_head_data_folder, "canproco"), exist_ok=True) - os.makedirs(os.path.join(cropped_head_data_folder, "nih-ms-mp2rage"), exist_ok=True) - os.makedirs(os.path.join(cropped_head_data_folder, "sct-testing-large"), exist_ok=True) - - # Path to the file containing the list of subjects to exclude from CanProCo - if args.canproco_exclude is not None: - with open(args.canproco_exclude, 'r') as file: - canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader) - # only keep the contrast psir and stir - canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] - - derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct - logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}") - - # create one json file with 60-20-20 train-val-test split - train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 - train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed) - # Use the training split to further split into training and validation splits - train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio), - random_state=args.seed, ) - # sort the subjects - train_derivatives = sorted(train_derivatives) - val_derivatives = sorted(val_derivatives) - test_derivatives = sorted(test_derivatives) - - # logger.info(f"Number of training subjects: {len(train_subjects)}") - # logger.info(f"Number of validation subjects: {len(val_subjects)}") - # logger.info(f"Number of testing subjects: {len(test_subjects)}") - - # keys to be defined in the dataset_0.json - params = {} - params["description"] = "ms-lesion-agnostic" - params["labels"] = { - "0": "background", - "1": "ms-lesion-seg" - } - params["license"] = "plb" - params["modality"] = { - "0": "MRI" - } - params["name"] = "ms-lesion-agnostic" - params["seed"] = args.seed - params["reference"] = "NeuroPoly" - params["tensorImageSize"] = "3D" - - train_derivatives_dict = {"train": train_derivatives} - val_derivatives_dict = {"validation": val_derivatives} - test_derivatives_dict = {"test": test_derivatives} - all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict] - - # iterate through the train/val/test splits and add those which have both image and label - for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"): - - for name, derivs_list in derivatives_dict.items(): - - temp_list = [] - for subject_no, derivative in enumerate(derivs_list): - - - temp_data_basel = {} - temp_data_bavaria = {} - temp_data_canproco = {} - temp_data_nih = {} - temp_data_sct = {} - - # Basel - if 'basel-mp2rage' in str(derivative): - relative_path = derivative.relative_to(basel_path).parent - temp_data_basel["label"] = str(derivative) - temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): - # Cropping image and seg and saving to the cropped_head_data folder - image, seg = cropping_saving(temp_data_basel["image"], temp_data_basel["label"], os.path.join(cropped_head_data_folder, "basel-mp2rage")) - temp_data_basel["label"] = seg - temp_data_basel["image"] = image - - total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) - temp_data_basel["total_lesion_volume"] = total_lesion_volume - temp_data_basel["nb_lesions"] = nb_lesions - temp_data_basel["site"]='basel' - temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_basel) - - # Bavaria-quebec - elif 'bavaria-quebec-spine-ms' in str(derivative): - temp_data_bavaria["label"] = str(derivative) - temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): - # Cropping image and seg and saving to the cropped_head_data folder - image, seg = cropping_saving(temp_data_bavaria["image"], temp_data_bavaria["label"], os.path.join(cropped_head_data_folder, "bavaria-quebec-spine-ms-unstitched")) - temp_data_bavaria["label"] = seg - temp_data_bavaria["image"] = image - - total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) - temp_data_bavaria["total_lesion_volume"] = total_lesion_volume - temp_data_bavaria["nb_lesions"] = nb_lesions - temp_data_bavaria["site"]='bavaria-quebec' - temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_bavaria) - - # Canproco - elif 'canproco' in str(derivative): - subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '') - subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') - if subject_id in canproco_exclude_list: - continue - temp_data_canproco["label"] = str(derivative) - temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): - # Cropping image and seg and saving to the cropped_head_data folder - image, seg = cropping_saving(temp_data_canproco["image"], temp_data_canproco["label"], os.path.join(cropped_head_data_folder, "canproco")) - temp_data_canproco["label"] = seg - temp_data_canproco["image"] = image - - total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) - temp_data_canproco["total_lesion_volume"] = total_lesion_volume - temp_data_canproco["nb_lesions"] = nb_lesions - temp_data_canproco["site"]='canproco' - temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_canproco) - - # nih-ms-mp2rage - elif 'nih-ms-mp2rage' in str(derivative): - temp_data_nih["label"] = str(derivative) - temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): - # Cropping image and seg and saving to the cropped_head_data folder - image, seg = cropping_saving(temp_data_nih["image"], temp_data_nih["label"], os.path.join(cropped_head_data_folder, "nih-ms-mp2rage")) - temp_data_nih["label"] = seg - temp_data_nih["image"] = image - - total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) - temp_data_nih["total_lesion_volume"] = total_lesion_volume - temp_data_nih["nb_lesions"] = nb_lesions - temp_data_nih["site"]='nih' - temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_nih) - - # sct-testing-large - elif 'sct-testing-large' in str(derivative): - temp_data_sct["label"] = str(derivative) - temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') - if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): - # Cropping image and seg and saving to the cropped_head_data folder - image, seg = cropping_saving(temp_data_sct["image"], temp_data_sct["label"], os.path.join(cropped_head_data_folder, "sct-testing-large")) - temp_data_sct["label"] = seg - temp_data_sct["image"] = image - - total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) - temp_data_sct["total_lesion_volume"] = total_lesion_volume - temp_data_sct["nb_lesions"] = nb_lesions - temp_data_sct["site"]='sct-testing-large' - temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') - temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"]) - if args.lesion_only and nb_lesions == 0: - continue - temp_list.append(temp_data_sct) - - params[name] = temp_list - logger.info(f"Number of images in {name} set: {len(temp_list)}") - params["numTest"] = len(params["test"]) - params["numTraining"] = len(params["train"]) - params["numValidation"] = len(params["validation"]) - # Print total number of images - logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}") - - final_json = json.dumps(params, indent=4, sort_keys=True) - if not os.path.exists(args.path_out): - os.makedirs(args.path_out, exist_ok=True) - if args.lesion_only: - jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w") - else: - jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") - jsonFile.write(final_json) - jsonFile.close() - - # dump train/val/test splits into a yaml file - with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: - yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) - - return None - - -if __name__ == "__main__": - main() \ No newline at end of file