diff --git a/monai/README.md b/monai/README.md new file mode 100644 index 00000000..c73c3fc7 --- /dev/null +++ b/monai/README.md @@ -0,0 +1,73 @@ +## Instructions for running inference with the contrast-agnostic spinal cord segmentation model + +The following steps are required for using the contrast-agnostic model. + +### Setting up the environment and Installing dependencies + +The following commands show how to set up the environment. Note that the documentation assumes that the user has `conda` installed on their system. Instructions on installing `conda` can be found [here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). + +1. Create a conda environment with the following command: + +```bash +conda create -n venv_monai python=3.9 +``` + +2. Activate the environment with the following command: + +```bash +conda activate venv_monai +``` + +3. The list of necessary packages can be found in `requirements_inference.txt`. Use the following command for installation: + +```bash +pip install -r requirements_inference.txt +``` + +### Method 1: Running inference on a single image + +The script for running inference is `run_inference_single_image.py`. Please run +``` +python run_inference_single_image.py -h +``` +to get the list of arguments and their descriptions. + + +### Method 2: Running inference on a dataset (Advanced) + +NOTE: This section is experimental and for advanced users only. Please use Method 1 for running inference. + +#### Creating a datalist + +The inference script assumes the dataset to be in Medical Segmentation Decathlon-style `json` file format containing image-label pairs. The `create_inference_msd_datalist.py` script allows to create one for your dataset. Use the following command to create the datalist: + +```bash +python create_inference_msd_datalist.py --dataset-name spine-generic --path-data --path-out --contrast-suffix T1w +``` + +`--dataset-name` - Corresponds to name of the dataset. The datalist will be saved as `_dataset.json` +`--path-data` - Path to the BIDS dataset +`--path-out` - Path to the output folder. The datalist will be saved under `/_dataset.json` +`--contrast-suffix` - The suffix of the contrast to be used for pairing images/labels + +> **Note** +> This script is not meant to run off-the-shelf. Placeholders are provided to update the script with the .... TODO + + +#### Running inference + +Use the following command: + +```bash +python run_inference.py --path-json --chkp-path --path-out --model --crop_size <48x160x320> --device +``` + +`--path-json` - Path to the datalist created in Step 2 +`--chkp-path` - Path to the model checkpoint. This folder should contain the `best_model_loss.ckpt` +`--path-out` - Path to the output folder where the predictions will be saved +`--model` - Model to be used for inference. Currently, only `unet` and `nnunet` are supported +`--crop_size` - Crop size used for center cropping the image before running inference. Recommended to be set to a multiple of 32 +`--device` - Device to be used for inference. Currently, only `gpu` and `cpu` are supported + + + diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py new file mode 100644 index 00000000..e7ccbff7 --- /dev/null +++ b/monai/create_msd_data.py @@ -0,0 +1,226 @@ +import os +import json +from tqdm import tqdm +import numpy as np +import argparse +import joblib +from utils import FoldGenerator +from loguru import logger +from sklearn.model_selection import train_test_split + +# root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" + +parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.') + +parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the data set directory') +parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.', + default=None, type=str) +parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') +parser.add_argument("--contrast", default="t2w", type=str, help="Contrast to use for training", + choices=["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"]) +parser.add_argument('--label-type', default='soft', type=str, help="Type of labels to use for training", + choices=['hard', 'soft']) +parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") +args = parser.parse_args() + + +root = args.path_data +seed = args.seed +contrast = args.contrast +if args.label_type == 'soft': + logger.info("Using SOFT LABELS ...") + PATH_DERIVATIVES = os.path.join(root, "derivatives", "labels_softseg") + SUFFIX = "softseg" +else: + logger.info("Using HARD LABELS ...") + PATH_DERIVATIVES = os.path.join(root, "derivatives", "labels") + SUFFIX = "seg-manual" + +# Get all subjects +# the participants.tsv file might not be up-to-date, hence rely on the existing folders +# subjects_df = pd.read_csv(os.path.join(root, 'participants.tsv'), sep='\t') +# subjects = subjects_df['participant_id'].values.tolist() +subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')] +logger.info(f"Total number of subjects in the root directory: {len(subjects)}") + +if args.path_joblib is not None: + # load information from the joblib to match train and test subjects + # joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') + splits = joblib.load(args.path_joblib) + # get the subjects from the joblib file + train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) + val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) + test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']]))) + +else: + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + # sort the subjects + train_subjects = sorted(train_subjects) + val_subjects = sorted(val_subjects) + test_subjects = sorted(test_subjects) + +logger.info(f"Number of training subjects: {len(train_subjects)}") +logger.info(f"Number of validation subjects: {len(val_subjects)}") +logger.info(f"Number of testing subjects: {len(test_subjects)}") + +# keys to be defined in the dataset_0.json +params = {} +params["description"] = "spine-generic-uncropped" +params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } +params["license"] = "nk" +params["modality"] = { + "0": "MRI" + } +params["name"] = "spine-generic" +params["numTest"] = len(test_subjects) +params["numTraining"] = len(train_subjects) +params["numValidation"] = len(val_subjects) +params["seed"] = args.seed +params["reference"] = "University of Zurich" +params["tensorImageSize"] = "3D" + +train_subjects_dict = {"train": train_subjects} +val_subjects_dict = {"validation": val_subjects} +test_subjects_dict = {"test": test_subjects} +all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] + +# # define the contrasts +# contrasts_list = ['T1w', 'T2w', 'T2star', 'flip-1_mt-on_MTS', 'flip-2_mt-off_MTS', 'dwi'] + +for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): + + for name, subs_list in subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(subs_list): + + if contrast == "all": + temp_data_t1w = {} + temp_data_t2w = {} + temp_data_t2star = {} + temp_data_mton_mts = {} + temp_data_mtoff_mts = {} + temp_data_dwi = {} + + # t1w + temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz") + temp_data_t1w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T1w_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]): + temp_list.append(temp_data_t1w) + + # t2w + temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz") + temp_data_t2w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2w_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]): + temp_list.append(temp_data_t2w) + + # t2star + temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz") + temp_data_t2star["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2star_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]): + temp_list.append(temp_data_t2star) + + # mton_mts + temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz") + temp_data_mton_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-1_mt-on_MTS_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]): + temp_list.append(temp_data_mton_mts) + + # t1w_mts + temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz") + temp_data_mtoff_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-2_mt-off_MTS_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]): + temp_list.append(temp_data_mtoff_mts) + + # dwi + temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz") + temp_data_dwi["label"] = os.path.join(PATH_DERIVATIVES, subject, 'dwi', f"{subject}_rec-average_dwi_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]): + temp_list.append(temp_data_dwi) + + + elif contrast == "t1w": # t1w + temp_data_t1w = {} + temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz") + temp_data_t1w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T1w_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]): + temp_list.append(temp_data_t1w) + else: + logger.info(f"Subject {subject} does not have T1w image or label.") + + + elif contrast == "t2w": # t2w + temp_data_t2w = {} + temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz") + temp_data_t2w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2w_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]): + temp_list.append(temp_data_t2w) + else: + logger.info(f"Subject {subject} does not have T2w image or label.") + + + elif contrast == "t2star": # t2star + temp_data_t2star = {} + temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz") + temp_data_t2star["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2star_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]): + temp_list.append(temp_data_t2star) + else: + logger.info(f"Subject {subject} does not have T2star image or label.") + + + elif contrast == "mton": # mton_mts + temp_data_mton_mts = {} + temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz") + temp_data_mton_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-1_mt-on_MTS_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]): + temp_list.append(temp_data_mton_mts) + else: + logger.info(f"Subject {subject} does not have MTOn image or label.") + + elif contrast == "mtoff": # t1w_mts + temp_data_mtoff_mts = {} + temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz") + temp_data_mtoff_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-2_mt-off_MTS_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]): + temp_list.append(temp_data_mtoff_mts) + else: + logger.info(f"Subject {subject} does not have MTOff image or label.") + + elif contrast == "dwi": # dwi + temp_data_dwi = {} + temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz") + temp_data_dwi["label"] = os.path.join(PATH_DERIVATIVES, subject, 'dwi', f"{subject}_rec-average_dwi_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]): + temp_list.append(temp_data_dwi) + else: + logger.info(f"Subject {subject} does not have DWI image or label.") + + else: + raise ValueError(f"Contrast {contrast} not recognized.") + + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + +final_json = json.dumps(params, indent=4, sort_keys=True) +if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + +jsonFile = open(args.path_out + "/" + f"dataset_{contrast}_{args.label_type}_seed{seed}.json", "w") +jsonFile.write(final_json) +jsonFile.close() + + + + + + diff --git a/monai/losses.py b/monai/losses.py new file mode 100644 index 00000000..9c1ecdfa --- /dev/null +++ b/monai/losses.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy +import numpy as np + + +# TODO: also check out nnUNet's implementation of soft-dice loss (if required) +# https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/loss/dice.py + +class SoftDiceLoss(nn.Module): + ''' + soft-dice loss, useful in binary segmentation + taken from: https://github.com/CoinCheung/pytorch-loss/blob/master/soft_dice_loss.py + ''' + def __init__(self, p=1, smooth=1): + super(SoftDiceLoss, self).__init__() + self.p = p + self.smooth = smooth + + def forward(self, logits, labels): + ''' + inputs: + preds: logits - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + preds = F.relu(logits) / F.relu(logits).max() if bool(F.relu(logits).max()) else F.relu(logits) + + numer = (preds * labels).sum() + denor = (preds.pow(self.p) + labels.pow(self.p)).sum() + # loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + loss = - (2 * numer + self.smooth) / (denor + self.smooth) + return loss + + +class DiceCrossEntropyLoss(nn.Module): + def __init__(self, weight_ce=1.0, weight_dice=1.0): + super(DiceCrossEntropyLoss, self).__init__() + self.ce_weight = weight_ce + self.dice_weight = weight_dice + + self.dice_loss = SoftDiceLoss() + # self.ce_loss = RobustCrossEntropyLoss() + self.ce_loss = nn.CrossEntropyLoss() + + def forward(self, preds, labels): + ''' + inputs: + preds: logits (not probabilities!) - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + ce_loss = self.ce_loss(preds, labels) + + # dice loss will convert logits to probabilities + dice_loss = self.dice_loss(preds, labels) + + loss = self.ce_weight * ce_loss + self.dice_weight * dice_loss + return loss + + +class AdapWingLoss(nn.Module): + """ + Adaptive Wing loss used for heatmap regression + Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/losses.py#L341 + + .. seealso:: + Wang, Xinyao, Liefeng Bo, and Li Fuxin. "Adaptive wing loss for robust face alignment via heatmap regression." + Proceedings of the IEEE International Conference on Computer Vision. 2019. + + Args: + theta (float): Threshold to switch between the linear and non-linear parts of the piece-wise loss function. + alpha (float): Used to adapt the behaviour of the loss function at y=0 and y=1 and make loss smooth at 0 (background). + It needs to be slightly above 2 to maintain ideal properties. + omega (float): Multiplicative factor for non linear part of the loss. + epsilon (float): factor to avoid gradient explosion. It must not be too small + NOTE: Larger omega and smaller epsilon values will increase the influence on small errors and vice versa + """ + + def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): + self.theta = theta + self.alpha = alpha + self.omega = omega + self.epsilon = epsilon + self.reduction = reduction + super(AdapWingLoss, self).__init__() + + def forward(self, input, target): + eps = self.epsilon + batch_size = target.size()[0] + + # Adaptive Wing loss. Section 4.2 of the paper. + # Compute adaptive factor + A = self.omega * (1 / (1 + torch.pow(self.theta / eps, + self.alpha - target))) * \ + (self.alpha - target) * torch.pow(self.theta / eps, + self.alpha - target - 1) * (1 / eps) + + # Constant term to link linear and non linear part + C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) + + diff_hm = torch.abs(target - input) + AWingLoss = A * diff_hm - C + idx = diff_hm < self.theta + # NOTE: this is a memory-efficient version than the one in ivadomed losses.py + # where idx is True, compute the non-linear part of the loss, otherwise keep the linear part + # the non-linear parts ensures small errors (as given by idx) have a larger influence to refine the predictions at the boundaries + # the linear part makes the loss function behave more like the MSE loss, which has a linear influence + # (i.e. small errors where y=0 --> small influence --> small gradients) + AWingLoss = torch.where(idx, self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target)), AWingLoss) + + + # Mask for weighting the loss function. Section 4.3 of the paper. + mask = torch.zeros_like(target) + kernel = scipy.ndimage.generate_binary_structure(2, 2) + # For 3D segmentation tasks + if len(input.shape) == 5: + kernel = scipy.ndimage.generate_binary_structure(3, 2) + + for i in range(batch_size): + img_list = list() + img_list.append(np.round(target[i].cpu().numpy() * 255)) + img_merge = np.concatenate(img_list) + img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) + # NOTE: why 51? the paper thresholds the dilated GT heatmap at 0.2. So, 51/255 = 0.2 + img_dilate[img_dilate < 51] = 1 # 0*omega+1 + img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 + img_dilate = np.array(img_dilate, dtype=int) + + mask[i] = torch.tensor(img_dilate) + + AWingLoss *= mask + + sum_loss = torch.sum(AWingLoss) + if self.reduction == "sum": + return sum_loss + elif self.reduction == "mean": + all_pixel = torch.sum(mask) + return sum_loss / all_pixel diff --git a/monai/main.py b/monai/main.py new file mode 100644 index 00000000..dd87afb0 --- /dev/null +++ b/monai/main.py @@ -0,0 +1,767 @@ +import os +import argparse +from datetime import datetime +from loguru import logger + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from utils import precision_score, recall_score, dice_score, \ + PolyLRScheduler, plot_slices, check_empty_patch +from losses import SoftDiceLoss, AdapWingLoss +from transforms import train_transforms, val_transforms +from models import create_nnunet_from_plans + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNETR +from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.args = args + self.save_hyperparameters(ignore=['net', 'loss_function']) + + self.root = data_root + self.net = net + self.lr = args.learning_rate + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = (1.0, 1.0, 1.0) + self.voxel_cropping_size = self.inference_roi_size = tuple([int(i) for i in args.crop_size.split("x")]) + # self.inference_roi_size = tuple([int(i) for i in args.val_crop_size.split("x")]) + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = Compose([EnsureType()]) + self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.args.seed) + + # define training and validation transforms + transforms_train = train_transforms( + crop_size=self.voxel_cropping_size, + lbl_key='label' + ) + transforms_val = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') + + # load the dataset + dataset = os.path.join(self.root, f"dataset_{self.args.contrast}_{self.args.label_type}_seed15.json") + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + if args.debug: + train_files = train_files[:15] + val_files = val_files[:15] + test_files = test_files[:6] + + train_cache_rate = 0.25 if args.debug else 0.5 + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) + + # define test transforms + transforms_test = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.args.batch_size, shuffle=True, num_workers=16, + pin_memory=True, persistent_workers=True) + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + persistent_workers=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + if self.args.optimizer == "sgd": + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=3e-5, nesterov=True) + else: + optimizer = self.optimizer_class(self.parameters(), lr=self.lr) + # scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) + # NOTE: ivadomed using CosineAnnealingLR with T_max = 200 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.max_epochs) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + if self.args.model == "nnunet" and self.args.enable_DS: + + # calculate dice loss for each output + loss, train_soft_dice = 0.0, 0.0 + for i in range(len(output)): + # give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + # NOTE: outputs[0] is the final pred, outputs[-1] is the lowest resolution pred (at the bottleneck) + # we're downsampling the GT to the resolution of each deepsupervision feature map output + # (instead of upsampling each deepsupervision feature map output to the final resolution) + downsampled_gt = F.interpolate(labels, size=output[i].shape[-3:], mode='trilinear', align_corners=False) + # print(f"downsampled_gt.shape: {downsampled_gt.shape} \t output[i].shape: {output[i].shape}") + loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) + + # get probabilities from logits + out = F.relu(output[i]) / F.relu(output[i]).max() if bool(F.relu(output[i]).max()) else F.relu(output[i]) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice += self.soft_dice_metric(out, downsampled_gt) + + # average dice loss across the outputs + loss /= len(output) + train_soft_dice /= len(output) + + else: + # calculate training loss + loss = self.loss_function(output, labels) + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + # "train_image": inputs[0].detach().cpu().squeeze(), + # "train_gt": labels[0].detach().cpu().squeeze(), + # "train_pred": output[0].detach().cpu().squeeze() + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + } + self.log_dict(wandb_logs) + + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"], + # debug=args.debug) + # wandb.log({"training images": wandb.Image(fig)}) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + # outputs shape: (B, C, ) + + if self.args.model == "nnunet" and self.args.enable_DS: + # we only need the output with the highest resolution + outputs = outputs[0] + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + # "val_image": inputs[0].detach().cpu().squeeze(), + # "val_gt": labels[0].detach().cpu().squeeze(), + # "val_pred": post_outputs[0].detach().cpu().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + } + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation CSA loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" + # f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + + # log on to wandb + self.log_dict(wandb_logs) + + # # plot the validation images + # fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + # gt=self.val_step_outputs[0]["val_gt"], + # pred=self.val_step_outputs[0]["val_pred"],) + # wandb.log({"validation images": wandb.Image(fig)}) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + # return {"log": wandb_logs} + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + # print(f"test_input.shape: {test_input.shape} \t test_label.shape: {test_label.shape}") + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + # print(f"batch['pred'].shape: {batch['pred'].shape}") + + if self.args.model == "nnunet" and self.args.enable_DS: + # we only need the output with the highest resolution + batch["pred"] = batch["pred"][0] + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # save the prediction and label + if self.args.save_test_preds: + + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + logger.info(f"Saving subject: {subject_name}") + + # image saver class + save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) + pred_saver = SaveImage( + output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False, resample=True) + # save the prediction + pred_saver(pred) + + # label_saver = SaveImage( + # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", + # separate_folder=False, print_log=False, resample=True) + # # save the label + # label_saver(label) + + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate all metrics here + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + # 2. Precision Score + test_precision = precision_score(pred.numpy(), label.numpy()) + # 3. Recall Score + test_recall = recall_score(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + "test_precision": test_precision, + "test_recall": test_recall, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + avg_precision_test = np.stack([x["test_precision"] for x in self.test_step_outputs]).mean() + avg_recall_test = np.stack([x["test_recall"] for x in self.test_step_outputs]).mean() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + logger.info(f"Test Precision Score: {avg_precision_test}") + logger.info(f"Test Recall Score: {avg_recall_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + self.avg_test_precision = avg_precision_test + self.avg_test_recall = avg_recall_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(args): + # Setting the seed + pl.seed_everything(args.seed, workers=True) + + # ====================================================================================================== + # Define plans json taken from nnUNet_preprocessed folder + # ====================================================================================================== + nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": args.init_filters, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, + } + + # define root path for finding datalists + dataset_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/spine-generic/seed15" + + # define optimizer + if args.optimizer in ["adam", "Adam"]: + optimizer_class = torch.optim.Adam + elif args.optimizer in ["SGD", "sgd"]: + optimizer_class = torch.optim.SGD + + # define models + if args.model in ["unetr"]: + # define image size to be fed to the model + img_size = (160, 224, 96) + + # define model + net = UNETR(spatial_dims=3, + in_channels=1, out_channels=1, + img_size=img_size, + feature_size=args.feature_size, + hidden_size=args.hidden_size, + mlp_dim=args.mlp_dim, + num_heads=args.num_heads, + pos_embed="conv", + norm_name="instance", + res_block=True, + dropout_rate=0.2, + ) + img_size = f"{img_size[0]}x{img_size[1]}x{img_size[2]}" + save_exp_id = f"{args.model}_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" \ + f"_CSAdiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{img_size}" \ + + elif args.model in ["nnunet"]: + if args.enable_DS: + logger.info(f" Using nnUNet model WITH deep supervision! ") + else: + logger.info(f" Using nnUNet model WITHOUT deep supervision! ") + + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=args.enable_DS) + patch_size = "64x192x320" + save_exp_id =f"{args.model}_{args.contrast}_{args.label_type}_nf={args.init_filters}" \ + f"_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_AdapW" \ + f"_bs={args.batch_size}_{patch_size}" + # save_exp_id =f"{args.model}_{args.contrast}_{args.label_type}_nf={args.init_filters}" \ + # f"_opt={args.optimizer}_lr={args.learning_rate}" \ + # f"_DiceL" \ + # f"_bs={args.batch_size}_{patch_size}" + + if args.debug: + save_exp_id = f"DEBUG_{save_exp_id}" + + + # TODO: move this inside the for loop when using more folds + timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format + save_exp_id = f"{save_exp_id}_{timestamp}" + + # save output to a log file + logger.add(os.path.join(args.save_path, f"{save_exp_id}", "logs.txt"), rotation="10 MB", level="INFO") + + # define loss function + # loss_func = SoftDiceLoss(p=1, smooth=1.0) + # logger.info(f"Using SoftDiceLoss with p={loss_func.p}, smooth={loss_func.smooth}!") + loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon}!") + + # define callbacks + # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, + # verbose=False, mode="max") + early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + + if not args.continue_from_checkpoint: + # to save the best model on validation + save_path = os.path.join(args.save_path, f"{save_exp_id}") + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + # to save the results/model predictions + results_path = os.path.join(args.results_dir, f"{save_exp_id}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # i.e. train by loading weights from scratch + pl_model = Model(args, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation loss + logger.info(f"Saving best model to {save_path}!") + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_loss', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=False) + + # saving the best model based on soft validation dice score + checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + logger.info(f" Starting training from scratch! ") + # wandb logger + grp = f"monai_ivado_{args.model}" if args.model == "unet" else f"monai_{args.model}" + exp_logger = pl.loggers.WandbLogger( + name=save_exp_id, + save_dir=args.save_path, + group=grp, + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=args) + + # Saving training script to wandb + wandb.save("main.py") + wandb.save("transforms.py") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", # strategy="ddp", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, checkpoint_callback_dice, lr_monitor, early_stopping], + check_val_every_n_epoch=args.check_val_every_n_epochs, + max_epochs=args.max_epochs, + precision=32, # TODO: see if 16-bit precision is stable + # deterministic=True, + enable_progress_bar=args.enable_progress_bar,) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + else: + logger.info(f" Resuming training from the latest checkpoint! ") + + # check if wandb run folder is provided to resume using the same run + if args.wandb_run_folder is None: + raise ValueError("Please provide the wandb run folder to resume training using the same run on WandB!") + else: + wandb_run_folder = os.path.basename(args.wandb_run_folder) + wandb_run_id = wandb_run_folder.split("-")[-1] + + save_exp_id = args.save_path + save_path = os.path.dirname(args.save_path) + logger.info(f"save_path: {save_path}") + results_path = args.results_dir + + # i.e. train by loading weights from scratch + pl_model = Model(args, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation CSA loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_exp_id, filename='best_model_loss', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + # saving the best model based on soft validation dice score + checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + dirpath=save_exp_id, filename='best_model_dice', monitor='val_soft_dice', + save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + # wandb logger + grp = f"monai_ivado_{args.model}" if args.model == "unet" else f"monai_{args.model}" + exp_logger = pl.loggers.WandbLogger( + save_dir=save_path, + group=grp, + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=args, + id=wandb_run_id, resume='must') + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", # strategy="ddp", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, checkpoint_callback_dice, lr_monitor, early_stopping], + check_val_every_n_epoch=args.check_val_every_n_epochs, + max_epochs=args.max_epochs, + precision=32, + enable_progress_bar=args.enable_progress_bar,) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model, ckpt_path=os.path.join(save_exp_id, "last.ckpt"),) + logger.info(f" Training Done!") + + # Test! + trainer.test(pl_model) + logger.info(f"TESTING DONE!") + + # closing the current wandb instance so that a new one is created for the next fold + wandb.finish() + + # TODO: Figure out saving test metrics to a file + with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: + print('\n-------------- Test Metrics ----------------', file=f) + print(f"\nSeed Used: {args.seed}", file=f) + print(f"\ninitf={args.init_filters}_lr={args.learning_rate}_bs={args.batch_size}_{timestamp}", file=f) + print(f"\npatch_size={pl_model.voxel_cropping_size}", file=f) + + print('\n-------------- Test Hard Dice Scores ----------------', file=f) + print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) + + print('\n-------------- Test Soft Dice Scores ----------------', file=f) + print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) + + print('\n-------------- Test Precision Scores ----------------', file=f) + print("Precision --> Mean: %0.3f" % (pl_model.avg_test_precision), file=f) + + print('\n-------------- Test Recall Scores -------------------', file=f) + print("Recall --> Mean: %0.3f" % (pl_model.avg_test_recall), file=f) + + print('-------------------------------------------------------', file=f) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Script for training custom models for SCI Lesion Segmentation.') + # Arguments for model, data, and training and saving + parser.add_argument('-m', '--model', choices=['unetr', 'nnunet'], + default='unet', type=str, help='Model type to be used') + parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') + # dataset + # define args for cropping size + parser.add_argument('-crop', '--crop_size', type=str, default="64x192x320", + help='Center crop size for training/validation. Values correspond to R-L, A-P, I-S axes' + 'of the image after 1mm isotropic resampling. Default: 64x192x320') + parser.add_argument("--contrast", default="t2w", type=str, help="Contrast to use for training", + choices=["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"]) + parser.add_argument('--label-type', default='soft', type=str, help="Type of labels to use for training", + choices=['hard', 'soft']) + + # unet model + parser.add_argument('-initf', '--init_filters', default=16, type=int, help="Number of Filters in Init Layer") + + # unetr model + parser.add_argument('-fs', '--feature_size', default=16, type=int, help="Feature Size") + parser.add_argument('-hs', '--hidden_size', default=768, type=int, help='Dimensionality of hidden embeddings') + parser.add_argument('-mlpd', '--mlp_dim', default=2048, type=int, help='Dimensionality of MLP layer') + parser.add_argument('-nh', '--num_heads', default=12, type=int, help='Number of heads in Multi-head Attention') + + # optimizations + parser.add_argument('-me', '--max_epochs', default=1000, type=int, help='Number of epochs for the training process') + parser.add_argument('-bs', '--batch_size', default=2, type=int, help='Batch size of the training and validation processes') + parser.add_argument('-opt', '--optimizer', + choices=['adam', 'Adam', 'SGD', 'sgd'], + default='adam', type=str, help='Optimizer to use') + parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training the model') + parser.add_argument('-pat', '--patience', default=25, type=int, + help='number of validation steps (val_every_n_iters) to wait before early stopping') + # NOTE: patience is acutally until (patience * check_val_every_n_epochs) epochs + parser.add_argument('-epb', '--enable_progress_bar', default=False, action='store_true', + help='by default is disabled since it doesnt work in colab') + parser.add_argument('-cve', '--check_val_every_n_epochs', default=1, type=int, help='num of epochs to wait before validation') + # saving + parser.add_argument('-sp', '--save_path', + default=f"/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", + type=str, help='Path to the saved models directory') + parser.add_argument('-se', '--seed', default=42, type=int, help='Set seeds for reproducibility') + parser.add_argument('-debug', default=False, action='store_true', help='if true, results are not logged to wandb') + parser.add_argument('-stp', '--save_test_preds', default=False, action='store_true', + help='if true, test predictions are saved in `save_path`') + parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', + help='Load model from checkpoint and continue training') + parser.add_argument('-wdb-run', '--wandb-run-folder', default=None, type=str, help='Path to the wandb run folder') + # testing + parser.add_argument('-rd', '--results_dir', + default=f"/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/results", + type=str, help='Path to the model prediction results directory') + + + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/monai/models.py b/monai/models.py new file mode 100644 index 00000000..90f4f210 --- /dev/null +++ b/monai/models.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ---------------------------- Imports for nnUNet's Model ----------------------------- +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + + +# ====================================================================================================== +# Define plans json taken from nnUNet +# ====================================================================================================== +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": 32, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} + + +# ====================================================================================================== +# Utils for nnUNet's Model +# ==================================================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +# ====================================================================================================== +# Define the network based on plans json +# ==================================================================================================== +def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model + + + +if __name__ == "__main__": + + enable_deep_supervision = True + model = create_nnunet_from_plans(nnunet_plans, 1, 1, enable_deep_supervision) + input = torch.randn(1, 1, 160, 224, 96) + output = model(input) + if enable_deep_supervision: + for i in range(len(output)): + print(output[i].shape) + else: + print(output.shape) + + # print(output.shape) diff --git a/monai/requirements.txt b/monai/requirements.txt new file mode 100644 index 00000000..4f69d6a1 --- /dev/null +++ b/monai/requirements.txt @@ -0,0 +1,13 @@ +cupy-cuda117==10.6.0 +dynamic_network_architectures==0.2 +joblib==1.3.0 +loguru==0.7.0 +matplotlib==3.7.2 +monai[all]==1.2.0 +numpy==1.24.4 +pytorch_lightning==2.0.4 +scikit_learn==1.3.0 +scipy==1.11.2 +torch==2.0.0 +tqdm==4.65.0 +wandb==0.15.5 diff --git a/monai/requirements_inference.txt b/monai/requirements_inference.txt new file mode 100644 index 00000000..d99de5ae --- /dev/null +++ b/monai/requirements_inference.txt @@ -0,0 +1,8 @@ +dynamic_network_architectures==0.2 +joblib==1.3.0 +loguru==0.7.0 +monai[nibabel]==1.2.0 +scipy==1.11.2 +numpy==1.24.4 +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.0.0 diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py new file mode 100644 index 00000000..7d08f377 --- /dev/null +++ b/monai/run_inference_single_image.py @@ -0,0 +1,360 @@ +""" +Script to run inference on a MONAI-based model for contrast-agnostic soft segmentation of the spinal cord. + +Author: Naga Karthik + +""" + +import os +import argparse +import numpy as np +from loguru import logger +import torch.nn.functional as F +import torch +import torch.nn as nn +import json +from time import time + +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd, + LoadImaged, NormalizeIntensityd, EnsureChannelFirstd, + DivisiblePadd, Orientationd, ResizeWithPadOrCropd) +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + + +# NNUNET global params +INIT_FILTERS=32 +ENABLE_DS = True + +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": INIT_FILTERS, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} + + +def get_parser(): + + parser = argparse.ArgumentParser(description="Run inference on a MONAI-trained model") + + parser.add_argument("--path-img", type=str, required=True, + help="Path to the image to run inference on") + parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint folder") + parser.add_argument("--path-out", type=str, required=True, + help="Path to the output folder where to store the predictions and associated metrics") + parser.add_argument('-crop', '--crop-size', type=str, default="64x192x-1", + help='Size of the window used to crop the volume before inference (NOTE: Images are resampled to 1mm' + ' isotropic before cropping). The window is centered in the middle of the volume. Dimensions are in the' + ' order R-L, A-P, I-S. Use -1 for no cropping in a specific axis, example: “64x160x-1”.' + ' NOTE: heavy R-L cropping is recommended for positioning the SC at the center of the image.' + ' Default: 64x192x-1') + parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], + help='Device to run inference on. Default: cpu') + + return parser + + +# =========================================================================== +# Test-time Transforms +# =========================================================================== +def inference_transforms_single_image(crop_size): + return Compose([ + LoadImaged(keys=["image"], image_only=False), + EnsureChannelFirstd(keys=["image"]), + Orientationd(keys=["image"], axcodes="RPI"), + Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)), + ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,), + DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + + +# =========================================================================== +# Model utils +# =========================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +# ============================================================================ +# Define the network based on nnunet_plans dict +# ============================================================================ +def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model + + +# =========================================================================== +# Prepare temporary dataset for inference +# =========================================================================== +def prepare_data(path_image, path_out, crop_size=(64, 160, 320)): + + # create a temporary datalist containing the image + # boiler plate keys to be defined in the MSD-style datalist + params = {} + params["description"] = "my-awesome-SC-image" + params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } + params["modality"] = { + "0": "MRI" + } + params["tensorImageSize"] = "3D" + params["test"] = [ + { + "image": path_image + } + ] + + final_json = json.dumps(params, indent=4, sort_keys=True) + jsonFile = open(path_out + "/" + f"temp_msd_datalist.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + dataset = os.path.join(path_out, f"temp_msd_datalist.json") + test_files = load_decathlon_datalist(dataset, True, "test") + + # define test transforms + transforms_test = inference_transforms_single_image(crop_size=crop_size) + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=transforms_test, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.75, num_workers=8) + + return test_ds, test_post_pred + + +# =========================================================================== +# Inference method +# =========================================================================== +def main(args): + + # define device + if args.device == "gpu" and not torch.cuda.is_available(): + logger.warning("GPU not available, using CPU instead") + DEVICE = torch.device("cpu") + else: + DEVICE = torch.device("cuda" if torch.cuda.is_available() and args.device == "gpu" else "cpu") + + # define root path for finding datalists + path_image = args.path_img + results_path = args.path_out + chkp_path = os.path.join(args.chkp_path, "best_model_loss.ckpt") + + # save terminal outputs to a file + logger.add(os.path.join(results_path, "logs.txt"), rotation="10 MB", level="INFO") + + logger.info(f"Saving results to: {results_path}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # define inference patch size and center crop size + crop_size = tuple([int(i) for i in args.crop_size.split("x")]) + inference_roi_size = (64, 192, 320) + + # define the dataset and dataloader + test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size) + test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=ENABLE_DS) + + # define list to collect the test metrics + test_step_outputs = [] + test_summary = {} + + # iterate over the dataset and compute metrics + with torch.no_grad(): + for batch in test_loader: + + # compute time for inference per subject + start_time = time() + + # get the test input + test_input = batch["image"].to(DEVICE) + + # this loop only takes about 0.2s on average on a CPU + checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + + # load the trained model weights + net.load_state_dict(checkpoint) + net.to(DEVICE) + net.eval() + + # run inference + batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # take only the highest resolution prediction + batch["pred"] = batch["pred"][0] + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + pred = post_test_out[0]['pred'].cpu() + + # clip the prediction between 0.5 and 1 + # turns out this sets the background to 0.5 and the SC to 1 (which is not correct) + # details: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/71 + pred = torch.clamp(pred, 0.5, 1) + # set background values to 0 + pred[pred <= 0.5] = 0 + + # get subject name + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + logger.info(f"Saving subject: {subject_name}") + + # this takes about 0.25s on average on a CPU + # image saver class + pred_saver = SaveImage( + output_dir=results_path, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + end_time = time() + metrics_dict = { + "subject_name_and_contrast": subject_name, + "inference_time_in_sec": round((end_time - start_time), 2), + } + test_step_outputs.append(metrics_dict) + + # save the test summary + test_summary["metrics_per_subject"] = test_step_outputs + + # compute the average inference time + avg_inference_time = np.stack([x["inference_time_in_sec"] for x in test_step_outputs]).mean() + + # store the average metrics in a dict + avg_metrics = { + "avg_inference_time_in_sec": round(avg_inference_time, 2), + } + test_summary["metrics_avg_across_cohort"] = avg_metrics + + logger.info("========================================================") + logger.info(f" Inference Time per Subject: {avg_inference_time:.2f}s") + logger.info("========================================================") + + + # dump the test summary to a json file + with open(os.path.join(results_path, "test_summary.json"), "w") as f: + json.dump(test_summary, f, indent=4, sort_keys=True) + + # free up memory + test_step_outputs.clear() + test_summary.clear() + os.remove(os.path.join(results_path, "temp_msd_datalist.json")) + + +if __name__ == "__main__": + + args = get_parser().parse_args() + main(args) \ No newline at end of file diff --git a/monai/transforms.py b/monai/transforms.py new file mode 100644 index 00000000..76ef9ad1 --- /dev/null +++ b/monai/transforms.py @@ -0,0 +1,61 @@ + +import numpy as np +from monai.transforms import (Compose, CropForegroundd, LoadImaged, RandFlipd, + Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, + DivisiblePadd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, + RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, + ResizeWithPadOrCropd) + +# TODO: Add RandSimulateLowResolutiond transform when monai 1.3.0 is released. +# Right now, in v1.2.0, it is not implemented yet (I had to manually add in the source code) + +def train_transforms(crop_size, lbl_key="label"): + + monai_transforms = [ + # pre-processing + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + # NOTE: spine interpolation with order=2 is spline, order=1 is linear + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + # data-augmentation + RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.9, + rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians + scale_range=(-0.2, 0.2), + translate_range=(-0.1, 0.1)), + Rand3DElasticd(keys=["image", lbl_key], prob=0.5, + sigma_range=(3.5, 5.5), + magnitude_range=(25., 35.)), + # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma + RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), + RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform + RandFlipd(keys=["image", lbl_key], prob=0.3,), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ] + + return Compose(monai_transforms) + +def inference_transforms(crop_size, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key], image_only=False), + EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + DivisiblePadd(keys=["image", lbl_key], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + +def val_transforms(crop_size, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key], image_only=False), + EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) diff --git a/monai/utils.py b/monai/utils.py new file mode 100644 index 00000000..9b670edc --- /dev/null +++ b/monai/utils.py @@ -0,0 +1,271 @@ +import numpy as np +import matplotlib.pyplot as plt +from torch.optim.lr_scheduler import _LRScheduler +import torch + + +# Check if any label image patch is empty in the batch +def check_empty_patch(labels): + for i, label in enumerate(labels): + if torch.sum(label) == 0.0: + # print(f"Empty label patch found at index {i}. Skipping training step ...") + return None + return labels # If no empty patch is found, return the labels + + +class FoldGenerator: + """ + Adapted from https://github.com/MIC-DKFZ/medicaldetectiontoolkit/blob/master/utils/dataloader_utils.py#L59 + Generates splits of indices for a given length of a dataset to perform n-fold cross-validation. + splits each fold into 3 subsets for training, validation and testing. + This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a + statistically reliable amount of patients, despite limited size of a dataset. + If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader. + This creates straight-forward train-val splits. + :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix. + """ + def __init__(self, seed, n_splits, len_data): + """ + :param seed: Random seed for splits. + :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation + :param len_data: number of elements in the dataset. + """ + self.tr_ix = [] + self.val_ix = [] + self.te_ix = [] + self.slicer = None + self.missing = 0 + self.fold = 0 + self.len_data = len_data + self.n_splits = n_splits + self.myseed = seed + self.boost_val = 0 + + def init_indices(self): + + t = list(np.arange(self.len_cv_names)) + # round up to next splittable data amount. + if self.n_splits == 5: + split_length = int(np.ceil(len(t) / float(self.n_splits)) // 1.5) + else: + split_length = int(np.ceil(len(t) / float(self.n_splits))) + self.slicer = split_length + print(self.slicer) + self.mod = len(t) % self.n_splits + if self.mod > 0: + # missing is the number of folds, in which the new splits are reduced to account for missing data. + self.missing = self.n_splits - self.mod + + # for 100 subjects, performs a 60-20-20 split with n_splits + self.te_ix = t[:self.slicer] + self.tr_ix = t[self.slicer:] + self.val_ix = self.tr_ix[:self.slicer] + self.tr_ix = self.tr_ix[self.slicer:] + + def new_fold(self): + + slicer = self.slicer + if self.fold < self.missing: + slicer = self.slicer - 1 + + temp = self.te_ix + + # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits. + # account for by reducing last fold split by 1. + if self.fold == self.n_splits-2 and self.mod ==1: + temp += self.val_ix[-1:] + self.val_ix = self.val_ix[:-1] + + self.te_ix = self.val_ix + self.val_ix = self.tr_ix[:slicer] + self.tr_ix = self.tr_ix[slicer:] + temp + + + def get_fold_names(self): + names_list = [] + rgen = np.random.RandomState(self.myseed) + cv_names = np.arange(self.len_data) + + rgen.shuffle(cv_names) + self.len_cv_names = len(cv_names) + self.init_indices() + + for split in range(self.n_splits): + train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix] + names_list.append([train_names, val_names, test_names, self.fold]) + self.new_fold() + self.fold += 1 + + return names_list + + +def numeric_score(prediction, groundtruth): + """Computation of statistical numerical scores: + + * FP = Soft False Positives + * FN = Soft False Negatives + * TP = Soft True Positives + * TN = Soft True Negatives + + Robust to hard or soft input masks. For example:: + prediction=np.asarray([0, 0.5, 1]) + groundtruth=np.asarray([0, 1, 1]) + Leads to FP = 1.5 + + Note: It assumes input values are between 0 and 1. + + Args: + prediction (ndarray): Binary prediction. + groundtruth (ndarray): Binary groundtruth. + + Returns: + float, float, float, float: FP, FN, TP, TN + """ + FP = float(np.sum(prediction * (1.0 - groundtruth))) + FN = float(np.sum((1.0 - prediction) * groundtruth)) + TP = float(np.sum(prediction * groundtruth)) + TN = float(np.sum((1.0 - prediction) * (1.0 - groundtruth))) + return FP, FN, TP, TN + + +def precision_score(prediction, groundtruth, err_value=0.0): + """Positive predictive value (PPV). + + Precision equals the number of true positive voxels divided by the sum of true and false positive voxels. + True and false positives are computed on soft masks, see ``"numeric_score"``. + Taken from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/metrics.py + + Args: + prediction (ndarray): First array. + groundtruth (ndarray): Second array. + err_value (float): Value returned in case of error. + + Returns: + float: Precision score. + """ + FP, FN, TP, TN = numeric_score(prediction, groundtruth) + if (TP + FP) <= 0.0: + return err_value + + precision = np.divide(TP, TP + FP) + return precision + + +def recall_score(prediction, groundtruth, err_value=0.0): + """True positive rate (TPR). + + Recall equals the number of true positive voxels divided by the sum of true positive and false negative voxels. + True positive and false negative values are computed on soft masks, see ``"numeric_score"``. + Taken from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/metrics.py + + Args: + prediction (ndarray): First array. + groundtruth (ndarray): Second array. + err_value (float): Value returned in case of error. + + Returns: + float: Recall score. + """ + FP, FN, TP, TN = numeric_score(prediction, groundtruth) + if (TP + FN) <= 0.0: + return err_value + TPR = np.divide(TP, TP + FN) + return TPR + + +def dice_score(prediction, groundtruth): + smooth = 1. + numer = (prediction * groundtruth).sum() + denor = (prediction + groundtruth).sum() + # loss = (2 * numer + self.smooth) / (denor + self.smooth) + dice = (2 * numer + smooth) / (denor + smooth) + return dice + + +def plot_slices(image, gt, pred, debug=False): + """ + Plot the image, ground truth and prediction of the mid-sagittal axial slice + The orientaion is assumed to RPI + """ + + # bring everything to numpy + image = image.numpy() + gt = gt.numpy() + pred = pred.numpy() + + if not debug: + mid_sagittal = image.shape[2]//2 + # plot X slices before and after the mid-sagittal slice in a grid + fig, axs = plt.subplots(3, 6, figsize=(10, 6)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + for i in range(6): + axs[0, i].imshow(image[:, :, mid_sagittal-3+i].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[:, :, mid_sagittal-3+i].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[:, :, mid_sagittal-3+i].T); axs[2, i].axis('off') + + # fig, axs = plt.subplots(1, 3, figsize=(10, 8)) + # fig.suptitle('Original Image --> Ground Truth --> Prediction') + # slice = image.shape[2]//2 + + # axs[0].imshow(image[:, :, slice].T, cmap='gray'); axs[0].axis('off') + # axs[1].imshow(gt[:, :, slice].T); axs[1].axis('off') + # axs[2].imshow(pred[:, :, slice].T); axs[2].axis('off') + + else: # plot multiple slices + mid_sagittal = image.shape[2]//2 + # plot X slices before and after the mid-sagittal slice in a grid + fig, axs = plt.subplots(3, 14, figsize=(20, 8)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + for i in range(14): + axs[0, i].imshow(image[:, :, mid_sagittal-7+i].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[:, :, mid_sagittal-7+i].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[:, :, mid_sagittal-7+i].T); axs[2, i].axis('off') + + plt.tight_layout() + fig.show() + return fig + + +def compute_average_csa(patch, spacing): + num_slices = patch.shape[2] + areas = torch.empty(num_slices) + for slice_idx in range(num_slices): + slice_mask = patch[:, :, slice_idx] + area = torch.count_nonzero(slice_mask) * (spacing[0] * spacing[1]) + areas[slice_idx] = area + + return torch.mean(areas) + + +class PolyLRScheduler(_LRScheduler): + """ + Polynomial learning rate scheduler. Taken from: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/lr_scheduler/polylr.py + + """ + + def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): + self.optimizer = optimizer + self.initial_lr = initial_lr + self.max_steps = max_steps + self.exponent = exponent + self.ctr = 0 + super().__init__(optimizer, current_step if current_step is not None else -1, False) + + def step(self, current_step=None): + if current_step is None or current_step == -1: + current_step = self.ctr + self.ctr += 1 + + new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent + for param_group in self.optimizer.param_groups: + param_group['lr'] = new_lr + + +if __name__ == "__main__": + + seed = 54 + num_cv_folds = 10 + names_list = FoldGenerator(seed, num_cv_folds, 100).get_fold_names() + tr_ix, val_tx, te_ix, fold = names_list[0] + print(len(tr_ix), len(val_tx), len(te_ix)) \ No newline at end of file