From 2fa882139980bfe8ee91d5ea08e91d439dc31f5d Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 9 Jul 2023 20:51:01 -0400 Subject: [PATCH 001/106] add dataset conversion script for monai training --- monai/create_msd_data.py | 251 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 monai/create_msd_data.py diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py new file mode 100644 index 00000000..7991f0da --- /dev/null +++ b/monai/create_msd_data.py @@ -0,0 +1,251 @@ +import os +import json +from tqdm import tqdm +import numpy as np +import argparse +# import nibabel as nib +from utils import FoldGenerator +from loguru import logger +from sklearn.model_selection import train_test_split + +# TODO: split the data using ivadomed joblib file + +root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" + +parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.') + +parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") +parser.add_argument('-ncvf', '--num-cv-folds', default=5, type=int, + help="[1-k] To create a k-fold dataset for cross validation, 0 for single file with all subjects") +parser.add_argument('-pd', '--path-data', default=root, type=str, help='Path to the data set directory') +parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') +args = parser.parse_args() + + +root = args.path_data +seed = args.seed +num_cv_folds = args.num_cv_folds # for 100 subjects, performs a 60-20-20 split with num_cv_folds + +# Get all subjects +# the participants.tsv file might not be up-to-date, hence rely on the existing folders +# subjects_df = pd.read_csv(os.path.join(root, 'participants.tsv'), sep='\t') +# subjects = subjects_df['participant_id'].values.tolist() +subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')] +logger.info(f"Total number of subjects in the root directory: {len(subjects)}") + +if args.num_cv_folds != 0: + # create k-fold CV datasets as usual + + # returns a nested list of length (num_cv_folds), each element (again, a list) consisting of + # train, val, test indices and the fold number + names_list = FoldGenerator(seed, num_cv_folds, len_data=len(subjects)).get_fold_names() + + for fold in range(num_cv_folds): + + train_ix, val_ix, test_ix, fold_num = names_list[fold] + training_subjects = [subjects[tr_ix] for tr_ix in train_ix] + validation_subjects = [subjects[v_ix] for v_ix in val_ix] + test_subjects = [subjects[te_ix] for te_ix in test_ix] + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "sci-zurich naga" + params["labels"] = { + "0": "background", + "1": "sc-lesion" + } + params["license"] = "nk" + params["modality"] = { + "0": "MRI" + } + params["name"] = "sci-zurich" + params["numTest"] = len(test_subjects) + params["numTraining"] = len(training_subjects) + len(validation_subjects) + params["reference"] = "University of Zurich" + params["tensorImageSize"] = "3D" + + + train_val_subjects_dict = { + "training": training_subjects, + "validation": validation_subjects, + } + test_subjects_dict = {"test": test_subjects} + + # run loop for training and validation subjects + temp_shapes_list = [] + for name, subs_list in train_val_subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(tqdm(subs_list, desc='Loading Volumes')): + + # Another for loop for going through sessions + temp_subject_path = os.path.join(root, subject) + num_sessions_per_subject = sum(os.path.isdir(os.path.join(temp_subject_path, pth)) for pth in os.listdir(temp_subject_path)) + + for ses_idx in range(1, num_sessions_per_subject+1): + temp_data = {} + # Get paths with session numbers + session = 'ses-0' + str(ses_idx) + subject_images_path = os.path.join(root, subject, session, 'anat') + subject_labels_path = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat') + + subject_image_file = os.path.join(subject_images_path, '%s_%s_acq-sag_T2w.nii.gz' % (subject, session)) + subject_label_file = os.path.join(subject_labels_path, '%s_%s_acq-sag_T2w_lesion-manual.nii.gz' % (subject, session)) + + # get shapes of each subject to calculate median later + # temp_shapes_list.append(np.shape(nib.load(subject_image_file).get_fdata())) + + # # load GT mask + # gt_label = nib.load(subject_label_file).get_fdata() + # bbox_coords = get_bounding_boxes(mask=gt_label) + + # store in a temp dictionary + temp_data["image"] = subject_image_file.replace(root+"/", '') # .strip(root) + temp_data["label"] = subject_label_file.replace(root+"/", '') # .strip(root) + # temp_data["box"] = bbox_coords + + temp_list.append(temp_data) + + params[name] = temp_list + + # print(temp_shapes_list) + # calculate the median shapes along each axis + params["train_val_median_shape"] = np.median(temp_shapes_list, axis=0).tolist() + + # run separate loop for testing + for name, subs_list in test_subjects_dict.items(): + temp_list = [] + for subject_no, subject in enumerate(tqdm(subs_list, desc='Loading Volumes')): + + # Another for loop for going through sessions + temp_subject_path = os.path.join(root, subject) + num_sessions_per_subject = sum(os.path.isdir(os.path.join(temp_subject_path, pth)) for pth in os.listdir(temp_subject_path)) + + for ses_idx in range(1, num_sessions_per_subject+1): + temp_data = {} + # Get paths with session numbers + session = 'ses-0' + str(ses_idx) + subject_images_path = os.path.join(root, subject, session, 'anat') + subject_labels_path = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat') + + subject_image_file = os.path.join(subject_images_path, '%s_%s_acq-sag_T2w.nii.gz' % (subject, session)) + subject_label_file = os.path.join(subject_labels_path, '%s_%s_acq-sag_T2w_lesion-manual.nii.gz' % (subject, session)) + + # # load GT mask + # gt_label = nib.load(subject_label_file).get_fdata() + # bbox_coords = get_bounding_boxes(mask=gt_label) + + temp_data["image"] = subject_image_file.replace(root+"/", '') + temp_data["label"] = subject_label_file.replace(root+"/", '') + # temp_data["box"] = bbox_coords + + temp_list.append(temp_data) + + params[name] = temp_list + + final_json = json.dumps(params, indent=4, sort_keys=True) + jsonFile = open(root + "/" + f"dataset_fold-{fold_num}.json", "w") + jsonFile.write(final_json) + jsonFile.close() +else: + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + + logger.info(f"Number of training subjects: {len(train_subjects)}") + logger.info(f"Number of validation subjects: {len(val_subjects)}") + logger.info(f"Number of testing subjects: {len(test_subjects)}") + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "spine-generic-uncropped" + params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } + params["license"] = "nk" + params["modality"] = { + "0": "MRI" + } + params["name"] = "spine-generic" + params["numTest"] = len(test_subjects) + params["numTraining"] = len(train_subjects) + params["numValidation"] = len(val_subjects) + params["seed"] = args.seed + params["reference"] = "University of Zurich" + params["tensorImageSize"] = "3D" + + train_subjects_dict = {"train": train_subjects} + val_subjects_dict = {"validation": val_subjects} + test_subjects_dict = {"test": test_subjects} + all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] + + # define the contrasts + contrasts_list = ['T1w', 'T2w', 'T2star', 'flip-1_mt-on_MTS', 'flip-2_mt-off_MTS', 'dwi'] + + for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): + + for name, subs_list in subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(subs_list): + + temp_data_t1w = {} + temp_data_t2w = {} + temp_data_t2star = {} + temp_data_mton_mts = {} + temp_data_mtoff_mts = {} + temp_data_dwi = {} + + # t1w + temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz") + temp_data_t1w["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T1w_softseg.nii.gz") + if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]): + temp_list.append(temp_data_t1w) + + # t2w + temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz") + temp_data_t2w["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T2w_softseg.nii.gz") + if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]): + temp_list.append(temp_data_t2w) + + # t2star + temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz") + temp_data_t2star["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T2star_softseg.nii.gz") + if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]): + temp_list.append(temp_data_t2star) + + # mton_mts + temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz") + temp_data_mton_mts["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_flip-1_mt-on_MTS_softseg.nii.gz") + if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]): + temp_list.append(temp_data_mton_mts) + + # t1w_mts + temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz") + temp_data_mtoff_mts["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_flip-2_mt-off_MTS_softseg.nii.gz") + if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]): + temp_list.append(temp_data_mtoff_mts) + + # dwi + temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz") + temp_data_dwi["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'dwi', f"{subject}_rec-average_dwi_softseg.nii.gz") + if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]): + temp_list.append(temp_data_dwi) + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + jsonFile = open(args.path_out + "/" + f"dataset.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + + + + + From 93957322821f4eb22b869b2b402f55897bcf2956 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 9 Jul 2023 20:51:38 -0400 Subject: [PATCH 002/106] add file for loss functions --- monai/losses.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 monai/losses.py diff --git a/monai/losses.py b/monai/losses.py new file mode 100644 index 00000000..ff151f48 --- /dev/null +++ b/monai/losses.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + + +class SoftDiceLoss(nn.Module): + ''' + soft-dice loss, useful in binary segmentation + taken from: https://github.com/CoinCheung/pytorch-loss/blob/master/soft_dice_loss.py + ''' + def __init__(self, p=1, smooth=1): + super(SoftDiceLoss, self).__init__() + self.p = p + self.smooth = smooth + + def forward(self, preds, labels): + ''' + inputs: + preds: normalized probabilities (not logits) - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + # probs = torch.sigmoid(logits) + numer = (preds * labels).sum() + denor = (preds.pow(self.p) + labels.pow(self.p)).sum() + loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + return loss \ No newline at end of file From f109634af5867f10d5078bc5ab0090f6aa5e02ae Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 9 Jul 2023 20:52:00 -0400 Subject: [PATCH 003/106] add transforms --- monai/transforms.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 monai/transforms.py diff --git a/monai/transforms.py b/monai/transforms.py new file mode 100644 index 00000000..b2da4e14 --- /dev/null +++ b/monai/transforms.py @@ -0,0 +1,59 @@ + +from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, + RandCropByPosNegLabeld, Spacingd, RandRotate90d, ToTensord, NormalizeIntensityd, + EnsureType, RandWeightedCropd, HistogramNormalized, EnsureTyped, Invertd, SaveImaged, + EnsureChannelFirstd, CenterSpatialCropd, RandSpatialCropSamplesd, Orientationd) + +# median image size in voxels - taken from nnUNet +# median_size = (123, 255, 214) +# so pad with this size + +def train_transforms(crop_size, num_samples_pv, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + # Orientationd(keys=["image", lbl_key], axcodes="RPI"), + # TODO: if the source_key is set to "label", then the cropping is only around the label mask + CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), + # RandSpatialCropSamplesd(keys=["image", lbl_key], roi_size=crop_size, num_samples=num_samples_pv, random_center=True, random_size=False), + RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", + spatial_size=crop_size, pos=1, neg=1, num_samples=num_samples_pv, + # if num_samples=4, then 4 samples/image are randomly generated + image_key="image", image_threshold=0.), + RandFlipd(keys=["image", lbl_key], spatial_axis=[0], prob=0.50,), + RandFlipd(keys=["image", lbl_key], spatial_axis=[1], prob=0.50,), + RandFlipd(keys=["image", lbl_key],spatial_axis=[2],prob=0.50,), + RandRotate90d(keys=["image", lbl_key], prob=0.10, max_k=3,), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error + # HistogramNormalized(keys=["image"], mask=None), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), + # ToTensord(keys=["image", lbl_key]), + ]) + +def val_transforms(lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), + CropForegroundd(keys=["image", lbl_key], source_key="image"), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), + # HistogramNormalized(keys=["image"], mask=None), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), + # ToTensord(keys=["image", lbl_key]), + ]) + +def test_transforms(lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), + CropForegroundd(keys=["image", lbl_key], source_key="image"), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + # AddChanneld(keys=["image", lbl_key]), + # HistogramNormalized(keys=["image"], mask=None), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), + # ToTensord(keys=["image", lbl_key]), + ]) \ No newline at end of file From d11a99cb21c2f3c8ece15831218eff2ce437c567 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 9 Jul 2023 20:53:01 -0400 Subject: [PATCH 004/106] add util functions --- monai/utils.py | 207 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 monai/utils.py diff --git a/monai/utils.py b/monai/utils.py new file mode 100644 index 00000000..a2dd10fc --- /dev/null +++ b/monai/utils.py @@ -0,0 +1,207 @@ +import numpy as np +import matplotlib.pyplot as plt + + +class FoldGenerator: + """ + Adapted from https://github.com/MIC-DKFZ/medicaldetectiontoolkit/blob/master/utils/dataloader_utils.py#L59 + Generates splits of indices for a given length of a dataset to perform n-fold cross-validation. + splits each fold into 3 subsets for training, validation and testing. + This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a + statistically reliable amount of patients, despite limited size of a dataset. + If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader. + This creates straight-forward train-val splits. + :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix. + """ + def __init__(self, seed, n_splits, len_data): + """ + :param seed: Random seed for splits. + :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation + :param len_data: number of elements in the dataset. + """ + self.tr_ix = [] + self.val_ix = [] + self.te_ix = [] + self.slicer = None + self.missing = 0 + self.fold = 0 + self.len_data = len_data + self.n_splits = n_splits + self.myseed = seed + self.boost_val = 0 + + def init_indices(self): + + t = list(np.arange(self.len_cv_names)) + # round up to next splittable data amount. + if self.n_splits == 5: + split_length = int(np.ceil(len(t) / float(self.n_splits)) // 1.5) + else: + split_length = int(np.ceil(len(t) / float(self.n_splits))) + self.slicer = split_length + print(self.slicer) + self.mod = len(t) % self.n_splits + if self.mod > 0: + # missing is the number of folds, in which the new splits are reduced to account for missing data. + self.missing = self.n_splits - self.mod + + # for 100 subjects, performs a 60-20-20 split with n_splits + self.te_ix = t[:self.slicer] + self.tr_ix = t[self.slicer:] + self.val_ix = self.tr_ix[:self.slicer] + self.tr_ix = self.tr_ix[self.slicer:] + + def new_fold(self): + + slicer = self.slicer + if self.fold < self.missing: + slicer = self.slicer - 1 + + temp = self.te_ix + + # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits. + # account for by reducing last fold split by 1. + if self.fold == self.n_splits-2 and self.mod ==1: + temp += self.val_ix[-1:] + self.val_ix = self.val_ix[:-1] + + self.te_ix = self.val_ix + self.val_ix = self.tr_ix[:slicer] + self.tr_ix = self.tr_ix[slicer:] + temp + + + def get_fold_names(self): + names_list = [] + rgen = np.random.RandomState(self.myseed) + cv_names = np.arange(self.len_data) + + rgen.shuffle(cv_names) + self.len_cv_names = len(cv_names) + self.init_indices() + + for split in range(self.n_splits): + train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix] + names_list.append([train_names, val_names, test_names, self.fold]) + self.new_fold() + self.fold += 1 + + return names_list + + +def numeric_score(prediction, groundtruth): + """Computation of statistical numerical scores: + + * FP = Soft False Positives + * FN = Soft False Negatives + * TP = Soft True Positives + * TN = Soft True Negatives + + Robust to hard or soft input masks. For example:: + prediction=np.asarray([0, 0.5, 1]) + groundtruth=np.asarray([0, 1, 1]) + Leads to FP = 1.5 + + Note: It assumes input values are between 0 and 1. + + Args: + prediction (ndarray): Binary prediction. + groundtruth (ndarray): Binary groundtruth. + + Returns: + float, float, float, float: FP, FN, TP, TN + """ + FP = float(np.sum(prediction * (1.0 - groundtruth))) + FN = float(np.sum((1.0 - prediction) * groundtruth)) + TP = float(np.sum(prediction * groundtruth)) + TN = float(np.sum((1.0 - prediction) * (1.0 - groundtruth))) + return FP, FN, TP, TN + + +def precision_score(prediction, groundtruth, err_value=0.0): + """Positive predictive value (PPV). + + Precision equals the number of true positive voxels divided by the sum of true and false positive voxels. + True and false positives are computed on soft masks, see ``"numeric_score"``. + Taken from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/metrics.py + + Args: + prediction (ndarray): First array. + groundtruth (ndarray): Second array. + err_value (float): Value returned in case of error. + + Returns: + float: Precision score. + """ + FP, FN, TP, TN = numeric_score(prediction, groundtruth) + if (TP + FP) <= 0.0: + return err_value + + precision = np.divide(TP, TP + FP) + return precision + + +def recall_score(prediction, groundtruth, err_value=0.0): + """True positive rate (TPR). + + Recall equals the number of true positive voxels divided by the sum of true positive and false negative voxels. + True positive and false negative values are computed on soft masks, see ``"numeric_score"``. + Taken from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/metrics.py + + Args: + prediction (ndarray): First array. + groundtruth (ndarray): Second array. + err_value (float): Value returned in case of error. + + Returns: + float: Recall score. + """ + FP, FN, TP, TN = numeric_score(prediction, groundtruth) + if (TP + FN) <= 0.0: + return err_value + TPR = np.divide(TP, TP + FN) + return TPR + + +def dice_score(prediction, groundtruth): + smooth = 1. + numer = (prediction * groundtruth).sum() + denor = (prediction + groundtruth).sum() + # loss = (2 * numer + self.smooth) / (denor + self.smooth) + dice = (2 * numer + smooth) / (denor + smooth) + return dice + + +def plot_slices(image, gt, pred): + """ + Plot the image, ground truth and prediction of the mid-sagittal axial slice + The orientaion is assumed to RPI + """ + + # bring everything to numpy + image = image.detach().cpu().numpy() + gt = gt.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + + fig, axs = plt.subplots(1, 3, figsize=(10, 8)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + slice = image.shape[2]//2 + + axs[0].imshow(image[:, :, slice].T, cmap='gray'); axs[0].axis('off') + axs[1].imshow(gt[:, :, slice].T); axs[1].axis('off') + axs[2].imshow(pred[:, :, slice].T); axs[2].axis('off') + + plt.tight_layout() + fig.show() + return fig + + + + + +if __name__ == "__main__": + + seed = 54 + num_cv_folds = 10 + names_list = FoldGenerator(seed, num_cv_folds, 100).get_fold_names() + tr_ix, val_tx, te_ix, fold = names_list[0] + print(len(tr_ix), len(val_tx), len(te_ix)) \ No newline at end of file From 659012555949af3d99158465be571846ce685772 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 9 Jul 2023 20:53:34 -0400 Subject: [PATCH 005/106] add working monai training script --- monai/main.py | 562 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 562 insertions(+) create mode 100644 monai/main.py diff --git a/monai/main.py b/monai/main.py new file mode 100644 index 00000000..12271966 --- /dev/null +++ b/monai/main.py @@ -0,0 +1,562 @@ +import os +import argparse +from datetime import datetime +from loguru import logger + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F + +from utils import precision_score, recall_score, dice_score, plot_slices +from losses import SoftDiceLoss +from transforms import train_transforms, val_transforms, test_transforms + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNet, DynUNet, BasicUNet, UNETR +from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_class, + exp_id=None, results_path=None): + super().__init__() + self.args = args + self.save_hyperparameters(ignore=['net']) + + # if self.args.unet_depth == 3: + # from models import ModifiedUNet3DEncoder, ModifiedUNet3DDecoder # this is 3-level UNet + # logger.info("Using UNet with Depth = 3! ") + # else: + # from models_original import ModifiedUNet3DEncoder, ModifiedUNet3DDecoder + # logger.info("Using UNet with Depth = 4! ") + + self.root = data_root + self.fold_num = fold_num + self.net = net + # self.load_pretrained = load_pretrained + self.lr = args.learning_rate + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.metric_values = [] + self.epoch_losses, self.epoch_soft_dice_train, self.epoch_hard_dice_train = [], [], [] + + # define cropping and padding dimensions + # NOTE: taken from nnUNet_plans.json + self.voxel_cropping_size = (80, 192, 160) + self.inference_roi_size = (80, 192, 160) + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = Compose([EnsureType()]) + self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.training_step_outputs = {} + self.val_step_outputs = {} + self.test_step_outputs = {} + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + # x, context_features = self.encoder(x) + # preds = self.decoder(x, context_features) + + out = self.net(x) + # NOTE: MONAI's models only output the logits, not the output after the final activation function + # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # as the final block applied to the input, which is just a convolutional layer with no activation function + # Hence, we are used Normalized ReLU to normalize the logits to the final output + normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return normalized_out + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.args.seed) + + # define training and validation transforms + transforms_train = train_transforms( + crop_size=self.voxel_cropping_size, + num_samples_pv=self.args.num_samples_per_volume, + lbl_key='label' + ) + transforms_val = val_transforms(lbl_key='label') + + # load the dataset + dataset = os.path.join(self.root, f"dataset.json") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + if args.debug: + train_files = train_files[:2] + val_files = val_files[:2] + test_files = test_files[:6] + + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=0.1, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.1, num_workers=4) + + # define test transforms + transforms_test = test_transforms(lbl_key='label') + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + # NOTE: if num_samples=4 in RandCropByPosNegLabeld and batch_size=2, then 2 x 4 images are generated for network training + return DataLoader(self.train_ds, batch_size=self.args.batch_size, shuffle=True, num_workers=4, + pin_memory=True,) # collate_fn=pad_list_data_collate) + # list_data_collate is only useful when each input in the batch has different shape + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, weight_decay=1e-5) + # TODO: look at poly learning rate scheduler + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + output = self.forward(inputs) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) + + self.training_step_outputs["loss"] = loss + self.training_step_outputs["train_soft_dice"] = train_soft_dice + self.training_step_outputs["train_hard_dice"] = train_hard_dice + self.training_step_outputs["train_number"] = len(inputs) + + # get input image and label for visualization + if batch_idx == 0: + self.training_step_outputs["train_image"] = inputs[0].squeeze() + self.training_step_outputs["train_gt"] = labels[0].squeeze() + self.training_step_outputs["train_pred"] = output[0].squeeze() + + return self.training_step_outputs + + def on_train_epoch_end(self): + avg_loss = torch.stack([self.training_step_outputs["loss"]]).mean() + avg_soft_dice_train = torch.stack([self.training_step_outputs["train_soft_dice"]]).mean() + avg_hard_dice_train = torch.stack([self.training_step_outputs["train_hard_dice"]]).mean() + + self.log('train_soft_dice', avg_soft_dice_train, on_step=False, on_epoch=True) + + # plot the training images + fig = plot_slices(image=self.training_step_outputs["train_image"], + gt=self.training_step_outputs["train_gt"], + pred=self.training_step_outputs["train_pred"],) + wandb.log({"training images": wandb.Image(fig)}) + + self.training_step_outputs.clear() # free up memory + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + outputs = sliding_window_inference(inputs, self.inference_roi_size, sw_batch_size=4, predictor=self.forward, overlap=0.5,) + # outputs shape: (B, C, ) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + val_hard_dice = self.soft_dice_metric((post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float()) + + self.val_step_outputs["val_loss"] = loss + self.val_step_outputs["val_soft_dice"] = val_soft_dice + self.val_step_outputs["val_hard_dice"] = val_hard_dice + self.val_step_outputs["val_number"] = len(post_outputs) + + # get input image and label for visualization + if batch_idx == 0: + self.val_step_outputs["val_image"] = inputs[0].squeeze() + self.val_step_outputs["val_gt"] = labels[0].squeeze() + self.val_step_outputs["val_pred"] = outputs[0].squeeze() + + return self.val_step_outputs + + def on_validation_epoch_end(self): + + avg_loss = torch.stack([self.val_step_outputs["val_loss"]]).mean() + avg_soft_dice_val = torch.stack([self.val_step_outputs["val_soft_dice"]]).mean() + avg_hard_dice_val = torch.stack([self.val_step_outputs["val_hard_dice"]]).mean() + + wandb_logs = { + "val_soft_dice": avg_soft_dice_val, + "val_hard_dice": avg_hard_dice_val, + "val_loss": avg_loss, + } + if avg_soft_dice_val > self.best_val_dice: + self.best_val_dice = avg_soft_dice_val + self.best_val_epoch = self.current_epoch + + print( + f"Current epoch: {self.current_epoch}" + f"\nCurrent Mean Soft Dice: {avg_soft_dice_val:.4f}" + f"\nCurrent Mean Hard Dice: {avg_hard_dice_val:.4f}" + f"\nBest Mean Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + self.metric_values.append(avg_soft_dice_val) + + # log on to wandb + self.log_dict(wandb_logs) + + # plot the validation images + fig = plot_slices(image=self.val_step_outputs["val_image"], + gt=self.val_step_outputs["val_gt"], + pred=self.val_step_outputs["val_pred"],) + wandb.log({"validation images": wandb.Image(fig)}) + + # free up memory + self.val_step_outputs.clear() + + return {"log": wandb_logs} + + # -------------------------------- + # TESTING + # -------------------------------- + # def test_step(self, batch, batch_idx, dataloader_idx): + + def test_step(self, batch, batch_idx): + + test_input, test_label = batch["image"], batch["label"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + # print(f"test_input.shape: {test_input.shape} \t test_label.shape: {test_label.shape}") + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, sw_batch_size=4, predictor=self.forward, overlap=0.5) + # print(f"batch['pred'].shape: {batch['pred'].shape}") + + # # upon fsleyes visualization, observed that very small values need to be set to zero, but NOT fully binarizing the pred + # batch["pred"][batch["pred"] < 0.099] = 0.0 + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # save the prediction and label + if self.args.save_test_preds: + + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + print(f"Saving subject: {subject_name}") + + # image saver class + save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) + pred_saver = SaveImage( + output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + label_saver = SaveImage( + output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the label + label_saver(label) + + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate all metrics here + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + # 2. Precision Score + test_precision = precision_score(pred.numpy(), label.numpy()) + # 3. Recall Score + test_recall = recall_score(pred.numpy(), label.numpy()) + + self.test_step_outputs["test_hard_dice"] = test_hard_dice + self.test_step_outputs["test_soft_dice"] = test_soft_dice + self.test_step_outputs["test_precision"] = test_precision + self.test_step_outputs["test_recall"] = test_recall + + + return self.test_step_outputs + + def on_test_epoch_end(self): + + # avg_hard_dice_test = torch.stack([x["test_hard_dice"] for x in outputs]).mean().cpu().numpy() + avg_hard_dice_test, std_hard_dice_test = np.stack([self.test_step_outputs["test_hard_dice"]]).mean(), \ + np.stack([self.test_step_outputs["test_hard_dice"]]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([self.test_step_outputs["test_soft_dice"]]).mean(), \ + np.stack([self.test_step_outputs["test_soft_dice"]]).std() + avg_precision_test = np.stack([self.test_step_outputs["test_precision"]]).mean() + avg_recall_test = np.stack([self.test_step_outputs["test_recall"]]).mean() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + logger.info(f"Test Precision Score: {avg_precision_test}") + logger.info(f"Test Recall Score: {avg_recall_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + self.avg_test_precision = avg_precision_test + self.avg_test_recall = avg_recall_test + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(args): + # Setting the seed + pl.seed_everything(args.seed, workers=True) + + # define root path for finding datalists + dataset_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/contrast-agnostic-softseg-spinalcord/monai" + + # define optimizer + if args.optimizer in ["adamw", "AdamW", "Adamw"]: + optimizer_class = torch.optim.AdamW + elif args.optimizer in ["SGD", "sgd"]: + optimizer_class = torch.optim.SGD + + # define models + if args.model in ["unet", "UNet"]: + net = UNet(spatial_dims=3, + in_channels=1, out_channels=1, + channels=( + args.init_filters, + args.init_filters * 2, + args.init_filters * 4, + args.init_filters * 8, + args.init_filters * 16 + ), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + save_exp_id =f"{args.model}_lr={args.learning_rate}" + elif args.model in ["unetr", "UNETR"]: + # define image size to be fed to the model + img_size = (96, 96, 96) + + # define model + net = UNETR(spatial_dims=3, + in_channels=1, out_channels=1, + img_size=img_size, + feature_size=args.feature_size, + hidden_size=args.hidden_size, + mlp_dim=args.mlp_dim, + num_heads=args.num_heads, + pos_embed="perceptron", + norm_name="instance", + res_block=True, + dropout_rate=0.2, + ) + save_exp_id = f"{args.model}_lr={args.learning_rate}" \ + f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" + + # define loss function + loss_func = SoftDiceLoss(p=1, smooth=1.0) + + # TODO: move this inside the for loop when using more folds + # to save the best model on validation + save_path = os.path.join(args.save_path, f"{save_exp_id}") + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + # to save the results/model predictions + results_path = os.path.join(args.results_dir, f"{save_exp_id}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # train across all folds of the dataset + for fold in range(args.num_cv_folds): + logger.info(f" Training on fold {fold+1} out of {args.num_cv_folds} folds! ") + + timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format + save_exp_id = f"{save_exp_id}_fold={fold}_{timestamp}" + + # i.e. train by loading weights from scratch + pl_model = Model(args, data_root=dataset_root, fold_num=fold, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # don't use wandb logger if in debug mode + # if not args.debug: + exp_logger = pl.loggers.WandbLogger( + name=save_exp_id, + save_dir=args.save_path, + group=f"{args.model}", + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=args) + # else: + # exp_logger = pl.loggers.CSVLogger(save_dir=args.save_path, name="my_exp_name") + + checkpoint_callback = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename=save_exp_id, monitor='val_loss', + save_top_k=1, mode="min", save_last=False, save_weights_only=True) + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, + verbose=False, mode="min") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", # strategy="ddp", + logger=exp_logger, + callbacks=[checkpoint_callback, lr_monitor, early_stopping], + check_val_every_n_epoch=args.check_val_every_n_epochs, + max_epochs=args.max_epochs, + precision=32, + # deterministic=True, + enable_progress_bar=args.enable_progress_bar) + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # TODO: Come back to testing when hyperparamters have been fixed after cross-validation + # Test! + trainer.test(pl_model) + logger.info(f"TESTING DONE!") + + # closing the current wandb instance so that a new one is created for the next fold + wandb.finish() + + # TODO: Figure out saving test metrics to a file + with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: + print('\n-------------- Test Metrics ----------------', file=f) + print(f"\nSeed Used: {args.seed}", file=f) + print(f"\ninitf={args.init_filters}_lr={args.learning_rate}_bs={args.batch_size}_{timestamp}", file=f) + # print(f"\n{np.array(centers_list)[None, :]}", file=f) + # print(f"\n{np.array(centers_list)[:, None]}", file=f) + + print('\n-------------- Test Hard Dice Scores ----------------', file=f) + print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) + + print('\n-------------- Test Soft Dice Scores ----------------', file=f) + print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) + + print('\n-------------- Test Precision Scores ----------------', file=f) + print("Precision --> Mean: %0.3f" % (pl_model.avg_test_precision), file=f) + + print('\n-------------- Test Recall Scores -------------------', file=f) + print("Recall --> Mean: %0.3f" % (pl_model.avg_test_recall), file=f) + + print('-------------------------------------------------------', file=f) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Script for training custom models for SCI Lesion Segmentation.') + # Arguments for model, data, and training and saving + parser.add_argument('-m', '--model', + choices=['unet', 'UNet', 'unetr', 'UNETR', 'segresnet', 'SegResNet'], + default='unet', type=str, help='Model type to be used') + # dataset + parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") + parser.add_argument('-ncv', '--num_cv_folds', default=5, type=int, help="Number of cross validation folds") + + # unet model + parser.add_argument('-initf', '--init_filters', default=16, type=int, help="Number of Filters in Init Layer") + # parser.add_argument('-ps', '--patch_size', type=int, default=128, help='List containing subvolume size') + parser.add_argument('-dep', '--unet_depth', default=3, type=int, help="Depth of UNet model") + + # unetr model + parser.add_argument('-fs', '--feature_size', default=16, type=int, help="Feature Size") + parser.add_argument('-hs', '--hidden_size', default=768, type=int, help='Dimensionality of hidden embeddings') + parser.add_argument('-mlpd', '--mlp_dim', default=2048, type=int, help='Dimensionality of MLP layer') + parser.add_argument('-nh', '--num_heads', default=12, type=int, help='Number of heads in Multi-head Attention') + + # optimizations + parser.add_argument('-me', '--max_epochs', default=1000, type=int, help='Number of epochs for the training process') + parser.add_argument('-bs', '--batch_size', default=2, type=int, help='Batch size of the training and validation processes') + parser.add_argument('-opt', '--optimizer', + choices=['adamw', 'AdamW', 'SGD', 'sgd'], + default='adamw', type=str, help='Optimizer to use') + parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training the model') + parser.add_argument('-pat', '--patience', default=200, type=int, + help='number of validation steps (val_every_n_iters) to wait before early stopping') + parser.add_argument('-epb', '--enable_progress_bar', default=False, action='store_true', + help='by default is disabled since it doesnt work in colab') + parser.add_argument('-cve', '--check_val_every_n_epochs', default=1, type=int, help='num of epochs to wait before validation') + # saving + parser.add_argument('-sp', '--save_path', + default=f"/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", + type=str, help='Path to the saved models directory') + parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', + help='Load model from checkpoint and continue training') + parser.add_argument('-se', '--seed', default=42, type=int, help='Set seeds for reproducibility') + parser.add_argument('-debug', default=False, action='store_true', help='if true, results are not logged to wandb') + parser.add_argument('-stp', '--save_test_preds', default=False, action='store_true', + help='if true, test predictions are saved in `save_path`') + # testing + parser.add_argument('-rd', '--results_dir', + default=f"/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/results", + type=str, help='Path to the model prediction results directory') + + + args = parser.parse_args() + + main(args) \ No newline at end of file From cc244aedec33a1887ec5c1351a752090ad2c3878 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 10 Jul 2023 15:00:17 -0400 Subject: [PATCH 006/106] change nearest interp to linear for labels in spacingD --- monai/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index b2da4e14..505f4bb1 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -15,7 +15,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): # Orientationd(keys=["image", lbl_key], axcodes="RPI"), # TODO: if the source_key is set to "label", then the cropping is only around the label mask CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), # RandSpatialCropSamplesd(keys=["image", lbl_key], roi_size=crop_size, num_samples=num_samples_pv, random_center=True, random_size=False), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", @@ -38,7 +38,7 @@ def val_transforms(lbl_key="label"): EnsureChannelFirstd(keys=["image", lbl_key]), Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), # HistogramNormalized(keys=["image"], mask=None), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), @@ -51,7 +51,7 @@ def test_transforms(lbl_key="label"): EnsureChannelFirstd(keys=["image", lbl_key]), Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), # AddChanneld(keys=["image", lbl_key]), # HistogramNormalized(keys=["image"], mask=None), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), From 7e8261a00c70a29b5ca9f3801554f087470aac45 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 10 Jul 2023 15:13:28 -0400 Subject: [PATCH 007/106] fix bug in train/val/test metrics accumulation --- monai/main.py | 192 +++++++++++++++++++++++++++++--------------------- 1 file changed, 111 insertions(+), 81 deletions(-) diff --git a/monai/main.py b/monai/main.py index 12271966..c7aa52a1 100644 --- a/monai/main.py +++ b/monai/main.py @@ -8,6 +8,7 @@ import torch import pytorch_lightning as pl import torch.nn.functional as F +import matplotlib.pyplot as plt from utils import precision_score, recall_score, dice_score, plot_slices from losses import SoftDiceLoss @@ -62,9 +63,9 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.soft_dice_metric = dice_score # temp lists for storing outputs from training, validation, and testing - self.training_step_outputs = {} - self.val_step_outputs = {} - self.test_step_outputs = {} + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] # -------------------------------- @@ -111,8 +112,8 @@ def prepare_data(self): val_files = val_files[:2] test_files = test_files[:6] - self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=0.1, num_workers=4) - self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.1, num_workers=4) + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=0.25, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.5, num_workers=4) # define test transforms transforms_test = test_transforms(lbl_key='label') @@ -127,7 +128,7 @@ def prepare_data(self): nearest_interp=False, to_tensor=True), ]) - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.5, num_workers=4) # -------------------------------- @@ -171,35 +172,47 @@ def training_step(self, batch, batch_idx): # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here # So, take this dice score with a lot of salt train_soft_dice = self.soft_dice_metric(output, labels) - train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) - - self.training_step_outputs["loss"] = loss - self.training_step_outputs["train_soft_dice"] = train_soft_dice - self.training_step_outputs["train_hard_dice"] = train_hard_dice - self.training_step_outputs["train_number"] = len(inputs) - - # get input image and label for visualization - if batch_idx == 0: - self.training_step_outputs["train_image"] = inputs[0].squeeze() - self.training_step_outputs["train_gt"] = labels[0].squeeze() - self.training_step_outputs["train_pred"] = output[0].squeeze() + # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) + + metrics_dict = { + "loss": loss, + "train_soft_dice": train_soft_dice, + # "train_hard_dice": train_hard_dice, + "train_number": len(inputs), + # "train_image": inputs[0].squeeze(), + # "train_gt": labels[0].squeeze(), + # "train_pred": output[0].squeeze() + } + self.train_step_outputs.append(metrics_dict) - return self.training_step_outputs + return metrics_dict + # TODO: remove on_train_epoch_end to save memory def on_train_epoch_end(self): - avg_loss = torch.stack([self.training_step_outputs["loss"]]).mean() - avg_soft_dice_train = torch.stack([self.training_step_outputs["train_soft_dice"]]).mean() - avg_hard_dice_train = torch.stack([self.training_step_outputs["train_hard_dice"]]).mean() + train_loss, num_items, train_soft_dice = 0, 0, 0 + for output in self.train_step_outputs: + train_loss += output["loss"].sum().item() + train_soft_dice += output["train_soft_dice"].sum().item() + num_items += output["train_number"] + + mean_train_loss = torch.tensor(train_loss / num_items) + mean_train_soft_dice = torch.tensor(train_soft_dice / num_items) - self.log('train_soft_dice', avg_soft_dice_train, on_step=False, on_epoch=True) + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + } + self.log_dict(wandb_logs) - # plot the training images - fig = plot_slices(image=self.training_step_outputs["train_image"], - gt=self.training_step_outputs["train_gt"], - pred=self.training_step_outputs["train_pred"],) - wandb.log({"training images": wandb.Image(fig)}) + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"],) + # wandb.log({"training images": wandb.Image(fig)}) - self.training_step_outputs.clear() # free up memory + # free up memory + self.train_step_outputs.clear() + # plt.close(fig) # -------------------------------- @@ -209,7 +222,8 @@ def validation_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - outputs = sliding_window_inference(inputs, self.inference_roi_size, sw_batch_size=4, predictor=self.forward, overlap=0.5,) + outputs = sliding_window_inference(inputs, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5,) # outputs shape: (B, C, ) # calculate validation loss @@ -219,70 +233,79 @@ def validation_step(self, batch, batch_idx): post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) - val_hard_dice = self.soft_dice_metric((post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float()) - - self.val_step_outputs["val_loss"] = loss - self.val_step_outputs["val_soft_dice"] = val_soft_dice - self.val_step_outputs["val_hard_dice"] = val_hard_dice - self.val_step_outputs["val_number"] = len(post_outputs) - - # get input image and label for visualization - if batch_idx == 0: - self.val_step_outputs["val_image"] = inputs[0].squeeze() - self.val_step_outputs["val_gt"] = labels[0].squeeze() - self.val_step_outputs["val_pred"] = outputs[0].squeeze() - - return self.val_step_outputs + val_hard_dice = self.soft_dice_metric( + (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + ) + + metrics_dict = { + "val_loss": loss, + "val_soft_dice": val_soft_dice, + "val_hard_dice": val_hard_dice, + "val_number": len(post_outputs), + "val_image": inputs[0].squeeze(), + "val_gt": labels[0].squeeze(), + "val_pred": post_outputs[0].detach().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict def on_validation_epoch_end(self): - avg_loss = torch.stack([self.val_step_outputs["val_loss"]]).mean() - avg_soft_dice_val = torch.stack([self.val_step_outputs["val_soft_dice"]]).mean() - avg_hard_dice_val = torch.stack([self.val_step_outputs["val_hard_dice"]]).mean() + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = torch.tensor(val_loss / num_items) + mean_val_soft_dice = torch.tensor(val_soft_dice / num_items) + mean_val_hard_dice = torch.tensor(val_hard_dice / num_items) wandb_logs = { - "val_soft_dice": avg_soft_dice_val, - "val_hard_dice": avg_hard_dice_val, - "val_loss": avg_loss, + "val_soft_dice": mean_val_soft_dice, + # "val_hard_dice": avg_hard_dice_val, + "val_loss": mean_val_loss, } - if avg_soft_dice_val > self.best_val_dice: - self.best_val_dice = avg_soft_dice_val + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice self.best_val_epoch = self.current_epoch print( f"Current epoch: {self.current_epoch}" - f"\nCurrent Mean Soft Dice: {avg_soft_dice_val:.4f}" - f"\nCurrent Mean Hard Dice: {avg_hard_dice_val:.4f}" - f"\nBest Mean Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") - self.metric_values.append(avg_soft_dice_val) + self.metric_values.append(mean_val_soft_dice) # log on to wandb self.log_dict(wandb_logs) # plot the validation images - fig = plot_slices(image=self.val_step_outputs["val_image"], - gt=self.val_step_outputs["val_gt"], - pred=self.val_step_outputs["val_pred"],) + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) wandb.log({"validation images": wandb.Image(fig)}) # free up memory self.val_step_outputs.clear() + plt.close(fig) return {"log": wandb_logs} # -------------------------------- # TESTING # -------------------------------- - # def test_step(self, batch, batch_idx, dataloader_idx): - def test_step(self, batch, batch_idx): test_input, test_label = batch["image"], batch["label"] # print(batch["label_meta_dict"]["filename_or_obj"][0]) # print(f"test_input.shape: {test_input.shape} \t test_label.shape: {test_label.shape}") - batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, sw_batch_size=4, predictor=self.forward, overlap=0.5) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) # print(f"batch['pred'].shape: {batch['pred'].shape}") # # upon fsleyes visualization, observed that very small values need to be set to zero, but NOT fully binarizing the pred @@ -333,24 +356,25 @@ def test_step(self, batch, batch_idx): # 3. Recall Score test_recall = recall_score(pred.numpy(), label.numpy()) - self.test_step_outputs["test_hard_dice"] = test_hard_dice - self.test_step_outputs["test_soft_dice"] = test_soft_dice - self.test_step_outputs["test_precision"] = test_precision - self.test_step_outputs["test_recall"] = test_recall - + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + "test_precision": test_precision, + "test_recall": test_recall, + } + self.test_step_outputs.append(metrics_dict) - return self.test_step_outputs + return metrics_dict def on_test_epoch_end(self): - - # avg_hard_dice_test = torch.stack([x["test_hard_dice"] for x in outputs]).mean().cpu().numpy() - avg_hard_dice_test, std_hard_dice_test = np.stack([self.test_step_outputs["test_hard_dice"]]).mean(), \ - np.stack([self.test_step_outputs["test_hard_dice"]]).std() - avg_soft_dice_test, std_soft_dice_test = np.stack([self.test_step_outputs["test_soft_dice"]]).mean(), \ - np.stack([self.test_step_outputs["test_soft_dice"]]).std() - avg_precision_test = np.stack([self.test_step_outputs["test_precision"]]).mean() - avg_recall_test = np.stack([self.test_step_outputs["test_recall"]]).mean() - + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + avg_precision_test = np.stack([x["test_precision"] for x in self.test_step_outputs]).mean() + avg_recall_test = np.stack([x["test_recall"] for x in self.test_step_outputs]).mean() + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") logger.info(f"Test Precision Score: {avg_precision_test}") @@ -361,6 +385,9 @@ def on_test_epoch_end(self): self.avg_test_precision = avg_precision_test self.avg_test_recall = avg_recall_test + # free up memory + self.test_step_outputs.clear() + # -------------------------------- # MAIN @@ -390,9 +417,9 @@ def main(args): args.init_filters * 16 ), strides=(2, 2, 2, 2), - num_res_units=2, + num_res_units=3, ) - save_exp_id =f"{args.model}_lr={args.learning_rate}" + save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=3_lr={args.learning_rate}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) @@ -417,6 +444,9 @@ def main(args): loss_func = SoftDiceLoss(p=1, smooth=1.0) # TODO: move this inside the for loop when using more folds + timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format + save_exp_id = f"{save_exp_id}_{timestamp}" + # to save the best model on validation save_path = os.path.join(args.save_path, f"{save_exp_id}") if not os.path.exists(save_path): @@ -431,8 +461,8 @@ def main(args): for fold in range(args.num_cv_folds): logger.info(f" Training on fold {fold+1} out of {args.num_cv_folds} folds! ") - timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format - save_exp_id = f"{save_exp_id}_fold={fold}_{timestamp}" + # timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format + # save_exp_id = f"{save_exp_id}_fold={fold}_{timestamp}" # i.e. train by loading weights from scratch pl_model = Model(args, data_root=dataset_root, fold_num=fold, From d361b73382bff95f4d15fb8b7ea7a937d7efee01 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 10 Jul 2023 17:49:06 -0400 Subject: [PATCH 008/106] update plot_slices() --- monai/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/utils.py b/monai/utils.py index a2dd10fc..66adce38 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -178,9 +178,9 @@ def plot_slices(image, gt, pred): """ # bring everything to numpy - image = image.detach().cpu().numpy() - gt = gt.detach().cpu().numpy() - pred = pred.detach().cpu().numpy() + image = image.numpy() + gt = gt.numpy() + pred = pred.numpy() fig, axs = plt.subplots(1, 3, figsize=(10, 8)) fig.suptitle('Original Image --> Ground Truth --> Prediction') From 11f3ac56322d094621ff1d9572d62f66cef2051b Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 10 Jul 2023 17:55:13 -0400 Subject: [PATCH 009/106] fix cuda memory leak issues; remove unused variables --- monai/main.py | 55 +++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/monai/main.py b/monai/main.py index c7aa52a1..42c0cc29 100644 --- a/monai/main.py +++ b/monai/main.py @@ -47,8 +47,6 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.results_path = results_path self.best_val_dice, self.best_val_epoch = 0, 0 - self.metric_values = [] - self.epoch_losses, self.epoch_soft_dice_train, self.epoch_hard_dice_train = [], [], [] # define cropping and padding dimensions # NOTE: taken from nnUNet_plans.json @@ -113,7 +111,7 @@ def prepare_data(self): test_files = test_files[:6] self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=0.25, num_workers=4) - self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.5, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) # define test transforms transforms_test = test_transforms(lbl_key='label') @@ -127,8 +125,7 @@ def prepare_data(self): meta_keys=["pred_meta_dict", "label_meta_dict"], nearest_interp=False, to_tensor=True), ]) - - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.5, num_workers=4) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) # -------------------------------- @@ -175,8 +172,8 @@ def training_step(self, batch, batch_idx): # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) metrics_dict = { - "loss": loss, - "train_soft_dice": train_soft_dice, + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), # "train_hard_dice": train_hard_dice, "train_number": len(inputs), # "train_image": inputs[0].squeeze(), @@ -187,7 +184,6 @@ def training_step(self, batch, batch_idx): return metrics_dict - # TODO: remove on_train_epoch_end to save memory def on_train_epoch_end(self): train_loss, num_items, train_soft_dice = 0, 0, 0 for output in self.train_step_outputs: @@ -195,12 +191,12 @@ def on_train_epoch_end(self): train_soft_dice += output["train_soft_dice"].sum().item() num_items += output["train_number"] - mean_train_loss = torch.tensor(train_loss / num_items) - mean_train_soft_dice = torch.tensor(train_soft_dice / num_items) + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) wandb_logs = { - "train_soft_dice": mean_train_soft_dice, - "train_loss": mean_train_loss, + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss } self.log_dict(wandb_logs) @@ -212,6 +208,7 @@ def on_train_epoch_end(self): # free up memory self.train_step_outputs.clear() + wandb_logs.clear() # plt.close(fig) @@ -237,14 +234,17 @@ def validation_step(self, batch, batch_idx): (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() ) + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 metrics_dict = { - "val_loss": loss, - "val_soft_dice": val_soft_dice, - "val_hard_dice": val_hard_dice, + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), - "val_image": inputs[0].squeeze(), - "val_gt": labels[0].squeeze(), - "val_pred": post_outputs[0].detach().squeeze(), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), } self.val_step_outputs.append(metrics_dict) @@ -259,9 +259,9 @@ def on_validation_epoch_end(self): val_hard_dice += output["val_hard_dice"].sum().item() num_items += output["val_number"] - mean_val_loss = torch.tensor(val_loss / num_items) - mean_val_soft_dice = torch.tensor(val_soft_dice / num_items) - mean_val_hard_dice = torch.tensor(val_hard_dice / num_items) + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) wandb_logs = { "val_soft_dice": mean_val_soft_dice, @@ -278,8 +278,6 @@ def on_validation_epoch_end(self): f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") - - self.metric_values.append(mean_val_soft_dice) # log on to wandb self.log_dict(wandb_logs) @@ -292,9 +290,10 @@ def on_validation_epoch_end(self): # free up memory self.val_step_outputs.clear() + wandb_logs.clear() plt.close(fig) - return {"log": wandb_logs} + # return {"log": wandb_logs} # -------------------------------- # TESTING @@ -417,9 +416,9 @@ def main(args): args.init_filters * 16 ), strides=(2, 2, 2, 2), - num_res_units=3, + num_res_units=2, ) - save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=3_lr={args.learning_rate}" + save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=2_lr={args.learning_rate}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) @@ -498,7 +497,7 @@ def main(args): callbacks=[checkpoint_callback, lr_monitor, early_stopping], check_val_every_n_epoch=args.check_val_every_n_epochs, max_epochs=args.max_epochs, - precision=32, + precision=32, # TODO: see if 16-bit precision is stable # deterministic=True, enable_progress_bar=args.enable_progress_bar) @@ -566,7 +565,7 @@ def main(args): choices=['adamw', 'AdamW', 'SGD', 'sgd'], default='adamw', type=str, help='Optimizer to use') parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training the model') - parser.add_argument('-pat', '--patience', default=200, type=int, + parser.add_argument('-pat', '--patience', default=15, type=int, help='number of validation steps (val_every_n_iters) to wait before early stopping') parser.add_argument('-epb', '--enable_progress_bar', default=False, action='store_true', help='by default is disabled since it doesnt work in colab') From e88afbcd1037c8bb17525fe0ba55bfa4fe570472 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 11 Jul 2023 18:23:53 -0400 Subject: [PATCH 010/106] add PolyLR scheduler --- monai/utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/monai/utils.py b/monai/utils.py index 66adce38..291c01ec 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -1,5 +1,6 @@ import numpy as np import matplotlib.pyplot as plt +from torch.optim.lr_scheduler import _LRScheduler class FoldGenerator: @@ -195,7 +196,29 @@ def plot_slices(image, gt, pred): return fig +class PolyLRScheduler(_LRScheduler): + """ + Polynomial learning rate scheduler. Taken from: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/lr_scheduler/polylr.py + + """ + def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): + self.optimizer = optimizer + self.initial_lr = initial_lr + self.max_steps = max_steps + self.exponent = exponent + self.ctr = 0 + super().__init__(optimizer, current_step if current_step is not None else -1, False) + + def step(self, current_step=None): + if current_step is None or current_step == -1: + current_step = self.ctr + self.ctr += 1 + + new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent + for param_group in self.optimizer.param_groups: + param_group['lr'] = new_lr if __name__ == "__main__": From b7ecd0ab5106aade0d85f85da46c33a8d1a5d0a4 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 11 Jul 2023 18:24:40 -0400 Subject: [PATCH 011/106] swap order of spacingD and cropForegroundD --- monai/transforms.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 505f4bb1..bf33ffdb 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -14,9 +14,10 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): EnsureChannelFirstd(keys=["image", lbl_key]), # Orientationd(keys=["image", lbl_key], axcodes="RPI"), # TODO: if the source_key is set to "label", then the cropping is only around the label mask - CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), - SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), + CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box + SpatialPadd(keys=["image", lbl_key], spatial_size=(64, 128, 128), method="symmetric"), + # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), # RandSpatialCropSamplesd(keys=["image", lbl_key], roi_size=crop_size, num_samples=num_samples_pv, random_center=True, random_size=False), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1, neg=1, num_samples=num_samples_pv, @@ -37,8 +38,8 @@ def val_transforms(lbl_key="label"): LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), Orientationd(keys=["image", lbl_key], axcodes="RPI"), - CropForegroundd(keys=["image", lbl_key], source_key="image"), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + CropForegroundd(keys=["image", lbl_key], source_key="image"), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), # HistogramNormalized(keys=["image"], mask=None), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), @@ -50,8 +51,8 @@ def test_transforms(lbl_key="label"): LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), Orientationd(keys=["image", lbl_key], axcodes="RPI"), - CropForegroundd(keys=["image", lbl_key], source_key="image"), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + CropForegroundd(keys=["image", lbl_key], source_key="image"), # AddChanneld(keys=["image", lbl_key]), # HistogramNormalized(keys=["image"], mask=None), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), From f73f22ecff97b39de8b986ecd6af013d5299cdaf Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 11 Jul 2023 18:25:56 -0400 Subject: [PATCH 012/106] add empty patch filtering; change LR scheduler --- monai/main.py | 95 +++++++++++++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/monai/main.py b/monai/main.py index 42c0cc29..c6bc68af 100644 --- a/monai/main.py +++ b/monai/main.py @@ -10,7 +10,7 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -from utils import precision_score, recall_score, dice_score, plot_slices +from utils import precision_score, recall_score, dice_score, plot_slices, PolyLRScheduler from losses import SoftDiceLoss from transforms import train_transforms, val_transforms, test_transforms @@ -50,8 +50,8 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas # define cropping and padding dimensions # NOTE: taken from nnUNet_plans.json - self.voxel_cropping_size = (80, 192, 160) - self.inference_roi_size = (80, 192, 160) + self.voxel_cropping_size = (64, 128, 128) # (80, 192, 160) + self.inference_roi_size = (64, 128, 128) # (80, 192, 160) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) self.val_post_pred = Compose([EnsureType()]) @@ -149,8 +149,8 @@ def test_dataloader(self): # -------------------------------- def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr, weight_decay=1e-5) - # TODO: look at poly learning rate scheduler - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) + # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) + scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) return [optimizer], [scheduler] @@ -160,6 +160,12 @@ def configure_optimizers(self): def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] + + # filter empty input patches + if not inputs.any(): + print("Encountered empty input patch. Skipping...") + return None + output = self.forward(inputs) # calculate training loss @@ -176,40 +182,45 @@ def training_step(self, batch, batch_idx): "train_soft_dice": train_soft_dice.detach().cpu(), # "train_hard_dice": train_hard_dice, "train_number": len(inputs), - # "train_image": inputs[0].squeeze(), - # "train_gt": labels[0].squeeze(), - # "train_pred": output[0].squeeze() + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze() } self.train_step_outputs.append(metrics_dict) return metrics_dict def on_train_epoch_end(self): - train_loss, num_items, train_soft_dice = 0, 0, 0 - for output in self.train_step_outputs: - train_loss += output["loss"].sum().item() - train_soft_dice += output["train_soft_dice"].sum().item() - num_items += output["train_number"] - - mean_train_loss = (train_loss / num_items) - mean_train_soft_dice = (train_soft_dice / num_items) - wandb_logs = { - "train_soft_dice": mean_train_soft_dice, - "train_loss": mean_train_loss - } - self.log_dict(wandb_logs) + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, num_items, train_soft_dice = 0, 0, 0 + for output in self.train_step_outputs: + train_loss += output["loss"].sum().item() + train_soft_dice += output["train_soft_dice"].sum().item() + num_items += output["train_number"] + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) - # # plot the training images - # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], - # gt=self.train_step_outputs[0]["train_gt"], - # pred=self.train_step_outputs[0]["train_pred"],) - # wandb.log({"training images": wandb.Image(fig)}) + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss + } + self.log_dict(wandb_logs) - # free up memory - self.train_step_outputs.clear() - wandb_logs.clear() - # plt.close(fig) + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"],) + wandb.log({"training images": wandb.Image(fig)}) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + plt.close(fig) # -------------------------------- @@ -242,9 +253,9 @@ def validation_step(self, batch, batch_idx): "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), - "val_image": inputs[0].detach().cpu().squeeze(), - "val_gt": labels[0].detach().cpu().squeeze(), - "val_pred": post_outputs[0].detach().cpu().squeeze(), + # "val_image": inputs[0].detach().cpu().squeeze(), + # "val_gt": labels[0].detach().cpu().squeeze(), + # "val_pred": post_outputs[0].detach().cpu().squeeze(), } self.val_step_outputs.append(metrics_dict) @@ -265,7 +276,7 @@ def on_validation_epoch_end(self): wandb_logs = { "val_soft_dice": mean_val_soft_dice, - # "val_hard_dice": avg_hard_dice_val, + "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, } if mean_val_soft_dice > self.best_val_dice: @@ -282,16 +293,16 @@ def on_validation_epoch_end(self): # log on to wandb self.log_dict(wandb_logs) - # plot the validation images - fig = plot_slices(image=self.val_step_outputs[0]["val_image"], - gt=self.val_step_outputs[0]["val_gt"], - pred=self.val_step_outputs[0]["val_pred"],) - wandb.log({"validation images": wandb.Image(fig)}) + # # plot the validation images + # fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + # gt=self.val_step_outputs[0]["val_gt"], + # pred=self.val_step_outputs[0]["val_pred"],) + # wandb.log({"validation images": wandb.Image(fig)}) # free up memory self.val_step_outputs.clear() wandb_logs.clear() - plt.close(fig) + # plt.close(fig) # return {"log": wandb_logs} @@ -518,6 +529,7 @@ def main(args): print('\n-------------- Test Metrics ----------------', file=f) print(f"\nSeed Used: {args.seed}", file=f) print(f"\ninitf={args.init_filters}_lr={args.learning_rate}_bs={args.batch_size}_{timestamp}", file=f) + print(f"\npatch_size={pl_model.voxel_cropping_size}", file=f) # print(f"\n{np.array(centers_list)[None, :]}", file=f) # print(f"\n{np.array(centers_list)[:, None]}", file=f) @@ -565,8 +577,9 @@ def main(args): choices=['adamw', 'AdamW', 'SGD', 'sgd'], default='adamw', type=str, help='Optimizer to use') parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training the model') - parser.add_argument('-pat', '--patience', default=15, type=int, + parser.add_argument('-pat', '--patience', default=25, type=int, help='number of validation steps (val_every_n_iters) to wait before early stopping') + # NOTE: patience is acutally until (patience * check_val_every_n_epochs) epochs parser.add_argument('-epb', '--enable_progress_bar', default=False, action='store_true', help='by default is disabled since it doesnt work in colab') parser.add_argument('-cve', '--check_val_every_n_epochs', default=1, type=int, help='num of epochs to wait before validation') From 5bb5315c38d0d2cbfae92d4dcf72195f6e677559 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 11 Jul 2023 18:31:56 -0400 Subject: [PATCH 013/106] add requirements --- monai/requirements.txt | 108 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 monai/requirements.txt diff --git a/monai/requirements.txt b/monai/requirements.txt new file mode 100644 index 00000000..435823c4 --- /dev/null +++ b/monai/requirements.txt @@ -0,0 +1,108 @@ +appdirs==1.4.4 +asttokens==2.2.1 +backcall==0.2.0 +backports.functools-lru-cache==1.6.5 +Brotli==1.0.9 +certifi==2023.5.7 +cffi==1.15.1 +charset-normalizer==3.1.0 +click==8.1.3 +cmake==3.26.4 +colorama==0.4.6 +comm==0.1.3 +contourpy==1.1.0 +cycler==0.11.0 +debugpy==1.6.7 +decorator==5.1.1 +docker-pycreds==0.4.0 +executing==1.2.0 +filelock==3.12.2 +fonttools==4.40.0 +fsspec==2023.6.0 +gitdb==4.0.10 +GitPython==3.1.31 +gmpy2==2.1.2 +idna==3.4 +importlib-metadata==6.8.0 +importlib-resources==6.0.0 +ipykernel==6.24.0 +ipython==8.14.0 +jedi==0.18.2 +Jinja2==3.1.2 +joblib==1.3.0 +jupyter_client==8.3.0 +jupyter_core==5.3.1 +kiwisolver==1.4.4 +lightning-utilities==0.9.0 +lit==16.0.6 +loguru==0.7.0 +MarkupSafe==2.1.3 +matplotlib==3.7.2 +matplotlib-inline==0.1.6 +monai==1.2.0 +mpmath==1.3.0 +nest-asyncio==1.5.6 +networkx==3.1 +nibabel==5.1.0 +numpy==1.25.0 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +packaging==23.1 +pandas==2.0.3 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==10.0.0 +pip==23.1.2 +platformdirs==3.8.0 +pooch==1.7.0 +prompt-toolkit==3.0.39 +protobuf==3.20.3 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycparser==2.21 +Pygments==2.15.1 +pyparsing==3.0.9 +PySocks==1.7.1 +python-dateutil==2.8.2 +pytorch-lightning==2.0.4 +pytz==2023.3 +PyYAML==6.0 +pyzmq==25.1.0 +requests==2.31.0 +scikit-learn==1.3.0 +scipy==1.11.1 +sentry-sdk==1.21.1 +setproctitle==1.3.2 +setuptools==68.0.0 +six==1.16.0 +smmap==3.0.5 +stack-data==0.6.2 +sympy==1.12 +threadpoolctl==3.1.0 +torch==2.0.0+cu117 +torchaudio==2.0.1+cu117 +torchmetrics==0.11.4 +torchvision==0.15.1+cu117 +tornado==6.3.2 +tqdm==4.65.0 +traitlets==5.9.0 +triton==2.0.0 +typing_extensions==4.7.1 +tzdata==2023.3 +urllib3==2.0.3 +wandb==0.15.5 +wcwidth==0.2.6 +wheel==0.40.0 +zipp==3.15.0 From c7a02816d4253608c8d7410b453cecf156028ce1 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Wed, 12 Jul 2023 11:33:15 -0400 Subject: [PATCH 014/106] add option to load from joblib for ivadomed comparison --- monai/create_msd_data.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 7991f0da..c9f15c41 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -3,7 +3,7 @@ from tqdm import tqdm import numpy as np import argparse -# import nibabel as nib +import joblib from utils import FoldGenerator from loguru import logger from sklearn.model_selection import train_test_split @@ -18,6 +18,8 @@ parser.add_argument('-ncvf', '--num-cv-folds', default=5, type=int, help="[1-k] To create a k-fold dataset for cross validation, 0 for single file with all subjects") parser.add_argument('-pd', '--path-data', default=root, type=str, help='Path to the data set directory') +parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.', + default=None, type=str) parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') args = parser.parse_args() @@ -148,12 +150,23 @@ jsonFile.write(final_json) jsonFile.close() else: - # create one json file with 60-20-20 train-val-test split - train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 - train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) - # Use the training split to further split into training and validation splits - train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), - random_state=args.seed, ) + + if args.path_joblib is not None: + # load information from the joblib to match train and test subjects + joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') + splits = joblib.load("split_datasets_all_seed=15.joblib") + # get the subjects from the joblib file + train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) + val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) + test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']]))) + + else: + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) logger.info(f"Number of training subjects: {len(train_subjects)}") logger.info(f"Number of validation subjects: {len(val_subjects)}") From 11961e78cde64e9474d828d4eb30005b81c5d9f8 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 14 Jul 2023 07:58:31 -0400 Subject: [PATCH 015/106] add elastic and biasfield transforms --- monai/transforms.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index bf33ffdb..b8f1c497 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -2,7 +2,8 @@ from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, RandCropByPosNegLabeld, Spacingd, RandRotate90d, ToTensord, NormalizeIntensityd, EnsureType, RandWeightedCropd, HistogramNormalized, EnsureTyped, Invertd, SaveImaged, - EnsureChannelFirstd, CenterSpatialCropd, RandSpatialCropSamplesd, Orientationd) + EnsureChannelFirstd, CenterSpatialCropd, RandSpatialCropSamplesd, Orientationd, + Rand3DElasticd, RandBiasFieldd) # median image size in voxels - taken from nnUNet # median_size = (123, 255, 214) @@ -13,7 +14,6 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), # Orientationd(keys=["image", lbl_key], axcodes="RPI"), - # TODO: if the source_key is set to "label", then the cropping is only around the label mask Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box SpatialPadd(keys=["image", lbl_key], spatial_size=(64, 128, 128), method="symmetric"), @@ -22,11 +22,13 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1, neg=1, num_samples=num_samples_pv, # if num_samples=4, then 4 samples/image are randomly generated - image_key="image", image_threshold=0.), + image_key="image", image_threshold=0.), + Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), prob=0.25, degree=3), RandFlipd(keys=["image", lbl_key], spatial_axis=[0], prob=0.50,), RandFlipd(keys=["image", lbl_key], spatial_axis=[1], prob=0.50,), RandFlipd(keys=["image", lbl_key],spatial_axis=[2],prob=0.50,), - RandRotate90d(keys=["image", lbl_key], prob=0.10, max_k=3,), + # RandRotate90d(keys=["image", lbl_key], prob=0.10, max_k=3,), Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error # HistogramNormalized(keys=["image"], mask=None), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), From d230e5996409055149f4e6ec847da2dea463c726 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 14 Jul 2023 08:01:11 -0400 Subject: [PATCH 016/106] change empty image patch filter to label patch --- monai/main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/main.py b/monai/main.py index c6bc68af..391d2314 100644 --- a/monai/main.py +++ b/monai/main.py @@ -161,9 +161,9 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # filter empty input patches - if not inputs.any(): - print("Encountered empty input patch. Skipping...") + # filter empty label patches + if not labels.any(): + print("Encountered empty label patch. Skipping...") return None output = self.forward(inputs) @@ -427,9 +427,9 @@ def main(args): args.init_filters * 16 ), strides=(2, 2, 2, 2), - num_res_units=2, + num_res_units=4, ) - save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=2_lr={args.learning_rate}" + save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_lr={args.learning_rate}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) From 56dfff34b3ff7f7c794f9df163f239c605bec687 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 17 Jul 2023 16:52:36 -0400 Subject: [PATCH 017/106] add initial verson of infeerence script --- monai/run_inference.py | 235 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 monai/run_inference.py diff --git a/monai/run_inference.py b/monai/run_inference.py new file mode 100644 index 00000000..986689c8 --- /dev/null +++ b/monai/run_inference.py @@ -0,0 +1,235 @@ +import os +import argparse +import numpy as np +from loguru import logger +import torch.nn.functional as F +import torch +import json +from time import time + +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) +from monai.networks.nets import UNet + +from transforms import test_transforms +from utils import precision_score, recall_score, dice_score + +DEBUG = True +INIT_FILTERS=8 +INFERENCE_ROI_SIZE = (64, 128, 128) # (80, 192, 160) +DEVICE = "cpu" + + +def get_parser(): + + parser = argparse.ArgumentParser(description="Run inference on a MONAI-trained model") + + parser.add_argument("--path-json", type=str, required=True, + help="Path to the json file containing the test dataset in MSD format") + parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint file") + parser.add_argument("--path-out", type=str, required=True, + help="Path to the output folder where to store the predictions and associated metrics") + + return parser + + +# -------------------------------- +# DATA +# -------------------------------- +def prepare_data(root): + # set deterministic training for reproducibility + # set_determinism(seed=self.args.seed) + + # load the dataset + dataset = os.path.join(root, f"dataset.json") + test_files = load_decathlon_datalist(dataset, True, "test") + + if DEBUG: # args.debug: + test_files = test_files[:3] + + # define test transforms + transforms_test = test_transforms(lbl_key='label') + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + return test_ds, test_post_pred + + +def main(args): + + # define start time + start = time() + + # define root path for finding datalists + dataset_root = args.path_json + + # TODO: change the name of the checkpoint file to best_model.ckpt + chkp_path = os.path.join(args.chkp_path, "unet_nf=8_nrs=4_lr=0.001_20230713-1206.ckpt") + + results_path = args.path_out + folder_name = chkp_path.split("/")[-2] + results_path = os.path.join(results_path, folder_name) + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + checkpoint = torch.load(chkp_path, map_location=torch.device('cpu'))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + + # initialize the model + net = UNet(spatial_dims=3, + in_channels=1, out_channels=1, + channels=( + INIT_FILTERS, + INIT_FILTERS * 2, + INIT_FILTERS * 4, + INIT_FILTERS * 8, + INIT_FILTERS * 16 + ), + strides=(2, 2, 2, 2), + num_res_units=4,) + + # load the trained model weights + net.load_state_dict(checkpoint) + net.to(DEVICE) + + # define the dataset and dataloader + test_ds, test_post_pred = prepare_data(dataset_root) + test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + # define list to collect the test metrics + test_step_outputs = [] + test_summary = {} + + # iterate over the dataset and compute metrics + net.eval() + with torch.no_grad(): + for i, batch in enumerate(test_loader): + # compute time for inference per subject + start_time = time() + + test_input = batch["image"] + batch["pred"] = sliding_window_inference(test_input, INFERENCE_ROI_SIZE, + sw_batch_size=4, predictor=net, overlap=0.5) + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # # upon fsleyes visualization, observed that very small values need to be set to zero, but NOT fully binarizing the pred + # batch["pred"][batch["pred"] < 0.099] = 0.0 + + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # save the prediction and label + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + print(f"Saving subject: {subject_name}") + + # image saver class + save_folder = os.path.join(results_path, subject_name.split("_")[0]) + pred_saver = SaveImage( + output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # label_saver = SaveImage( + # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", + # separate_folder=False, print_log=False) + # # save the label + # label_saver(label) + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate all metrics here + # 1. Dice Score + test_soft_dice = dice_score(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = dice_score(pred.numpy(), label.numpy()) + # 2. Precision Score + test_precision = precision_score(pred.numpy(), label.numpy()) + # 3. Recall Score + test_recall = recall_score(pred.numpy(), label.numpy()) + + end_time = time() + metrics_dict = { + "subject_name_and_contrast": subject_name, + "dice_binary": round(test_hard_dice, 2), + "dice_soft": round(test_soft_dice.item(), 2), + "precision": round(test_precision, 2), + "recall": round(test_recall, 2), + # TODO: add relative volume difference here + "inference_time_in_sec": round((end_time - start_time), 2), + } + test_step_outputs.append(metrics_dict) + + # save the test summary + test_summary["metrics_per_subject"] = test_step_outputs + + # compute the average of all metrics + avg_hard_dice_test, std_hard_dice_test = np.stack([x["dice_binary"] for x in test_step_outputs]).mean(), \ + np.stack([x["dice_binary"] for x in test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["dice_soft"] for x in test_step_outputs]).mean(), \ + np.stack([x["dice_soft"] for x in test_step_outputs]).std() + avg_precision_test = np.stack([x["precision"] for x in test_step_outputs]).mean() + avg_recall_test = np.stack([x["recall"] for x in test_step_outputs]).mean() + avg_inference_time = np.stack([x["inference_time_in_sec"] for x in test_step_outputs]).mean() + + # store the average metrics in a dict + avg_metrics = { + "avg_dice_binary": round(avg_hard_dice_test, 2), + "avg_dice_soft": round(avg_soft_dice_test, 2), + "avg_precision": round(avg_precision_test, 2), + "avg_recall": round(avg_recall_test, 2), + "avg_inference_time_in_sec": round(avg_inference_time, 2), + } + test_summary["metrics_avg_across_cohort"] = avg_metrics + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + logger.info(f"Test Precision Score: {avg_precision_test}") + logger.info(f"Test Recall Score: {avg_recall_test}") + logger.info(f"Average Inference Time per Subject: {avg_inference_time:.2f}s") + + # dump the test summary to a json file + with open(os.path.join(results_path, "test_summary.json"), "w") as f: + json.dump(test_summary, f, indent=4, sort_keys=True) + + # free up memory + test_step_outputs.clear() + + end = time() + + print("=====================================================================") + print(f"Total time taken for inference: {(end - start) / 60:.2f} minutes") + print("=====================================================================") + + +if __name__ == "__main__": + + args = get_parser().parse_args() + main(args) \ No newline at end of file From ad3d8474d3b008265d3304ef4d51f9b5071f0371 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 17 Jul 2023 16:57:41 -0400 Subject: [PATCH 018/106] add ref to chkpt loading for inference --- monai/run_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/run_inference.py b/monai/run_inference.py index 986689c8..73b75948 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -84,6 +84,7 @@ def main(args): checkpoint = torch.load(chkp_path, map_location=torch.device('cpu'))["state_dict"] # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 for key in list(checkpoint.keys()): if 'net.' in key: checkpoint[key.replace('net.', '')] = checkpoint[key] From 407ae043c35b88eccb3f73678b9cb0323b23df3f Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 18 Jul 2023 20:20:46 -0400 Subject: [PATCH 019/106] remove duplicated test transforms --- monai/transforms.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index b8f1c497..266e0245 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -19,8 +19,10 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): SpatialPadd(keys=["image", lbl_key], spatial_size=(64, 128, 128), method="symmetric"), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), # RandSpatialCropSamplesd(keys=["image", lbl_key], roi_size=crop_size, num_samples=num_samples_pv, random_center=True, random_size=False), + # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a + # foreground voxel as a center rather than a background voxel. RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", - spatial_size=crop_size, pos=1, neg=1, num_samples=num_samples_pv, + spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, # if num_samples=4, then 4 samples/image are randomly generated image_key="image", image_threshold=0.), Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), @@ -40,23 +42,10 @@ def val_transforms(lbl_key="label"): LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), Orientationd(keys=["image", lbl_key], axcodes="RPI"), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), CropForegroundd(keys=["image", lbl_key], source_key="image"), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), # HistogramNormalized(keys=["image"], mask=None), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), # ToTensord(keys=["image", lbl_key]), ]) - -def test_transforms(lbl_key="label"): - return Compose([ - LoadImaged(keys=["image", lbl_key]), - EnsureChannelFirstd(keys=["image", lbl_key]), - Orientationd(keys=["image", lbl_key], axcodes="RPI"), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), - CropForegroundd(keys=["image", lbl_key], source_key="image"), - # AddChanneld(keys=["image", lbl_key]), - # HistogramNormalized(keys=["image"], mask=None), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), - # ToTensord(keys=["image", lbl_key]), - ]) \ No newline at end of file From e67427f803c6d2355eb4c331aac09ee7a1ab56a0 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 18 Jul 2023 20:21:46 -0400 Subject: [PATCH 020/106] add feature to plot more slices --- monai/utils.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/monai/utils.py b/monai/utils.py index 291c01ec..a1fa8cfe 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -172,7 +172,7 @@ def dice_score(prediction, groundtruth): return dice -def plot_slices(image, gt, pred): +def plot_slices(image, gt, pred, debug=False): """ Plot the image, ground truth and prediction of the mid-sagittal axial slice The orientaion is assumed to RPI @@ -183,13 +183,33 @@ def plot_slices(image, gt, pred): gt = gt.numpy() pred = pred.numpy() - fig, axs = plt.subplots(1, 3, figsize=(10, 8)) - fig.suptitle('Original Image --> Ground Truth --> Prediction') - slice = image.shape[2]//2 - - axs[0].imshow(image[:, :, slice].T, cmap='gray'); axs[0].axis('off') - axs[1].imshow(gt[:, :, slice].T); axs[1].axis('off') - axs[2].imshow(pred[:, :, slice].T); axs[2].axis('off') + if not debug: + mid_sagittal = image.shape[2]//2 + # plot X slices before and after the mid-sagittal slice in a grid + fig, axs = plt.subplots(3, 6, figsize=(10, 6)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + for i in range(6): + axs[0, i].imshow(image[:, :, mid_sagittal-3+i].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[:, :, mid_sagittal-3+i].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[:, :, mid_sagittal-3+i].T); axs[2, i].axis('off') + + # fig, axs = plt.subplots(1, 3, figsize=(10, 8)) + # fig.suptitle('Original Image --> Ground Truth --> Prediction') + # slice = image.shape[2]//2 + + # axs[0].imshow(image[:, :, slice].T, cmap='gray'); axs[0].axis('off') + # axs[1].imshow(gt[:, :, slice].T); axs[1].axis('off') + # axs[2].imshow(pred[:, :, slice].T); axs[2].axis('off') + + else: # plot multiple slices + mid_sagittal = image.shape[2]//2 + # plot X slices before and after the mid-sagittal slice in a grid + fig, axs = plt.subplots(3, 14, figsize=(20, 8)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + for i in range(14): + axs[0, i].imshow(image[:, :, mid_sagittal-7+i].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[:, :, mid_sagittal-7+i].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[:, :, mid_sagittal-7+i].T); axs[2, i].axis('off') plt.tight_layout() fig.show() From 72ac6fe49e362dddc1082c9426416f32d6c994f1 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 18 Jul 2023 20:22:26 -0400 Subject: [PATCH 021/106] minor fixes and improvements --- monai/main.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/main.py b/monai/main.py index 391d2314..a3dd321f 100644 --- a/monai/main.py +++ b/monai/main.py @@ -12,7 +12,7 @@ from utils import precision_score, recall_score, dice_score, plot_slices, PolyLRScheduler from losses import SoftDiceLoss -from transforms import train_transforms, val_transforms, test_transforms +from transforms import train_transforms, val_transforms from monai.utils import set_determinism from monai.inferers import sliding_window_inference @@ -49,8 +49,7 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.best_val_dice, self.best_val_epoch = 0, 0 # define cropping and padding dimensions - # NOTE: taken from nnUNet_plans.json - self.voxel_cropping_size = (64, 128, 128) # (80, 192, 160) + self.voxel_cropping_size = (64, 128, 128) # (80, 192, 160) taken from nnUNet_plans.json self.inference_roi_size = (64, 128, 128) # (80, 192, 160) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) @@ -114,7 +113,7 @@ def prepare_data(self): self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) # define test transforms - transforms_test = test_transforms(lbl_key='label') + transforms_test = val_transforms(lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 @@ -214,7 +213,8 @@ def on_train_epoch_end(self): # plot the training images fig = plot_slices(image=self.train_step_outputs[0]["train_image"], gt=self.train_step_outputs[0]["train_gt"], - pred=self.train_step_outputs[0]["train_pred"],) + pred=self.train_step_outputs[0]["train_pred"], + debug=args.debug) wandb.log({"training images": wandb.Image(fig)}) # free up memory @@ -429,7 +429,7 @@ def main(args): strides=(2, 2, 2, 2), num_res_units=4, ) - save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_lr={args.learning_rate}" + save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_lr={args.learning_rate}_bs={args.batch_size}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) @@ -493,7 +493,7 @@ def main(args): # exp_logger = pl.loggers.CSVLogger(save_dir=args.save_path, name="my_exp_name") checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename=save_exp_id, monitor='val_loss', + dirpath=save_path, filename='best_model', monitor='val_loss', save_top_k=1, mode="min", save_last=False, save_weights_only=True) lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') @@ -516,6 +516,9 @@ def main(args): trainer.fit(pl_model) logger.info(f" Training Done!") + # Saving training script to wandb + wandb.save("main.py") + # TODO: Come back to testing when hyperparamters have been fixed after cross-validation # Test! trainer.test(pl_model) From 0540c1f1c0ced5109c12fd8c710aa873fdf07630 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 20 Jul 2023 15:44:09 -0400 Subject: [PATCH 022/106] update code to do inference on various datasets --- monai/run_inference.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index 73b75948..623e5f51 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -12,10 +12,10 @@ from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) from monai.networks.nets import UNet -from transforms import test_transforms +from transforms import val_transforms from utils import precision_score, recall_score, dice_score -DEBUG = True +DEBUG = False INIT_FILTERS=8 INFERENCE_ROI_SIZE = (64, 128, 128) # (80, 192, 160) DEVICE = "cpu" @@ -27,9 +27,11 @@ def get_parser(): parser.add_argument("--path-json", type=str, required=True, help="Path to the json file containing the test dataset in MSD format") - parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint file") + parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint folder") parser.add_argument("--path-out", type=str, required=True, help="Path to the output folder where to store the predictions and associated metrics") + parser.add_argument("-dname", "--dataset-name", type=str, default="spine-generic", + help="Name of the dataset to run inference on") return parser @@ -37,19 +39,19 @@ def get_parser(): # -------------------------------- # DATA # -------------------------------- -def prepare_data(root): +def prepare_data(root, dataset_name="spine-generic"): # set deterministic training for reproducibility # set_determinism(seed=self.args.seed) # load the dataset - dataset = os.path.join(root, f"dataset.json") + dataset = os.path.join(root, f"{dataset_name}_dataset.json") test_files = load_decathlon_datalist(dataset, True, "test") if DEBUG: # args.debug: test_files = test_files[:3] # define test transforms - transforms_test = test_transforms(lbl_key='label') + transforms_test = val_transforms(lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 @@ -72,13 +74,14 @@ def main(args): # define root path for finding datalists dataset_root = args.path_json + dataset_name = args.dataset_name # TODO: change the name of the checkpoint file to best_model.ckpt - chkp_path = os.path.join(args.chkp_path, "unet_nf=8_nrs=4_lr=0.001_20230713-1206.ckpt") + chkp_path = os.path.join(args.chkp_path, "best_model.ckpt") results_path = args.path_out - folder_name = chkp_path.split("/")[-2] - results_path = os.path.join(results_path, folder_name) + model_name = chkp_path.split("/")[-2] + results_path = os.path.join(results_path, dataset_name, model_name) if not os.path.exists(results_path): os.makedirs(results_path, exist_ok=True) @@ -108,7 +111,7 @@ def main(args): net.to(DEVICE) # define the dataset and dataloader - test_ds, test_post_pred = prepare_data(dataset_root) + test_ds, test_post_pred = prepare_data(dataset_root, dataset_name) test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) # define list to collect the test metrics From 5d93f9a7cc970fb7bb2acf889a625e6e82f3db74 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 21 Jul 2023 13:08:22 -0400 Subject: [PATCH 023/106] fix diceLoss; add DiceCE loss --- monai/losses.py | 60 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/monai/losses.py b/monai/losses.py index ff151f48..5aeb61c4 100644 --- a/monai/losses.py +++ b/monai/losses.py @@ -1,7 +1,12 @@ import torch import torch.nn as nn +from torch import Tensor +import torch.nn.functional as F +# TODO: also check out nnUNet's implementation of soft-dice loss (if required) +# https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/loss/dice.py + class SoftDiceLoss(nn.Module): ''' soft-dice loss, useful in binary segmentation @@ -12,16 +17,63 @@ def __init__(self, p=1, smooth=1): self.p = p self.smooth = smooth - def forward(self, preds, labels): + def forward(self, logits, labels): ''' inputs: - preds: normalized probabilities (not logits) - tensor of shape (N, H, W, ...) + preds: logits - tensor of shape (N, H, W, ...) labels: soft labels [0,1] - tensor of shape(N, H, W, ...) output: loss: tensor of shape(1, ) ''' - # probs = torch.sigmoid(logits) + preds = F.relu(logits) / F.relu(logits).max() if bool(F.relu(logits).max()) else F.relu(logits) + numer = (preds * labels).sum() denor = (preds.pow(self.p) + labels.pow(self.p)).sum() - loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + # loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + loss = - (2 * numer + self.smooth) / (denor + self.smooth) + return loss + + +class RobustCrossEntropyLoss(nn.CrossEntropyLoss): + """ + this is just a compatibility layer because my target tensor is float and has an extra dimension + + input must be logits, not probabilities! + adapted from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/loss/robust_ce_loss.py + """ + def forward(self, input: Tensor, target: Tensor) -> Tensor: + + # binarize soft labels with threshold 0.5 before using cross-entropy + target = torch.where(target > 0.5, torch.ones_like(target), torch.zeros_like(target)) + + if len(target.shape) == len(input.shape): + assert target.shape[1] == 1 + target = target[:, 0] + + return super().forward(input, target.long()) + + +class DiceCrossEntropyLoss(nn.Module): + def __init__(self, weight_ce=1.0, weight_dice=1.0): + super(DiceCrossEntropyLoss).__init__() + self.ce_weight = weight_ce + self.dice_weight = weight_dice + + self.dice_loss = SoftDiceLoss() + self.ce_loss = RobustCrossEntropyLoss() + + def forward(self, preds, labels): + ''' + inputs: + preds: logits (not probabilities!) - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + ce_loss = self.ce_loss(preds, labels) + + # dice loss will convert logits to probabilities + dice_loss = self.dice_loss(preds, labels) + + loss = self.ce_weight * ce_loss + self.dice_weight * dice_loss return loss \ No newline at end of file From bd81aca3052c7d0e95e8aaaa8342c006989d152c Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 21 Jul 2023 13:11:37 -0400 Subject: [PATCH 024/106] change to explicitly normalizing the logits --- monai/main.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/monai/main.py b/monai/main.py index a3dd321f..e051eaf8 100644 --- a/monai/main.py +++ b/monai/main.py @@ -72,15 +72,15 @@ def forward(self, x): # x, context_features = self.encoder(x) # preds = self.decoder(x, context_features) - out = self.net(x) - # NOTE: MONAI's models only output the logits, not the output after the final activation function - # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the - # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) - # as the final block applied to the input, which is just a convolutional layer with no activation function - # Hence, we are used Normalized ReLU to normalize the logits to the final output - normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) - return normalized_out + return out # returns logits # -------------------------------- @@ -165,11 +165,14 @@ def training_step(self, batch, batch_idx): print("Encountered empty label patch. Skipping...") return None - output = self.forward(inputs) + output = self.forward(inputs) # logits # calculate training loss loss = self.loss_function(output, labels) + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + # calculate train dice # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here # So, take this dice score with a lot of salt @@ -236,6 +239,9 @@ def validation_step(self, batch, batch_idx): # calculate validation loss loss = self.loss_function(outputs, labels) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) # post-process for calculating the evaluation metric post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] @@ -317,6 +323,9 @@ def test_step(self, batch, batch_idx): batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, sw_batch_size=4, predictor=self.forward, overlap=0.5) # print(f"batch['pred'].shape: {batch['pred'].shape}") + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) # # upon fsleyes visualization, observed that very small values need to be set to zero, but NOT fully binarizing the pred # batch["pred"][batch["pred"] < 0.099] = 0.0 From 29825f976167c0b6956848c43e19b386d76db1c8 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 21 Jul 2023 13:12:46 -0400 Subject: [PATCH 025/106] add option for attentionunet --- monai/main.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/monai/main.py b/monai/main.py index e051eaf8..1aad64bc 100644 --- a/monai/main.py +++ b/monai/main.py @@ -11,15 +11,17 @@ import matplotlib.pyplot as plt from utils import precision_score, recall_score, dice_score, plot_slices, PolyLRScheduler -from losses import SoftDiceLoss +from losses import SoftDiceLoss, DiceCrossEntropyLoss from transforms import train_transforms, val_transforms from monai.utils import set_determinism from monai.inferers import sliding_window_inference -from monai.networks.nets import UNet, DynUNet, BasicUNet, UNETR +from monai.networks.nets import UNet, BasicUNet, UNETR, AttentionUnet from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) +# TODO: change back to adam, bs=4 and start with lr=1e-3 this time, nrs=2 +# TODO: try one model with compound loss function # create a "model"-agnostic class with PL to use different models class Model(pl.LightningModule): @@ -147,7 +149,10 @@ def test_dataloader(self): # OPTIMIZATION # -------------------------------- def configure_optimizers(self): - optimizer = self.optimizer_class(self.parameters(), lr=self.lr, weight_decay=1e-5) + if self.args.optimizer == "sgd": + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=1e-5, nesterov=True) + else: + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, weight_decay=1e-5) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) return [optimizer], [scheduler] @@ -438,7 +443,7 @@ def main(args): strides=(2, 2, 2, 2), num_res_units=4, ) - save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_lr={args.learning_rate}_bs={args.batch_size}" + save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_opt={args.optimizer}_lr={args.learning_rate}_bs={args.batch_size}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) @@ -459,6 +464,21 @@ def main(args): save_exp_id = f"{args.model}_lr={args.learning_rate}" \ f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" + elif args.model == "attentionunet": + net = AttentionUnet(spatial_dims=3, + in_channels=1, out_channels=1, + channels=( + args.init_filters, + args.init_filters * 2, + args.init_filters * 4, + args.init_filters * 8, + args.init_filters * 16 + ), + strides=(2, 2, 2, 2), + # dropout=0.2, + ) + save_exp_id = f"attn-unet_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}_bs={args.batch_size}" + # define loss function loss_func = SoftDiceLoss(p=1, smooth=1.0) @@ -565,7 +585,7 @@ def main(args): parser = argparse.ArgumentParser(description='Script for training custom models for SCI Lesion Segmentation.') # Arguments for model, data, and training and saving parser.add_argument('-m', '--model', - choices=['unet', 'UNet', 'unetr', 'UNETR', 'segresnet', 'SegResNet'], + choices=['unet', 'UNet', 'unetr', 'UNETR', 'attentionunet'], default='unet', type=str, help='Model type to be used') # dataset parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") From 6b43c599f7b48f730a91f282e857361944d721e1 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 25 Jul 2023 13:25:13 -0400 Subject: [PATCH 026/106] remove todo; update args for path-data --- monai/create_msd_data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index c9f15c41..40199c85 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -8,16 +8,14 @@ from loguru import logger from sklearn.model_selection import train_test_split -# TODO: split the data using ivadomed joblib file - -root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" +# root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped" parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.') parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") parser.add_argument('-ncvf', '--num-cv-folds', default=5, type=int, help="[1-k] To create a k-fold dataset for cross validation, 0 for single file with all subjects") -parser.add_argument('-pd', '--path-data', default=root, type=str, help='Path to the data set directory') +parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the data set directory') parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.', default=None, type=str) parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') From ac2572cf572d6b0a447430b966f98a9caf8f1ff3 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 25 Jul 2023 13:26:17 -0400 Subject: [PATCH 027/106] remove RobustCELoss; update DiceCELoss --- monai/losses.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/monai/losses.py b/monai/losses.py index 5aeb61c4..835c9568 100644 --- a/monai/losses.py +++ b/monai/losses.py @@ -34,33 +34,15 @@ def forward(self, logits, labels): return loss -class RobustCrossEntropyLoss(nn.CrossEntropyLoss): - """ - this is just a compatibility layer because my target tensor is float and has an extra dimension - - input must be logits, not probabilities! - adapted from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/loss/robust_ce_loss.py - """ - def forward(self, input: Tensor, target: Tensor) -> Tensor: - - # binarize soft labels with threshold 0.5 before using cross-entropy - target = torch.where(target > 0.5, torch.ones_like(target), torch.zeros_like(target)) - - if len(target.shape) == len(input.shape): - assert target.shape[1] == 1 - target = target[:, 0] - - return super().forward(input, target.long()) - - class DiceCrossEntropyLoss(nn.Module): def __init__(self, weight_ce=1.0, weight_dice=1.0): - super(DiceCrossEntropyLoss).__init__() + super(DiceCrossEntropyLoss, self).__init__() self.ce_weight = weight_ce self.dice_weight = weight_dice self.dice_loss = SoftDiceLoss() - self.ce_loss = RobustCrossEntropyLoss() + # self.ce_loss = RobustCrossEntropyLoss() + self.ce_loss = nn.CrossEntropyLoss() def forward(self, preds, labels): ''' From 8ddddeb029fc17c1009d14732ca303d2b472d28a Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 25 Jul 2023 13:27:33 -0400 Subject: [PATCH 028/106] Notable changes: 1. Change order of preprocessing as per nnunet 2. Choose padding and cropping size based on iso-resampled median shape of images 3. Add more data-augmentation transforms --- monai/transforms.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 266e0245..752fa4aa 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -1,40 +1,49 @@ +import numpy as np from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, - RandCropByPosNegLabeld, Spacingd, RandRotate90d, ToTensord, NormalizeIntensityd, - EnsureType, RandWeightedCropd, HistogramNormalized, EnsureTyped, Invertd, SaveImaged, - EnsureChannelFirstd, CenterSpatialCropd, RandSpatialCropSamplesd, Orientationd, - Rand3DElasticd, RandBiasFieldd) + RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, + RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, + Orientationd, Rand3DElasticd, RandBiasFieldd) # median image size in voxels - taken from nnUNet -# median_size = (123, 255, 214) -# so pad with this size +# median_size = (123, 255, 214) # so pad with this size +# median_size after 1mm isotropic resampling +# median_size = [ 192. 228. 106.] + +# Order in which nnunet does preprocessing: +# 1. Crop to non-zero +# 2. Normalization +# 3. Resample to target spacing def train_transforms(crop_size, num_samples_pv, lbl_key="label"): return Compose([ + # pre-processing LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - # Orientationd(keys=["image", lbl_key], axcodes="RPI"), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box - SpatialPadd(keys=["image", lbl_key], spatial_size=(64, 128, 128), method="symmetric"), - # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), - # RandSpatialCropSamplesd(keys=["image", lbl_key], roi_size=crop_size, num_samples=num_samples_pv, random_center=True, random_size=False), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + # data-augmentation + SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a # foreground voxel as a center rather than a background voxel. RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, # if num_samples=4, then 4 samples/image are randomly generated image_key="image", image_threshold=0.), - Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.25), RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), prob=0.25, degree=3), + RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.2), RandFlipd(keys=["image", lbl_key], spatial_axis=[0], prob=0.50,), RandFlipd(keys=["image", lbl_key], spatial_axis=[1], prob=0.50,), RandFlipd(keys=["image", lbl_key],spatial_axis=[2],prob=0.50,), - # RandRotate90d(keys=["image", lbl_key], prob=0.10, max_k=3,), + RandRotated(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.1, + range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 + range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), + # re-orientation Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error - # HistogramNormalized(keys=["image"], mask=None), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), - # ToTensord(keys=["image", lbl_key]), ]) def val_transforms(lbl_key="label"): From 36da5e162d8d9f57077c934a48cc9f0d3fb1c485 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 25 Jul 2023 13:31:36 -0400 Subject: [PATCH 029/106] remove attn-unet; update save_exp_id to contain more hyperparams --- monai/main.py | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/monai/main.py b/monai/main.py index 1aad64bc..89f3133b 100644 --- a/monai/main.py +++ b/monai/main.py @@ -20,9 +20,6 @@ from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) -# TODO: change back to adam, bs=4 and start with lr=1e-3 this time, nrs=2 -# TODO: try one model with compound loss function - # create a "model"-agnostic class with PL to use different models class Model(pl.LightningModule): def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_class, @@ -51,8 +48,8 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.best_val_dice, self.best_val_epoch = 0, 0 # define cropping and padding dimensions - self.voxel_cropping_size = (64, 128, 128) # (80, 192, 160) taken from nnUNet_plans.json - self.inference_roi_size = (64, 128, 128) # (80, 192, 160) + self.voxel_cropping_size = (160, 224, 96) # (80, 192, 160) taken from nnUNet_plans.json + self.inference_roi_size = (160, 224, 96) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) self.val_post_pred = Compose([EnsureType()]) @@ -107,8 +104,8 @@ def prepare_data(self): test_files = load_decathlon_datalist(dataset, True, "test") if args.debug: - train_files = train_files[:2] - val_files = val_files[:2] + train_files = train_files[:10] + val_files = val_files[:10] test_files = test_files[:6] self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=0.25, num_workers=4) @@ -443,7 +440,9 @@ def main(args): strides=(2, 2, 2, 2), num_res_units=4, ) - save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_opt={args.optimizer}_lr={args.learning_rate}_bs={args.batch_size}" + patch_size = "160x224x96" + save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_diceCE_bs={args.batch_size}_{patch_size}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) @@ -464,23 +463,9 @@ def main(args): save_exp_id = f"{args.model}_lr={args.learning_rate}" \ f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" - elif args.model == "attentionunet": - net = AttentionUnet(spatial_dims=3, - in_channels=1, out_channels=1, - channels=( - args.init_filters, - args.init_filters * 2, - args.init_filters * 4, - args.init_filters * 8, - args.init_filters * 16 - ), - strides=(2, 2, 2, 2), - # dropout=0.2, - ) - save_exp_id = f"attn-unet_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}_bs={args.batch_size}" - # define loss function - loss_func = SoftDiceLoss(p=1, smooth=1.0) + # loss_func = SoftDiceLoss(p=1, smooth=1.0) + loss_func = DiceCrossEntropyLoss(weight_ce=1.0, weight_dice=1.0) # TODO: move this inside the for loop when using more folds timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format From ec95391cdc415327b12d762f902afd1217a830cf Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 25 Jul 2023 15:16:31 -0400 Subject: [PATCH 030/106] remove individual axis flip in RandFlipd causing collation error --- monai/transforms.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 752fa4aa..1cafc3a5 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -33,17 +33,15 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): image_key="image", image_threshold=0.), RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.25), - RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), prob=0.25, degree=3), + RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), degree=3, prob=0.25), RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.2), - RandFlipd(keys=["image", lbl_key], spatial_axis=[0], prob=0.50,), - RandFlipd(keys=["image", lbl_key], spatial_axis=[1], prob=0.50,), - RandFlipd(keys=["image", lbl_key],spatial_axis=[2],prob=0.50,), - RandRotated(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.1, + RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.5,), + RandRotated(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.2, range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), - # re-orientation - Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error + # # re-orientation + # Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error ]) def val_transforms(lbl_key="label"): From 1e5610b4242212345e425ee58508fbe5f16e9a5b Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 25 Jul 2023 15:44:43 -0400 Subject: [PATCH 031/106] bring changes from train to val transforms --- monai/transforms.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 1cafc3a5..d019f371 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -48,11 +48,9 @@ def val_transforms(lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - Orientationd(keys=["image", lbl_key], axcodes="RPI"), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + # Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), - # HistogramNormalized(keys=["image"], mask=None), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), - # ToTensord(keys=["image", lbl_key]), ]) From d3fadaf32c7b752cb619d928987b0e87e0c21d76 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 31 Jul 2023 23:34:12 -0400 Subject: [PATCH 032/106] lower prob of RandFlip transform --- monai/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms.py b/monai/transforms.py index d019f371..d417af6a 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -35,7 +35,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.25), RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), degree=3, prob=0.25), RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.2), - RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.5,), + RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.2,), RandRotated(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.2, range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), From e551206a21514ff5a4094de573867a4e8800b04c Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 7 Aug 2023 09:47:56 -0400 Subject: [PATCH 033/106] minor changes --- monai/create_msd_data.py | 2 +- monai/transforms.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 40199c85..488d90e9 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -152,7 +152,7 @@ if args.path_joblib is not None: # load information from the joblib to match train and test subjects joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') - splits = joblib.load("split_datasets_all_seed=15.joblib") + splits = joblib.load(joblib_file) # get the subjects from the joblib file train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) diff --git a/monai/transforms.py b/monai/transforms.py index d417af6a..1a2c029c 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -22,7 +22,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): EnsureChannelFirstd(keys=["image", lbl_key]), CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), # data-augmentation SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a @@ -36,7 +36,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), degree=3, prob=0.25), RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.2), RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.2,), - RandRotated(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.2, + RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), @@ -51,6 +51,6 @@ def val_transforms(lbl_key="label"): # Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), ]) From 31f0b59fad3fa8145bef4c4394e132a6546799da Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 7 Aug 2023 09:48:46 -0400 Subject: [PATCH 034/106] add Modified3DUNet model from ivadomed --- monai/models.py | 246 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 monai/models.py diff --git a/monai/models.py b/monai/models.py new file mode 100644 index 00000000..8d11516f --- /dev/null +++ b/monai/models.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def conv_norm_lrelu(feat_in, feat_out): + """Conv3D + InstanceNorm3D + LeakyReLU block""" + return nn.Sequential( + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), + nn.InstanceNorm3d(feat_out), + nn.LeakyReLU() + ) + + +def norm_lrelu_conv(feat_in, feat_out): + """InstanceNorm3D + LeakyReLU + Conv3D block""" + return nn.Sequential( + nn.InstanceNorm3d(feat_in), + nn.LeakyReLU(), + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) + ) + + +def lrelu_conv(feat_in, feat_out): + """LeakyReLU + Conv3D block""" + return nn.Sequential( + nn.LeakyReLU(), + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) + ) + + +def norm_lrelu_upscale_conv_norm_lrelu(feat_in, feat_out): + """InstanceNorm3D + LeakyReLU + 2X Upsample + Conv3D + InstanceNorm3D + LeakyReLU block""" + return nn.Sequential( + nn.InstanceNorm3d(feat_in), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), + nn.InstanceNorm3d(feat_out), + nn.LeakyReLU() + ) + +# ---------------------------- ModifiedUNet3D Encoder Implementation ----------------------------- +class ModifiedUNet3DEncoder(nn.Module): + """Encoder for ModifiedUNet3D. Adapted from ivadomed.models""" + def __init__(self, in_channels=1, base_n_filter=8): + super(ModifiedUNet3DEncoder, self).__init__() + + # Initialize common operations + self.lrelu = nn.LeakyReLU() + self.dropout3d = nn.Dropout3d(p=0.5) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + # Level 1 context pathway + self.conv3d_c1_1 = nn.Conv3d(in_channels, base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) + self.conv3d_c1_2 = nn.Conv3d(base_n_filter, base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) + self.lrelu_conv_c1 = lrelu_conv(base_n_filter, base_n_filter) + self.inorm3d_c1 = nn.InstanceNorm3d(base_n_filter) + + # Level 2 context pathway + self.conv3d_c2 = nn.Conv3d(base_n_filter, base_n_filter * 2, kernel_size=3, stride=2, padding=1, bias=False) + self.norm_lrelu_conv_c2 = norm_lrelu_conv(base_n_filter * 2, base_n_filter * 2) + self.inorm3d_c2 = nn.InstanceNorm3d(base_n_filter * 2) + + # Level 3 context pathway + self.conv3d_c3 = nn.Conv3d(base_n_filter * 2, base_n_filter * 4, kernel_size=3, stride=2, padding=1, bias=False) + self.norm_lrelu_conv_c3 = norm_lrelu_conv(base_n_filter * 4, base_n_filter * 4) + self.inorm3d_c3 = nn.InstanceNorm3d(base_n_filter * 4) + + # Level 4 context pathway + self.conv3d_c4 = nn.Conv3d(base_n_filter * 4, base_n_filter * 8, kernel_size=3, stride=2, padding=1, bias=False) + self.norm_lrelu_conv_c4 = norm_lrelu_conv(base_n_filter * 8, base_n_filter * 8) + self.inorm3d_c4 = nn.InstanceNorm3d(base_n_filter * 8) + + # Level 5 context pathway, level 0 localization pathway + self.conv3d_c5 = nn.Conv3d(base_n_filter * 8, base_n_filter * 16, kernel_size=3, stride=2, padding=1, bias=False) + self.norm_lrelu_conv_c5 = norm_lrelu_conv(base_n_filter * 16, base_n_filter * 16) + self.norm_lrelu_upscale_conv_norm_lrelu_l0 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 16, base_n_filter * 8) + + def forward(self, x): + # Level 1 context pathway + out = self.conv3d_c1_1(x) + residual_1 = out + out = self.lrelu(out) + out = self.conv3d_c1_2(out) + out = self.dropout3d(out) + out = self.lrelu_conv_c1(out) + + # Element Wise Summation + out += residual_1 + context_1 = self.lrelu(out) + out = self.inorm3d_c1(out) + out = self.lrelu(out) + + # Level 2 context pathway + out = self.conv3d_c2(out) + residual_2 = out + out = self.norm_lrelu_conv_c2(out) + out = self.dropout3d(out) + out = self.norm_lrelu_conv_c2(out) + out += residual_2 + out = self.inorm3d_c2(out) + out = self.lrelu(out) + context_2 = out + + # Level 3 context pathway + out = self.conv3d_c3(out) + residual_3 = out + out = self.norm_lrelu_conv_c3(out) + out = self.dropout3d(out) + out = self.norm_lrelu_conv_c3(out) + out += residual_3 + out = self.inorm3d_c3(out) + out = self.lrelu(out) + context_3 = out + + # Level 4 context pathway + out = self.conv3d_c4(out) + residual_4 = out + out = self.norm_lrelu_conv_c4(out) + out = self.dropout3d(out) + out = self.norm_lrelu_conv_c4(out) + out += residual_4 + out = self.inorm3d_c4(out) + out = self.lrelu(out) + context_4 = out + + # Level 5 + out = self.conv3d_c5(out) + residual_5 = out + out = self.norm_lrelu_conv_c5(out) + out = self.dropout3d(out) + out = self.norm_lrelu_conv_c5(out) + out += residual_5 + + out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out) + + context_features = [context_1, context_2, context_3, context_4] + + return out, context_features + + +# ---------------------------- ModifiedUNet3D Decoder Implementation ----------------------------- +class ModifiedUNet3DDecoder(nn.Module): + """Decoder for ModifiedUNet3D. Adapted from ivadomed.models""" + def __init__(self, n_classes=1, base_n_filter=8): + super(ModifiedUNet3DDecoder, self).__init__() + + # Initialize common operations + self.lrelu = nn.LeakyReLU() + self.dropout3d = nn.Dropout3d(p=0.5) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + self.conv3d_l0 = nn.Conv3d(base_n_filter * 8, base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False) + self.inorm3d_l0 = nn.InstanceNorm3d(base_n_filter * 8) + + # Level 1 localization pathway + self.conv_norm_lrelu_l1 = conv_norm_lrelu(base_n_filter * 16, base_n_filter * 16) + self.conv3d_l1 = nn.Conv3d(base_n_filter * 16, base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False) + self.norm_lrelu_upscale_conv_norm_lrelu_l1 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 8, base_n_filter * 4) + + # Level 2 localization pathway + self.conv_norm_lrelu_l2 = conv_norm_lrelu(base_n_filter * 8, base_n_filter * 8) + self.conv3d_l2 = nn.Conv3d(base_n_filter * 8, base_n_filter * 4, kernel_size=1, stride=1, padding=0, bias=False) + self.norm_lrelu_upscale_conv_norm_lrelu_l2 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 4, base_n_filter * 2) + + # Level 3 localization pathway + self.conv_norm_lrelu_l3 = conv_norm_lrelu(base_n_filter * 4, base_n_filter * 4) + self.conv3d_l3 = nn.Conv3d(base_n_filter * 4, base_n_filter * 2, kernel_size=1, stride=1, padding=0, bias=False) + self.norm_lrelu_upscale_conv_norm_lrelu_l3 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 2, base_n_filter) + + # Level 4 localization pathway + self.conv_norm_lrelu_l4 = conv_norm_lrelu(base_n_filter * 2, base_n_filter * 2) + self.conv3d_l4 = nn.Conv3d(base_n_filter * 2, n_classes, kernel_size=1, stride=1, padding=0, bias=False) + + self.ds2_1x1_conv3d = nn.Conv3d(base_n_filter * 8, n_classes, kernel_size=1, stride=1, padding=0, bias=False) + self.ds3_1x1_conv3d = nn.Conv3d(base_n_filter * 4, n_classes, kernel_size=1, stride=1, padding=0, bias=False) + + def forward(self, x, context_features): + # Get context features from the encoder + context_1, context_2, context_3, context_4 = context_features + + out = self.conv3d_l0(x) + out = self.inorm3d_l0(out) + out = self.lrelu(out) + + # Level 1 localization pathway + out = torch.cat([out, context_4], dim=1) + out = self.conv_norm_lrelu_l1(out) + out = self.conv3d_l1(out) + out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out) + + # Level 2 localization pathway + out = torch.cat([out, context_3], dim=1) + out = self.conv_norm_lrelu_l2(out) + ds2 = out + out = self.conv3d_l2(out) + out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out) + + # Level 3 localization pathway + out = torch.cat([out, context_2], dim=1) + out = self.conv_norm_lrelu_l3(out) + ds3 = out + out = self.conv3d_l3(out) + out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out) + + # Level 4 localization pathway + out = torch.cat([out, context_1], dim=1) + out = self.conv_norm_lrelu_l4(out) + out_pred = self.conv3d_l4(out) + + ds2_1x1_conv = self.ds2_1x1_conv3d(ds2) + ds1_ds2_sum_upscale = self.upsample(ds2_1x1_conv) + ds3_1x1_conv = self.ds3_1x1_conv3d(ds3) + ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv + ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsample(ds1_ds2_sum_upscale_ds3_sum) + + out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale + + # # Final Activation Layer + # out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + # out = out.squeeze() + + return out # this is just the logits, not the probablities + + +# ---------------------------- ModifiedUNet3D Implementation ----------------------------- +class ModifiedUNet3D(nn.Module): + """ModifiedUNet3D with Encoder + Decoder. Adapted from ivadomed.models""" + def __init__(self, in_channels=1, out_channels=1, init_filters=8): + super(ModifiedUNet3D, self).__init__() + self.unet_encoder = ModifiedUNet3DEncoder(in_channels=in_channels, base_n_filter=init_filters) + self.unet_decoder = ModifiedUNet3DDecoder(n_classes=out_channels, base_n_filter=init_filters) + + def forward(self, x): + + x, context_features = self.unet_encoder(x) + # x: (B, 8 * F, SV // 8, SV // 8, SV // 8) + # context_features: [4] + # 0 -> (B, F, SV, SV, SV) + # 1 -> (B, 2 * F, SV / 2, SV / 2, SV / 2) + # 2 -> (B, 4 * F, SV / 4, SV / 4, SV / 4) + # 3 -> (B, 8 * F, SV / 8, SV / 8, SV / 8) + + seg_logits = self.unet_decoder(x, context_features) + + return seg_logits \ No newline at end of file From 8fa0f4979e6163ae27dd3b1f75fcea4d3994c4ab Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 7 Aug 2023 09:49:28 -0400 Subject: [PATCH 035/106] add initial veersion of AdapWingLoss --- monai/losses.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/monai/losses.py b/monai/losses.py index 835c9568..46eb8c74 100644 --- a/monai/losses.py +++ b/monai/losses.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn -from torch import Tensor import torch.nn.functional as F +import scipy +import numpy as np # TODO: also check out nnUNet's implementation of soft-dice loss (if required) @@ -58,4 +59,76 @@ def forward(self, preds, labels): dice_loss = self.dice_loss(preds, labels) loss = self.ce_weight * ce_loss + self.dice_weight * dice_loss - return loss \ No newline at end of file + return loss + + +class AdapWingLoss(nn.Module): + """ + Adaptive Wing loss + Used for heatmap ground truth. + Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/losses.py#L341 + + .. seealso:: + Wang, Xinyao, Liefeng Bo, and Li Fuxin. "Adaptive wing loss for robust face alignment via heatmap regression." + Proceedings of the IEEE International Conference on Computer Vision. 2019. + + Args: + theta (float): Threshold between linear and non linear loss. + alpha (float): Used to adapt loss shape to input shape and make loss smooth at 0 (background). + It needs to be slightly above 2 to maintain ideal properties. + omega (float): Multiplicating factor for non linear part of the loss. + epsilon (float): factor to avoid gradient explosion. It must not be too small + """ + + def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): + self.theta = theta + self.alpha = alpha + self.omega = omega + self.epsilon = epsilon + self.reduction = reduction + super(AdapWingLoss, self).__init__() + + def forward(self, input, target): + batch_size = target.size()[0] + hm_num = target.size()[1] + + mask = torch.zeros_like(target) + kernel = scipy.ndimage.generate_binary_structure(2, 2) + # For 3D segmentation tasks + if len(input.shape) == 5: + kernel = scipy.ndimage.generate_binary_structure(3, 2) + + for i in range(batch_size): + img_list = list() + img_list.append(np.round(target[i].cpu().numpy() * 255)) + img_merge = np.concatenate(img_list) + img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) + img_dilate[img_dilate < 51] = 1 # 0*omega+1 + img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 + img_dilate = np.array(img_dilate, dtype=int) + + mask[i] = torch.tensor(img_dilate) + + eps = self.epsilon + # Compute adaptative factor + A = self.omega * (1 / (1 + torch.pow(self.theta / eps, + self.alpha - target))) * \ + (self.alpha - target) * torch.pow(self.theta / eps, + self.alpha - target - 1) * (1 / eps) + + # Constant term to link linear and non linear part + C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) + + diff_hm = torch.abs(target - input) + AWingLoss = A * diff_hm - C + idx = torch.argwhere(diff_hm < self.theta) + AWingLoss[idx] = self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target))[idx] + + # AWingLoss *= mask + sum_loss = torch.sum(AWingLoss) + if self.reduction == "sum": + return sum_loss + elif self.reduction == "mean": + all_pixel = torch.sum(mask) + return sum_loss / all_pixel + From 4db99f7746ff9769b8c0d58e3507a9940e7fd70e Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 7 Aug 2023 13:19:22 -0400 Subject: [PATCH 036/106] add RandomGaussianSmooth transform (ie RandomBlur) --- monai/transforms.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 1a2c029c..c623702f 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -3,7 +3,7 @@ from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, - Orientationd, Rand3DElasticd, RandBiasFieldd) + RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd) # median image size in voxels - taken from nnUNet # median_size = (123, 255, 214) # so pad with this size @@ -31,11 +31,12 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, # if num_samples=4, then 4 samples/image are randomly generated image_key="image", image_threshold=0.), - RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), - Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.25), - RandBiasFieldd(keys=["image", lbl_key], coeff_range=(0.0, 0.5), degree=3, prob=0.25), - RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.2), - RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.2,), + # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.4), # this is monai's RandomGamma + RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + RandGaussianSmoothd(keys=["image"], sigma_x=(0.0, 2.0), sigma_y=(0.0, 2.0), sigma_z=(0.0, 2.0), prob=0.3), + RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.4,), RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), @@ -48,7 +49,7 @@ def val_transforms(lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - # Orientationd(keys=["image", lbl_key], axcodes="RPI"), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), From 3f4aa390ba09da53e2f77ad1d835ea439a349d01 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 15 Aug 2023 17:47:09 -0400 Subject: [PATCH 037/106] minor changes --- monai/run_inference.py | 5 +++-- monai/transforms.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index 623e5f51..5dd4bf4b 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -17,7 +17,7 @@ DEBUG = False INIT_FILTERS=8 -INFERENCE_ROI_SIZE = (64, 128, 128) # (80, 192, 160) +INFERENCE_ROI_SIZE = (160, 224, 96) # (80, 192, 160) DEVICE = "cpu" @@ -45,6 +45,7 @@ def prepare_data(root, dataset_name="spine-generic"): # load the dataset dataset = os.path.join(root, f"{dataset_name}_dataset.json") + # dataset = os.path.join(root, f"dataset_ivado_comparison.json") test_files = load_decathlon_datalist(dataset, True, "test") if DEBUG: # args.debug: @@ -76,7 +77,6 @@ def main(args): dataset_root = args.path_json dataset_name = args.dataset_name - # TODO: change the name of the checkpoint file to best_model.ckpt chkp_path = os.path.join(args.chkp_path, "best_model.ckpt") results_path = args.path_out @@ -187,6 +187,7 @@ def main(args): "precision": round(test_precision, 2), "recall": round(test_recall, 2), # TODO: add relative volume difference here + # NOTE: RVD is usually compared with binary objects (not soft) "inference_time_in_sec": round((end_time - start_time), 2), } test_step_outputs.append(metrics_dict) diff --git a/monai/transforms.py b/monai/transforms.py index c623702f..7f12352d 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -8,7 +8,7 @@ # median image size in voxels - taken from nnUNet # median_size = (123, 255, 214) # so pad with this size # median_size after 1mm isotropic resampling -# median_size = [ 192. 228. 106.] +# median_size = [ 192. 228. 106.] # Order in which nnunet does preprocessing: # 1. Crop to non-zero @@ -22,7 +22,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): EnsureChannelFirstd(keys=["image", lbl_key]), CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), # data-augmentation SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a @@ -52,6 +52,6 @@ def val_transforms(lbl_key="label"): Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"),), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), ]) From 60fbd9f0df7637023ba2bac6d8ed91209b526a56 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 15 Aug 2023 17:47:42 -0400 Subject: [PATCH 038/106] add func to compute avg csa --- monai/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/monai/utils.py b/monai/utils.py index a1fa8cfe..b2b35039 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -1,6 +1,7 @@ import numpy as np import matplotlib.pyplot as plt from torch.optim.lr_scheduler import _LRScheduler +import torch class FoldGenerator: @@ -216,6 +217,17 @@ def plot_slices(image, gt, pred, debug=False): return fig +def compute_average_csa(patch, spacing): + num_slices = patch.shape[2] + areas = torch.empty(num_slices) + for slice_idx in range(num_slices): + slice_mask = patch[:, :, slice_idx] + area = torch.count_nonzero(slice_mask) * (spacing[0] * spacing[1]) + areas[slice_idx] = area + + return torch.mean(areas) + + class PolyLRScheduler(_LRScheduler): """ Polynomial learning rate scheduler. Taken from: From 3494106a918769d6e96bd88c0a2a1f6f3105dd77 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 15 Aug 2023 17:52:07 -0400 Subject: [PATCH 039/106] remove unused comments; change opt to Adam, add ivadomed unet, --- monai/main.py | 71 ++++++++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/monai/main.py b/monai/main.py index 89f3133b..8429afc6 100644 --- a/monai/main.py +++ b/monai/main.py @@ -11,8 +11,9 @@ import matplotlib.pyplot as plt from utils import precision_score, recall_score, dice_score, plot_slices, PolyLRScheduler -from losses import SoftDiceLoss, DiceCrossEntropyLoss +from losses import SoftDiceLoss, AdapWingLoss from transforms import train_transforms, val_transforms +from models import ModifiedUNet3D from monai.utils import set_determinism from monai.inferers import sliding_window_inference @@ -28,13 +29,6 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.args = args self.save_hyperparameters(ignore=['net']) - # if self.args.unet_depth == 3: - # from models import ModifiedUNet3DEncoder, ModifiedUNet3DDecoder # this is 3-level UNet - # logger.info("Using UNet with Depth = 3! ") - # else: - # from models_original import ModifiedUNet3DEncoder, ModifiedUNet3DDecoder - # logger.info("Using UNet with Depth = 4! ") - self.root = data_root self.fold_num = fold_num self.net = net @@ -98,7 +92,7 @@ def prepare_data(self): transforms_val = val_transforms(lbl_key='label') # load the dataset - dataset = os.path.join(self.root, f"dataset.json") + dataset = os.path.join(self.root, f"spine-generic-ivado-comparison_dataset.json") train_files = load_decathlon_datalist(dataset, True, "train") val_files = load_decathlon_datalist(dataset, True, "validation") test_files = load_decathlon_datalist(dataset, True, "test") @@ -108,7 +102,8 @@ def prepare_data(self): val_files = val_files[:10] test_files = test_files[:6] - self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=0.25, num_workers=4) + train_cache_rate = 0.25 if args.debug else 0.5 + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=4) self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) # define test transforms @@ -421,28 +416,36 @@ def main(args): dataset_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/contrast-agnostic-softseg-spinalcord/monai" # define optimizer - if args.optimizer in ["adamw", "AdamW", "Adamw"]: - optimizer_class = torch.optim.AdamW + if args.optimizer in ["adam", "Adam"]: + optimizer_class = torch.optim.Adam elif args.optimizer in ["SGD", "sgd"]: optimizer_class = torch.optim.SGD # define models if args.model in ["unet", "UNet"]: - net = UNet(spatial_dims=3, - in_channels=1, out_channels=1, - channels=( - args.init_filters, - args.init_filters * 2, - args.init_filters * 4, - args.init_filters * 8, - args.init_filters * 16 - ), - strides=(2, 2, 2, 2), - num_res_units=4, - ) - patch_size = "160x224x96" - save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_diceCE_bs={args.batch_size}_{patch_size}" + # # this is the MONAI model + # net = UNet(spatial_dims=3, + # in_channels=1, out_channels=1, + # channels=( + # args.init_filters, + # args.init_filters * 2, + # args.init_filters * 4, + # args.init_filters * 8, + # args.init_filters * 16 + # ), + # strides=(2, 2, 2, 2), + # num_res_units=4, + # ) + # patch_size = "160x224x96" + # save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_opt={args.optimizer}_lr={args.learning_rate}" \ + # f"_diceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{patch_size}" + + # This is the ivadomed model + net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) + patch_size = "160x224x96" # "64x128x64" + save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_CsaDiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{patch_size}" + elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model img_size = (96, 96, 96) @@ -464,8 +467,9 @@ def main(args): f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" # define loss function - # loss_func = SoftDiceLoss(p=1, smooth=1.0) - loss_func = DiceCrossEntropyLoss(weight_ce=1.0, weight_dice=1.0) + loss_func = SoftDiceLoss(p=1, smooth=1.0) + # loss_func = DiceCrossEntropyLoss(weight_ce=1.0, weight_dice=1.0) + # loss_func = AdapWingLoss(epsilon=1, theta=0.5, alpha=2.1, omega=8.0, reduction='mean') # TODO: move this inside the for loop when using more folds timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format @@ -498,7 +502,7 @@ def main(args): exp_logger = pl.loggers.WandbLogger( name=save_exp_id, save_dir=args.save_path, - group=f"{args.model}", + group=f"{args.model}_Adam", log_model=True, # save best model using checkpoint callback project='contrast-agnostic', entity='naga-karthik', @@ -533,7 +537,6 @@ def main(args): # Saving training script to wandb wandb.save("main.py") - # TODO: Come back to testing when hyperparamters have been fixed after cross-validation # Test! trainer.test(pl_model) logger.info(f"TESTING DONE!") @@ -547,8 +550,6 @@ def main(args): print(f"\nSeed Used: {args.seed}", file=f) print(f"\ninitf={args.init_filters}_lr={args.learning_rate}_bs={args.batch_size}_{timestamp}", file=f) print(f"\npatch_size={pl_model.voxel_cropping_size}", file=f) - # print(f"\n{np.array(centers_list)[None, :]}", file=f) - # print(f"\n{np.array(centers_list)[:, None]}", file=f) print('\n-------------- Test Hard Dice Scores ----------------', file=f) print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) @@ -591,8 +592,8 @@ def main(args): parser.add_argument('-me', '--max_epochs', default=1000, type=int, help='Number of epochs for the training process') parser.add_argument('-bs', '--batch_size', default=2, type=int, help='Batch size of the training and validation processes') parser.add_argument('-opt', '--optimizer', - choices=['adamw', 'AdamW', 'SGD', 'sgd'], - default='adamw', type=str, help='Optimizer to use') + choices=['adam', 'Adam', 'SGD', 'sgd'], + default='adam', type=str, help='Optimizer to use') parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training the model') parser.add_argument('-pat', '--patience', default=25, type=int, help='number of validation steps (val_every_n_iters) to wait before early stopping') From 18e8c91c3d1b5420109a48dde92dd6052dac13ad Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 15 Aug 2023 17:53:08 -0400 Subject: [PATCH 040/106] update to add csa loss during training/val --- monai/main.py | 96 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 25 deletions(-) diff --git a/monai/main.py b/monai/main.py index 8429afc6..92f27e97 100644 --- a/monai/main.py +++ b/monai/main.py @@ -10,14 +10,14 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -from utils import precision_score, recall_score, dice_score, plot_slices, PolyLRScheduler +from utils import precision_score, recall_score, dice_score, compute_average_csa, PolyLRScheduler from losses import SoftDiceLoss, AdapWingLoss from transforms import train_transforms, val_transforms from models import ModifiedUNet3D from monai.utils import set_determinism from monai.inferers import sliding_window_inference -from monai.networks.nets import UNet, BasicUNet, UNETR, AttentionUnet +from monai.networks.nets import UNet, UNETR from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) @@ -43,7 +43,8 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas # define cropping and padding dimensions self.voxel_cropping_size = (160, 224, 96) # (80, 192, 160) taken from nnUNet_plans.json - self.inference_roi_size = (160, 224, 96) + self.inference_roi_size = (160, 224, 96) + self.spacing = (1.0, 1.0, 1.0) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) self.val_post_pred = Compose([EnsureType()]) @@ -57,6 +58,9 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.val_step_outputs = [] self.test_step_outputs = [] + # MSE loss for comparing the CSA values + self.mse_loss = torch.nn.MSELoss() + # -------------------------------- # FORWARD PASS @@ -164,8 +168,9 @@ def training_step(self, batch, batch_idx): output = self.forward(inputs) # logits - # calculate training loss - loss = self.loss_function(output, labels) + # calculate training loss + # NOTE: the diceLoss expects the input to be logits (which it then normalizes inside) + dice_loss = self.loss_function(output, labels) # get probabilities from logits output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) @@ -176,14 +181,29 @@ def training_step(self, batch, batch_idx): train_soft_dice = self.soft_dice_metric(output, labels) # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) + # binarize the predictions and the labels + output = (output.detach() > 0.5).float() + labels = (labels.detach() > 0.5).float() + + # compute CSA for each element of the batch + csa_loss = 0.0 + for batch_idx in range(output.shape[0]): + pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) + gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) + csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + + # total loss + loss = dice_loss + csa_loss + metrics_dict = { "loss": loss.cpu(), + "dice_loss": dice_loss.cpu(), + "csa_loss": csa_loss.cpu(), "train_soft_dice": train_soft_dice.detach().cpu(), - # "train_hard_dice": train_hard_dice, "train_number": len(inputs), - "train_image": inputs[0].detach().cpu().squeeze(), - "train_gt": labels[0].detach().cpu().squeeze(), - "train_pred": output[0].detach().cpu().squeeze() + # "train_image": inputs[0].detach().cpu().squeeze(), + # "train_gt": labels[0].detach().cpu().squeeze(), + # "train_pred": output[0].detach().cpu().squeeze() } self.train_step_outputs.append(metrics_dict) @@ -195,32 +215,39 @@ def on_train_epoch_end(self): # means the training step was skipped because of empty input patch return None else: - train_loss, num_items, train_soft_dice = 0, 0, 0 + train_loss, train_dice_loss, train_csa_loss = 0, 0, 0 + num_items, train_soft_dice = 0, 0 for output in self.train_step_outputs: - train_loss += output["loss"].sum().item() - train_soft_dice += output["train_soft_dice"].sum().item() + train_loss += output["loss"].item() + train_dice_loss += output["dice_loss"].item() + train_csa_loss += output["csa_loss"].item() + train_soft_dice += output["train_soft_dice"].item() num_items += output["train_number"] mean_train_loss = (train_loss / num_items) + mean_train_dice_loss = (train_dice_loss / num_items) + mean_train_csa_loss = (train_csa_loss / num_items) mean_train_soft_dice = (train_soft_dice / num_items) wandb_logs = { "train_soft_dice": mean_train_soft_dice, - "train_loss": mean_train_loss + "train_loss": mean_train_loss, + "train_dice_loss": mean_train_dice_loss, + "train_csa_loss": mean_train_csa_loss, } self.log_dict(wandb_logs) - # plot the training images - fig = plot_slices(image=self.train_step_outputs[0]["train_image"], - gt=self.train_step_outputs[0]["train_gt"], - pred=self.train_step_outputs[0]["train_pred"], - debug=args.debug) - wandb.log({"training images": wandb.Image(fig)}) + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"], + # debug=args.debug) + # wandb.log({"training images": wandb.Image(fig)}) # free up memory self.train_step_outputs.clear() wandb_logs.clear() - plt.close(fig) + # plt.close(fig) # -------------------------------- @@ -235,7 +262,7 @@ def validation_step(self, batch, batch_idx): # outputs shape: (B, C, ) # calculate validation loss - loss = self.loss_function(outputs, labels) + dice_loss = self.loss_function(outputs, labels) # get probabilities from logits outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) @@ -244,15 +271,27 @@ def validation_step(self, batch, batch_idx): post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) - val_hard_dice = self.soft_dice_metric( - (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() - ) - + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # compute val CSA loss + val_csa_loss = 0.0 + for batch_idx in range(hard_preds.shape[0]): + pred_patch_csa = compute_average_csa(hard_preds[batch_idx].squeeze(), self.spacing) + gt_patch_csa = compute_average_csa(hard_labels[batch_idx].squeeze(), self.spacing) + val_csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + + # total loss + loss = dice_loss + val_csa_loss + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, # using .detach() to avoid storing the whole computation graph # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 metrics_dict = { "val_loss": loss.detach().cpu(), + "val_dice_loss": dice_loss.detach().cpu(), + "val_csa_loss": val_csa_loss.detach().cpu(), "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), @@ -267,20 +306,27 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + val_dice_loss, val_csa_loss = 0, 0 for output in self.val_step_outputs: val_loss += output["val_loss"].sum().item() val_soft_dice += output["val_soft_dice"].sum().item() val_hard_dice += output["val_hard_dice"].sum().item() + val_dice_loss += output["val_dice_loss"].sum().item() + val_csa_loss += output["val_csa_loss"].sum().item() num_items += output["val_number"] mean_val_loss = (val_loss / num_items) mean_val_soft_dice = (val_soft_dice / num_items) mean_val_hard_dice = (val_hard_dice / num_items) + mean_val_dice_loss = (val_dice_loss / num_items) + mean_val_csa_loss = (val_csa_loss / num_items) wandb_logs = { "val_soft_dice": mean_val_soft_dice, "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, + "val_dice_loss": mean_val_dice_loss, + "val_csa_loss": mean_val_csa_loss, } if mean_val_soft_dice > self.best_val_dice: self.best_val_dice = mean_val_soft_dice From 1023122fbe275154636275761eb27e71d0b760ca Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 21 Aug 2023 09:26:58 -0400 Subject: [PATCH 041/106] remove monai unet, add ivadomed unet --- monai/run_inference.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index 5dd4bf4b..a4f1fae4 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -10,10 +10,10 @@ from monai.inferers import sliding_window_inference from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) -from monai.networks.nets import UNet from transforms import val_transforms from utils import precision_score, recall_score, dice_score +from models import ModifiedUNet3D DEBUG = False INIT_FILTERS=8 @@ -85,27 +85,17 @@ def main(args): if not os.path.exists(results_path): os.makedirs(results_path, exist_ok=True) - checkpoint = torch.load(chkp_path, map_location=torch.device('cpu'))["state_dict"] + checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 for key in list(checkpoint.keys()): if 'net.' in key: checkpoint[key.replace('net.', '')] = checkpoint[key] del checkpoint[key] - - # initialize the model - net = UNet(spatial_dims=3, - in_channels=1, out_channels=1, - channels=( - INIT_FILTERS, - INIT_FILTERS * 2, - INIT_FILTERS * 4, - INIT_FILTERS * 8, - INIT_FILTERS * 16 - ), - strides=(2, 2, 2, 2), - num_res_units=4,) + # initialize ivadomed unet model + net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) + # load the trained model weights net.load_state_dict(checkpoint) net.to(DEVICE) From 8f324c2aa59b20f72a6c14220d70edac3c126839 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 21 Aug 2023 16:27:53 -0400 Subject: [PATCH 042/106] finalize csa loss-related changes --- monai/main.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/monai/main.py b/monai/main.py index 92f27e97..b12f32d3 100644 --- a/monai/main.py +++ b/monai/main.py @@ -40,6 +40,7 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.results_path = results_path self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_csa = float("inf") # define cropping and padding dimensions self.voxel_cropping_size = (160, 224, 96) # (80, 192, 160) taken from nnUNet_plans.json @@ -192,6 +193,9 @@ def training_step(self, batch, batch_idx): gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # average CSA loss across the batch + csa_loss = csa_loss / output.shape[0] + # total loss loss = dice_loss + csa_loss @@ -282,8 +286,11 @@ def validation_step(self, batch, batch_idx): gt_patch_csa = compute_average_csa(hard_labels[batch_idx].squeeze(), self.spacing) val_csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # average CSA loss across the batch + val_csa_loss = val_csa_loss / hard_preds.shape[0] + # total loss - loss = dice_loss + val_csa_loss + loss = dice_loss + val_csa_loss # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, # using .detach() to avoid storing the whole computation graph @@ -328,15 +335,23 @@ def on_validation_epoch_end(self): "val_dice_loss": mean_val_dice_loss, "val_csa_loss": mean_val_csa_loss, } - if mean_val_soft_dice > self.best_val_dice: - self.best_val_dice = mean_val_soft_dice + # # save the best model based on validation dice score + # if mean_val_soft_dice > self.best_val_dice: + # self.best_val_dice = mean_val_soft_dice + # self.best_val_epoch = self.current_epoch + + # save the best model based on validation CSA loss + if mean_val_csa_loss < self.best_val_csa: + self.best_val_csa = mean_val_csa_loss self.best_val_epoch = self.current_epoch print( f"Current epoch: {self.current_epoch}" f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" - f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + f"\nAverage CSA (VAL): {mean_val_csa_loss:.4f}" + # f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + f"\nBest Average CSA: {self.best_val_csa:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") # log on to wandb @@ -490,7 +505,8 @@ def main(args): net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) patch_size = "160x224x96" # "64x128x64" save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_CsaDiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{patch_size}" + f"_CSAdiceL_bestValCSA_nspv={args.num_samples_per_volume}" \ + f"_bs={args.batch_size}_{patch_size}" elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model @@ -553,18 +569,22 @@ def main(args): project='contrast-agnostic', entity='naga-karthik', config=args) - # else: - # exp_logger = pl.loggers.CSVLogger(save_dir=args.save_path, name="my_exp_name") + # # saving the best model based on soft validation dice score + # checkpoint_callback = pl.callbacks.ModelCheckpoint( + # dirpath=save_path, filename='best_model', monitor='val_soft_dice', + # save_top_k=5, mode="max", save_last=False, save_weights_only=True) checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model', monitor='val_loss', - save_top_k=1, mode="min", save_last=False, save_weights_only=True) + dirpath=save_path, filename='best_model', monitor='val_csa_loss', + save_top_k=5, mode="min", save_last=False, save_weights_only=True) - lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') - - early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, + # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, + # verbose=False, mode="max") + early_stopping = pl.callbacks.EarlyStopping(monitor="val_csa_loss", min_delta=0.00, patience=args.patience, verbose=False, mode="min") + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + # initialise Lightning's trainer. trainer = pl.Trainer( devices=1, accelerator="gpu", # strategy="ddp", From 3aec0d59cf93171d233ace30979bcb139c2db3c6 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 21 Aug 2023 16:28:30 -0400 Subject: [PATCH 043/106] minor fixes --- monai/main.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/monai/main.py b/monai/main.py index b12f32d3..5835ff5c 100644 --- a/monai/main.py +++ b/monai/main.py @@ -59,9 +59,6 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.val_step_outputs = [] self.test_step_outputs = [] - # MSE loss for comparing the CSA values - self.mse_loss = torch.nn.MSELoss() - # -------------------------------- # FORWARD PASS @@ -261,7 +258,7 @@ def validation_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - outputs = sliding_window_inference(inputs, self.inference_roi_size, + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", sw_batch_size=4, predictor=self.forward, overlap=0.5,) # outputs shape: (B, C, ) @@ -483,25 +480,8 @@ def main(args): optimizer_class = torch.optim.SGD # define models - if args.model in ["unet", "UNet"]: - # # this is the MONAI model - # net = UNet(spatial_dims=3, - # in_channels=1, out_channels=1, - # channels=( - # args.init_filters, - # args.init_filters * 2, - # args.init_filters * 4, - # args.init_filters * 8, - # args.init_filters * 16 - # ), - # strides=(2, 2, 2, 2), - # num_res_units=4, - # ) - # patch_size = "160x224x96" - # save_exp_id =f"{args.model}_nf={args.init_filters}_nrs=4_opt={args.optimizer}_lr={args.learning_rate}" \ - # f"_diceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{patch_size}" - - # This is the ivadomed model + if args.model in ["unet", "UNet"]: + # this is the ivadomed unet model net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) patch_size = "160x224x96" # "64x128x64" save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ @@ -530,8 +510,6 @@ def main(args): # define loss function loss_func = SoftDiceLoss(p=1, smooth=1.0) - # loss_func = DiceCrossEntropyLoss(weight_ce=1.0, weight_dice=1.0) - # loss_func = AdapWingLoss(epsilon=1, theta=0.5, alpha=2.1, omega=8.0, reduction='mean') # TODO: move this inside the for loop when using more folds timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format @@ -564,7 +542,7 @@ def main(args): exp_logger = pl.loggers.WandbLogger( name=save_exp_id, save_dir=args.save_path, - group=f"{args.model}_Adam", + group=f"{args.model}_final", log_model=True, # save best model using checkpoint callback project='contrast-agnostic', entity='naga-karthik', From a2677ce2d5519ca681c137fdcc2fd0c27d1e7590 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 21 Aug 2023 17:06:29 -0400 Subject: [PATCH 044/106] refactor to add ensembling --- monai/run_inference.py | 102 ++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 42 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index a4f1fae4..a5cb757a 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -6,6 +6,7 @@ import torch import json from time import time +from tqdm import tqdm from monai.inferers import sliding_window_inference from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) @@ -18,7 +19,8 @@ DEBUG = False INIT_FILTERS=8 INFERENCE_ROI_SIZE = (160, 224, 96) # (80, 192, 160) -DEVICE = "cpu" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +# DEVICE = "cpu" def get_parser(): @@ -63,7 +65,7 @@ def prepare_data(root, dataset_name="spine-generic"): meta_keys=["pred_meta_dict", "label_meta_dict"], nearest_interp=False, to_tensor=True), ]) - test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.25, num_workers=4) return test_ds, test_post_pred @@ -77,68 +79,84 @@ def main(args): dataset_root = args.path_json dataset_name = args.dataset_name - chkp_path = os.path.join(args.chkp_path, "best_model.ckpt") + # chkp_path = os.path.join(args.chkp_path, "best_model.ckpt") results_path = args.path_out - model_name = chkp_path.split("/")[-2] + model_name = args.chkp_path.split("/")[-1] results_path = os.path.join(results_path, dataset_name, model_name) if not os.path.exists(results_path): os.makedirs(results_path, exist_ok=True) - checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] - # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning - # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 - for key in list(checkpoint.keys()): - if 'net.' in key: - checkpoint[key.replace('net.', '')] = checkpoint[key] - del checkpoint[key] - - # initialize ivadomed unet model - net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) - - # load the trained model weights - net.load_state_dict(checkpoint) - net.to(DEVICE) - # define the dataset and dataloader test_ds, test_post_pred = prepare_data(dataset_root, dataset_name) test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + # initialize ivadomed unet model + net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) + # define list to collect the test metrics test_step_outputs = [] test_summary = {} - + + preds_stack = [] # iterate over the dataset and compute metrics - net.eval() with torch.no_grad(): - for i, batch in enumerate(test_loader): + for batch in test_loader: # compute time for inference per subject start_time = time() + + # get the test input + test_input = batch["image"].to(DEVICE) + + # load the checkpoints + for chkp in os.listdir(args.chkp_path): + chkp_path = os.path.join(args.chkp_path, chkp) + # print(f"Loading checkpoint: {chkp_path}") - test_input = batch["image"] - batch["pred"] = sliding_window_inference(test_input, INFERENCE_ROI_SIZE, - sw_batch_size=4, predictor=net, overlap=0.5) - # NOTE: monai's models do not normalize the output, so we need to do it manually - if bool(F.relu(batch["pred"]).max()): - batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() - else: - batch["pred"] = F.relu(batch["pred"]) - - # # upon fsleyes visualization, observed that very small values need to be set to zero, but NOT fully binarizing the pred - # batch["pred"][batch["pred"] < 0.099] = 0.0 - - post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] - - # make sure that the shapes of prediction and GT label are the same - # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") - assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape - - pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + + # load the trained model weights + net.load_state_dict(checkpoint) + net.to(DEVICE) + net.eval() + + # run inference + batch["pred"] = sliding_window_inference(test_input, INFERENCE_ROI_SIZE, mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # stack the predictions + preds_stack.append(pred) - # save the prediction and label + # save the (soft) prediction and label subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") print(f"Saving subject: {subject_name}") + # take the average of the predictions + pred = torch.stack(preds_stack).mean(dim=0) + preds_stack.clear() + + # check whether the prediction and label have the same shape + assert pred.shape == label.shape, f"Prediction and label shapes are different: {pred.shape} vs {label.shape}" + # image saver class save_folder = os.path.join(results_path, subject_name.split("_")[0]) pred_saver = SaveImage( From 801a155bb56426c57eca7a1a21bbafeea4438a81 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sat, 26 Aug 2023 16:10:14 -0400 Subject: [PATCH 045/106] add arg to run inference using unetr --- monai/run_inference.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index a5cb757a..66a12aa1 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -15,12 +15,20 @@ from transforms import val_transforms from utils import precision_score, recall_score, dice_score from models import ModifiedUNet3D +from monai.networks.nets import UNETR +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEBUG = False -INIT_FILTERS=8 INFERENCE_ROI_SIZE = (160, 224, 96) # (80, 192, 160) -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -# DEVICE = "cpu" +# UNET params +INIT_FILTERS=8 +# UNETR params +FEATURE_SIZE = 8 +HIDDEN_SIZE = 512 +MLP_DIM = 1024 +NUM_HEADS = 8 + +EXAMPLE_INPUT = torch.randn(1, 1, 160, 224, 96).to(DEVICE) def get_parser(): @@ -34,6 +42,7 @@ def get_parser(): help="Path to the output folder where to store the predictions and associated metrics") parser.add_argument("-dname", "--dataset-name", type=str, default="spine-generic", help="Name of the dataset to run inference on") + parser.add_argument("--model", type=str, default="unet", help="Name of the model to use for inference") return parser @@ -91,8 +100,23 @@ def main(args): test_ds, test_post_pred = prepare_data(dataset_root, dataset_name) test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) - # initialize ivadomed unet model - net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) + if args.model == "unet": + # initialize ivadomed unet model + net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) + elif args.model == "unetr": + # initialize unetr model + net = UNETR(spatial_dims=3, + in_channels=1, out_channels=1, + img_size=INFERENCE_ROI_SIZE, + feature_size=FEATURE_SIZE, + hidden_size=HIDDEN_SIZE, + mlp_dim=MLP_DIM, + num_heads=NUM_HEADS, + pos_embed="conv", + norm_name="instance", + res_block=True, + dropout_rate=0.2, + ) # define list to collect the test metrics test_step_outputs = [] @@ -111,7 +135,7 @@ def main(args): # load the checkpoints for chkp in os.listdir(args.chkp_path): chkp_path = os.path.join(args.chkp_path, chkp) - # print(f"Loading checkpoint: {chkp_path}") + print(f"Loading checkpoint: {chkp_path}") checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning From 22e5b1849296eca43f055301ed94c11a3d0fc3f8 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sat, 26 Aug 2023 16:10:42 -0400 Subject: [PATCH 046/106] add function to check for empty patches --- monai/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/monai/utils.py b/monai/utils.py index b2b35039..9b670edc 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -4,6 +4,15 @@ import torch +# Check if any label image patch is empty in the batch +def check_empty_patch(labels): + for i, label in enumerate(labels): + if torch.sum(label) == 0.0: + # print(f"Empty label patch found at index {i}. Skipping training step ...") + return None + return labels # If no empty patch is found, return the labels + + class FoldGenerator: """ Adapted from https://github.com/MIC-DKFZ/medicaldetectiontoolkit/blob/master/utils/dataloader_utils.py#L59 From 27cf099976bdf8e4e3a50014450f37d79ebc96d5 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sat, 26 Aug 2023 16:11:23 -0400 Subject: [PATCH 047/106] add TODO about adding RandSimulateLowResolution --- monai/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 7f12352d..04f80e28 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -1,7 +1,7 @@ import numpy as np from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, - RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, + RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd) @@ -32,7 +32,8 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): # if num_samples=4, then 4 samples/image are randomly generated image_key="image", image_threshold=0.), # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), - Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + # TODO: Try Spacingd with low resolution here with prob=0.5 RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.4), # this is monai's RandomGamma RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), RandGaussianSmoothd(keys=["image"], sigma_x=(0.0, 2.0), sigma_y=(0.0, 2.0), sigma_z=(0.0, 2.0), prob=0.3), From 182d27a098fe91b102e08ec0ee013c8a54877bd9 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sat, 26 Aug 2023 16:13:09 -0400 Subject: [PATCH 048/106] increase num_workers to speed up time/epoch --- monai/main.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/main.py b/monai/main.py index 5835ff5c..dbfb9bb5 100644 --- a/monai/main.py +++ b/monai/main.py @@ -59,6 +59,9 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.val_step_outputs = [] self.test_step_outputs = [] + # specify example_input_array for model summary + self.example_input_array = torch.rand(1, 1, 160, 224, 96) + # -------------------------------- # FORWARD PASS @@ -128,15 +131,16 @@ def prepare_data(self): # -------------------------------- def train_dataloader(self): # NOTE: if num_samples=4 in RandCropByPosNegLabeld and batch_size=2, then 2 x 4 images are generated for network training - return DataLoader(self.train_ds, batch_size=self.args.batch_size, shuffle=True, num_workers=4, - pin_memory=True,) # collate_fn=pad_list_data_collate) + return DataLoader(self.train_ds, batch_size=self.args.batch_size, shuffle=True, num_workers=16, + pin_memory=True, persistent_workers=True) # collate_fn=pad_list_data_collate) # list_data_collate is only useful when each input in the batch has different shape def val_dataloader(self): - return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + persistent_workers=True) def test_dataloader(self): - return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) # -------------------------------- From d8719b664cd9c77e29c4da0b85c91a7460ac8425 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sat, 26 Aug 2023 16:14:17 -0400 Subject: [PATCH 049/106] fix code to filter out empty patches --- monai/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/main.py b/monai/main.py index dbfb9bb5..0a461bd0 100644 --- a/monai/main.py +++ b/monai/main.py @@ -10,7 +10,8 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -from utils import precision_score, recall_score, dice_score, compute_average_csa, PolyLRScheduler +from utils import precision_score, recall_score, dice_score, compute_average_csa, \ + PolyLRScheduler, check_empty_patch from losses import SoftDiceLoss, AdapWingLoss from transforms import train_transforms, val_transforms from models import ModifiedUNet3D @@ -163,9 +164,9 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # filter empty label patches - if not labels.any(): - print("Encountered empty label patch. Skipping...") + # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") return None output = self.forward(inputs) # logits @@ -489,7 +490,7 @@ def main(args): net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) patch_size = "160x224x96" # "64x128x64" save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_CSAdiceL_bestValCSA_nspv={args.num_samples_per_volume}" \ + f"_CSAdiceL_bestValCSA_nspv={args.num_samples_per_volume}_fltr" \ f"_bs={args.batch_size}_{patch_size}" elif args.model in ["unetr", "UNETR"]: From 3f45cb837b4da5973d1149c83b1bbeb08bb766b5 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sat, 26 Aug 2023 16:15:11 -0400 Subject: [PATCH 050/106] minor update UNetR params; add simple profiling --- monai/main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/main.py b/monai/main.py index 0a461bd0..a9d13da2 100644 --- a/monai/main.py +++ b/monai/main.py @@ -495,7 +495,7 @@ def main(args): elif args.model in ["unetr", "UNETR"]: # define image size to be fed to the model - img_size = (96, 96, 96) + img_size = (160, 224, 96) # define model net = UNETR(spatial_dims=3, @@ -505,13 +505,15 @@ def main(args): hidden_size=args.hidden_size, mlp_dim=args.mlp_dim, num_heads=args.num_heads, - pos_embed="perceptron", + pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.2, ) - save_exp_id = f"{args.model}_lr={args.learning_rate}" \ - f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" + img_size = f"{img_size[0]}x{img_size[1]}x{img_size[2]}" + save_exp_id = f"{args.model}_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" \ + f"_CSAdiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{img_size}" \ # define loss function loss_func = SoftDiceLoss(p=1, smooth=1.0) @@ -577,7 +579,8 @@ def main(args): max_epochs=args.max_epochs, precision=32, # TODO: see if 16-bit precision is stable # deterministic=True, - enable_progress_bar=args.enable_progress_bar) + enable_progress_bar=args.enable_progress_bar, + profiler="simple",) # to profile the training time taken for each step # Train! trainer.fit(pl_model) From 1945b0d5045e0fb18ca9390f3b17e443c7c44cef Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 28 Aug 2023 23:23:41 -0400 Subject: [PATCH 051/106] add argument for specifying model to use for inference --- monai/run_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index 66a12aa1..52d481b9 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -42,7 +42,8 @@ def get_parser(): help="Path to the output folder where to store the predictions and associated metrics") parser.add_argument("-dname", "--dataset-name", type=str, default="spine-generic", help="Name of the dataset to run inference on") - parser.add_argument("--model", type=str, default="unet", help="Name of the model to use for inference") + parser.add_argument("--model", type=str, default="unet", required=True, + help="Name of the model to use for inference") return parser @@ -135,7 +136,7 @@ def main(args): # load the checkpoints for chkp in os.listdir(args.chkp_path): chkp_path = os.path.join(args.chkp_path, chkp) - print(f"Loading checkpoint: {chkp_path}") + # print(f"Loading checkpoint: {chkp_path}") checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning From e319ca675eb93c962f67832d9b0030acc9a56da9 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 28 Aug 2023 23:27:31 -0400 Subject: [PATCH 052/106] add dynunet model for training --- monai/main.py | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/monai/main.py b/monai/main.py index a9d13da2..80e2ac89 100644 --- a/monai/main.py +++ b/monai/main.py @@ -12,13 +12,13 @@ from utils import precision_score, recall_score, dice_score, compute_average_csa, \ PolyLRScheduler, check_empty_patch -from losses import SoftDiceLoss, AdapWingLoss +from losses import SoftDiceLoss from transforms import train_transforms, val_transforms from models import ModifiedUNet3D from monai.utils import set_determinism from monai.inferers import sliding_window_inference -from monai.networks.nets import UNet, UNETR +from monai.networks.nets import UNet, UNETR, DynUNet from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) @@ -485,15 +485,16 @@ def main(args): optimizer_class = torch.optim.SGD # define models - if args.model in ["unet", "UNet"]: + if args.model in ["unet"]: + logger.info(f" Using ivadomed's UNet model! ") # this is the ivadomed unet model net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) - patch_size = "160x224x96" # "64x128x64" + patch_size = "160x224x96" # "64x128x64" save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_CSAdiceL_bestValCSA_nspv={args.num_samples_per_volume}_fltr" \ + f"_CSAdiceL_bestValCSA_nspv={args.num_samples_per_volume}" \ f"_bs={args.batch_size}_{patch_size}" - elif args.model in ["unetr", "UNETR"]: + elif args.model in ["unetr"]: # define image size to be fed to the model img_size = (160, 224, 96) @@ -515,6 +516,34 @@ def main(args): f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" \ f"_CSAdiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{img_size}" \ + elif args.model in ["dynunet"]: + logger.info(f" Using MONAI's DynUNet model! ") + + # NOTE: these values are taken from nnUNetPlans.json + kernel_sizes = (3, 3, 3, 3, 3, 3) + stride_sizes = ((1, 1, 1), 2, 2, 2, 2, (1, 2, 2)) + # num_filters = (8, 16, 32, 64, 128, 256) + num_filters = (16, 32, 64, 128, 256, 320) + + # define model + net = DynUNet(spatial_dims=3, + in_channels=1, out_channels=1, + kernel_size=kernel_sizes, + strides=stride_sizes, + upsample_kernel_size=stride_sizes[1:], + filters=num_filters, + norm_name="instance", + deep_supervision=True, + deep_supr_num=4, #(len(stride_sizes)-2), + res_block=True, + dropout=0.3, + ) + patch_size = "160x224x96" + save_exp_id =f"{args.model}_initf=16_DS=4opt={args.optimizer}_lr={args.learning_rate}" \ + f"nspv={args.num_samples_per_volume}" \ + f"_bs={args.batch_size}_{patch_size}" + + # define loss function loss_func = SoftDiceLoss(p=1, smooth=1.0) @@ -622,8 +651,7 @@ def main(args): parser = argparse.ArgumentParser(description='Script for training custom models for SCI Lesion Segmentation.') # Arguments for model, data, and training and saving - parser.add_argument('-m', '--model', - choices=['unet', 'UNet', 'unetr', 'UNETR', 'attentionunet'], + parser.add_argument('-m', '--model', choices=['unet', 'unetr', 'dynunet'], default='unet', type=str, help='Model type to be used') # dataset parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") From 41e47b8b81017d7eb5faecfb9445828314375e55 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 28 Aug 2023 23:28:43 -0400 Subject: [PATCH 053/106] update training_step to for handling deepsupervison outputs in the loss function --- monai/main.py | 99 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 29 deletions(-) diff --git a/monai/main.py b/monai/main.py index 80e2ac89..45d5cb43 100644 --- a/monai/main.py +++ b/monai/main.py @@ -164,39 +164,80 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # check if any label image patch is empty in the batch - if check_empty_patch(labels) is None: - # print(f"Empty label patch found. Skipping training step ...") - return None + # NOTE: surprisingly, filtering out empty patches is adding more CSA bias + # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # # print(f"Empty label patch found. Skipping training step ...") + # return None output = self.forward(inputs) # logits - - # calculate training loss - # NOTE: the diceLoss expects the input to be logits (which it then normalizes inside) - dice_loss = self.loss_function(output, labels) - - # get probabilities from logits - output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) - - # calculate train dice - # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here - # So, take this dice score with a lot of salt - train_soft_dice = self.soft_dice_metric(output, labels) - # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) - - # binarize the predictions and the labels - output = (output.detach() > 0.5).float() - labels = (labels.detach() > 0.5).float() + # if using dynunet, output.shape = (B, num_upsample_layers+1, C, H, W, D) + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - # compute CSA for each element of the batch - csa_loss = 0.0 - for batch_idx in range(output.shape[0]): - pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) - gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) - csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + if self.args.model == "dynunet": + # unbind the preds to calculate loss for each output + outputs = torch.unbind(output, dim=1) + + # calculate dice loss for each output + dice_loss, train_soft_dice = 0.0, 0.0 + for i in range(len(outputs)): + # give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + # NOTE: outputs[0] is the final pred, outputs[-1] is the lowest resolution pred (at the bottleneck) + dice_loss += (0.5 ** i) * self.loss_function(outputs[i], labels) + + # get probabilities from logits + out = F.relu(outputs[i]) / F.relu(outputs[i]).max() if bool(F.relu(outputs[i]).max()) else F.relu(outputs[i]) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice += self.soft_dice_metric(out, labels) + + # average dice loss across the outputs + dice_loss = dice_loss / len(outputs) + train_soft_dice = train_soft_dice / len(outputs) - # average CSA loss across the batch - csa_loss = csa_loss / output.shape[0] + # binarize the predictions and the labels (take only the final feature map i.e. the final prediction) + output = (outputs[0].detach() > 0.5).float() + labels = (labels.detach() > 0.5).float() + + # compute CSA for each element of the batch + # NOTE: the CSA is computed only for the final feature map (i.e. the prediction, not the intermediate deepsupervision feature maps) + csa_loss = 0.0 + for batch_idx in range(output.shape[0]): + pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) + gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) + csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # average CSA loss across the batch + csa_loss = csa_loss / output.shape[0] + + else: + # calculate training loss + # NOTE: the diceLoss expects the input to be logits (which it then normalizes inside) + dice_loss = self.loss_function(output, labels) + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) + + # binarize the predictions and the labels + output = (output.detach() > 0.5).float() + labels = (labels.detach() > 0.5).float() + + # compute CSA for each element of the batch + csa_loss = 0.0 + for batch_idx in range(output.shape[0]): + pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) + gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) + csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # average CSA loss across the batch + csa_loss = csa_loss / output.shape[0] # total loss loss = dice_loss + csa_loss From d6582c389c43833f0c5575e00202f5f4a1b0cce2 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 28 Aug 2023 23:29:40 -0400 Subject: [PATCH 054/106] add model checkpointing based on val csa and dice --- monai/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/monai/main.py b/monai/main.py index 45d5cb43..ac60a5cb 100644 --- a/monai/main.py +++ b/monai/main.py @@ -619,7 +619,7 @@ def main(args): exp_logger = pl.loggers.WandbLogger( name=save_exp_id, save_dir=args.save_path, - group=f"{args.model}_final", + group=f"{args.model}", #_final", log_model=True, # save best model using checkpoint callback project='contrast-agnostic', entity='naga-karthik', @@ -629,9 +629,15 @@ def main(args): # checkpoint_callback = pl.callbacks.ModelCheckpoint( # dirpath=save_path, filename='best_model', monitor='val_soft_dice', # save_top_k=5, mode="max", save_last=False, save_weights_only=True) - checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model', monitor='val_csa_loss', - save_top_k=5, mode="min", save_last=False, save_weights_only=True) + # saving the best model based on validation CSA loss + checkpoint_callback_csa = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_csa', monitor='val_csa_loss', + save_top_k=1, mode="min", save_last=False, save_weights_only=True) + + # saving the best model based on soft validation dice score + checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + save_top_k=1, mode="max", save_last=False, save_weights_only=True) # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, # verbose=False, mode="max") @@ -644,7 +650,7 @@ def main(args): trainer = pl.Trainer( devices=1, accelerator="gpu", # strategy="ddp", logger=exp_logger, - callbacks=[checkpoint_callback, lr_monitor, early_stopping], + callbacks=[checkpoint_callback_csa, checkpoint_callback_dice, lr_monitor, early_stopping], check_val_every_n_epoch=args.check_val_every_n_epochs, max_epochs=args.max_epochs, precision=32, # TODO: see if 16-bit precision is stable From 4017b9e126b58eddcf11f207479e822778bb1b79 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 29 Aug 2023 15:20:10 -0400 Subject: [PATCH 055/106] move network helper functions from models.py --- monai/building_blocks.py | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 monai/building_blocks.py diff --git a/monai/building_blocks.py b/monai/building_blocks.py new file mode 100644 index 00000000..b7d8e6a9 --- /dev/null +++ b/monai/building_blocks.py @@ -0,0 +1,54 @@ +""" +Some useful blocks for building the network architecture. +""" + +import torch.nn as nn + + +def conv_norm_lrelu(feat_in, feat_out): + """Conv3D + InstanceNorm3D + LeakyReLU block""" + return nn.Sequential( + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), + nn.InstanceNorm3d(feat_out), + nn.LeakyReLU() + ) + + +def norm_lrelu_conv(feat_in, feat_out): + """InstanceNorm3D + LeakyReLU + Conv3D block""" + return nn.Sequential( + nn.InstanceNorm3d(feat_in), + nn.LeakyReLU(), + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) + ) + + +def lrelu_conv(feat_in, feat_out): + """LeakyReLU + Conv3D block""" + return nn.Sequential( + nn.LeakyReLU(), + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) + ) + + +def norm_lrelu_upscale_conv_norm_lrelu(feat_in, feat_out): + """InstanceNorm3D + LeakyReLU + 2X Upsample + Conv3D + InstanceNorm3D + LeakyReLU block""" + return nn.Sequential( + nn.InstanceNorm3d(feat_in), + nn.LeakyReLU(), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), + nn.InstanceNorm3d(feat_out), + nn.LeakyReLU() + ) + + +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) From 674c5b46aef1562d9978dbfeebf5cb2f2bd62830 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 29 Aug 2023 15:21:34 -0400 Subject: [PATCH 056/106] add script for creating MSD datalists for running inference on pathlogy datasets --- monai/create_inference_msd_datalist.py | 98 ++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 monai/create_inference_msd_datalist.py diff --git a/monai/create_inference_msd_datalist.py b/monai/create_inference_msd_datalist.py new file mode 100644 index 00000000..9cbcba56 --- /dev/null +++ b/monai/create_inference_msd_datalist.py @@ -0,0 +1,98 @@ +import os +import json +import argparse +import joblib +from loguru import logger + +parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.') + +parser.add_argument('-dname', '--dataset-name', default='spine-generic', type=str, help='Name of the dataset') +parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the data set directory') +parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.', + default=None, type=str) +parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') +parser.add_argument('-csuf', '--contrast-suffix', type=str, default='T1w', + help='Contrast suffix used in the BIDS dataset') +args = parser.parse_args() + + +def main(args): + + root = args.path_data + contrast = args.contrast_suffix + + # Get all subjects + # the participants.tsv file might not be up-to-date, hence rely on the existing folders + # subjects_df = pd.read_csv(os.path.join(root, 'participants.tsv'), sep='\t') + # subjects = subjects_df['participant_id'].values.tolist() + subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')] + logger.info(f"Total number of subjects in the root directory: {len(subjects)}") + + + if args.path_joblib is not None: + # load information from the joblib to match train and test subjects + joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') + splits = joblib.load("split_datasets_all_seed=15.joblib") + # get the subjects from the joblib file + # train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) + # val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) + test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']]))) + + else: + test_subjects = subjects + + logger.info(f"Number of testing subjects: {len(test_subjects)}") + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = args.dataset_name + params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } + params["license"] = "nk" + params["modality"] = { + "0": "MRI" + } + params["name"] = "spine-generic" + params["numTest"] = len(test_subjects) + params["reference"] = "University of Zurich" + params["tensorImageSize"] = "3D" + + test_subjects_dict = {"test": test_subjects} + + for name, subs_list in test_subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(subs_list): + + temp_data= {} + + temp_data["image"] = os.path.join(root, subject, 'anat', f"{subject}_{contrast}.nii.gz") + if args.dataset_name == "sci-colorado": + temp_data["label"] = os.path.join(root, "derivatives", "labels", subject, 'anat', f"{subject}_{contrast}_seg-manual.nii.gz") + elif args.dataset_name == "basel-mp2rage-rpi": + temp_data["label"] = os.path.join(root, "derivatives", "labels", subject, 'anat', f"{subject}_{contrast}_label-SC_seg.nii.gz") + else: + raise NotImplementedError(f"Dataset {args.dataset_name} not implemented yet.") + + if os.path.exists(temp_data["label"]) and os.path.exists(temp_data["image"]): + temp_list.append(temp_data) + else: + logger.info(f"Subject {subject} does not have label or image file. Skipping it.") + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + jsonFile = open(args.path_out + "/" + f"{args.dataset_name}_dataset.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + +if __name__ == "__main__": + main(args) + + + + From 91ebb0e8f44b572f79502d8919bc794f9d4eb8a9 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 29 Aug 2023 15:22:56 -0400 Subject: [PATCH 057/106] add function to create model used in nnunet --- monai/models.py | 155 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 119 insertions(+), 36 deletions(-) diff --git a/monai/models.py b/monai/models.py index 8d11516f..195a3cf3 100644 --- a/monai/models.py +++ b/monai/models.py @@ -1,44 +1,111 @@ import torch import torch.nn as nn import torch.nn.functional as F - -def conv_norm_lrelu(feat_in, feat_out): - """Conv3D + InstanceNorm3D + LeakyReLU block""" - return nn.Sequential( - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm3d(feat_out), - nn.LeakyReLU() - ) - - -def norm_lrelu_conv(feat_in, feat_out): - """InstanceNorm3D + LeakyReLU + Conv3D block""" - return nn.Sequential( - nn.InstanceNorm3d(feat_in), - nn.LeakyReLU(), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) - ) - - -def lrelu_conv(feat_in, feat_out): - """LeakyReLU + Conv3D block""" - return nn.Sequential( - nn.LeakyReLU(), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) +from building_blocks import conv_norm_lrelu, norm_lrelu_conv, lrelu_conv, norm_lrelu_upscale_conv_norm_lrelu, InitWeights_He + +# ---------------------------- Imports for nnUNet's Model ----------------------------- +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + + +# ====================================================================================================== +# Define plans json taken from nnUNet +# ====================================================================================================== +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": 32, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} + + +# ====================================================================================================== +# Define the network based on plans json +# ==================================================================================================== +def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model -def norm_lrelu_upscale_conv_norm_lrelu(feat_in, feat_out): - """InstanceNorm3D + LeakyReLU + 2X Upsample + Conv3D + InstanceNorm3D + LeakyReLU block""" - return nn.Sequential( - nn.InstanceNorm3d(feat_in), - nn.LeakyReLU(), - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm3d(feat_out), - nn.LeakyReLU() - ) - # ---------------------------- ModifiedUNet3D Encoder Implementation ----------------------------- class ModifiedUNet3DEncoder(nn.Module): """Encoder for ModifiedUNet3D. Adapted from ivadomed.models""" @@ -243,4 +310,20 @@ def forward(self, x): seg_logits = self.unet_decoder(x, context_features) - return seg_logits \ No newline at end of file + return seg_logits + + + +if __name__ == "__main__": + + enable_deep_supervision = True + model = create_nnunet_from_plans(nnunet_plans, 1, 1, enable_deep_supervision) + input = torch.randn(1, 1, 160, 224, 96) + output = model(input) + if enable_deep_supervision: + for i in range(len(output)): + print(output[i].shape) + else: + print(output.shape) + + # print(output.shape) From 2f511e24e8f97e38bc5e2377fd2e3b159ea9bf4d Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 29 Aug 2023 15:26:20 -0400 Subject: [PATCH 058/106] add option to train using the model used in nnunet --- monai/main.py | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/monai/main.py b/monai/main.py index ac60a5cb..d05bd0f3 100644 --- a/monai/main.py +++ b/monai/main.py @@ -14,7 +14,7 @@ PolyLRScheduler, check_empty_patch from losses import SoftDiceLoss from transforms import train_transforms, val_transforms -from models import ModifiedUNet3D +from models import ModifiedUNet3D, create_nnunet_from_plans from monai.utils import set_determinism from monai.inferers import sliding_window_inference @@ -516,6 +516,33 @@ def main(args): # Setting the seed pl.seed_everything(args.seed, workers=True) + # ====================================================================================================== + # Define plans json taken from nnUNet_preprocessed folder + # ====================================================================================================== + nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": args.init_filters, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, + } + # define root path for finding datalists dataset_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/contrast-agnostic-softseg-spinalcord/monai" @@ -584,6 +611,19 @@ def main(args): f"nspv={args.num_samples_per_volume}" \ f"_bs={args.batch_size}_{patch_size}" + elif args.model in ["nnunet"]: + if args.enable_DS: + logger.info(f" Using nnUNet model WITH deep supervision! ") + else: + logger.info(f" Using nnUNet model WITHOUT deep supervision! ") + + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=args.enable_DS) + patch_size = "160x224x96" + save_exp_id =f"{args.model}_nf={args.init_filters}_DS={int(args.enable_DS)}" \ + f"_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_CSAdiceL_nspv={args.num_samples_per_volume}" \ + f"_bs={args.batch_size}_{patch_size}" # define loss function loss_func = SoftDiceLoss(p=1, smooth=1.0) @@ -698,8 +738,9 @@ def main(args): parser = argparse.ArgumentParser(description='Script for training custom models for SCI Lesion Segmentation.') # Arguments for model, data, and training and saving - parser.add_argument('-m', '--model', choices=['unet', 'unetr', 'dynunet'], + parser.add_argument('-m', '--model', choices=['unet', 'unetr', 'nnunet'], default='unet', type=str, help='Model type to be used') + parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') # dataset parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") parser.add_argument('-ncv', '--num_cv_folds', default=5, type=int, help="Number of cross validation folds") From 2e2a06150fb76782633601076f39acb1183baaa8 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 29 Aug 2023 15:27:20 -0400 Subject: [PATCH 059/106] update train/validation step to deal with deepsupervison outputs in loss calculation --- monai/main.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/monai/main.py b/monai/main.py index d05bd0f3..428c4035 100644 --- a/monai/main.py +++ b/monai/main.py @@ -174,32 +174,34 @@ def training_step(self, batch, batch_idx): # if using dynunet, output.shape = (B, num_upsample_layers+1, C, H, W, D) # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - if self.args.model == "dynunet": - # unbind the preds to calculate loss for each output - outputs = torch.unbind(output, dim=1) + if self.args.model == "nnunet" and self.args.enable_DS: # calculate dice loss for each output dice_loss, train_soft_dice = 0.0, 0.0 - for i in range(len(outputs)): + for i in range(len(output)): # give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss # NOTE: outputs[0] is the final pred, outputs[-1] is the lowest resolution pred (at the bottleneck) - dice_loss += (0.5 ** i) * self.loss_function(outputs[i], labels) + # we're downsampling the GT to the resolution of each deepsupervision feature map output + # (instead of upsampling each deepsupervision feature map output to the final resolution) + downsampled_gt = F.interpolate(labels, size=output[i].shape[-3:], mode='trilinear', align_corners=False) + # print(f"downsampled_gt.shape: {downsampled_gt.shape} \t output[i].shape: {output[i].shape}") + dice_loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) # get probabilities from logits - out = F.relu(outputs[i]) / F.relu(outputs[i]).max() if bool(F.relu(outputs[i]).max()) else F.relu(outputs[i]) + out = F.relu(output[i]) / F.relu(output[i]).max() if bool(F.relu(output[i]).max()) else F.relu(output[i]) # calculate train dice # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here # So, take this dice score with a lot of salt - train_soft_dice += self.soft_dice_metric(out, labels) + train_soft_dice += self.soft_dice_metric(out, downsampled_gt) # average dice loss across the outputs - dice_loss = dice_loss / len(outputs) - train_soft_dice = train_soft_dice / len(outputs) + dice_loss = dice_loss / len(output) + train_soft_dice = train_soft_dice / len(output) # binarize the predictions and the labels (take only the final feature map i.e. the final prediction) - output = (outputs[0].detach() > 0.5).float() + output = (output[0].detach() > 0.5).float() labels = (labels.detach() > 0.5).float() # compute CSA for each element of the batch @@ -308,6 +310,10 @@ def validation_step(self, batch, batch_idx): sw_batch_size=4, predictor=self.forward, overlap=0.5,) # outputs shape: (B, C, ) + if self.args.model == "nnunet" and self.args.enable_DS: + # we only need the output with the highest resolution + outputs = outputs[0] + # calculate validation loss dice_loss = self.loss_function(outputs, labels) From d1fc56fad568175ccbfbd7f8bbc9b196b0e8b2d0 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 29 Aug 2023 16:49:48 -0400 Subject: [PATCH 060/106] add options to create datalists per contrast and specify hard/soft labels --- monai/create_msd_data.py | 334 +++++++++++++++++---------------------- 1 file changed, 147 insertions(+), 187 deletions(-) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 488d90e9..14a6d492 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -12,19 +12,29 @@ parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.') -parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") -parser.add_argument('-ncvf', '--num-cv-folds', default=5, type=int, - help="[1-k] To create a k-fold dataset for cross validation, 0 for single file with all subjects") parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the data set directory') parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.', default=None, type=str) parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') +parser.add_argument("--contrast", default="t2w", type=str, help="Contrast to use for training", + choices=["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"]) +parser.add_argument('--label-type', default='soft', type=str, help="Type of labels to use for training", + choices=['hard', 'soft']) +parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") args = parser.parse_args() root = args.path_data seed = args.seed -num_cv_folds = args.num_cv_folds # for 100 subjects, performs a 60-20-20 split with num_cv_folds +contrast = args.contrast +if args.label_type == 'soft': + logger.info("Using SOFT LABELS ...") + PATH_DERIVATIVES = os.path.join(root, "derivatives", "labels_softseg") + SUFFIX = "softseg" +else: + logger.info("Using HARD LABELS ...") + PATH_DERIVATIVES = os.path.join(root, "derivatives", "labels") + SUFFIX = "seg-manual" # Get all subjects # the participants.tsv file might not be up-to-date, hence rely on the existing folders @@ -33,177 +43,62 @@ subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')] logger.info(f"Total number of subjects in the root directory: {len(subjects)}") -if args.num_cv_folds != 0: - # create k-fold CV datasets as usual - - # returns a nested list of length (num_cv_folds), each element (again, a list) consisting of - # train, val, test indices and the fold number - names_list = FoldGenerator(seed, num_cv_folds, len_data=len(subjects)).get_fold_names() - - for fold in range(num_cv_folds): - - train_ix, val_ix, test_ix, fold_num = names_list[fold] - training_subjects = [subjects[tr_ix] for tr_ix in train_ix] - validation_subjects = [subjects[v_ix] for v_ix in val_ix] - test_subjects = [subjects[te_ix] for te_ix in test_ix] - - # keys to be defined in the dataset_0.json - params = {} - params["description"] = "sci-zurich naga" - params["labels"] = { - "0": "background", - "1": "sc-lesion" - } - params["license"] = "nk" - params["modality"] = { - "0": "MRI" - } - params["name"] = "sci-zurich" - params["numTest"] = len(test_subjects) - params["numTraining"] = len(training_subjects) + len(validation_subjects) - params["reference"] = "University of Zurich" - params["tensorImageSize"] = "3D" - - - train_val_subjects_dict = { - "training": training_subjects, - "validation": validation_subjects, - } - test_subjects_dict = {"test": test_subjects} - - # run loop for training and validation subjects - temp_shapes_list = [] - for name, subs_list in train_val_subjects_dict.items(): - - temp_list = [] - for subject_no, subject in enumerate(tqdm(subs_list, desc='Loading Volumes')): - - # Another for loop for going through sessions - temp_subject_path = os.path.join(root, subject) - num_sessions_per_subject = sum(os.path.isdir(os.path.join(temp_subject_path, pth)) for pth in os.listdir(temp_subject_path)) - - for ses_idx in range(1, num_sessions_per_subject+1): - temp_data = {} - # Get paths with session numbers - session = 'ses-0' + str(ses_idx) - subject_images_path = os.path.join(root, subject, session, 'anat') - subject_labels_path = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat') - - subject_image_file = os.path.join(subject_images_path, '%s_%s_acq-sag_T2w.nii.gz' % (subject, session)) - subject_label_file = os.path.join(subject_labels_path, '%s_%s_acq-sag_T2w_lesion-manual.nii.gz' % (subject, session)) - - # get shapes of each subject to calculate median later - # temp_shapes_list.append(np.shape(nib.load(subject_image_file).get_fdata())) - - # # load GT mask - # gt_label = nib.load(subject_label_file).get_fdata() - # bbox_coords = get_bounding_boxes(mask=gt_label) - - # store in a temp dictionary - temp_data["image"] = subject_image_file.replace(root+"/", '') # .strip(root) - temp_data["label"] = subject_label_file.replace(root+"/", '') # .strip(root) - # temp_data["box"] = bbox_coords - - temp_list.append(temp_data) - - params[name] = temp_list - - # print(temp_shapes_list) - # calculate the median shapes along each axis - params["train_val_median_shape"] = np.median(temp_shapes_list, axis=0).tolist() - - # run separate loop for testing - for name, subs_list in test_subjects_dict.items(): - temp_list = [] - for subject_no, subject in enumerate(tqdm(subs_list, desc='Loading Volumes')): - - # Another for loop for going through sessions - temp_subject_path = os.path.join(root, subject) - num_sessions_per_subject = sum(os.path.isdir(os.path.join(temp_subject_path, pth)) for pth in os.listdir(temp_subject_path)) - - for ses_idx in range(1, num_sessions_per_subject+1): - temp_data = {} - # Get paths with session numbers - session = 'ses-0' + str(ses_idx) - subject_images_path = os.path.join(root, subject, session, 'anat') - subject_labels_path = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat') - - subject_image_file = os.path.join(subject_images_path, '%s_%s_acq-sag_T2w.nii.gz' % (subject, session)) - subject_label_file = os.path.join(subject_labels_path, '%s_%s_acq-sag_T2w_lesion-manual.nii.gz' % (subject, session)) - - # # load GT mask - # gt_label = nib.load(subject_label_file).get_fdata() - # bbox_coords = get_bounding_boxes(mask=gt_label) - - temp_data["image"] = subject_image_file.replace(root+"/", '') - temp_data["label"] = subject_label_file.replace(root+"/", '') - # temp_data["box"] = bbox_coords - - temp_list.append(temp_data) - - params[name] = temp_list - - final_json = json.dumps(params, indent=4, sort_keys=True) - jsonFile = open(root + "/" + f"dataset_fold-{fold_num}.json", "w") - jsonFile.write(final_json) - jsonFile.close() -else: - - if args.path_joblib is not None: - # load information from the joblib to match train and test subjects - joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') - splits = joblib.load(joblib_file) - # get the subjects from the joblib file - train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) - val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) - test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']]))) - - else: - # create one json file with 60-20-20 train-val-test split - train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 - train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) - # Use the training split to further split into training and validation splits - train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), - random_state=args.seed, ) - - logger.info(f"Number of training subjects: {len(train_subjects)}") - logger.info(f"Number of validation subjects: {len(val_subjects)}") - logger.info(f"Number of testing subjects: {len(test_subjects)}") - - # keys to be defined in the dataset_0.json - params = {} - params["description"] = "spine-generic-uncropped" - params["labels"] = { - "0": "background", - "1": "soft-sc-seg" - } - params["license"] = "nk" - params["modality"] = { - "0": "MRI" - } - params["name"] = "spine-generic" - params["numTest"] = len(test_subjects) - params["numTraining"] = len(train_subjects) - params["numValidation"] = len(val_subjects) - params["seed"] = args.seed - params["reference"] = "University of Zurich" - params["tensorImageSize"] = "3D" - - train_subjects_dict = {"train": train_subjects} - val_subjects_dict = {"validation": val_subjects} - test_subjects_dict = {"test": test_subjects} - all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] - - # define the contrasts - contrasts_list = ['T1w', 'T2w', 'T2star', 'flip-1_mt-on_MTS', 'flip-2_mt-off_MTS', 'dwi'] - - for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): - - for name, subs_list in subjects_dict.items(): - - temp_list = [] - for subject_no, subject in enumerate(subs_list): - +if args.path_joblib is not None: + # load information from the joblib to match train and test subjects + # joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') + splits = joblib.load(args.path_joblib) + # get the subjects from the joblib file + train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) + val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) + test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']]))) + +else: + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + +logger.info(f"Number of training subjects: {len(train_subjects)}") +logger.info(f"Number of validation subjects: {len(val_subjects)}") +logger.info(f"Number of testing subjects: {len(test_subjects)}") + +# keys to be defined in the dataset_0.json +params = {} +params["description"] = "spine-generic-uncropped" +params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } +params["license"] = "nk" +params["modality"] = { + "0": "MRI" + } +params["name"] = "spine-generic" +params["numTest"] = len(test_subjects) +params["numTraining"] = len(train_subjects) +params["numValidation"] = len(val_subjects) +params["seed"] = args.seed +params["reference"] = "University of Zurich" +params["tensorImageSize"] = "3D" + +train_subjects_dict = {"train": train_subjects} +val_subjects_dict = {"validation": val_subjects} +test_subjects_dict = {"test": test_subjects} +all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict] + +# # define the contrasts +# contrasts_list = ['T1w', 'T2w', 'T2star', 'flip-1_mt-on_MTS', 'flip-2_mt-off_MTS', 'dwi'] + +for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"): + + for name, subs_list in subjects_dict.items(): + + temp_list = [] + for subject_no, subject in enumerate(subs_list): + + if contrast == "all": temp_data_t1w = {} temp_data_t2w = {} temp_data_t2star = {} @@ -213,47 +108,112 @@ # t1w temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz") - temp_data_t1w["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T1w_softseg.nii.gz") + temp_data_t1w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T1w_{SUFFIX}.nii.gz") if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]): temp_list.append(temp_data_t1w) # t2w temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz") - temp_data_t2w["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T2w_softseg.nii.gz") + temp_data_t2w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2w_{SUFFIX}.nii.gz") if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]): temp_list.append(temp_data_t2w) # t2star temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz") - temp_data_t2star["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T2star_softseg.nii.gz") + temp_data_t2star["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2star_{SUFFIX}.nii.gz") if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]): temp_list.append(temp_data_t2star) # mton_mts temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz") - temp_data_mton_mts["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_flip-1_mt-on_MTS_softseg.nii.gz") + temp_data_mton_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-1_mt-on_MTS_{SUFFIX}.nii.gz") if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]): temp_list.append(temp_data_mton_mts) # t1w_mts temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz") - temp_data_mtoff_mts["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_flip-2_mt-off_MTS_softseg.nii.gz") + temp_data_mtoff_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-2_mt-off_MTS_{SUFFIX}.nii.gz") if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]): temp_list.append(temp_data_mtoff_mts) # dwi temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz") - temp_data_dwi["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'dwi', f"{subject}_rec-average_dwi_softseg.nii.gz") + temp_data_dwi["label"] = os.path.join(PATH_DERIVATIVES, subject, 'dwi', f"{subject}_rec-average_dwi_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]): + temp_list.append(temp_data_dwi) + + + elif contrast == "t1w": # t1w + temp_data_t1w = {} + temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz") + temp_data_t1w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T1w_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]): + temp_list.append(temp_data_t1w) + else: + logger.info(f"Subject {subject} does not have T1w image or label.") + + + elif contrast == "t2w": # t2w + temp_data_t2w = {} + temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz") + temp_data_t2w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2w_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]): + temp_list.append(temp_data_t2w) + else: + logger.info(f"Subject {subject} does not have T2w image or label.") + + + elif contrast == "t2star": # t2star + temp_data_t2star = {} + temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz") + temp_data_t2star["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2star_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]): + temp_list.append(temp_data_t2star) + else: + logger.info(f"Subject {subject} does not have T2star image or label.") + + + elif contrast == "mton": # mton_mts + temp_data_mton_mts = {} + temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz") + temp_data_mton_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-1_mt-on_MTS_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]): + temp_list.append(temp_data_mton_mts) + else: + logger.info(f"Subject {subject} does not have MTOn image or label.") + + elif contrast == "mtoff": # t1w_mts + temp_data_mtoff_mts = {} + temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz") + temp_data_mtoff_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-2_mt-off_MTS_{SUFFIX}.nii.gz") + if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]): + temp_list.append(temp_data_mtoff_mts) + else: + logger.info(f"Subject {subject} does not have MTOff image or label.") + + elif contrast == "dwi": # dwi + temp_data_dwi = {} + temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz") + temp_data_dwi["label"] = os.path.join(PATH_DERIVATIVES, subject, 'dwi', f"{subject}_rec-average_dwi_{SUFFIX}.nii.gz") if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]): temp_list.append(temp_data_dwi) + else: + logger.info(f"Subject {subject} does not have DWI image or label.") + + else: + raise ValueError(f"Contrast {contrast} not recognized.") - params[name] = temp_list - logger.info(f"Number of images in {name} set: {len(temp_list)}") + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + +final_json = json.dumps(params, indent=4, sort_keys=True) +if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) - final_json = json.dumps(params, indent=4, sort_keys=True) - jsonFile = open(args.path_out + "/" + f"dataset.json", "w") - jsonFile.write(final_json) - jsonFile.close() +jsonFile = open(args.path_out + "/" + f"dataset_{contrast}_{args.label_type}_seed{seed}.json", "w") +jsonFile.write(final_json) +jsonFile.close() From 8a830186f29ae22fd40d16f4c600de0aaedd53d3 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 31 Aug 2023 17:54:32 -0400 Subject: [PATCH 061/106] rearrange transforms to match ivadomed's; add RandSimulateLowResolution --- monai/transforms.py | 77 +++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 04f80e28..795a0e10 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -1,9 +1,10 @@ import numpy as np from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, - RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, + RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, RandAffined, RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, - RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd) + RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, RandSimulateLowResolutiond) +# import torchio as tio # median image size in voxels - taken from nnUNet # median_size = (123, 255, 214) # so pad with this size @@ -16,35 +17,49 @@ # 3. Resample to target spacing def train_transforms(crop_size, num_samples_pv, lbl_key="label"): - return Compose([ - # pre-processing - LoadImaged(keys=["image", lbl_key]), - EnsureChannelFirstd(keys=["image", lbl_key]), - CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), - # data-augmentation - SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), - # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a - # foreground voxel as a center rather than a background voxel. - RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", - spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, - # if num_samples=4, then 4 samples/image are randomly generated - image_key="image", image_threshold=0.), - # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), - Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), - # TODO: Try Spacingd with low resolution here with prob=0.5 - RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.4), # this is monai's RandomGamma - RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), - RandGaussianSmoothd(keys=["image"], sigma_x=(0.0, 2.0), sigma_y=(0.0, 2.0), sigma_z=(0.0, 2.0), prob=0.3), - RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.4,), - RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, - range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 - range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), - range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), - # # re-orientation - # Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error - ]) + + monai_transforms = [ + # pre-processing + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + # data-augmentation + SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), + # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a + # foreground voxel as a center rather than a background voxel. + RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", + spatial_size=crop_size, pos=2, neg=1, num_samples=num_samples_pv, + # if num_samples=4, then 4 samples/image are randomly generated + image_key="image", image_threshold=0.), + # transforms used by ivadomed and nnunet + RandAffined(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=1.0, + rotate_range=(-20.0, 20.0), scale_range=(0.8, 1.2), translate_range=(-0.1, 0.1)), + Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + RandAdjustContrastd(keys=["image"], gamma=(0.5, 1.5), prob=0.5), # this is monai's RandomGamma + RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + RandGaussianSmoothd(keys=["image"], sigma_x=(0.0, 2.0), sigma_y=(0.0, 2.0), sigma_z=(0.0, 2.0), prob=0.3), + # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + # TODO: Try Spacingd with low resolution here with prob=0.5 + # RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.4,), + # RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, + # range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 + # range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + # range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), + # # re-orientation + # Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error + ] + + # tio_transforms = [ + # # tio.RandomBiasField(coefficients=0.5, order=3, p=0.3, include=["image"]), + # # Multiply spacing of one of the 3 axes by a factor randomly chosen in [1, 4] + # tio.RandomAnisotropy(axes=(0, 1, 2), downsampling=(1.0, 4.0), p=0.3, include=["image", "label"]), # from nnUNetPlans - median spacing is 0.9x0.9x5.0, + # ] + + # return Compose(monai_transforms + tio_transforms) + return Compose(monai_transforms) def val_transforms(lbl_key="label"): return Compose([ From c59edac28e65e3f358411c5167a844a6118305b7 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 4 Sep 2023 14:49:44 -0400 Subject: [PATCH 062/106] update train transforms like nnunet's --- monai/transforms.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 795a0e10..eb3128bb 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -1,15 +1,17 @@ import numpy as np from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, - RandCropByPosNegLabeld, Spacingd, RandRotated, NormalizeIntensityd, RandAffined, + RandCropByPosNegLabeld, Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, - RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, RandSimulateLowResolutiond) + RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, RandSimulateLowResolutiond, + ResizeWithPadOrCropd) # import torchio as tio # median image size in voxels - taken from nnUNet -# median_size = (123, 255, 214) # so pad with this size +# median_size = (123, 255, 214) as per 0.9 iso resampling and patch_size = (80, 192, 160) +# note the the order of the axes is different in nnunet and monai (dims 0 and 2 are swapped) # median_size after 1mm isotropic resampling -# median_size = [ 192. 228. 106.] +# median_size = [ 192. 228. 106.] # Order in which nnunet does preprocessing: # 1. Crop to non-zero @@ -26,24 +28,28 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), # data-augmentation - SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), + # SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), + SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a # foreground voxel as a center rather than a background voxel. RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=2, neg=1, num_samples=num_samples_pv, # if num_samples=4, then 4 samples/image are randomly generated image_key="image", image_threshold=0.), - # transforms used by ivadomed and nnunet - RandAffined(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=1.0, + # ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + + # re-ordering transforms as used by nnunet + RandAffined(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.25, rotate_range=(-20.0, 20.0), scale_range=(0.8, 1.2), translate_range=(-0.1, 0.1)), - Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + # Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0), prob=0.25), + RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), - RandAdjustContrastd(keys=["image"], gamma=(0.5, 1.5), prob=0.5), # this is monai's RandomGamma + RandAdjustContrastd(keys=["image"], gamma=(0.5, 1.5), prob=0.3), # this is monai's RandomGamma RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), - RandGaussianSmoothd(keys=["image"], sigma_x=(0.0, 2.0), sigma_y=(0.0, 2.0), sigma_z=(0.0, 2.0), prob=0.3), - # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), # TODO: Try Spacingd with low resolution here with prob=0.5 - # RandFlipd(keys=["image", lbl_key], spatial_axis=None, prob=0.4,), + RandFlipd(keys=["image", lbl_key], prob=0.5,), # RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, # range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 # range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), From 77d3beab7728d389ea825529c188ca4af83820cc Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 7 Sep 2023 10:11:59 -0400 Subject: [PATCH 063/106] add working version of AdapWingLoss --- monai/losses.py | 52 ++++++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/monai/losses.py b/monai/losses.py index 46eb8c74..9c1ecdfa 100644 --- a/monai/losses.py +++ b/monai/losses.py @@ -64,8 +64,7 @@ def forward(self, preds, labels): class AdapWingLoss(nn.Module): """ - Adaptive Wing loss - Used for heatmap ground truth. + Adaptive Wing loss used for heatmap regression Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/losses.py#L341 .. seealso:: @@ -73,11 +72,12 @@ class AdapWingLoss(nn.Module): Proceedings of the IEEE International Conference on Computer Vision. 2019. Args: - theta (float): Threshold between linear and non linear loss. - alpha (float): Used to adapt loss shape to input shape and make loss smooth at 0 (background). + theta (float): Threshold to switch between the linear and non-linear parts of the piece-wise loss function. + alpha (float): Used to adapt the behaviour of the loss function at y=0 and y=1 and make loss smooth at 0 (background). It needs to be slightly above 2 to maintain ideal properties. - omega (float): Multiplicating factor for non linear part of the loss. + omega (float): Multiplicative factor for non linear part of the loss. epsilon (float): factor to avoid gradient explosion. It must not be too small + NOTE: Larger omega and smaller epsilon values will increase the influence on small errors and vice versa """ def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): @@ -89,9 +89,31 @@ def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): super(AdapWingLoss, self).__init__() def forward(self, input, target): + eps = self.epsilon batch_size = target.size()[0] - hm_num = target.size()[1] + # Adaptive Wing loss. Section 4.2 of the paper. + # Compute adaptive factor + A = self.omega * (1 / (1 + torch.pow(self.theta / eps, + self.alpha - target))) * \ + (self.alpha - target) * torch.pow(self.theta / eps, + self.alpha - target - 1) * (1 / eps) + + # Constant term to link linear and non linear part + C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) + + diff_hm = torch.abs(target - input) + AWingLoss = A * diff_hm - C + idx = diff_hm < self.theta + # NOTE: this is a memory-efficient version than the one in ivadomed losses.py + # where idx is True, compute the non-linear part of the loss, otherwise keep the linear part + # the non-linear parts ensures small errors (as given by idx) have a larger influence to refine the predictions at the boundaries + # the linear part makes the loss function behave more like the MSE loss, which has a linear influence + # (i.e. small errors where y=0 --> small influence --> small gradients) + AWingLoss = torch.where(idx, self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target)), AWingLoss) + + + # Mask for weighting the loss function. Section 4.3 of the paper. mask = torch.zeros_like(target) kernel = scipy.ndimage.generate_binary_structure(2, 2) # For 3D segmentation tasks @@ -103,32 +125,18 @@ def forward(self, input, target): img_list.append(np.round(target[i].cpu().numpy() * 255)) img_merge = np.concatenate(img_list) img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) + # NOTE: why 51? the paper thresholds the dilated GT heatmap at 0.2. So, 51/255 = 0.2 img_dilate[img_dilate < 51] = 1 # 0*omega+1 img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 img_dilate = np.array(img_dilate, dtype=int) mask[i] = torch.tensor(img_dilate) - eps = self.epsilon - # Compute adaptative factor - A = self.omega * (1 / (1 + torch.pow(self.theta / eps, - self.alpha - target))) * \ - (self.alpha - target) * torch.pow(self.theta / eps, - self.alpha - target - 1) * (1 / eps) - - # Constant term to link linear and non linear part - C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) - - diff_hm = torch.abs(target - input) - AWingLoss = A * diff_hm - C - idx = torch.argwhere(diff_hm < self.theta) - AWingLoss[idx] = self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target))[idx] + AWingLoss *= mask - # AWingLoss *= mask sum_loss = torch.sum(AWingLoss) if self.reduction == "sum": return sum_loss elif self.reduction == "mean": all_pixel = torch.sum(mask) return sum_loss / all_pixel - From dd0ece4c1fa981eea5beb94c3e2766d8a2ebf893 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 11 Sep 2023 18:21:22 -0400 Subject: [PATCH 064/106] add training transforms as per ivadomed --- monai/transforms.py | 91 ++++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index eb3128bb..c6c473da 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -18,53 +18,66 @@ # 2. Normalization # 3. Resample to target spacing +# Order in which ivadomed does preprocessing: +# 1. Resample to 1mm iso +# 2. CenterCrop using 46x176x288 +# 3. RandomAffine --> RandomElastic --> RandomGamma --> RandomBiasField --> RandomBlur --> NormalizeInstance + + def train_transforms(crop_size, num_samples_pv, lbl_key="label"): monai_transforms = [ - # pre-processing + # # pre-processing + # LoadImaged(keys=["image", lbl_key]), + # EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box + # NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + # Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), + # # data-augmentation + # # SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), + # SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size + # # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a + # # foreground voxel as a center rather than a background voxel. + # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", + # spatial_size=crop_size, pos=2, neg=1, num_samples=num_samples_pv, + # # if num_samples=4, then 4 samples/image are randomly generated + # image_key="image", image_threshold=0.), + # # re-ordering transforms as used by nnunet + # RandAffined(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.75, + # rotate_range=(-20.0, 20.0), scale_range=(0.8, 1.2), translate_range=(-0.1, 0.1)), + # # Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), + # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + # RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0), prob=0.25), + # RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform + # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + # RandAdjustContrastd(keys=["image"], gamma=(0.5, 1.5), prob=0.3), # this is monai's RandomGamma + # RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + # RandFlipd(keys=["image", lbl_key], prob=0.5,), + # # RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, + # # range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 + # # range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + # # range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), + + # defining transforms as used by ivadomed (with the same probabilities) LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), - # data-augmentation - # SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), - SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size - # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a - # foreground voxel as a center rather than a background voxel. - RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", - spatial_size=crop_size, pos=2, neg=1, num_samples=num_samples_pv, - # if num_samples=4, then 4 samples/image are randomly generated - image_key="image", image_threshold=0.), - # ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - - # re-ordering transforms as used by nnunet - RandAffined(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.25, - rotate_range=(-20.0, 20.0), scale_range=(0.8, 1.2), translate_range=(-0.1, 0.1)), - # Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), - RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), - RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0), prob=0.25), - RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform - RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), - RandAdjustContrastd(keys=["image"], gamma=(0.5, 1.5), prob=0.3), # this is monai's RandomGamma + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=1.0, + rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians + scale_range=(-0.2, 0.2), # ivadomed uses sth like scale_x = random.uniform(1 - self.scale[0], 1 + self.scale[0]), but monai adds 1.0 to the scale + translate_range=(-0.1, 0.1)), + Rand3DElasticd(keys=["image", lbl_key], prob=0.5, + sigma_range=(3.5, 5.5), + magnitude_range=(25., 35.)), + # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), - # TODO: Try Spacingd with low resolution here with prob=0.5 - RandFlipd(keys=["image", lbl_key], prob=0.5,), - # RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, - # range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 - # range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), - # range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), - # # re-orientation - # Orientationd(keys=["image", lbl_key], axcodes="RPI"), # NOTE: if not using it here, then it results in collation error + RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), + # RandFlipd(keys=["image", lbl_key], prob=0.5,), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ] - # tio_transforms = [ - # # tio.RandomBiasField(coefficients=0.5, order=3, p=0.3, include=["image"]), - # # Multiply spacing of one of the 3 axes by a factor randomly chosen in [1, 4] - # tio.RandomAnisotropy(axes=(0, 1, 2), downsampling=(1.0, 4.0), p=0.3, include=["image", "label"]), # from nnUNetPlans - median spacing is 0.9x0.9x5.0, - # ] - - # return Compose(monai_transforms + tio_transforms) return Compose(monai_transforms) def val_transforms(lbl_key="label"): From 0fc8b9046c0e39dce4d082b3858be9b545c3db9a Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 11 Sep 2023 18:22:14 -0400 Subject: [PATCH 065/106] add val_transforms_with_center_crop() --- monai/transforms.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index c6c473da..ac75b460 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -5,7 +5,6 @@ RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, RandSimulateLowResolutiond, ResizeWithPadOrCropd) -# import torchio as tio # median image size in voxels - taken from nnUNet # median_size = (123, 255, 214) as per 0.9 iso resampling and patch_size = (80, 192, 160) @@ -84,9 +83,19 @@ def val_transforms(lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - Orientationd(keys=["image", lbl_key], axcodes="RPI"), + # Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + +def val_transforms_with_center_crop(crop_size, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key]), + EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + # TODO: do cropping only in R-L so sth like (48, -1, -1) NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), - # SpatialPadd(keys=["image", lbl_key], spatial_size=(123, 255, 214), method="symmetric"), ]) From 2cff684ebccaa3896a7d8cf3a966b421196e391c Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 11 Sep 2023 18:26:18 -0400 Subject: [PATCH 066/106] modify train/val steps for training with AdapWingLoss --- monai/main.py | 148 ++++++++++++++++++++++++++------------------------ 1 file changed, 77 insertions(+), 71 deletions(-) diff --git a/monai/main.py b/monai/main.py index 428c4035..a73baa6b 100644 --- a/monai/main.py +++ b/monai/main.py @@ -11,8 +11,8 @@ import matplotlib.pyplot as plt from utils import precision_score, recall_score, dice_score, compute_average_csa, \ - PolyLRScheduler, check_empty_patch -from losses import SoftDiceLoss + PolyLRScheduler, plot_slices, check_empty_patch +from losses import SoftDiceLoss, AdapWingLoss from transforms import train_transforms, val_transforms from models import ModifiedUNet3D, create_nnunet_from_plans @@ -200,24 +200,25 @@ def training_step(self, batch, batch_idx): dice_loss = dice_loss / len(output) train_soft_dice = train_soft_dice / len(output) - # binarize the predictions and the labels (take only the final feature map i.e. the final prediction) - output = (output[0].detach() > 0.5).float() - labels = (labels.detach() > 0.5).float() + # # binarize the predictions and the labels (take only the final feature map i.e. the final prediction) + # output = (output[0].detach() > 0.5).float() + # labels = (labels.detach() > 0.5).float() - # compute CSA for each element of the batch - # NOTE: the CSA is computed only for the final feature map (i.e. the prediction, not the intermediate deepsupervision feature maps) - csa_loss = 0.0 - for batch_idx in range(output.shape[0]): - pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) - gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) - csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 - # average CSA loss across the batch - csa_loss = csa_loss / output.shape[0] + # # compute CSA for each element of the batch + # # NOTE: the CSA is computed only for the final feature map (i.e. the prediction, not the intermediate deepsupervision feature maps) + # csa_loss = 0.0 + # for batch_idx in range(output.shape[0]): + # pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) + # gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) + # csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # # average CSA loss across the batch + # csa_loss = csa_loss / output.shape[0] else: # calculate training loss # NOTE: the diceLoss expects the input to be logits (which it then normalizes inside) - dice_loss = self.loss_function(output, labels) + # dice_loss = self.loss_function(output, labels) + loss = self.loss_function(output, labels) # get probabilities from logits output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) @@ -228,26 +229,26 @@ def training_step(self, batch, batch_idx): train_soft_dice = self.soft_dice_metric(output, labels) # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) - # binarize the predictions and the labels - output = (output.detach() > 0.5).float() - labels = (labels.detach() > 0.5).float() + # # binarize the predictions and the labels + # output = (output.detach() > 0.5).float() + # labels = (labels.detach() > 0.5).float() - # compute CSA for each element of the batch - csa_loss = 0.0 - for batch_idx in range(output.shape[0]): - pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) - gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) - csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 - # average CSA loss across the batch - csa_loss = csa_loss / output.shape[0] - - # total loss - loss = dice_loss + csa_loss + # # compute CSA for each element of the batch + # csa_loss = 0.0 + # for batch_idx in range(output.shape[0]): + # pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) + # gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) + # csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # # average CSA loss across the batch + # csa_loss = csa_loss / output.shape[0] + + # # total loss + # loss = dice_loss + csa_loss metrics_dict = { "loss": loss.cpu(), - "dice_loss": dice_loss.cpu(), - "csa_loss": csa_loss.cpu(), + # "dice_loss": dice_loss.cpu(), + # "csa_loss": csa_loss.cpu(), "train_soft_dice": train_soft_dice.detach().cpu(), "train_number": len(inputs), # "train_image": inputs[0].detach().cpu().squeeze(), @@ -264,25 +265,25 @@ def on_train_epoch_end(self): # means the training step was skipped because of empty input patch return None else: - train_loss, train_dice_loss, train_csa_loss = 0, 0, 0 - num_items, train_soft_dice = 0, 0 + train_loss, train_dice_loss, train_csa_loss, train_soft_dice = 0, 0, 0, 0 + num_items = len(self.train_step_outputs) for output in self.train_step_outputs: train_loss += output["loss"].item() - train_dice_loss += output["dice_loss"].item() - train_csa_loss += output["csa_loss"].item() + # train_dice_loss += output["dice_loss"].item() + # train_csa_loss += output["csa_loss"].item() train_soft_dice += output["train_soft_dice"].item() - num_items += output["train_number"] + # num_items += output["train_number"] mean_train_loss = (train_loss / num_items) - mean_train_dice_loss = (train_dice_loss / num_items) - mean_train_csa_loss = (train_csa_loss / num_items) + # mean_train_dice_loss = (train_dice_loss / num_items) + # mean_train_csa_loss = (train_csa_loss / num_items) mean_train_soft_dice = (train_soft_dice / num_items) wandb_logs = { "train_soft_dice": mean_train_soft_dice, "train_loss": mean_train_loss, - "train_dice_loss": mean_train_dice_loss, - "train_csa_loss": mean_train_csa_loss, + # "train_dice_loss": mean_train_dice_loss, + # "train_csa_loss": mean_train_csa_loss, } self.log_dict(wandb_logs) @@ -306,16 +307,18 @@ def validation_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] + # NOTE: this calculates the loss on the entire image after sliding window outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", - sw_batch_size=4, predictor=self.forward, overlap=0.5,) + sw_batch_size=1, predictor=self.forward, overlap=0.5,) # outputs shape: (B, C, ) - + if self.args.model == "nnunet" and self.args.enable_DS: # we only need the output with the highest resolution outputs = outputs[0] # calculate validation loss - dice_loss = self.loss_function(outputs, labels) + # dice_loss = self.loss_function(outputs, labels) + loss = self.loss_function(outputs, labels) # get probabilities from logits outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) @@ -328,26 +331,26 @@ def validation_step(self, batch, batch_idx): hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) - # compute val CSA loss - val_csa_loss = 0.0 - for batch_idx in range(hard_preds.shape[0]): - pred_patch_csa = compute_average_csa(hard_preds[batch_idx].squeeze(), self.spacing) - gt_patch_csa = compute_average_csa(hard_labels[batch_idx].squeeze(), self.spacing) - val_csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 + # # compute val CSA loss + # val_csa_loss = 0.0 + # for batch_idx in range(hard_preds.shape[0]): + # pred_patch_csa = compute_average_csa(hard_preds[batch_idx].squeeze(), self.spacing) + # gt_patch_csa = compute_average_csa(hard_labels[batch_idx].squeeze(), self.spacing) + # val_csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 - # average CSA loss across the batch - val_csa_loss = val_csa_loss / hard_preds.shape[0] + # # average CSA loss across the batch + # val_csa_loss = val_csa_loss / hard_preds.shape[0] - # total loss - loss = dice_loss + val_csa_loss + # # total loss + # loss = dice_loss + val_csa_loss # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, # using .detach() to avoid storing the whole computation graph # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 metrics_dict = { "val_loss": loss.detach().cpu(), - "val_dice_loss": dice_loss.detach().cpu(), - "val_csa_loss": val_csa_loss.detach().cpu(), + # "val_dice_loss": dice_loss.detach().cpu(), + # "val_csa_loss": val_csa_loss.detach().cpu(), "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), @@ -367,40 +370,41 @@ def on_validation_epoch_end(self): val_loss += output["val_loss"].sum().item() val_soft_dice += output["val_soft_dice"].sum().item() val_hard_dice += output["val_hard_dice"].sum().item() - val_dice_loss += output["val_dice_loss"].sum().item() - val_csa_loss += output["val_csa_loss"].sum().item() + # val_dice_loss += output["val_dice_loss"].sum().item() + # val_csa_loss += output["val_csa_loss"].sum().item() num_items += output["val_number"] mean_val_loss = (val_loss / num_items) mean_val_soft_dice = (val_soft_dice / num_items) mean_val_hard_dice = (val_hard_dice / num_items) - mean_val_dice_loss = (val_dice_loss / num_items) - mean_val_csa_loss = (val_csa_loss / num_items) + # mean_val_dice_loss = (val_dice_loss / num_items) + # mean_val_csa_loss = (val_csa_loss / num_items) wandb_logs = { "val_soft_dice": mean_val_soft_dice, "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, - "val_dice_loss": mean_val_dice_loss, - "val_csa_loss": mean_val_csa_loss, + # "val_dice_loss": mean_val_dice_loss, + # "val_csa_loss": mean_val_csa_loss, } - # # save the best model based on validation dice score - # if mean_val_soft_dice > self.best_val_dice: - # self.best_val_dice = mean_val_soft_dice - # self.best_val_epoch = self.current_epoch + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch # save the best model based on validation CSA loss - if mean_val_csa_loss < self.best_val_csa: - self.best_val_csa = mean_val_csa_loss + # if mean_val_loss < self.best_val_csa: + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss self.best_val_epoch = self.current_epoch print( f"Current epoch: {self.current_epoch}" f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" - f"\nAverage CSA (VAL): {mean_val_csa_loss:.4f}" - # f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" - f"\nBest Average CSA: {self.best_val_csa:.4f} at Epoch: {self.best_val_epoch}" + f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" + f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + # f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") # log on to wandb @@ -632,7 +636,9 @@ def main(args): f"_bs={args.batch_size}_{patch_size}" # define loss function - loss_func = SoftDiceLoss(p=1, smooth=1.0) + # loss_func = SoftDiceLoss(p=1, smooth=1.0) + loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon}!") # TODO: move this inside the for loop when using more folds timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format From f1a8343787194faf6fe84cd6f21305355f973d89 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 11 Sep 2023 18:28:33 -0400 Subject: [PATCH 067/106] modify dataloading to use val_transforms_with_center_crop --- monai/main.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/monai/main.py b/monai/main.py index a73baa6b..57d1329b 100644 --- a/monai/main.py +++ b/monai/main.py @@ -13,7 +13,7 @@ from utils import precision_score, recall_score, dice_score, compute_average_csa, \ PolyLRScheduler, plot_slices, check_empty_patch from losses import SoftDiceLoss, AdapWingLoss -from transforms import train_transforms, val_transforms +from transforms import train_transforms, val_transforms, val_transforms_with_center_crop from models import ModifiedUNet3D, create_nnunet_from_plans from monai.utils import set_determinism @@ -41,11 +41,17 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.results_path = results_path self.best_val_dice, self.best_val_epoch = 0, 0 - self.best_val_csa = float("inf") + # self.best_val_csa = float("inf") + self.best_val_loss = float("inf") # define cropping and padding dimensions - self.voxel_cropping_size = (160, 224, 96) # (80, 192, 160) taken from nnUNet_plans.json - self.inference_roi_size = (160, 224, 96) + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.voxel_cropping_size = (48, 176, 288) + self.inference_roi_size = self.voxel_cropping_size self.spacing = (1.0, 1.0, 1.0) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) @@ -95,7 +101,8 @@ def prepare_data(self): num_samples_pv=self.args.num_samples_per_volume, lbl_key='label' ) - transforms_val = val_transforms(lbl_key='label') + # transforms_val = val_transforms(lbl_key='label') + transforms_val = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # load the dataset dataset = os.path.join(self.root, f"spine-generic-ivado-comparison_dataset.json") @@ -104,8 +111,8 @@ def prepare_data(self): test_files = load_decathlon_datalist(dataset, True, "test") if args.debug: - train_files = train_files[:10] - val_files = val_files[:10] + train_files = train_files[:15] + val_files = val_files[:15] test_files = test_files[:6] train_cache_rate = 0.25 if args.debug else 0.5 @@ -113,7 +120,8 @@ def prepare_data(self): self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) # define test transforms - transforms_test = val_transforms(lbl_key='label') + # transforms_test = val_transforms(lbl_key='label') + transforms_test = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 From 8777cd3f8e2b626954ffca7dd084b9f8cb3ac499 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 11 Sep 2023 18:29:25 -0400 Subject: [PATCH 068/106] minor modifications to match ivadomed's training --- monai/main.py | 106 +++++++++++++++++++------------------------------- 1 file changed, 41 insertions(+), 65 deletions(-) diff --git a/monai/main.py b/monai/main.py index 57d1329b..a739821a 100644 --- a/monai/main.py +++ b/monai/main.py @@ -22,6 +22,10 @@ from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) +# TODO: +# 1. Remove centerCropping but train with AdapWingLoss +# 2. Modify probs of transformations; add new ones + # create a "model"-agnostic class with PL to use different models class Model(pl.LightningModule): def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_class, @@ -33,7 +37,6 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.root = data_root self.fold_num = fold_num self.net = net - # self.load_pretrained = load_pretrained self.lr = args.learning_rate self.loss_function = loss_function self.optimizer_class = optimizer_class @@ -66,9 +69,6 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas self.val_step_outputs = [] self.test_step_outputs = [] - # specify example_input_array for model summary - self.example_input_array = torch.rand(1, 1, 160, 224, 96) - # -------------------------------- # FORWARD PASS @@ -141,8 +141,7 @@ def prepare_data(self): def train_dataloader(self): # NOTE: if num_samples=4 in RandCropByPosNegLabeld and batch_size=2, then 2 x 4 images are generated for network training return DataLoader(self.train_ds, batch_size=self.args.batch_size, shuffle=True, num_workers=16, - pin_memory=True, persistent_workers=True) # collate_fn=pad_list_data_collate) - # list_data_collate is only useful when each input in the batch has different shape + pin_memory=True, persistent_workers=True) def val_dataloader(self): return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, @@ -157,11 +156,12 @@ def test_dataloader(self): # -------------------------------- def configure_optimizers(self): if self.args.optimizer == "sgd": - optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=1e-5, nesterov=True) + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=3e-5, nesterov=True) else: - optimizer = self.optimizer_class(self.parameters(), lr=self.lr, weight_decay=1e-5) - # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) - scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) + optimizer = self.optimizer_class(self.parameters(), lr=self.lr) + # scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) + # NOTE: ivadomed using CosineAnnealingLR with T_max = 200 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.max_epochs) return [optimizer], [scheduler] @@ -172,11 +172,11 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # NOTE: surprisingly, filtering out empty patches is adding more CSA bias - # # check if any label image patch is empty in the batch - # if check_empty_patch(labels) is None: - # # print(f"Empty label patch found. Skipping training step ...") - # return None + # NOTE: surprisingly, filtering out empty patches is adding more CSA bias; TODO: verify with new patch size + # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + print(f"Empty label patch found. Skipping training step ...") + return None output = self.forward(inputs) # logits # if using dynunet, output.shape = (B, num_upsample_layers+1, C, H, W, D) @@ -440,9 +440,13 @@ def test_step(self, batch, batch_idx): # print(batch["label_meta_dict"]["filename_or_obj"][0]) # print(f"test_input.shape: {test_input.shape} \t test_label.shape: {test_label.shape}") batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, - sw_batch_size=4, predictor=self.forward, overlap=0.5) + sw_batch_size=1, predictor=self.forward, overlap=0.5) # print(f"batch['pred'].shape: {batch['pred'].shape}") + if self.args.model == "nnunet" and self.args.enable_DS: + # we only need the output with the highest resolution + batch["pred"] = batch["pred"][0] + # normalize the logits batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) @@ -548,7 +552,7 @@ def main(args): [2, 2, 2], [2, 2, 2], [2, 2, 2], - [1, 2, 2] + [1, 2, 2] # dims 0 and 2 of nnunet and monai images are swapped. Hence, 2x2x1 instead of 1x2x2 ], "conv_kernel_sizes": [ [3, 3, 3], @@ -575,10 +579,11 @@ def main(args): logger.info(f" Using ivadomed's UNet model! ") # this is the ivadomed unet model net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) - patch_size = "160x224x96" # "64x128x64" - save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_CSAdiceL_bestValCSA_nspv={args.num_samples_per_volume}" \ - f"_bs={args.batch_size}_{patch_size}" + patch_size = "48x176x288" # "160x224x96" # "64x128x64" + # f"_CSAdiceL_nspv={args.num_samples_per_volume}" \ + save_exp_id =f"ivado_reImp_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_AdapW_ValCCrop_Splin_bs={args.batch_size}_{patch_size}" + elif args.model in ["unetr"]: # define image size to be fed to the model @@ -602,33 +607,6 @@ def main(args): f"_fs={args.feature_size}_hs={args.hidden_size}_mlpd={args.mlp_dim}_nh={args.num_heads}" \ f"_CSAdiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{img_size}" \ - elif args.model in ["dynunet"]: - logger.info(f" Using MONAI's DynUNet model! ") - - # NOTE: these values are taken from nnUNetPlans.json - kernel_sizes = (3, 3, 3, 3, 3, 3) - stride_sizes = ((1, 1, 1), 2, 2, 2, 2, (1, 2, 2)) - # num_filters = (8, 16, 32, 64, 128, 256) - num_filters = (16, 32, 64, 128, 256, 320) - - # define model - net = DynUNet(spatial_dims=3, - in_channels=1, out_channels=1, - kernel_size=kernel_sizes, - strides=stride_sizes, - upsample_kernel_size=stride_sizes[1:], - filters=num_filters, - norm_name="instance", - deep_supervision=True, - deep_supr_num=4, #(len(stride_sizes)-2), - res_block=True, - dropout=0.3, - ) - patch_size = "160x224x96" - save_exp_id =f"{args.model}_initf=16_DS=4opt={args.optimizer}_lr={args.learning_rate}" \ - f"nspv={args.num_samples_per_volume}" \ - f"_bs={args.batch_size}_{patch_size}" - elif args.model in ["nnunet"]: if args.enable_DS: logger.info(f" Using nnUNet model WITH deep supervision! ") @@ -637,10 +615,10 @@ def main(args): # define model net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=args.enable_DS) - patch_size = "160x224x96" + patch_size = "160x192x80" save_exp_id =f"{args.model}_nf={args.init_filters}_DS={int(args.enable_DS)}" \ f"_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_CSAdiceL_nspv={args.num_samples_per_volume}" \ + f"_CSAdiceL_sameTFs_nspv={args.num_samples_per_volume}" \ f"_bs={args.batch_size}_{patch_size}" # define loss function @@ -676,22 +654,19 @@ def main(args): # don't use wandb logger if in debug mode # if not args.debug: + grp = f"monai_ivado_{args.model}" if args.model == "unet" else f"monai_{args.model}" exp_logger = pl.loggers.WandbLogger( name=save_exp_id, save_dir=args.save_path, - group=f"{args.model}", #_final", + group=grp, log_model=True, # save best model using checkpoint callback project='contrast-agnostic', entity='naga-karthik', config=args) - - # # saving the best model based on soft validation dice score - # checkpoint_callback = pl.callbacks.ModelCheckpoint( - # dirpath=save_path, filename='best_model', monitor='val_soft_dice', - # save_top_k=5, mode="max", save_last=False, save_weights_only=True) + # saving the best model based on validation CSA loss - checkpoint_callback_csa = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model_csa', monitor='val_csa_loss', + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_loss', monitor='val_loss', save_top_k=1, mode="min", save_last=False, save_weights_only=True) # saving the best model based on soft validation dice score @@ -701,30 +676,31 @@ def main(args): # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, # verbose=False, mode="max") - early_stopping = pl.callbacks.EarlyStopping(monitor="val_csa_loss", min_delta=0.00, patience=args.patience, + early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=False, mode="min") lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + # Saving training script to wandb + wandb.save("main.py") + wandb.save("transforms.py") + # initialise Lightning's trainer. trainer = pl.Trainer( devices=1, accelerator="gpu", # strategy="ddp", logger=exp_logger, - callbacks=[checkpoint_callback_csa, checkpoint_callback_dice, lr_monitor, early_stopping], + callbacks=[checkpoint_callback_loss, checkpoint_callback_dice, lr_monitor, early_stopping], check_val_every_n_epoch=args.check_val_every_n_epochs, max_epochs=args.max_epochs, precision=32, # TODO: see if 16-bit precision is stable # deterministic=True, - enable_progress_bar=args.enable_progress_bar, - profiler="simple",) # to profile the training time taken for each step + enable_progress_bar=args.enable_progress_bar,) + # profiler="simple",) # to profile the training time taken for each step # Train! trainer.fit(pl_model) logger.info(f" Training Done!") - # Saving training script to wandb - wandb.save("main.py") - # Test! trainer.test(pl_model) logger.info(f"TESTING DONE!") From 284898fde2cd67fa5161bd7ef161ecb46c6fc24d Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 12 Sep 2023 22:11:47 -0400 Subject: [PATCH 069/106] add transforms used by nnunet; replace centerCrop --- monai/transforms.py | 79 +++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index ac75b460..8b77d55f 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -26,62 +26,63 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): monai_transforms = [ - # # pre-processing - # LoadImaged(keys=["image", lbl_key]), - # EnsureChannelFirstd(keys=["image", lbl_key]), - # CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box - # NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - # Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear"),), - # # data-augmentation - # # SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), - # SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size - # # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a - # # foreground voxel as a center rather than a background voxel. - # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", - # spatial_size=crop_size, pos=2, neg=1, num_samples=num_samples_pv, - # # if num_samples=4, then 4 samples/image are randomly generated - # image_key="image", image_threshold=0.), - # # re-ordering transforms as used by nnunet - # RandAffined(keys=["image", lbl_key], mode=("bilinear", "bilinear"), prob=0.75, - # rotate_range=(-20.0, 20.0), scale_range=(0.8, 1.2), translate_range=(-0.1, 0.1)), - # # Rand3DElasticd(keys=["image", lbl_key], sigma_range=(3.5, 5.5), magnitude_range=(25, 35), prob=0.5), - # RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), - # RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0), prob=0.25), - # RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform - # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), - # RandAdjustContrastd(keys=["image"], gamma=(0.5, 1.5), prob=0.3), # this is monai's RandomGamma - # RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), - # RandFlipd(keys=["image", lbl_key], prob=0.5,), - # # RandRotated(keys=["image", lbl_key], mode=("bilinear", "nearest"), prob=0.2, - # # range_x=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), # NOTE: -pi/6 to pi/6 - # # range_y=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), - # # range_z=(-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)), - - # defining transforms as used by ivadomed (with the same probabilities) + # pre-processing LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), + CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box + # NOTE: spine interpolation with order=2 is spline, order=1 is linear Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), - ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=1.0, + # data-augmentation + # SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), + SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size + # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a + # foreground voxel as a center rather than a background voxel. + RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", + spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, + # if num_samples=4, then 4 samples/image are randomly generated + image_key="image", image_threshold=0.), + # re-ordering transforms as used by nnunet + RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.75, rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians - scale_range=(-0.2, 0.2), # ivadomed uses sth like scale_x = random.uniform(1 - self.scale[0], 1 + self.scale[0]), but monai adds 1.0 to the scale + scale_range=(-0.2, 0.2), translate_range=(-0.1, 0.1)), Rand3DElasticd(keys=["image", lbl_key], prob=0.5, sigma_range=(3.5, 5.5), magnitude_range=(25., 35.)), - # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), - # RandFlipd(keys=["image", lbl_key], prob=0.5,), + RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform + RandFlipd(keys=["image", lbl_key], prob=0.3,), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + + # # defining transforms as used by ivadomed (with the same probabilities) + # LoadImaged(keys=["image", lbl_key], image_only=False), # image_only=True to avoid loading the label + # EnsureChannelFirstd(keys=["image", lbl_key]), + # Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), + # ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + # RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=1.0, + # rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians + # scale_range=(-0.2, 0.2), # ivadomed uses sth like scale_x = random.uniform(1 - self.scale[0], 1 + self.scale[0]), but monai adds 1.0 to the scale + # translate_range=(-0.1, 0.1)), + # Rand3DElasticd(keys=["image", lbl_key], prob=0.5, + # sigma_range=(3.5, 5.5), + # magnitude_range=(25., 35.)), + # # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + # RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma + # RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + # RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), + # # RandFlipd(keys=["image", lbl_key], prob=0.5,), + # NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ] return Compose(monai_transforms) def val_transforms(lbl_key="label"): return Compose([ - LoadImaged(keys=["image", lbl_key]), + LoadImaged(keys=["image", lbl_key], image_only=False), EnsureChannelFirstd(keys=["image", lbl_key]), # Orientationd(keys=["image", lbl_key], axcodes="RPI"), CropForegroundd(keys=["image", lbl_key], source_key="image"), @@ -91,7 +92,7 @@ def val_transforms(lbl_key="label"): def val_transforms_with_center_crop(crop_size, lbl_key="label"): return Compose([ - LoadImaged(keys=["image", lbl_key]), + LoadImaged(keys=["image", lbl_key], image_only=False), EnsureChannelFirstd(keys=["image", lbl_key]), # CropForegroundd(keys=["image", lbl_key], source_key="image"), Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), From 8c9fcd9454824b4cad5911c2e162dbc9ec8b07cc Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 12 Sep 2023 22:14:06 -0400 Subject: [PATCH 070/106] add feature to resume training from checkpoint --- monai/main.py | 166 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 109 insertions(+), 57 deletions(-) diff --git a/monai/main.py b/monai/main.py index a739821a..6bb1f8fb 100644 --- a/monai/main.py +++ b/monai/main.py @@ -630,30 +630,43 @@ def main(args): timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format save_exp_id = f"{save_exp_id}_{timestamp}" - # to save the best model on validation - save_path = os.path.join(args.save_path, f"{save_exp_id}") - if not os.path.exists(save_path): - os.makedirs(save_path, exist_ok=True) + # define callbacks + # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, + # verbose=False, mode="max") + early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, + verbose=False, mode="min") - # to save the results/model predictions - results_path = os.path.join(args.results_dir, f"{save_exp_id}") - if not os.path.exists(results_path): - os.makedirs(results_path, exist_ok=True) + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') - # train across all folds of the dataset - for fold in range(args.num_cv_folds): - logger.info(f" Training on fold {fold+1} out of {args.num_cv_folds} folds! ") - # timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format - # save_exp_id = f"{save_exp_id}_fold={fold}_{timestamp}" + if not args.continue_from_checkpoint: + # to save the best model on validation + save_path = os.path.join(args.save_path, f"{save_exp_id}") + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) - # i.e. train by loading weights from scratch - pl_model = Model(args, data_root=dataset_root, fold_num=fold, - optimizer_class=optimizer_class, loss_function=loss_func, net=net, - exp_id=save_exp_id, results_path=results_path) + # to save the results/model predictions + results_path = os.path.join(args.results_dir, f"{save_exp_id}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) - # don't use wandb logger if in debug mode - # if not args.debug: + # i.e. train by loading weights from scratch + pl_model = Model(args, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_loss', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=False) + + # saving the best model based on soft validation dice score + checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + logger.info(f" Starting training from scratch! ") + # wandb logger grp = f"monai_ivado_{args.model}" if args.model == "unet" else f"monai_{args.model}" exp_logger = pl.loggers.WandbLogger( name=save_exp_id, @@ -663,27 +676,67 @@ def main(args): project='contrast-agnostic', entity='naga-karthik', config=args) + + # Saving training script to wandb + wandb.save("main.py") + wandb.save("transforms.py") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", # strategy="ddp", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, checkpoint_callback_dice, lr_monitor, early_stopping], + check_val_every_n_epoch=args.check_val_every_n_epochs, + max_epochs=args.max_epochs, + precision=32, # TODO: see if 16-bit precision is stable + # deterministic=True, + enable_progress_bar=args.enable_progress_bar,) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + else: + logger.info(f" Resuming training from the latest checkpoint! ") + + # check if wandb run folder is provided to resume using the same run + if args.wandb_run_folder is None: + raise ValueError("Please provide the wandb run folder to resume training using the same run on WandB!") + else: + wandb_run_folder = os.path.basename(args.wandb_run_folder) + wandb_run_id = wandb_run_folder.split("-")[-1] + + save_exp_id = args.save_path + save_path = os.path.dirname(args.save_path) + print(f"save_path: {save_path}") + results_path = args.results_dir + + # i.e. train by loading weights from scratch + pl_model = Model(args, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) # saving the best model based on validation CSA loss checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( dirpath=save_path, filename='best_model_loss', monitor='val_loss', - save_top_k=1, mode="min", save_last=False, save_weights_only=True) + save_top_k=1, mode="min", save_last=True, save_weights_only=True) # saving the best model based on soft validation dice score checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', save_top_k=1, mode="max", save_last=False, save_weights_only=True) - - # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, - # verbose=False, mode="max") - early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, - verbose=False, mode="min") - lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') - - # Saving training script to wandb - wandb.save("main.py") - wandb.save("transforms.py") + # wandb logger + grp = f"monai_ivado_{args.model}" if args.model == "unet" else f"monai_{args.model}" + exp_logger = pl.loggers.WandbLogger( + save_dir=save_path, + group=grp, + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=args, + id=wandb_run_id, resume='must') # initialise Lightning's trainer. trainer = pl.Trainer( @@ -692,42 +745,41 @@ def main(args): callbacks=[checkpoint_callback_loss, checkpoint_callback_dice, lr_monitor, early_stopping], check_val_every_n_epoch=args.check_val_every_n_epochs, max_epochs=args.max_epochs, - precision=32, # TODO: see if 16-bit precision is stable - # deterministic=True, + precision=32, enable_progress_bar=args.enable_progress_bar,) # profiler="simple",) # to profile the training time taken for each step # Train! - trainer.fit(pl_model) + trainer.fit(pl_model, ckpt_path=os.path.join(save_exp_id, "last.ckpt"),) logger.info(f" Training Done!") - # Test! - trainer.test(pl_model) - logger.info(f"TESTING DONE!") + # Test! + trainer.test(pl_model) + logger.info(f"TESTING DONE!") - # closing the current wandb instance so that a new one is created for the next fold - wandb.finish() + # closing the current wandb instance so that a new one is created for the next fold + wandb.finish() + + # TODO: Figure out saving test metrics to a file + with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: + print('\n-------------- Test Metrics ----------------', file=f) + print(f"\nSeed Used: {args.seed}", file=f) + print(f"\ninitf={args.init_filters}_lr={args.learning_rate}_bs={args.batch_size}_{timestamp}", file=f) + print(f"\npatch_size={pl_model.voxel_cropping_size}", file=f) - # TODO: Figure out saving test metrics to a file - with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: - print('\n-------------- Test Metrics ----------------', file=f) - print(f"\nSeed Used: {args.seed}", file=f) - print(f"\ninitf={args.init_filters}_lr={args.learning_rate}_bs={args.batch_size}_{timestamp}", file=f) - print(f"\npatch_size={pl_model.voxel_cropping_size}", file=f) - - print('\n-------------- Test Hard Dice Scores ----------------', file=f) - print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) + print('\n-------------- Test Hard Dice Scores ----------------', file=f) + print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) - print('\n-------------- Test Soft Dice Scores ----------------', file=f) - print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) + print('\n-------------- Test Soft Dice Scores ----------------', file=f) + print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) - print('\n-------------- Test Precision Scores ----------------', file=f) - print("Precision --> Mean: %0.3f" % (pl_model.avg_test_precision), file=f) + print('\n-------------- Test Precision Scores ----------------', file=f) + print("Precision --> Mean: %0.3f" % (pl_model.avg_test_precision), file=f) - print('\n-------------- Test Recall Scores -------------------', file=f) - print("Recall --> Mean: %0.3f" % (pl_model.avg_test_recall), file=f) + print('\n-------------- Test Recall Scores -------------------', file=f) + print("Recall --> Mean: %0.3f" % (pl_model.avg_test_recall), file=f) - print('-------------------------------------------------------', file=f) + print('-------------------------------------------------------', file=f) if __name__ == "__main__": @@ -739,7 +791,6 @@ def main(args): parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') # dataset parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") - parser.add_argument('-ncv', '--num_cv_folds', default=5, type=int, help="Number of cross validation folds") # unet model parser.add_argument('-initf', '--init_filters', default=16, type=int, help="Number of Filters in Init Layer") @@ -769,12 +820,13 @@ def main(args): parser.add_argument('-sp', '--save_path', default=f"/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", type=str, help='Path to the saved models directory') - parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', - help='Load model from checkpoint and continue training') parser.add_argument('-se', '--seed', default=42, type=int, help='Set seeds for reproducibility') parser.add_argument('-debug', default=False, action='store_true', help='if true, results are not logged to wandb') parser.add_argument('-stp', '--save_test_preds', default=False, action='store_true', help='if true, test predictions are saved in `save_path`') + parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', + help='Load model from checkpoint and continue training') + parser.add_argument('-wdb-run', '--wandb-run-folder', default=None, type=str, help='Path to the wandb run folder') # testing parser.add_argument('-rd', '--results_dir', default=f"/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/results", From 6b8836271c0fdaec451ab1bc78dbb6720b745b83 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 12 Sep 2023 22:14:48 -0400 Subject: [PATCH 071/106] fix warnings; remove comments; fix save names --- monai/main.py | 42 +++++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/monai/main.py b/monai/main.py index 6bb1f8fb..4c5140b8 100644 --- a/monai/main.py +++ b/monai/main.py @@ -23,19 +23,16 @@ from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) # TODO: -# 1. Remove centerCropping but train with AdapWingLoss -# 2. Modify probs of transformations; add new ones +# 1. increase omega in adapwingloss # create a "model"-agnostic class with PL to use different models class Model(pl.LightningModule): - def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_class, - exp_id=None, results_path=None): + def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): super().__init__() self.args = args - self.save_hyperparameters(ignore=['net']) + self.save_hyperparameters(ignore=['net', 'loss_function']) self.root = data_root - self.fold_num = fold_num self.net = net self.lr = args.learning_rate self.loss_function = loss_function @@ -53,7 +50,8 @@ def __init__(self, args, data_root, fold_num, net, loss_function, optimizer_clas # which could be sub-optimal. # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that # only the SC is in context. - self.voxel_cropping_size = (48, 176, 288) + # self.voxel_cropping_size = (48, 176, 288) + self.voxel_cropping_size = (48, 192, 256) self.inference_roi_size = self.voxel_cropping_size self.spacing = (1.0, 1.0, 1.0) @@ -101,8 +99,8 @@ def prepare_data(self): num_samples_pv=self.args.num_samples_per_volume, lbl_key='label' ) - # transforms_val = val_transforms(lbl_key='label') - transforms_val = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') + transforms_val = val_transforms(lbl_key='label') + # transforms_val = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # load the dataset dataset = os.path.join(self.root, f"spine-generic-ivado-comparison_dataset.json") @@ -120,8 +118,8 @@ def prepare_data(self): self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) # define test transforms - # transforms_test = val_transforms(lbl_key='label') - transforms_test = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') + transforms_test = val_transforms(lbl_key='label') + # transforms_test = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 @@ -175,7 +173,7 @@ def training_step(self, batch, batch_idx): # NOTE: surprisingly, filtering out empty patches is adding more CSA bias; TODO: verify with new patch size # check if any label image patch is empty in the batch if check_empty_patch(labels) is None: - print(f"Empty label patch found. Skipping training step ...") + # print(f"Empty label patch found. Skipping training step ...") return None output = self.forward(inputs) # logits @@ -317,7 +315,7 @@ def validation_step(self, batch, batch_idx): # NOTE: this calculates the loss on the entire image after sliding window outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", - sw_batch_size=1, predictor=self.forward, overlap=0.5,) + sw_batch_size=4, predictor=self.forward, overlap=0.5,) # outputs shape: (B, C, ) if self.args.model == "nnunet" and self.args.enable_DS: @@ -440,7 +438,7 @@ def test_step(self, batch, batch_idx): # print(batch["label_meta_dict"]["filename_or_obj"][0]) # print(f"test_input.shape: {test_input.shape} \t test_label.shape: {test_label.shape}") batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, - sw_batch_size=1, predictor=self.forward, overlap=0.5) + sw_batch_size=4, predictor=self.forward, overlap=0.5) # print(f"batch['pred'].shape: {batch['pred'].shape}") if self.args.model == "nnunet" and self.args.enable_DS: @@ -450,9 +448,6 @@ def test_step(self, batch, batch_idx): # normalize the logits batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) - # # upon fsleyes visualization, observed that very small values need to be set to zero, but NOT fully binarizing the pred - # batch["pred"][batch["pred"] < 0.099] = 0.0 - post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] # make sure that the shapes of prediction and GT label are the same @@ -471,13 +466,13 @@ def test_step(self, batch, batch_idx): save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) pred_saver = SaveImage( output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", - separate_folder=False, print_log=False) + separate_folder=False, print_log=False, resample=True) # save the prediction pred_saver(pred) label_saver = SaveImage( output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", - separate_folder=False, print_log=False) + separate_folder=False, print_log=False, resample=True) # save the label label_saver(label) @@ -579,10 +574,11 @@ def main(args): logger.info(f" Using ivadomed's UNet model! ") # this is the ivadomed unet model net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) - patch_size = "48x176x288" # "160x224x96" # "64x128x64" - # f"_CSAdiceL_nspv={args.num_samples_per_volume}" \ - save_exp_id =f"ivado_reImp_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_AdapW_ValCCrop_Splin_bs={args.batch_size}_{patch_size}" + patch_size = "48x192x256" # "160x224x96" + save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ + f"_AdapW_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{patch_size}" + if args.debug: + save_exp_id = f"DEBUG_{save_exp_id}" elif args.model in ["unetr"]: From 3a76d312382e0ebbfd617a9dd35b3365401a2c1b Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:05:59 -0400 Subject: [PATCH 072/106] add variants of validation transforms --- monai/transforms.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 8b77d55f..f00f628a 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -80,7 +80,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): return Compose(monai_transforms) -def val_transforms(lbl_key="label"): +def val_transforms_without_center_crop(lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key], image_only=False), EnsureChannelFirstd(keys=["image", lbl_key]), @@ -90,7 +90,19 @@ def val_transforms(lbl_key="label"): NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ]) -def val_transforms_with_center_crop(crop_size, lbl_key="label"): +def val_transforms_with_orientation_and_crop(crop_size, lbl_key="label"): + return Compose([ + LoadImaged(keys=["image", lbl_key], image_only=False), + EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + Orientationd(keys=["image", lbl_key], axcodes="RPI"), + Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + # TODO: do cropping only in R-L so sth like (48, -1, -1) + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + +def val_transforms(crop_size, lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key], image_only=False), EnsureChannelFirstd(keys=["image", lbl_key]), From 0df19a648983b37034f28ce8708420e9e21fad51 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:06:40 -0400 Subject: [PATCH 073/106] re-use nnunet-like transforms --- monai/transforms.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index f00f628a..a0a18c03 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -22,6 +22,8 @@ # 2. CenterCrop using 46x176x288 # 3. RandomAffine --> RandomElastic --> RandomGamma --> RandomBiasField --> RandomBlur --> NormalizeInstance +# TODO: Use cropping on R-L (56) and A-P (176) but not on S-I to avoid cropping the last few slices +# At test time, check performance on cropped and uncropped full images def train_transforms(crop_size, num_samples_pv, lbl_key="label"): @@ -29,18 +31,18 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): # pre-processing LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box + # CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box # NOTE: spine interpolation with order=2 is spline, order=1 is linear Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), + ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), # data-augmentation - # SpatialPadd(keys=["image", lbl_key], spatial_size=(192, 228, 106), method="symmetric"), - SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size - # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a - # foreground voxel as a center rather than a background voxel. - RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", - spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, - # if num_samples=4, then 4 samples/image are randomly generated - image_key="image", image_threshold=0.), + # SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size + # # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a + # # foreground voxel as a center rather than a background voxel. + # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", + # spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, + # # if num_samples=4, then 4 samples/image are randomly generated + # image_key="image", image_threshold=0.), # re-ordering transforms as used by nnunet RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.75, rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians From 79b369fff6338b59cd950e186e4a1013f4079c87 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:11:07 -0400 Subject: [PATCH 074/106] update code to train nnunet-based model --- monai/main.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/monai/main.py b/monai/main.py index 4c5140b8..ae6126a4 100644 --- a/monai/main.py +++ b/monai/main.py @@ -13,7 +13,7 @@ from utils import precision_score, recall_score, dice_score, compute_average_csa, \ PolyLRScheduler, plot_slices, check_empty_patch from losses import SoftDiceLoss, AdapWingLoss -from transforms import train_transforms, val_transforms, val_transforms_with_center_crop +from transforms import train_transforms, val_transforms from models import ModifiedUNet3D, create_nnunet_from_plans from monai.utils import set_determinism @@ -51,9 +51,12 @@ def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id= # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that # only the SC is in context. # self.voxel_cropping_size = (48, 176, 288) - self.voxel_cropping_size = (48, 192, 256) - self.inference_roi_size = self.voxel_cropping_size self.spacing = (1.0, 1.0, 1.0) + self.voxel_cropping_size = (48, 160, 320) + self.inference_roi_size = tuple([int(i) for i in args.val_crop_size.split("x")]) + if self.inference_roi_size == (-1,): # means no cropping is required + logger.info(f"Using full image for validation ...") + self.inference_roi_size = (-1, -1, -1) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) self.val_post_pred = Compose([EnsureType()]) @@ -99,7 +102,7 @@ def prepare_data(self): num_samples_pv=self.args.num_samples_per_volume, lbl_key='label' ) - transforms_val = val_transforms(lbl_key='label') + transforms_val = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') # transforms_val = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # load the dataset @@ -118,7 +121,7 @@ def prepare_data(self): self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) # define test transforms - transforms_test = val_transforms(lbl_key='label') + transforms_test = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') # transforms_test = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from @@ -183,7 +186,7 @@ def training_step(self, batch, batch_idx): if self.args.model == "nnunet" and self.args.enable_DS: # calculate dice loss for each output - dice_loss, train_soft_dice = 0.0, 0.0 + loss, train_soft_dice = 0.0, 0.0 for i in range(len(output)): # give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss @@ -192,7 +195,7 @@ def training_step(self, batch, batch_idx): # (instead of upsampling each deepsupervision feature map output to the final resolution) downsampled_gt = F.interpolate(labels, size=output[i].shape[-3:], mode='trilinear', align_corners=False) # print(f"downsampled_gt.shape: {downsampled_gt.shape} \t output[i].shape: {output[i].shape}") - dice_loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) + loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) # get probabilities from logits out = F.relu(output[i]) / F.relu(output[i]).max() if bool(F.relu(output[i]).max()) else F.relu(output[i]) @@ -203,8 +206,8 @@ def training_step(self, batch, batch_idx): train_soft_dice += self.soft_dice_metric(out, downsampled_gt) # average dice loss across the outputs - dice_loss = dice_loss / len(output) - train_soft_dice = train_soft_dice / len(output) + loss /= len(output) + train_soft_dice /= len(output) # # binarize the predictions and the labels (take only the final feature map i.e. the final prediction) # output = (output[0].detach() > 0.5).float() @@ -547,7 +550,7 @@ def main(args): [2, 2, 2], [2, 2, 2], [2, 2, 2], - [1, 2, 2] # dims 0 and 2 of nnunet and monai images are swapped. Hence, 2x2x1 instead of 1x2x2 + [1, 2, 2] ], "conv_kernel_sizes": [ [3, 3, 3], @@ -574,9 +577,9 @@ def main(args): logger.info(f" Using ivadomed's UNet model! ") # this is the ivadomed unet model net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) - patch_size = "48x192x256" # "160x224x96" + patch_size = "48x160x320" # "160x224x96" save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_AdapW_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{patch_size}" + f"_AdapW_valCCrop_bs={args.batch_size}_{patch_size}" if args.debug: save_exp_id = f"DEBUG_{save_exp_id}" @@ -611,17 +614,12 @@ def main(args): # define model net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=args.enable_DS) - patch_size = "160x192x80" + patch_size = "48x160x320" save_exp_id =f"{args.model}_nf={args.init_filters}_DS={int(args.enable_DS)}" \ f"_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_CSAdiceL_sameTFs_nspv={args.num_samples_per_volume}" \ + f"_AdapW_CCrop" \ f"_bs={args.batch_size}_{patch_size}" - # define loss function - # loss_func = SoftDiceLoss(p=1, smooth=1.0) - loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") - logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon}!") - # TODO: move this inside the for loop when using more folds timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format save_exp_id = f"{save_exp_id}_{timestamp}" @@ -786,7 +784,11 @@ def main(args): default='unet', type=str, help='Model type to be used') parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') # dataset - parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") + parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") + # define args for cropping size. inputs should be in the format of "48x192x256" + parser.add_argument('-val-crop', '--val_crop_size', type=str, default="48x192x256", + help='Center crop size for validation and testing. Values correspond to R-L, A-P, I-S axes' + 'of the image. Use -1 if no cropping is intended. Default: 48x160x320') # unet model parser.add_argument('-initf', '--init_filters', default=16, type=int, help="Number of Filters in Init Layer") From 5affa03504fa4eb3a81253ab82bbfb4e97373dc5 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:12:00 -0400 Subject: [PATCH 075/106] modify code to log everything to a txt file --- monai/main.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/monai/main.py b/monai/main.py index ae6126a4..0e81ca5b 100644 --- a/monai/main.py +++ b/monai/main.py @@ -407,14 +407,15 @@ def on_validation_epoch_end(self): self.best_val_loss = mean_val_loss self.best_val_epoch = self.current_epoch - print( - f"Current epoch: {self.current_epoch}" + logger.info( + f"\nCurrent epoch: {self.current_epoch}" f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" # f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") + # log on to wandb self.log_dict(wandb_logs) @@ -463,7 +464,7 @@ def test_step(self, batch, batch_idx): if self.args.save_test_preds: subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") - print(f"Saving subject: {subject_name}") + logger.info(f"Saving subject: {subject_name}") # image saver class save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) @@ -624,6 +625,13 @@ def main(args): timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format save_exp_id = f"{save_exp_id}_{timestamp}" + # save output to a log file + logger.add(os.path.join(args.save_path, f"{save_exp_id}", "logs.txt"), rotation="10 MB", level="INFO") + + # define loss function + loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon}!") + # define callbacks # early_stopping = pl.callbacks.EarlyStopping(monitor="val_soft_dice", min_delta=0.00, patience=args.patience, # verbose=False, mode="max") @@ -650,6 +658,7 @@ def main(args): exp_id=save_exp_id, results_path=results_path) # saving the best model based on validation loss + logger.info(f"Saving best model to {save_path}!") checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( dirpath=save_path, filename='best_model_loss', monitor='val_loss', save_top_k=1, mode="min", save_last=True, save_weights_only=False) @@ -703,7 +712,7 @@ def main(args): save_exp_id = args.save_path save_path = os.path.dirname(args.save_path) - print(f"save_path: {save_path}") + logger.info(f"save_path: {save_path}") results_path = args.results_dir # i.e. train by loading weights from scratch @@ -713,12 +722,12 @@ def main(args): # saving the best model based on validation CSA loss checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model_loss', monitor='val_loss', + dirpath=save_exp_id, filename='best_model_loss', monitor='val_loss', save_top_k=1, mode="min", save_last=True, save_weights_only=True) # saving the best model based on soft validation dice score checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( - dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + dirpath=save_exp_id, filename='best_model_dice', monitor='val_soft_dice', save_top_k=1, mode="max", save_last=False, save_weights_only=True) # wandb logger From e52c6c830471a841a42a42423de8ba831930e01b Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:49:11 -0400 Subject: [PATCH 076/106] add code for inference with monai-based nnunet --- monai/run_inference.py | 91 +++++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index 52d481b9..2bc6814b 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -12,23 +12,38 @@ from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) -from transforms import val_transforms +from transforms import val_transforms, val_transforms_with_orientation_and_crop from utils import precision_score, recall_score, dice_score -from models import ModifiedUNet3D -from monai.networks.nets import UNETR - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -DEBUG = False -INFERENCE_ROI_SIZE = (160, 224, 96) # (80, 192, 160) -# UNET params -INIT_FILTERS=8 -# UNETR params -FEATURE_SIZE = 8 -HIDDEN_SIZE = 512 -MLP_DIM = 1024 -NUM_HEADS = 8 - -EXAMPLE_INPUT = torch.randn(1, 1, 160, 224, 96).to(DEVICE) +from models import ModifiedUNet3D, create_nnunet_from_plans + + +# NNUNET global params +INIT_FILTERS=32 +ENABLE_DS = True + +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": INIT_FILTERS, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} def get_parser(): @@ -64,7 +79,7 @@ def prepare_data(root, dataset_name="spine-generic"): test_files = test_files[:3] # define test transforms - transforms_test = val_transforms(lbl_key='label') + transforms_test = val_transforms_with_orientation_and_crop(crop_size=crop_size, lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 @@ -93,31 +108,36 @@ def main(args): results_path = args.path_out model_name = args.chkp_path.split("/")[-1] - results_path = os.path.join(results_path, dataset_name, model_name) + if args.best_model_type == "dice": + chkp_paths = [os.path.join(args.chkp_path, "best_model_dice.ckpt")] + results_path = os.path.join(results_path, dataset_name, model_name, "best_dice") + elif args.best_model_type == "loss": + chkp_paths = [os.path.join(args.chkp_path, "best_model_loss.ckpt")] + results_path = os.path.join(results_path, dataset_name, model_name) + + # save terminal outputs to a file + logger.add(os.path.join(results_path, "logs.txt"), rotation="10 MB", level="INFO") + + logger.info(f"Saving results to: {results_path}") if not os.path.exists(results_path): os.makedirs(results_path, exist_ok=True) + # define cropping size + inference_roi_size = tuple([int(i) for i in args.crop_size.split("x")]) + if inference_roi_size == (-1,): # means no cropping is required + logger.info(f"Doing Sliding Window Inference on Whole Images ...") + inference_roi_size = (-1, -1, -1) + # define the dataset and dataloader - test_ds, test_post_pred = prepare_data(dataset_root, dataset_name) + test_ds, test_post_pred = prepare_data(dataset_root, dataset_name, crop_size=inference_roi_size) test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) if args.model == "unet": # initialize ivadomed unet model net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) - elif args.model == "unetr": - # initialize unetr model - net = UNETR(spatial_dims=3, - in_channels=1, out_channels=1, - img_size=INFERENCE_ROI_SIZE, - feature_size=FEATURE_SIZE, - hidden_size=HIDDEN_SIZE, - mlp_dim=MLP_DIM, - num_heads=NUM_HEADS, - pos_embed="conv", - norm_name="instance", - res_block=True, - dropout_rate=0.2, - ) + elif args.model == "nnunet": + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=ENABLE_DS) # define list to collect the test metrics test_step_outputs = [] @@ -154,6 +174,11 @@ def main(args): # run inference batch["pred"] = sliding_window_inference(test_input, INFERENCE_ROI_SIZE, mode="gaussian", sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + if ENABLE_DS and args.model == "nnunet": + # take only the highest resolution prediction + batch["pred"] = batch["pred"][0] + # NOTE: monai's models do not normalize the output, so we need to do it manually if bool(F.relu(batch["pred"]).max()): batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() From 36ae67ef88df834766e2a7c793ebcd6980018092 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:52:17 -0400 Subject: [PATCH 077/106] save terminal outputs to log file; add args --- monai/run_inference.py | 51 ++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index 2bc6814b..a3c8a507 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -59,24 +59,35 @@ def get_parser(): help="Name of the dataset to run inference on") parser.add_argument("--model", type=str, default="unet", required=True, help="Name of the model to use for inference") - + parser.add_argument("--best-model-type", type=str, default="dice", required=True, choices=["csa", "dice", "loss", "all"], + help="Type of the best model to use for inference i.e. based on csa/dice/both") + # define args for cropping size. inputs should be in the format of "48x192x256" + parser.add_argument('-crop', '--crop-size', type=str, default="48x160x320", + help='Patch size used for center cropping the images during inference. Values correspond to R-L, A-P, I-S axes' + 'of the image. Sliding window will be run across the cropped images. Use -1 if no cropping is intended ' + '(sliding window will run across the whole image). Note, heavy R-L, A-P cropping is recommmended for best ' + 'results. Default: 48x160x320') + parser.add_argument('-debug', default=False, action='store_true', + help='run inference only on a few images to check if things are working') + parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], + help='Device to run inference on. Default: gpu') + return parser # -------------------------------- # DATA # -------------------------------- -def prepare_data(root, dataset_name="spine-generic"): +def prepare_data(root, dataset_name="spine-generic", crop_size=(48, 160, 320)): # set deterministic training for reproducibility # set_determinism(seed=self.args.seed) # load the dataset dataset = os.path.join(root, f"{dataset_name}_dataset.json") - # dataset = os.path.join(root, f"dataset_ivado_comparison.json") test_files = load_decathlon_datalist(dataset, True, "test") - if DEBUG: # args.debug: - test_files = test_files[:3] + if args.debug: + test_files = test_files[:6] # define test transforms transforms_test = val_transforms_with_orientation_and_crop(crop_size=crop_size, lbl_key='label') @@ -100,12 +111,17 @@ def main(args): # define start time start = time() + # define device + if args.device == "gpu" and not torch.cuda.is_available(): + logger.warning("GPU not available, using CPU instead") + DEVICE = torch.device("cpu") + else: + DEVICE = torch.device("cuda" if torch.cuda.is_available() and args.device == "gpu" else "cpu") + # define root path for finding datalists dataset_root = args.path_json dataset_name = args.dataset_name - # chkp_path = os.path.join(args.chkp_path, "best_model.ckpt") - results_path = args.path_out model_name = args.chkp_path.split("/")[-1] if args.best_model_type == "dice": @@ -154,9 +170,7 @@ def main(args): test_input = batch["image"].to(DEVICE) # load the checkpoints - for chkp in os.listdir(args.chkp_path): - chkp_path = os.path.join(args.chkp_path, chkp) - # print(f"Loading checkpoint: {chkp_path}") + for chkp_path in chkp_paths: checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning @@ -172,7 +186,7 @@ def main(args): net.eval() # run inference - batch["pred"] = sliding_window_inference(test_input, INFERENCE_ROI_SIZE, mode="gaussian", + batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian", sw_batch_size=4, predictor=net, overlap=0.5, progress=False) if ENABLE_DS and args.model == "nnunet": @@ -188,7 +202,6 @@ def main(args): post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] # make sure that the shapes of prediction and GT label are the same - # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() @@ -198,7 +211,7 @@ def main(args): # save the (soft) prediction and label subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") - print(f"Saving subject: {subject_name}") + logger.info(f"Saving subject: {subject_name}") # take the average of the predictions pred = torch.stack(preds_stack).mean(dim=0) @@ -214,12 +227,6 @@ def main(args): separate_folder=False, print_log=False) # save the prediction pred_saver(pred) - - # label_saver = SaveImage( - # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", - # separate_folder=False, print_log=False) - # # save the label - # label_saver(label) # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics # calculate all metrics here @@ -287,9 +294,9 @@ def main(args): end = time() - print("=====================================================================") - print(f"Total time taken for inference: {(end - start) / 60:.2f} minutes") - print("=====================================================================") + logger.info("=====================================================================") + logger.info(f"Total time taken for inference: {(end - start) / 60:.2f} minutes") + logger.info("=====================================================================") if __name__ == "__main__": From 9763ed510889d03c3e8ca9a38b0e368bfe99dcb5 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 18 Sep 2023 12:52:51 -0400 Subject: [PATCH 078/106] add documentation on usage --- monai/run_inference.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/monai/run_inference.py b/monai/run_inference.py index a3c8a507..4df19c21 100644 --- a/monai/run_inference.py +++ b/monai/run_inference.py @@ -1,3 +1,16 @@ +""" +Script to run inference on a MONAI-based model for contrast-agnostic soft segmentation of the spinal cord. +Predictions are stored in independent folders for each subject. Summary of the test metrics (both per subject and overall) +are stored in a json file, along with the time taken for inference. + +Usage: + python run_inference.py --path-json --chkp-path --path-out + --model --best-model-type --crop_size <48x160x320> --device + +Author: Naga Karthik + +""" + import os import argparse import numpy as np @@ -6,7 +19,6 @@ import torch import json from time import time -from tqdm import tqdm from monai.inferers import sliding_window_inference from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) From c6e3e92ee478ca79503994a9ce6dbaf7e6f46570 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Wed, 20 Sep 2023 22:28:27 -0400 Subject: [PATCH 079/106] finalize train transforms --- monai/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms.py b/monai/transforms.py index a0a18c03..d7c2289f 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -44,7 +44,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): # # if num_samples=4, then 4 samples/image are randomly generated # image_key="image", image_threshold=0.), # re-ordering transforms as used by nnunet - RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.75, + RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.9, rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians scale_range=(-0.2, 0.2), translate_range=(-0.1, 0.1)), From 9bb37d31123d3069d83e16526c8400febe05446d Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Wed, 20 Sep 2023 22:29:44 -0400 Subject: [PATCH 080/106] rename to inference_transforms(); add DivisiblePadd --- monai/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index d7c2289f..7005809a 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -2,7 +2,7 @@ import numpy as np from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, RandCropByPosNegLabeld, Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, - RandWeightedCropd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, + DivisiblePadd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, RandSimulateLowResolutiond, ResizeWithPadOrCropd) @@ -92,7 +92,7 @@ def val_transforms_without_center_crop(lbl_key="label"): NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ]) -def val_transforms_with_orientation_and_crop(crop_size, lbl_key="label"): +def inference_transforms(crop_size, lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key], image_only=False), EnsureChannelFirstd(keys=["image", lbl_key]), @@ -100,7 +100,7 @@ def val_transforms_with_orientation_and_crop(crop_size, lbl_key="label"): Orientationd(keys=["image", lbl_key], axcodes="RPI"), Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - # TODO: do cropping only in R-L so sth like (48, -1, -1) + DivisiblePadd(keys=["image", lbl_key], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ]) From 4ab0d4b675460f8c1bd3eff42f7a4f04c8d91d34 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 21 Sep 2023 11:04:47 -0400 Subject: [PATCH 081/106] add requirements for cpu inference --- monai/requirements_inference.txt | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 monai/requirements_inference.txt diff --git a/monai/requirements_inference.txt b/monai/requirements_inference.txt new file mode 100644 index 00000000..211f79d0 --- /dev/null +++ b/monai/requirements_inference.txt @@ -0,0 +1,6 @@ +dynamic_network_architectures==0.2 +joblib==1.3.0 +loguru==0.7.0 +monai==1.2.0 +numpy==1.24.4 +torch==2.0.0+cpu From 82113d82aaa31ac8b138e1270d4f6c3e82952c9e Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 21 Sep 2023 12:03:49 -0400 Subject: [PATCH 082/106] add init version of instructions --- monai/inference_instructions.md | 60 +++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 monai/inference_instructions.md diff --git a/monai/inference_instructions.md b/monai/inference_instructions.md new file mode 100644 index 00000000..3fe27ea5 --- /dev/null +++ b/monai/inference_instructions.md @@ -0,0 +1,60 @@ +## Instructions for running inference with the contrast-agnostic spinal cord segmentation model + +The following steps are required for using the contrast-agnostic model. + +### Step 1: Setting up the environment and Installing dependencies + +The following commands show how to set up the environment. Note that the documentation assumes that the user has `conda` installed on their system. Instructions on installing `conda` can be found [here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). + +1. Create a conda environment with the following command: + +```bash +conda create -n venv_monai python=3.9 +``` + +2. Activate the environment with the following command: + +```bash +conda activate venv_monai +``` + +3. The list of necessary packages can be found in `requirements_inference.txt`. Use the following command for installation: + +```bash +pip install -r requirements_inference.txt +``` + +### Step 2: Creating a datalist + +The inference script assumes the dataset to be in Medical Segmentation Decathlon-style `json` file format containing image-label pairs. The `create_inference_msd_datalist.py` script allows to create one for your dataset. Use the following command to create the datalist: + +```bash +python create_inference_msd_datalist.py --dataset-name spine-generic --path-data --path-out --contrast-suffix T1w +``` + +`--dataset-name` - Corresponds to name of the dataset. The datalist will be saved as `_dataset.json` +`--path-data` - Path to the BIDS dataset +`--path-out` - Path to the output folder. The datalist will be saved under `/_dataset.json` +`--contrast-suffix` - The suffix of the contrast to be used for pairing images/labels + +> **Note** +> This script is not meant to run off-the-shelf. Placeholders are provided to update the script with the .... TODO + + +### Step 3: Running inference + +Use the following command: + +```bash +python run_inference.py --path-json --chkp-path --path-out --model --crop_size <48x160x320> --device +``` + +`--path-json` - Path to the datalist created in Step 2 +`--chkp-path` - Path to the model checkpoint. This folder should contain the `best_model_loss.ckpt` +`--path-out` - Path to the output folder where the predictions will be saved +`--model` - Model to be used for inference. Currently, only `unet` and `nnunet` are supported +`--crop_size` - Crop size used for center cropping the image before running inference. Recommended to be set to a multiple of 32 +`--device` - Device to be used for inference. Currently, only `gpu` and `cpu` are supported + + + From 265c70440ef539eaff2ead21c6d418fda71df449 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 25 Sep 2023 16:31:11 -0400 Subject: [PATCH 083/106] fix torch cpu version download --- monai/requirements_inference.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/requirements_inference.txt b/monai/requirements_inference.txt index 211f79d0..2f50d4d8 100644 --- a/monai/requirements_inference.txt +++ b/monai/requirements_inference.txt @@ -3,4 +3,5 @@ joblib==1.3.0 loguru==0.7.0 monai==1.2.0 numpy==1.24.4 -torch==2.0.0+cpu +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.0.0+cpu From edaed3da98ded979ebfe9820cc867b7230158f35 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 25 Sep 2023 17:08:59 -0400 Subject: [PATCH 084/106] add script for inference on single image --- monai/run_inference_single_image.py | 268 ++++++++++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 monai/run_inference_single_image.py diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py new file mode 100644 index 00000000..a5c70ac8 --- /dev/null +++ b/monai/run_inference_single_image.py @@ -0,0 +1,268 @@ +""" +Script to run inference on a MONAI-based model for contrast-agnostic soft segmentation of the spinal cord. +Prediction is stored in an independent folder given by subject name. The time taken for inference is stored in a json file. + +Usage: + python run_inference_single_image.py --path-img /path/to/my-awesome-SC-image.nii.gz --chkp-path /path/to/best/model + --path-out /path/to/output/folder --crop-size <64x160x320> --device + +Author: Naga Karthik + +""" + +import os +import argparse +import numpy as np +from loguru import logger +import torch.nn.functional as F +import torch +import json +from time import time + +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) + +from transforms import inference_transforms_single_image +from models import create_nnunet_from_plans + + +# NNUNET global params +INIT_FILTERS=32 +ENABLE_DS = True + +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": INIT_FILTERS, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} + + +def get_parser(): + + parser = argparse.ArgumentParser(description="Run inference on a MONAI-trained model") + + parser.add_argument("--path-img", type=str, required=True, + help="Path to the image to run inference on") + parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint folder") + parser.add_argument("--path-out", type=str, required=True, + help="Path to the output folder where to store the predictions and associated metrics") + parser.add_argument('-crop', '--crop-size', type=str, default="48x160x320", + help='Patch size used for center cropping the images during inference. Values correspond to R-L, A-P, I-S axes' + 'of the image. Sliding window will be run across the cropped images. Use -1 if no cropping is intended ' + '(sliding window will run across the whole image). Note, heavy R-L, A-P cropping is recommmended for best ' + 'results. Default: 48x160x320') + parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], + help='Device to run inference on. Default: gpu') + + return parser + + +# -------------------------------- +# DATA +# -------------------------------- +def prepare_data(path_image, path_out, crop_size=(48, 160, 320)): + + # create a temporary datalist containing the image + # boiler plate keys to be defined in the MSD-style datalist + params = {} + params["description"] = "my-awesome-SC-image" + params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } + params["modality"] = { + "0": "MRI" + } + params["tensorImageSize"] = "3D" + params["test"] = [ + { + "image": path_image + } + ] + + final_json = json.dumps(params, indent=4, sort_keys=True) + jsonFile = open(path_out + "/" + f"temp_msd_datalist.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + dataset = os.path.join(path_out, f"temp_msd_datalist.json") + test_files = load_decathlon_datalist(dataset, True, "test") + + # define test transforms + transforms_test = inference_transforms_single_image(crop_size=crop_size) + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=transforms_test, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.75, num_workers=8) + + return test_ds, test_post_pred + + +def main(args): + + # define start time + start = time() + + # define device + if args.device == "gpu" and not torch.cuda.is_available(): + logger.warning("GPU not available, using CPU instead") + DEVICE = torch.device("cpu") + else: + DEVICE = torch.device("cuda" if torch.cuda.is_available() and args.device == "gpu" else "cpu") + + # define root path for finding datalists + path_image = args.path_img + results_path = args.path_out + chkp_path = os.path.join(args.chkp_path, "best_model_loss.ckpt") + + # save terminal outputs to a file + logger.add(os.path.join(results_path, "logs.txt"), rotation="10 MB", level="INFO") + + logger.info(f"Saving results to: {results_path}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # define inference patch size and center crop size + crop_size = tuple([int(i) for i in args.crop_size.split("x")]) + inference_roi_size = (64, 160, 320) + + # define the dataset and dataloader + test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size) + test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=ENABLE_DS) + + # define list to collect the test metrics + test_step_outputs = [] + test_summary = {} + + # iterate over the dataset and compute metrics + with torch.no_grad(): + for batch in test_loader: + + # compute time for inference per subject + start_time = time() + + # get the test input + test_input = batch["image"].to(DEVICE) + + # this loop only takes about 0.2s on average on a CPU + checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + + # load the trained model weights + net.load_state_dict(checkpoint) + net.to(DEVICE) + net.eval() + + # run inference + batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # take only the highest resolution prediction + batch["pred"] = batch["pred"][0] + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + pred = post_test_out[0]['pred'].cpu() + + # clip the prediction between 0.5 and 1 + pred = torch.clamp(pred, 0.5, 1) + # # threshold the prediction + # pred = (pred > 0.1).float() + + # get subject name + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + logger.info(f"Saving subject: {subject_name}") + + # this takes about 0.25s on average on a CPU + # image saver class + save_folder = os.path.join(results_path, subject_name.split("_")[0]) + pred_saver = SaveImage( + output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + end_time = time() + metrics_dict = { + "subject_name_and_contrast": subject_name, + "inference_time_in_sec": round((end_time - start_time), 2), + } + test_step_outputs.append(metrics_dict) + + # save the test summary + test_summary["metrics_per_subject"] = test_step_outputs + + # compute the average inference time + avg_inference_time = np.stack([x["inference_time_in_sec"] for x in test_step_outputs]).mean() + + # store the average metrics in a dict + avg_metrics = { + "avg_inference_time_in_sec": round(avg_inference_time, 2), + } + test_summary["metrics_avg_across_cohort"] = avg_metrics + + logger.info("========================================================") + logger.info(f" Inference Time per Subject: {avg_inference_time:.2f}s") + logger.info("========================================================") + + + # dump the test summary to a json file + with open(os.path.join(results_path, "test_summary.json"), "w") as f: + json.dump(test_summary, f, indent=4, sort_keys=True) + + # free up memory + test_step_outputs.clear() + test_summary.clear() + os.remove(os.path.join(results_path, "temp_msd_datalist.json")) + + end = time() + + # logger.info("===============================================================") + # logger.info(f"Total time taken for inference: {(end - start) / 60:.2f} minutes") + # logger.info("===============================================================") + + +if __name__ == "__main__": + + args = get_parser().parse_args() + main(args) \ No newline at end of file From 043804c28e7f3ae75d02962b2228bca2c9a16b85 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 25 Sep 2023 17:15:46 -0400 Subject: [PATCH 085/106] add instructions for single-image inference --- monai/inference_instructions.md | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/monai/inference_instructions.md b/monai/inference_instructions.md index 3fe27ea5..3f159710 100644 --- a/monai/inference_instructions.md +++ b/monai/inference_instructions.md @@ -2,7 +2,7 @@ The following steps are required for using the contrast-agnostic model. -### Step 1: Setting up the environment and Installing dependencies +### Setting up the environment and Installing dependencies The following commands show how to set up the environment. Note that the documentation assumes that the user has `conda` installed on their system. Instructions on installing `conda` can be found [here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). @@ -24,7 +24,22 @@ conda activate venv_monai pip install -r requirements_inference.txt ``` -### Step 2: Creating a datalist +### Method 1: Running inference on a single image + +```bash +python run_inference_single_image.py --path-img /path/to/my-awesome-SC-image.nii.gz --chkp-path /path/to/best/model --path-out /path/to/output/folder --crop-size <64x160x320> --device +``` + +`--path-img` - Path to the image to be segmented +`--chkp-path` - Path to the model checkpoint. This folder should contain the `best_model_loss.ckpt` +`--path-out` - Path to the output folder where the predictions will be saved +`--crop_size` - Crop size used for center cropping the image before running inference. Recommended to be set to a multiple of 32 +`--device` - Device to be used for inference. Currently, only `gpu` and `cpu` are supported + + +### Method 2: Running inference on a dataset + +#### Creating a datalist The inference script assumes the dataset to be in Medical Segmentation Decathlon-style `json` file format containing image-label pairs. The `create_inference_msd_datalist.py` script allows to create one for your dataset. Use the following command to create the datalist: @@ -41,7 +56,7 @@ python create_inference_msd_datalist.py --dataset-name spine-generic --path-data > This script is not meant to run off-the-shelf. Placeholders are provided to update the script with the .... TODO -### Step 3: Running inference +#### Running inference Use the following command: From 963e9d29c846c095604854f41112a1431f9fd235 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 25 Sep 2023 17:28:32 -0400 Subject: [PATCH 086/106] add inference transforms; remove transforms import --- monai/run_inference_single_image.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index a5c70ac8..3ed28624 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -19,12 +19,12 @@ import json from time import time +from models import create_nnunet_from_plans from monai.inferers import sliding_window_inference from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) - -from transforms import inference_transforms_single_image -from models import create_nnunet_from_plans +from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd, + LoadImaged, NormalizeIntensityd, EnsureChannelFirstd, + DivisiblePadd, Orientationd, ResizeWithPadOrCropd) # NNUNET global params @@ -76,6 +76,18 @@ def get_parser(): return parser +# define transforms for inference +def inference_transforms_single_image(crop_size): + return Compose([ + LoadImaged(keys=["image"], image_only=False), + EnsureChannelFirstd(keys=["image"]), + Orientationd(keys=["image"], axcodes="RPI"), + Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)), + ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,), + DivisiblePadd(keys=["image"], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) + NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + # -------------------------------- # DATA # -------------------------------- From 5330b19307fa732bdad203bf79ba9cd48af81a8b Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 25 Sep 2023 17:45:24 -0400 Subject: [PATCH 087/106] add nibabel dep for monai; add scipy --- monai/requirements_inference.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/requirements_inference.txt b/monai/requirements_inference.txt index 2f50d4d8..2cb4845f 100644 --- a/monai/requirements_inference.txt +++ b/monai/requirements_inference.txt @@ -1,7 +1,8 @@ dynamic_network_architectures==0.2 joblib==1.3.0 loguru==0.7.0 -monai==1.2.0 +monai[nibabel]==1.2.0 +scipy==1.11.2 numpy==1.24.4 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.0.0+cpu From e0d90b037dfb33af7ccc041894b7057e1a3745e3 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 26 Sep 2023 18:42:07 -0400 Subject: [PATCH 088/106] udpate args description for crop size --- monai/run_inference_single_image.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index 3ed28624..2a02ac8c 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -65,11 +65,11 @@ def get_parser(): parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint folder") parser.add_argument("--path-out", type=str, required=True, help="Path to the output folder where to store the predictions and associated metrics") - parser.add_argument('-crop', '--crop-size', type=str, default="48x160x320", + parser.add_argument('-crop', '--crop-size', type=str, default="64x160x320", help='Patch size used for center cropping the images during inference. Values correspond to R-L, A-P, I-S axes' - 'of the image. Sliding window will be run across the cropped images. Use -1 if no cropping is intended ' - '(sliding window will run across the whole image). Note, heavy R-L, A-P cropping is recommmended for best ' - 'results. Default: 48x160x320') + ' *in mm*. All images are resampled to 1mm isotropic before cropping. Inference is run on the cropped images.' + ' Use -1 if no cropping is intended. Note, heavy R-L cropping that positions the SC at the center of the image ' + 'is recommmended for best results. Default: 64x160x320') parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], help='Device to run inference on. Default: gpu') @@ -91,7 +91,7 @@ def inference_transforms_single_image(crop_size): # -------------------------------- # DATA # -------------------------------- -def prepare_data(path_image, path_out, crop_size=(48, 160, 320)): +def prepare_data(path_image, path_out, crop_size=(64, 160, 320)): # create a temporary datalist containing the image # boiler plate keys to be defined in the MSD-style datalist From 5a9a1236c1940f5fa648840f4cb4949f60f60bd6 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 26 Sep 2023 18:48:24 -0400 Subject: [PATCH 089/106] add sorting of subjects --- monai/create_msd_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 14a6d492..e7ccbff7 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -59,6 +59,10 @@ # Use the training split to further split into training and validation splits train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio), random_state=args.seed, ) + # sort the subjects + train_subjects = sorted(train_subjects) + val_subjects = sorted(val_subjects) + test_subjects = sorted(test_subjects) logger.info(f"Number of training subjects: {len(train_subjects)}") logger.info(f"Number of validation subjects: {len(val_subjects)}") From 57de7776e60ae43447ca32da2b1bdb4e21ccf235 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Wed, 27 Sep 2023 12:52:20 -0400 Subject: [PATCH 090/106] remove dataset-based inference --- monai/create_inference_msd_datalist.py | 98 -------- monai/run_inference.py | 317 ------------------------- 2 files changed, 415 deletions(-) delete mode 100644 monai/create_inference_msd_datalist.py delete mode 100644 monai/run_inference.py diff --git a/monai/create_inference_msd_datalist.py b/monai/create_inference_msd_datalist.py deleted file mode 100644 index 9cbcba56..00000000 --- a/monai/create_inference_msd_datalist.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import json -import argparse -import joblib -from loguru import logger - -parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.') - -parser.add_argument('-dname', '--dataset-name', default='spine-generic', type=str, help='Name of the dataset') -parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the data set directory') -parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.', - default=None, type=str) -parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') -parser.add_argument('-csuf', '--contrast-suffix', type=str, default='T1w', - help='Contrast suffix used in the BIDS dataset') -args = parser.parse_args() - - -def main(args): - - root = args.path_data - contrast = args.contrast_suffix - - # Get all subjects - # the participants.tsv file might not be up-to-date, hence rely on the existing folders - # subjects_df = pd.read_csv(os.path.join(root, 'participants.tsv'), sep='\t') - # subjects = subjects_df['participant_id'].values.tolist() - subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')] - logger.info(f"Total number of subjects in the root directory: {len(subjects)}") - - - if args.path_joblib is not None: - # load information from the joblib to match train and test subjects - joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib') - splits = joblib.load("split_datasets_all_seed=15.joblib") - # get the subjects from the joblib file - # train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']]))) - # val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']]))) - test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']]))) - - else: - test_subjects = subjects - - logger.info(f"Number of testing subjects: {len(test_subjects)}") - - # keys to be defined in the dataset_0.json - params = {} - params["description"] = args.dataset_name - params["labels"] = { - "0": "background", - "1": "soft-sc-seg" - } - params["license"] = "nk" - params["modality"] = { - "0": "MRI" - } - params["name"] = "spine-generic" - params["numTest"] = len(test_subjects) - params["reference"] = "University of Zurich" - params["tensorImageSize"] = "3D" - - test_subjects_dict = {"test": test_subjects} - - for name, subs_list in test_subjects_dict.items(): - - temp_list = [] - for subject_no, subject in enumerate(subs_list): - - temp_data= {} - - temp_data["image"] = os.path.join(root, subject, 'anat', f"{subject}_{contrast}.nii.gz") - if args.dataset_name == "sci-colorado": - temp_data["label"] = os.path.join(root, "derivatives", "labels", subject, 'anat', f"{subject}_{contrast}_seg-manual.nii.gz") - elif args.dataset_name == "basel-mp2rage-rpi": - temp_data["label"] = os.path.join(root, "derivatives", "labels", subject, 'anat', f"{subject}_{contrast}_label-SC_seg.nii.gz") - else: - raise NotImplementedError(f"Dataset {args.dataset_name} not implemented yet.") - - if os.path.exists(temp_data["label"]) and os.path.exists(temp_data["image"]): - temp_list.append(temp_data) - else: - logger.info(f"Subject {subject} does not have label or image file. Skipping it.") - - params[name] = temp_list - logger.info(f"Number of images in {name} set: {len(temp_list)}") - - final_json = json.dumps(params, indent=4, sort_keys=True) - jsonFile = open(args.path_out + "/" + f"{args.dataset_name}_dataset.json", "w") - jsonFile.write(final_json) - jsonFile.close() - - -if __name__ == "__main__": - main(args) - - - - diff --git a/monai/run_inference.py b/monai/run_inference.py deleted file mode 100644 index 4df19c21..00000000 --- a/monai/run_inference.py +++ /dev/null @@ -1,317 +0,0 @@ -""" -Script to run inference on a MONAI-based model for contrast-agnostic soft segmentation of the spinal cord. -Predictions are stored in independent folders for each subject. Summary of the test metrics (both per subject and overall) -are stored in a json file, along with the time taken for inference. - -Usage: - python run_inference.py --path-json --chkp-path --path-out - --model --best-model-type --crop_size <48x160x320> --device - -Author: Naga Karthik - -""" - -import os -import argparse -import numpy as np -from loguru import logger -import torch.nn.functional as F -import torch -import json -from time import time - -from monai.inferers import sliding_window_inference -from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage) - -from transforms import val_transforms, val_transforms_with_orientation_and_crop -from utils import precision_score, recall_score, dice_score -from models import ModifiedUNet3D, create_nnunet_from_plans - - -# NNUNET global params -INIT_FILTERS=32 -ENABLE_DS = True - -nnunet_plans = { - "UNet_class_name": "PlainConvUNet", - "UNet_base_num_features": INIT_FILTERS, - "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], - "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], - "pool_op_kernel_sizes": [ - [1, 1, 1], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [1, 2, 2] - ], - "conv_kernel_sizes": [ - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3] - ], - "unet_max_num_features": 320, -} - - -def get_parser(): - - parser = argparse.ArgumentParser(description="Run inference on a MONAI-trained model") - - parser.add_argument("--path-json", type=str, required=True, - help="Path to the json file containing the test dataset in MSD format") - parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint folder") - parser.add_argument("--path-out", type=str, required=True, - help="Path to the output folder where to store the predictions and associated metrics") - parser.add_argument("-dname", "--dataset-name", type=str, default="spine-generic", - help="Name of the dataset to run inference on") - parser.add_argument("--model", type=str, default="unet", required=True, - help="Name of the model to use for inference") - parser.add_argument("--best-model-type", type=str, default="dice", required=True, choices=["csa", "dice", "loss", "all"], - help="Type of the best model to use for inference i.e. based on csa/dice/both") - # define args for cropping size. inputs should be in the format of "48x192x256" - parser.add_argument('-crop', '--crop-size', type=str, default="48x160x320", - help='Patch size used for center cropping the images during inference. Values correspond to R-L, A-P, I-S axes' - 'of the image. Sliding window will be run across the cropped images. Use -1 if no cropping is intended ' - '(sliding window will run across the whole image). Note, heavy R-L, A-P cropping is recommmended for best ' - 'results. Default: 48x160x320') - parser.add_argument('-debug', default=False, action='store_true', - help='run inference only on a few images to check if things are working') - parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], - help='Device to run inference on. Default: gpu') - - return parser - - -# -------------------------------- -# DATA -# -------------------------------- -def prepare_data(root, dataset_name="spine-generic", crop_size=(48, 160, 320)): - # set deterministic training for reproducibility - # set_determinism(seed=self.args.seed) - - # load the dataset - dataset = os.path.join(root, f"{dataset_name}_dataset.json") - test_files = load_decathlon_datalist(dataset, True, "test") - - if args.debug: - test_files = test_files[:6] - - # define test transforms - transforms_test = val_transforms_with_orientation_and_crop(crop_size=crop_size, lbl_key='label') - - # define post-processing transforms for testing; taken (with explanations) from - # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - test_post_pred = Compose([ - EnsureTyped(keys=["pred", "label"]), - Invertd(keys=["pred", "label"], transform=transforms_test, - orig_keys=["image", "label"], - meta_keys=["pred_meta_dict", "label_meta_dict"], - nearest_interp=False, to_tensor=True), - ]) - test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.25, num_workers=4) - - return test_ds, test_post_pred - - -def main(args): - - # define start time - start = time() - - # define device - if args.device == "gpu" and not torch.cuda.is_available(): - logger.warning("GPU not available, using CPU instead") - DEVICE = torch.device("cpu") - else: - DEVICE = torch.device("cuda" if torch.cuda.is_available() and args.device == "gpu" else "cpu") - - # define root path for finding datalists - dataset_root = args.path_json - dataset_name = args.dataset_name - - results_path = args.path_out - model_name = args.chkp_path.split("/")[-1] - if args.best_model_type == "dice": - chkp_paths = [os.path.join(args.chkp_path, "best_model_dice.ckpt")] - results_path = os.path.join(results_path, dataset_name, model_name, "best_dice") - elif args.best_model_type == "loss": - chkp_paths = [os.path.join(args.chkp_path, "best_model_loss.ckpt")] - results_path = os.path.join(results_path, dataset_name, model_name) - - # save terminal outputs to a file - logger.add(os.path.join(results_path, "logs.txt"), rotation="10 MB", level="INFO") - - logger.info(f"Saving results to: {results_path}") - if not os.path.exists(results_path): - os.makedirs(results_path, exist_ok=True) - - # define cropping size - inference_roi_size = tuple([int(i) for i in args.crop_size.split("x")]) - if inference_roi_size == (-1,): # means no cropping is required - logger.info(f"Doing Sliding Window Inference on Whole Images ...") - inference_roi_size = (-1, -1, -1) - - # define the dataset and dataloader - test_ds, test_post_pred = prepare_data(dataset_root, dataset_name, crop_size=inference_roi_size) - test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) - - if args.model == "unet": - # initialize ivadomed unet model - net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=INIT_FILTERS) - elif args.model == "nnunet": - # define model - net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=ENABLE_DS) - - # define list to collect the test metrics - test_step_outputs = [] - test_summary = {} - - preds_stack = [] - # iterate over the dataset and compute metrics - with torch.no_grad(): - for batch in test_loader: - # compute time for inference per subject - start_time = time() - - # get the test input - test_input = batch["image"].to(DEVICE) - - # load the checkpoints - for chkp_path in chkp_paths: - - checkpoint = torch.load(chkp_path, map_location=torch.device(DEVICE))["state_dict"] - # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning - # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 - for key in list(checkpoint.keys()): - if 'net.' in key: - checkpoint[key.replace('net.', '')] = checkpoint[key] - del checkpoint[key] - - # load the trained model weights - net.load_state_dict(checkpoint) - net.to(DEVICE) - net.eval() - - # run inference - batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian", - sw_batch_size=4, predictor=net, overlap=0.5, progress=False) - - if ENABLE_DS and args.model == "nnunet": - # take only the highest resolution prediction - batch["pred"] = batch["pred"][0] - - # NOTE: monai's models do not normalize the output, so we need to do it manually - if bool(F.relu(batch["pred"]).max()): - batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() - else: - batch["pred"] = F.relu(batch["pred"]) - - post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] - - # make sure that the shapes of prediction and GT label are the same - assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape - - pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() - - # stack the predictions - preds_stack.append(pred) - - # save the (soft) prediction and label - subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") - logger.info(f"Saving subject: {subject_name}") - - # take the average of the predictions - pred = torch.stack(preds_stack).mean(dim=0) - preds_stack.clear() - - # check whether the prediction and label have the same shape - assert pred.shape == label.shape, f"Prediction and label shapes are different: {pred.shape} vs {label.shape}" - - # image saver class - save_folder = os.path.join(results_path, subject_name.split("_")[0]) - pred_saver = SaveImage( - output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", - separate_folder=False, print_log=False) - # save the prediction - pred_saver(pred) - - # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics - # calculate all metrics here - # 1. Dice Score - test_soft_dice = dice_score(pred, label) - - # binarizing the predictions - pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() - label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() - - # 1.1 Hard Dice Score - test_hard_dice = dice_score(pred.numpy(), label.numpy()) - # 2. Precision Score - test_precision = precision_score(pred.numpy(), label.numpy()) - # 3. Recall Score - test_recall = recall_score(pred.numpy(), label.numpy()) - - end_time = time() - metrics_dict = { - "subject_name_and_contrast": subject_name, - "dice_binary": round(test_hard_dice, 2), - "dice_soft": round(test_soft_dice.item(), 2), - "precision": round(test_precision, 2), - "recall": round(test_recall, 2), - # TODO: add relative volume difference here - # NOTE: RVD is usually compared with binary objects (not soft) - "inference_time_in_sec": round((end_time - start_time), 2), - } - test_step_outputs.append(metrics_dict) - - # save the test summary - test_summary["metrics_per_subject"] = test_step_outputs - - # compute the average of all metrics - avg_hard_dice_test, std_hard_dice_test = np.stack([x["dice_binary"] for x in test_step_outputs]).mean(), \ - np.stack([x["dice_binary"] for x in test_step_outputs]).std() - avg_soft_dice_test, std_soft_dice_test = np.stack([x["dice_soft"] for x in test_step_outputs]).mean(), \ - np.stack([x["dice_soft"] for x in test_step_outputs]).std() - avg_precision_test = np.stack([x["precision"] for x in test_step_outputs]).mean() - avg_recall_test = np.stack([x["recall"] for x in test_step_outputs]).mean() - avg_inference_time = np.stack([x["inference_time_in_sec"] for x in test_step_outputs]).mean() - - # store the average metrics in a dict - avg_metrics = { - "avg_dice_binary": round(avg_hard_dice_test, 2), - "avg_dice_soft": round(avg_soft_dice_test, 2), - "avg_precision": round(avg_precision_test, 2), - "avg_recall": round(avg_recall_test, 2), - "avg_inference_time_in_sec": round(avg_inference_time, 2), - } - test_summary["metrics_avg_across_cohort"] = avg_metrics - - logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") - logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") - logger.info(f"Test Precision Score: {avg_precision_test}") - logger.info(f"Test Recall Score: {avg_recall_test}") - logger.info(f"Average Inference Time per Subject: {avg_inference_time:.2f}s") - - # dump the test summary to a json file - with open(os.path.join(results_path, "test_summary.json"), "w") as f: - json.dump(test_summary, f, indent=4, sort_keys=True) - - # free up memory - test_step_outputs.clear() - - end = time() - - logger.info("=====================================================================") - logger.info(f"Total time taken for inference: {(end - start) / 60:.2f} minutes") - logger.info("=====================================================================") - - -if __name__ == "__main__": - - args = get_parser().parse_args() - main(args) \ No newline at end of file From 890c465dacb0237e15a8f6c649ae3d457e9a56db Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 28 Sep 2023 12:33:26 -0400 Subject: [PATCH 091/106] update args description about cropping --- monai/run_inference_single_image.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index 2a02ac8c..e0cc0f90 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -66,12 +66,13 @@ def get_parser(): parser.add_argument("--path-out", type=str, required=True, help="Path to the output folder where to store the predictions and associated metrics") parser.add_argument('-crop', '--crop-size', type=str, default="64x160x320", - help='Patch size used for center cropping the images during inference. Values correspond to R-L, A-P, I-S axes' - ' *in mm*. All images are resampled to 1mm isotropic before cropping. Inference is run on the cropped images.' - ' Use -1 if no cropping is intended. Note, heavy R-L cropping that positions the SC at the center of the image ' - 'is recommmended for best results. Default: 64x160x320') + help='Size of the window used to crop the volume before inference (NOTE: Images are resampled to 1mm' + ' isotropic before cropping). The window is centered in the middle of the volume. Dimensions are in the' + ' order R-L, A-P, I-S. Use -1 for no cropping in a specific axis, example: “64x160x-1”.' + ' NOTE: heavy R-L cropping is recommended for positioning the SC at the center of the image.' + ' Default: 64x160x320') parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], - help='Device to run inference on. Default: gpu') + help='Device to run inference on. Default: cpu') return parser From 650bb73dbca038fe41ce14837b2432f33535027c Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 28 Sep 2023 12:39:55 -0400 Subject: [PATCH 092/106] remove building_blocks --- monai/building_blocks.py | 54 ---------------------------------------- 1 file changed, 54 deletions(-) delete mode 100644 monai/building_blocks.py diff --git a/monai/building_blocks.py b/monai/building_blocks.py deleted file mode 100644 index b7d8e6a9..00000000 --- a/monai/building_blocks.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Some useful blocks for building the network architecture. -""" - -import torch.nn as nn - - -def conv_norm_lrelu(feat_in, feat_out): - """Conv3D + InstanceNorm3D + LeakyReLU block""" - return nn.Sequential( - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm3d(feat_out), - nn.LeakyReLU() - ) - - -def norm_lrelu_conv(feat_in, feat_out): - """InstanceNorm3D + LeakyReLU + Conv3D block""" - return nn.Sequential( - nn.InstanceNorm3d(feat_in), - nn.LeakyReLU(), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) - ) - - -def lrelu_conv(feat_in, feat_out): - """LeakyReLU + Conv3D block""" - return nn.Sequential( - nn.LeakyReLU(), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False) - ) - - -def norm_lrelu_upscale_conv_norm_lrelu(feat_in, feat_out): - """InstanceNorm3D + LeakyReLU + 2X Upsample + Conv3D + InstanceNorm3D + LeakyReLU block""" - return nn.Sequential( - nn.InstanceNorm3d(feat_in), - nn.LeakyReLU(), - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm3d(feat_out), - nn.LeakyReLU() - ) - - -class InitWeights_He(object): - def __init__(self, neg_slope=1e-2): - self.neg_slope = neg_slope - - def __call__(self, module): - if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): - module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) - if module.bias is not None: - module.bias = nn.init.constant_(module.bias, 0) From 68f6bd3a670e3b6de842d8fe5eb2c01c406c4834 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 28 Sep 2023 12:40:32 -0400 Subject: [PATCH 093/106] move weights init class from building_blocks --- monai/models.py | 222 +++--------------------------------------------- 1 file changed, 14 insertions(+), 208 deletions(-) diff --git a/monai/models.py b/monai/models.py index 195a3cf3..90f4f210 100644 --- a/monai/models.py +++ b/monai/models.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from building_blocks import conv_norm_lrelu, norm_lrelu_conv, lrelu_conv, norm_lrelu_upscale_conv_norm_lrelu, InitWeights_He # ---------------------------- Imports for nnUNet's Model ----------------------------- from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet @@ -37,6 +36,20 @@ } +# ====================================================================================================== +# Utils for nnUNet's Model +# ==================================================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + # ====================================================================================================== # Define the network based on plans json # ==================================================================================================== @@ -106,213 +119,6 @@ def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, d return model -# ---------------------------- ModifiedUNet3D Encoder Implementation ----------------------------- -class ModifiedUNet3DEncoder(nn.Module): - """Encoder for ModifiedUNet3D. Adapted from ivadomed.models""" - def __init__(self, in_channels=1, base_n_filter=8): - super(ModifiedUNet3DEncoder, self).__init__() - - # Initialize common operations - self.lrelu = nn.LeakyReLU() - self.dropout3d = nn.Dropout3d(p=0.5) - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - - # Level 1 context pathway - self.conv3d_c1_1 = nn.Conv3d(in_channels, base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) - self.conv3d_c1_2 = nn.Conv3d(base_n_filter, base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) - self.lrelu_conv_c1 = lrelu_conv(base_n_filter, base_n_filter) - self.inorm3d_c1 = nn.InstanceNorm3d(base_n_filter) - - # Level 2 context pathway - self.conv3d_c2 = nn.Conv3d(base_n_filter, base_n_filter * 2, kernel_size=3, stride=2, padding=1, bias=False) - self.norm_lrelu_conv_c2 = norm_lrelu_conv(base_n_filter * 2, base_n_filter * 2) - self.inorm3d_c2 = nn.InstanceNorm3d(base_n_filter * 2) - - # Level 3 context pathway - self.conv3d_c3 = nn.Conv3d(base_n_filter * 2, base_n_filter * 4, kernel_size=3, stride=2, padding=1, bias=False) - self.norm_lrelu_conv_c3 = norm_lrelu_conv(base_n_filter * 4, base_n_filter * 4) - self.inorm3d_c3 = nn.InstanceNorm3d(base_n_filter * 4) - - # Level 4 context pathway - self.conv3d_c4 = nn.Conv3d(base_n_filter * 4, base_n_filter * 8, kernel_size=3, stride=2, padding=1, bias=False) - self.norm_lrelu_conv_c4 = norm_lrelu_conv(base_n_filter * 8, base_n_filter * 8) - self.inorm3d_c4 = nn.InstanceNorm3d(base_n_filter * 8) - - # Level 5 context pathway, level 0 localization pathway - self.conv3d_c5 = nn.Conv3d(base_n_filter * 8, base_n_filter * 16, kernel_size=3, stride=2, padding=1, bias=False) - self.norm_lrelu_conv_c5 = norm_lrelu_conv(base_n_filter * 16, base_n_filter * 16) - self.norm_lrelu_upscale_conv_norm_lrelu_l0 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 16, base_n_filter * 8) - - def forward(self, x): - # Level 1 context pathway - out = self.conv3d_c1_1(x) - residual_1 = out - out = self.lrelu(out) - out = self.conv3d_c1_2(out) - out = self.dropout3d(out) - out = self.lrelu_conv_c1(out) - - # Element Wise Summation - out += residual_1 - context_1 = self.lrelu(out) - out = self.inorm3d_c1(out) - out = self.lrelu(out) - - # Level 2 context pathway - out = self.conv3d_c2(out) - residual_2 = out - out = self.norm_lrelu_conv_c2(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c2(out) - out += residual_2 - out = self.inorm3d_c2(out) - out = self.lrelu(out) - context_2 = out - - # Level 3 context pathway - out = self.conv3d_c3(out) - residual_3 = out - out = self.norm_lrelu_conv_c3(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c3(out) - out += residual_3 - out = self.inorm3d_c3(out) - out = self.lrelu(out) - context_3 = out - - # Level 4 context pathway - out = self.conv3d_c4(out) - residual_4 = out - out = self.norm_lrelu_conv_c4(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c4(out) - out += residual_4 - out = self.inorm3d_c4(out) - out = self.lrelu(out) - context_4 = out - - # Level 5 - out = self.conv3d_c5(out) - residual_5 = out - out = self.norm_lrelu_conv_c5(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c5(out) - out += residual_5 - - out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out) - - context_features = [context_1, context_2, context_3, context_4] - - return out, context_features - - -# ---------------------------- ModifiedUNet3D Decoder Implementation ----------------------------- -class ModifiedUNet3DDecoder(nn.Module): - """Decoder for ModifiedUNet3D. Adapted from ivadomed.models""" - def __init__(self, n_classes=1, base_n_filter=8): - super(ModifiedUNet3DDecoder, self).__init__() - - # Initialize common operations - self.lrelu = nn.LeakyReLU() - self.dropout3d = nn.Dropout3d(p=0.5) - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - - self.conv3d_l0 = nn.Conv3d(base_n_filter * 8, base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False) - self.inorm3d_l0 = nn.InstanceNorm3d(base_n_filter * 8) - - # Level 1 localization pathway - self.conv_norm_lrelu_l1 = conv_norm_lrelu(base_n_filter * 16, base_n_filter * 16) - self.conv3d_l1 = nn.Conv3d(base_n_filter * 16, base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False) - self.norm_lrelu_upscale_conv_norm_lrelu_l1 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 8, base_n_filter * 4) - - # Level 2 localization pathway - self.conv_norm_lrelu_l2 = conv_norm_lrelu(base_n_filter * 8, base_n_filter * 8) - self.conv3d_l2 = nn.Conv3d(base_n_filter * 8, base_n_filter * 4, kernel_size=1, stride=1, padding=0, bias=False) - self.norm_lrelu_upscale_conv_norm_lrelu_l2 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 4, base_n_filter * 2) - - # Level 3 localization pathway - self.conv_norm_lrelu_l3 = conv_norm_lrelu(base_n_filter * 4, base_n_filter * 4) - self.conv3d_l3 = nn.Conv3d(base_n_filter * 4, base_n_filter * 2, kernel_size=1, stride=1, padding=0, bias=False) - self.norm_lrelu_upscale_conv_norm_lrelu_l3 = norm_lrelu_upscale_conv_norm_lrelu(base_n_filter * 2, base_n_filter) - - # Level 4 localization pathway - self.conv_norm_lrelu_l4 = conv_norm_lrelu(base_n_filter * 2, base_n_filter * 2) - self.conv3d_l4 = nn.Conv3d(base_n_filter * 2, n_classes, kernel_size=1, stride=1, padding=0, bias=False) - - self.ds2_1x1_conv3d = nn.Conv3d(base_n_filter * 8, n_classes, kernel_size=1, stride=1, padding=0, bias=False) - self.ds3_1x1_conv3d = nn.Conv3d(base_n_filter * 4, n_classes, kernel_size=1, stride=1, padding=0, bias=False) - - def forward(self, x, context_features): - # Get context features from the encoder - context_1, context_2, context_3, context_4 = context_features - - out = self.conv3d_l0(x) - out = self.inorm3d_l0(out) - out = self.lrelu(out) - - # Level 1 localization pathway - out = torch.cat([out, context_4], dim=1) - out = self.conv_norm_lrelu_l1(out) - out = self.conv3d_l1(out) - out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out) - - # Level 2 localization pathway - out = torch.cat([out, context_3], dim=1) - out = self.conv_norm_lrelu_l2(out) - ds2 = out - out = self.conv3d_l2(out) - out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out) - - # Level 3 localization pathway - out = torch.cat([out, context_2], dim=1) - out = self.conv_norm_lrelu_l3(out) - ds3 = out - out = self.conv3d_l3(out) - out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out) - - # Level 4 localization pathway - out = torch.cat([out, context_1], dim=1) - out = self.conv_norm_lrelu_l4(out) - out_pred = self.conv3d_l4(out) - - ds2_1x1_conv = self.ds2_1x1_conv3d(ds2) - ds1_ds2_sum_upscale = self.upsample(ds2_1x1_conv) - ds3_1x1_conv = self.ds3_1x1_conv3d(ds3) - ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv - ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsample(ds1_ds2_sum_upscale_ds3_sum) - - out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale - - # # Final Activation Layer - # out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) - # out = out.squeeze() - - return out # this is just the logits, not the probablities - - -# ---------------------------- ModifiedUNet3D Implementation ----------------------------- -class ModifiedUNet3D(nn.Module): - """ModifiedUNet3D with Encoder + Decoder. Adapted from ivadomed.models""" - def __init__(self, in_channels=1, out_channels=1, init_filters=8): - super(ModifiedUNet3D, self).__init__() - self.unet_encoder = ModifiedUNet3DEncoder(in_channels=in_channels, base_n_filter=init_filters) - self.unet_decoder = ModifiedUNet3DDecoder(n_classes=out_channels, base_n_filter=init_filters) - - def forward(self, x): - - x, context_features = self.unet_encoder(x) - # x: (B, 8 * F, SV // 8, SV // 8, SV // 8) - # context_features: [4] - # 0 -> (B, F, SV, SV, SV) - # 1 -> (B, 2 * F, SV / 2, SV / 2, SV / 2) - # 2 -> (B, 4 * F, SV / 4, SV / 4, SV / 4) - # 3 -> (B, 8 * F, SV / 8, SV / 8, SV / 8) - - seg_logits = self.unet_decoder(x, context_features) - - return seg_logits - - if __name__ == "__main__": From 2a18878d4d11509a2a08094960e81ea50b4bc510 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 28 Sep 2023 12:47:48 -0400 Subject: [PATCH 094/106] remove usage example in docstring --- monai/run_inference_single_image.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index e0cc0f90..1fa54b26 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -2,10 +2,6 @@ Script to run inference on a MONAI-based model for contrast-agnostic soft segmentation of the spinal cord. Prediction is stored in an independent folder given by subject name. The time taken for inference is stored in a json file. -Usage: - python run_inference_single_image.py --path-img /path/to/my-awesome-SC-image.nii.gz --chkp-path /path/to/best/model - --path-out /path/to/output/folder --crop-size <64x160x320> --device - Author: Naga Karthik """ From ee688d319819375ed6ac3fed0489fe8fd8a28dfc Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 2 Oct 2023 11:02:32 -0400 Subject: [PATCH 095/106] remove saving in separate folders; clean script --- monai/run_inference_single_image.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index 1fa54b26..3b43b032 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -1,6 +1,5 @@ """ Script to run inference on a MONAI-based model for contrast-agnostic soft segmentation of the spinal cord. -Prediction is stored in an independent folder given by subject name. The time taken for inference is stored in a json file. Author: Naga Karthik @@ -73,7 +72,9 @@ def get_parser(): return parser -# define transforms for inference +# =========================================================================== +# Test-time Transforms +# =========================================================================== def inference_transforms_single_image(crop_size): return Compose([ LoadImaged(keys=["image"], image_only=False), @@ -133,11 +134,11 @@ def prepare_data(path_image, path_out, crop_size=(64, 160, 320)): return test_ds, test_post_pred +# =========================================================================== +# Inference method +# =========================================================================== def main(args): - # define start time - start = time() - # define device if args.device == "gpu" and not torch.cuda.is_available(): logger.warning("GPU not available, using CPU instead") @@ -224,9 +225,8 @@ def main(args): # this takes about 0.25s on average on a CPU # image saver class - save_folder = os.path.join(results_path, subject_name.split("_")[0]) pred_saver = SaveImage( - output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + output_dir=results_path, output_postfix="pred", output_ext=".nii.gz", separate_folder=False, print_log=False) # save the prediction pred_saver(pred) @@ -264,12 +264,6 @@ def main(args): test_summary.clear() os.remove(os.path.join(results_path, "temp_msd_datalist.json")) - end = time() - - # logger.info("===============================================================") - # logger.info(f"Total time taken for inference: {(end - start) / 60:.2f} minutes") - # logger.info("===============================================================") - if __name__ == "__main__": From d5bea6b1d4a949fc9d26e7a35770aa5383ff5284 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 2 Oct 2023 11:03:14 -0400 Subject: [PATCH 096/106] create standalone script --- monai/run_inference_single_image.py | 99 +++++++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 4 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index 3b43b032..85a2201c 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -11,17 +11,24 @@ from loguru import logger import torch.nn.functional as F import torch +import torch.nn as nn import json from time import time -from models import create_nnunet_from_plans from monai.inferers import sliding_window_inference from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd, LoadImaged, NormalizeIntensityd, EnsureChannelFirstd, DivisiblePadd, Orientationd, ResizeWithPadOrCropd) +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +# TODO: +# 1. Add options for hard/soft labels, post-processing, etc.? +# sct_deepseg already has these https://spinalcordtoolbox.com/user_section/command-line.html#sct-deepseg + # NNUNET global params INIT_FILTERS=32 ENABLE_DS = True @@ -86,9 +93,93 @@ def inference_transforms_single_image(crop_size): NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ]) -# -------------------------------- -# DATA -# -------------------------------- + +# =========================================================================== +# Model utils +# =========================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +# ============================================================================ +# Define the network based on nnunet_plans dict +# ============================================================================ +def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model + + +# =========================================================================== +# Prepare temporary dataset for inference +# =========================================================================== def prepare_data(path_image, path_out, crop_size=(64, 160, 320)): # create a temporary datalist containing the image From f88d927c6709dcebb79d38ccd3bd45aa2e355474 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 2 Oct 2023 11:22:07 -0400 Subject: [PATCH 097/106] create light-weight, clean requirements txt --- monai/requirements.txt | 110 +++-------------------------------------- 1 file changed, 7 insertions(+), 103 deletions(-) diff --git a/monai/requirements.txt b/monai/requirements.txt index 435823c4..81ed0668 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -1,108 +1,12 @@ -appdirs==1.4.4 -asttokens==2.2.1 -backcall==0.2.0 -backports.functools-lru-cache==1.6.5 -Brotli==1.0.9 -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -click==8.1.3 -cmake==3.26.4 -colorama==0.4.6 -comm==0.1.3 -contourpy==1.1.0 -cycler==0.11.0 -debugpy==1.6.7 -decorator==5.1.1 -docker-pycreds==0.4.0 -executing==1.2.0 -filelock==3.12.2 -fonttools==4.40.0 -fsspec==2023.6.0 -gitdb==4.0.10 -GitPython==3.1.31 -gmpy2==2.1.2 -idna==3.4 -importlib-metadata==6.8.0 -importlib-resources==6.0.0 -ipykernel==6.24.0 -ipython==8.14.0 -jedi==0.18.2 -Jinja2==3.1.2 +dynamic_network_architectures==0.2 joblib==1.3.0 -jupyter_client==8.3.0 -jupyter_core==5.3.1 -kiwisolver==1.4.4 -lightning-utilities==0.9.0 -lit==16.0.6 loguru==0.7.0 -MarkupSafe==2.1.3 matplotlib==3.7.2 -matplotlib-inline==0.1.6 -monai==1.2.0 -mpmath==1.3.0 -nest-asyncio==1.5.6 -networkx==3.1 -nibabel==5.1.0 -numpy==1.25.0 -nvidia-cublas-cu11==11.10.3.66 -nvidia-cuda-cupti-cu11==11.7.101 -nvidia-cuda-nvrtc-cu11==11.7.99 -nvidia-cuda-runtime-cu11==11.7.99 -nvidia-cudnn-cu11==8.5.0.96 -nvidia-cufft-cu11==10.9.0.58 -nvidia-curand-cu11==10.2.10.91 -nvidia-cusolver-cu11==11.4.0.1 -nvidia-cusparse-cu11==11.7.4.91 -nvidia-nccl-cu11==2.14.3 -nvidia-nvtx-cu11==11.7.91 -packaging==23.1 -pandas==2.0.3 -parso==0.8.3 -pathtools==0.1.2 -pexpect==4.8.0 -pickleshare==0.7.5 -Pillow==10.0.0 -pip==23.1.2 -platformdirs==3.8.0 -pooch==1.7.0 -prompt-toolkit==3.0.39 -protobuf==3.20.3 -psutil==5.9.5 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pycparser==2.21 -Pygments==2.15.1 -pyparsing==3.0.9 -PySocks==1.7.1 -python-dateutil==2.8.2 -pytorch-lightning==2.0.4 -pytz==2023.3 -PyYAML==6.0 -pyzmq==25.1.0 -requests==2.31.0 -scikit-learn==1.3.0 -scipy==1.11.1 -sentry-sdk==1.21.1 -setproctitle==1.3.2 -setuptools==68.0.0 -six==1.16.0 -smmap==3.0.5 -stack-data==0.6.2 -sympy==1.12 -threadpoolctl==3.1.0 -torch==2.0.0+cu117 -torchaudio==2.0.1+cu117 -torchmetrics==0.11.4 -torchvision==0.15.1+cu117 -tornado==6.3.2 +monai[all]==1.2.0 +numpy==1.24.4 +pytorch_lightning==2.0.4 +scikit_learn==1.3.0 +scipy==1.11.2 +torch==2.0.0 tqdm==4.65.0 -traitlets==5.9.0 -triton==2.0.0 -typing_extensions==4.7.1 -tzdata==2023.3 -urllib3==2.0.3 wandb==0.15.5 -wcwidth==0.2.6 -wheel==0.40.0 -zipp==3.15.0 From c4cbd61da6021237128d9fe6d4b41a1c712e90ea Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 8 Oct 2023 15:48:24 -0400 Subject: [PATCH 098/106] change default crop size --- monai/run_inference_single_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index 85a2201c..9ec7d2e7 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -67,12 +67,12 @@ def get_parser(): parser.add_argument("--chkp-path", type=str, required=True, help="Path to the checkpoint folder") parser.add_argument("--path-out", type=str, required=True, help="Path to the output folder where to store the predictions and associated metrics") - parser.add_argument('-crop', '--crop-size', type=str, default="64x160x320", + parser.add_argument('-crop', '--crop-size', type=str, default="64x192x-1", help='Size of the window used to crop the volume before inference (NOTE: Images are resampled to 1mm' ' isotropic before cropping). The window is centered in the middle of the volume. Dimensions are in the' ' order R-L, A-P, I-S. Use -1 for no cropping in a specific axis, example: “64x160x-1”.' ' NOTE: heavy R-L cropping is recommended for positioning the SC at the center of the image.' - ' Default: 64x160x320') + ' Default: 64x192x-1') parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], help='Device to run inference on. Default: cpu') @@ -251,7 +251,7 @@ def main(args): # define inference patch size and center crop size crop_size = tuple([int(i) for i in args.crop_size.split("x")]) - inference_roi_size = (64, 160, 320) + inference_roi_size = (64, 192, 320) # define the dataset and dataloader test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size) From 4694169352fb962f607e7684fe778677adc123dc Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Sun, 8 Oct 2023 15:49:00 -0400 Subject: [PATCH 099/106] fix bug with torch.clamp setting bg values to 0.5 --- monai/run_inference_single_image.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index 9ec7d2e7..7d08f377 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -25,10 +25,6 @@ from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 -# TODO: -# 1. Add options for hard/soft labels, post-processing, etc.? -# sct_deepseg already has these https://spinalcordtoolbox.com/user_section/command-line.html#sct-deepseg - # NNUNET global params INIT_FILTERS=32 ENABLE_DS = True @@ -306,9 +302,11 @@ def main(args): pred = post_test_out[0]['pred'].cpu() # clip the prediction between 0.5 and 1 + # turns out this sets the background to 0.5 and the SC to 1 (which is not correct) + # details: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/71 pred = torch.clamp(pred, 0.5, 1) - # # threshold the prediction - # pred = (pred > 0.1).float() + # set background values to 0 + pred[pred <= 0.5] = 0 # get subject name subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") From 2e3528269754b30986e636b45e57bfe51df1b1fb Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 10:15:40 -0400 Subject: [PATCH 100/106] remove +cpu from torch version --- monai/requirements_inference.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/requirements_inference.txt b/monai/requirements_inference.txt index 2cb4845f..d99de5ae 100644 --- a/monai/requirements_inference.txt +++ b/monai/requirements_inference.txt @@ -5,4 +5,4 @@ monai[nibabel]==1.2.0 scipy==1.11.2 numpy==1.24.4 --extra-index-url https://download.pytorch.org/whl/cpu -torch==2.0.0+cpu +torch==2.0.0 From 5e3f97a31aa64b4883be7a3217090bc46aa74af2 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 10:51:33 -0400 Subject: [PATCH 101/106] renamed as README, removed usage instructions --- monai/{inference_instructions.md => README.md} | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) rename monai/{inference_instructions.md => README.md} (79%) diff --git a/monai/inference_instructions.md b/monai/README.md similarity index 79% rename from monai/inference_instructions.md rename to monai/README.md index 3f159710..c73c3fc7 100644 --- a/monai/inference_instructions.md +++ b/monai/README.md @@ -26,18 +26,16 @@ pip install -r requirements_inference.txt ### Method 1: Running inference on a single image -```bash -python run_inference_single_image.py --path-img /path/to/my-awesome-SC-image.nii.gz --chkp-path /path/to/best/model --path-out /path/to/output/folder --crop-size <64x160x320> --device +The script for running inference is `run_inference_single_image.py`. Please run ``` - -`--path-img` - Path to the image to be segmented -`--chkp-path` - Path to the model checkpoint. This folder should contain the `best_model_loss.ckpt` -`--path-out` - Path to the output folder where the predictions will be saved -`--crop_size` - Crop size used for center cropping the image before running inference. Recommended to be set to a multiple of 32 -`--device` - Device to be used for inference. Currently, only `gpu` and `cpu` are supported +python run_inference_single_image.py -h +``` +to get the list of arguments and their descriptions. + +### Method 2: Running inference on a dataset (Advanced) -### Method 2: Running inference on a dataset +NOTE: This section is experimental and for advanced users only. Please use Method 1 for running inference. #### Creating a datalist From bdb3e987b51bc362ff2fe5203897111fc01a0580 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 11:27:57 -0400 Subject: [PATCH 102/106] cleaned by removing unused code --- monai/transforms.py | 69 +++++---------------------------------------- 1 file changed, 7 insertions(+), 62 deletions(-) diff --git a/monai/transforms.py b/monai/transforms.py index 7005809a..76ef9ad1 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -1,49 +1,24 @@ import numpy as np -from monai.transforms import (SpatialPadd, Compose, CropForegroundd, LoadImaged, RandFlipd, - RandCropByPosNegLabeld, Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, +from monai.transforms import (Compose, CropForegroundd, LoadImaged, RandFlipd, + Spacingd, RandScaleIntensityd, NormalizeIntensityd, RandAffined, DivisiblePadd, RandAdjustContrastd, EnsureChannelFirstd, RandGaussianNoised, - RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, RandSimulateLowResolutiond, + RandGaussianSmoothd, Orientationd, Rand3DElasticd, RandBiasFieldd, ResizeWithPadOrCropd) -# median image size in voxels - taken from nnUNet -# median_size = (123, 255, 214) as per 0.9 iso resampling and patch_size = (80, 192, 160) -# note the the order of the axes is different in nnunet and monai (dims 0 and 2 are swapped) -# median_size after 1mm isotropic resampling -# median_size = [ 192. 228. 106.] +# TODO: Add RandSimulateLowResolutiond transform when monai 1.3.0 is released. +# Right now, in v1.2.0, it is not implemented yet (I had to manually add in the source code) -# Order in which nnunet does preprocessing: -# 1. Crop to non-zero -# 2. Normalization -# 3. Resample to target spacing - -# Order in which ivadomed does preprocessing: -# 1. Resample to 1mm iso -# 2. CenterCrop using 46x176x288 -# 3. RandomAffine --> RandomElastic --> RandomGamma --> RandomBiasField --> RandomBlur --> NormalizeInstance - -# TODO: Use cropping on R-L (56) and A-P (176) but not on S-I to avoid cropping the last few slices -# At test time, check performance on cropped and uncropped full images - -def train_transforms(crop_size, num_samples_pv, lbl_key="label"): +def train_transforms(crop_size, lbl_key="label"): monai_transforms = [ # pre-processing LoadImaged(keys=["image", lbl_key]), EnsureChannelFirstd(keys=["image", lbl_key]), - # CropForegroundd(keys=["image", lbl_key], source_key="image"), # crops >0 values with a bounding box # NOTE: spine interpolation with order=2 is spline, order=1 is linear Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), # data-augmentation - # SpatialPadd(keys=["image", lbl_key], spatial_size=crop_size, method="symmetric"), # pad with the same size as crop_size - # # NOTE: used with neg together to calculate the ratio pos / (pos + neg) for the probability to pick a - # # foreground voxel as a center rather than a background voxel. - # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", - # spatial_size=crop_size, pos=3, neg=1, num_samples=num_samples_pv, - # # if num_samples=4, then 4 samples/image are randomly generated - # image_key="image", image_threshold=0.), - # re-ordering transforms as used by nnunet RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.9, rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians scale_range=(-0.2, 0.2), @@ -51,7 +26,7 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): Rand3DElasticd(keys=["image", lbl_key], prob=0.5, sigma_range=(3.5, 5.5), magnitude_range=(25., 35.)), - RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), @@ -59,39 +34,10 @@ def train_transforms(crop_size, num_samples_pv, lbl_key="label"): RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform RandFlipd(keys=["image", lbl_key], prob=0.3,), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - - # # defining transforms as used by ivadomed (with the same probabilities) - # LoadImaged(keys=["image", lbl_key], image_only=False), # image_only=True to avoid loading the label - # EnsureChannelFirstd(keys=["image", lbl_key]), - # Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), - # ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - # RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=1.0, - # rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians - # scale_range=(-0.2, 0.2), # ivadomed uses sth like scale_x = random.uniform(1 - self.scale[0], 1 + self.scale[0]), but monai adds 1.0 to the scale - # translate_range=(-0.1, 0.1)), - # Rand3DElasticd(keys=["image", lbl_key], prob=0.5, - # sigma_range=(3.5, 5.5), - # magnitude_range=(25., 35.)), - # # RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), - # RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma - # RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), - # RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), - # # RandFlipd(keys=["image", lbl_key], prob=0.5,), - # NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ] return Compose(monai_transforms) -def val_transforms_without_center_crop(lbl_key="label"): - return Compose([ - LoadImaged(keys=["image", lbl_key], image_only=False), - EnsureChannelFirstd(keys=["image", lbl_key]), - # Orientationd(keys=["image", lbl_key], axcodes="RPI"), - CropForegroundd(keys=["image", lbl_key], source_key="image"), - Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), - ]) - def inference_transforms(crop_size, lbl_key="label"): return Compose([ LoadImaged(keys=["image", lbl_key], image_only=False), @@ -111,6 +57,5 @@ def val_transforms(crop_size, lbl_key="label"): # CropForegroundd(keys=["image", lbl_key], source_key="image"), Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), - # TODO: do cropping only in R-L so sth like (48, -1, -1) NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ]) From 5cce1e86465bed5e5281bf8764088f07930b75ab Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 11:59:25 -0400 Subject: [PATCH 103/106] add cupy for inference on gpu --- monai/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/requirements.txt b/monai/requirements.txt index 81ed0668..4f69d6a1 100644 --- a/monai/requirements.txt +++ b/monai/requirements.txt @@ -1,3 +1,4 @@ +cupy-cuda117==10.6.0 dynamic_network_architectures==0.2 joblib==1.3.0 loguru==0.7.0 From 2151b23d1a0eb50d3bd5489a08e59af3f52cdcb3 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 12:04:30 -0400 Subject: [PATCH 104/106] add args for contrast,label-type for per-contrast,hard/soft training --- monai/main.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/main.py b/monai/main.py index 0e81ca5b..eccb55ed 100644 --- a/monai/main.py +++ b/monai/main.py @@ -106,7 +106,8 @@ def prepare_data(self): # transforms_val = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # load the dataset - dataset = os.path.join(self.root, f"spine-generic-ivado-comparison_dataset.json") + dataset = os.path.join(self.root, f"dataset_{self.args.contrast}_{self.args.label_type}_seed15.json") + logger.info(f"Loading dataset: {dataset}") train_files = load_decathlon_datalist(dataset, True, "train") val_files = load_decathlon_datalist(dataset, True, "validation") test_files = load_decathlon_datalist(dataset, True, "test") @@ -565,7 +566,7 @@ def main(args): } # define root path for finding datalists - dataset_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/contrast-agnostic-softseg-spinalcord/monai" + dataset_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/spine-generic/seed15" # define optimizer if args.optimizer in ["adam", "Adam"]: @@ -793,12 +794,11 @@ def main(args): default='unet', type=str, help='Model type to be used') parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') # dataset - parser.add_argument('-nspv', '--num_samples_per_volume', default=4, type=int, help="Number of samples to crop per volume") - # define args for cropping size. inputs should be in the format of "48x192x256" - parser.add_argument('-val-crop', '--val_crop_size', type=str, default="48x192x256", - help='Center crop size for validation and testing. Values correspond to R-L, A-P, I-S axes' - 'of the image. Use -1 if no cropping is intended. Default: 48x160x320') - + parser.add_argument("--contrast", default="t2w", type=str, help="Contrast to use for training", + choices=["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"]) + parser.add_argument('--label-type', default='soft', type=str, help="Type of labels to use for training", + choices=['hard', 'soft']) + # unet model parser.add_argument('-initf', '--init_filters', default=16, type=int, help="Number of Filters in Init Layer") # parser.add_argument('-ps', '--patch_size', type=int, default=128, help='List containing subvolume size') From 227ba30f43d7b72380cdd3a1115c5336b205c386 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 12:05:16 -0400 Subject: [PATCH 105/106] fix same crop size of train/val --- monai/main.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/monai/main.py b/monai/main.py index eccb55ed..b20aeee4 100644 --- a/monai/main.py +++ b/monai/main.py @@ -50,13 +50,12 @@ def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id= # which could be sub-optimal. # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that # only the SC is in context. - # self.voxel_cropping_size = (48, 176, 288) self.spacing = (1.0, 1.0, 1.0) - self.voxel_cropping_size = (48, 160, 320) - self.inference_roi_size = tuple([int(i) for i in args.val_crop_size.split("x")]) - if self.inference_roi_size == (-1,): # means no cropping is required - logger.info(f"Using full image for validation ...") - self.inference_roi_size = (-1, -1, -1) + self.voxel_cropping_size = self.inference_roi_size = tuple([int(i) for i in args.crop_size.split("x")]) + # self.inference_roi_size = tuple([int(i) for i in args.val_crop_size.split("x")]) + # if self.inference_roi_size == (-1,): # means no cropping is required + # logger.info(f"Using full image for validation ...") + # self.inference_roi_size = (-1, -1, -1) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) self.val_post_pred = Compose([EnsureType()]) @@ -794,6 +793,10 @@ def main(args): default='unet', type=str, help='Model type to be used') parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') # dataset + # define args for cropping size + parser.add_argument('-crop', '--crop_size', type=str, default="64x192x320", + help='Center crop size for training/validation. Values correspond to R-L, A-P, I-S axes' + 'of the image after 1mm isotropic resampling. Default: 64x192x320') parser.add_argument("--contrast", default="t2w", type=str, help="Contrast to use for training", choices=["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"]) parser.add_argument('--label-type', default='soft', type=str, help="Type of labels to use for training", From 990c3570914e0dcc9c4fc820625f91d0610861de Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 9 Oct 2023 12:06:43 -0400 Subject: [PATCH 106/106] clean code --- monai/main.py | 141 +++++++++++--------------------------------------- 1 file changed, 30 insertions(+), 111 deletions(-) diff --git a/monai/main.py b/monai/main.py index b20aeee4..dd87afb0 100644 --- a/monai/main.py +++ b/monai/main.py @@ -10,20 +10,18 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -from utils import precision_score, recall_score, dice_score, compute_average_csa, \ +from utils import precision_score, recall_score, dice_score, \ PolyLRScheduler, plot_slices, check_empty_patch from losses import SoftDiceLoss, AdapWingLoss from transforms import train_transforms, val_transforms -from models import ModifiedUNet3D, create_nnunet_from_plans +from models import create_nnunet_from_plans from monai.utils import set_determinism from monai.inferers import sliding_window_inference -from monai.networks.nets import UNet, UNETR, DynUNet +from monai.networks.nets import UNETR from monai.data import (DataLoader, Dataset, CacheDataset, load_decathlon_datalist, decollate_batch) -from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImaged, SaveImage) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) -# TODO: -# 1. increase omega in adapwingloss # create a "model"-agnostic class with PL to use different models class Model(pl.LightningModule): @@ -41,7 +39,6 @@ def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id= self.results_path = results_path self.best_val_dice, self.best_val_epoch = 0, 0 - # self.best_val_csa = float("inf") self.best_val_loss = float("inf") # define cropping and padding dimensions @@ -53,9 +50,6 @@ def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id= self.spacing = (1.0, 1.0, 1.0) self.voxel_cropping_size = self.inference_roi_size = tuple([int(i) for i in args.crop_size.split("x")]) # self.inference_roi_size = tuple([int(i) for i in args.val_crop_size.split("x")]) - # if self.inference_roi_size == (-1,): # means no cropping is required - # logger.info(f"Using full image for validation ...") - # self.inference_roi_size = (-1, -1, -1) # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) self.val_post_pred = Compose([EnsureType()]) @@ -74,8 +68,6 @@ def __init__(self, args, data_root, net, loss_function, optimizer_class, exp_id= # FORWARD PASS # -------------------------------- def forward(self, x): - # x, context_features = self.encoder(x) - # preds = self.decoder(x, context_features) out = self.net(x) # # NOTE: MONAI's models only output the logits, not the output after the final activation function @@ -98,11 +90,9 @@ def prepare_data(self): # define training and validation transforms transforms_train = train_transforms( crop_size=self.voxel_cropping_size, - num_samples_pv=self.args.num_samples_per_volume, lbl_key='label' ) transforms_val = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') - # transforms_val = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # load the dataset dataset = os.path.join(self.root, f"dataset_{self.args.contrast}_{self.args.label_type}_seed15.json") @@ -122,7 +112,6 @@ def prepare_data(self): # define test transforms transforms_test = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') - # transforms_test = val_transforms_with_center_crop(crop_size=self.voxel_cropping_size, lbl_key='label') # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 @@ -140,7 +129,6 @@ def prepare_data(self): # DATA LOADERS # -------------------------------- def train_dataloader(self): - # NOTE: if num_samples=4 in RandCropByPosNegLabeld and batch_size=2, then 2 x 4 images are generated for network training return DataLoader(self.train_ds, batch_size=self.args.batch_size, shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True) @@ -173,14 +161,12 @@ def training_step(self, batch, batch_idx): inputs, labels = batch["image"], batch["label"] - # NOTE: surprisingly, filtering out empty patches is adding more CSA bias; TODO: verify with new patch size # check if any label image patch is empty in the batch if check_empty_patch(labels) is None: # print(f"Empty label patch found. Skipping training step ...") return None output = self.forward(inputs) # logits - # if using dynunet, output.shape = (B, num_upsample_layers+1, C, H, W, D) # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") if self.args.model == "nnunet" and self.args.enable_DS: @@ -209,24 +195,8 @@ def training_step(self, batch, batch_idx): loss /= len(output) train_soft_dice /= len(output) - # # binarize the predictions and the labels (take only the final feature map i.e. the final prediction) - # output = (output[0].detach() > 0.5).float() - # labels = (labels.detach() > 0.5).float() - - # # compute CSA for each element of the batch - # # NOTE: the CSA is computed only for the final feature map (i.e. the prediction, not the intermediate deepsupervision feature maps) - # csa_loss = 0.0 - # for batch_idx in range(output.shape[0]): - # pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) - # gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) - # csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 - # # average CSA loss across the batch - # csa_loss = csa_loss / output.shape[0] - else: # calculate training loss - # NOTE: the diceLoss expects the input to be logits (which it then normalizes inside) - # dice_loss = self.loss_function(output, labels) loss = self.loss_function(output, labels) # get probabilities from logits @@ -236,28 +206,9 @@ def training_step(self, batch, batch_idx): # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here # So, take this dice score with a lot of salt train_soft_dice = self.soft_dice_metric(output, labels) - # train_hard_dice = self.soft_dice_metric((output.detach() > 0.5).float(), (labels.detach() > 0.5).float()) - - # # binarize the predictions and the labels - # output = (output.detach() > 0.5).float() - # labels = (labels.detach() > 0.5).float() - - # # compute CSA for each element of the batch - # csa_loss = 0.0 - # for batch_idx in range(output.shape[0]): - # pred_patch_csa = compute_average_csa(output[batch_idx].squeeze(), self.spacing) - # gt_patch_csa = compute_average_csa(labels[batch_idx].squeeze(), self.spacing) - # csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 - # # average CSA loss across the batch - # csa_loss = csa_loss / output.shape[0] - - # # total loss - # loss = dice_loss + csa_loss metrics_dict = { "loss": loss.cpu(), - # "dice_loss": dice_loss.cpu(), - # "csa_loss": csa_loss.cpu(), "train_soft_dice": train_soft_dice.detach().cpu(), "train_number": len(inputs), # "train_image": inputs[0].detach().cpu().squeeze(), @@ -274,25 +225,18 @@ def on_train_epoch_end(self): # means the training step was skipped because of empty input patch return None else: - train_loss, train_dice_loss, train_csa_loss, train_soft_dice = 0, 0, 0, 0 + train_loss, train_soft_dice = 0, 0 num_items = len(self.train_step_outputs) for output in self.train_step_outputs: train_loss += output["loss"].item() - # train_dice_loss += output["dice_loss"].item() - # train_csa_loss += output["csa_loss"].item() train_soft_dice += output["train_soft_dice"].item() - # num_items += output["train_number"] mean_train_loss = (train_loss / num_items) - # mean_train_dice_loss = (train_dice_loss / num_items) - # mean_train_csa_loss = (train_csa_loss / num_items) mean_train_soft_dice = (train_soft_dice / num_items) wandb_logs = { "train_soft_dice": mean_train_soft_dice, "train_loss": mean_train_loss, - # "train_dice_loss": mean_train_dice_loss, - # "train_csa_loss": mean_train_csa_loss, } self.log_dict(wandb_logs) @@ -326,7 +270,6 @@ def validation_step(self, batch, batch_idx): outputs = outputs[0] # calculate validation loss - # dice_loss = self.loss_function(outputs, labels) loss = self.loss_function(outputs, labels) # get probabilities from logits @@ -340,26 +283,11 @@ def validation_step(self, batch, batch_idx): hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) - # # compute val CSA loss - # val_csa_loss = 0.0 - # for batch_idx in range(hard_preds.shape[0]): - # pred_patch_csa = compute_average_csa(hard_preds[batch_idx].squeeze(), self.spacing) - # gt_patch_csa = compute_average_csa(hard_labels[batch_idx].squeeze(), self.spacing) - # val_csa_loss += (pred_patch_csa - gt_patch_csa) ** 2 - - # # average CSA loss across the batch - # val_csa_loss = val_csa_loss / hard_preds.shape[0] - - # # total loss - # loss = dice_loss + val_csa_loss - # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, # using .detach() to avoid storing the whole computation graph # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 metrics_dict = { "val_loss": loss.detach().cpu(), - # "val_dice_loss": dice_loss.detach().cpu(), - # "val_csa_loss": val_csa_loss.detach().cpu(), "val_soft_dice": val_soft_dice.detach().cpu(), "val_hard_dice": val_hard_dice.detach().cpu(), "val_number": len(post_outputs), @@ -374,27 +302,20 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 - val_dice_loss, val_csa_loss = 0, 0 for output in self.val_step_outputs: val_loss += output["val_loss"].sum().item() val_soft_dice += output["val_soft_dice"].sum().item() val_hard_dice += output["val_hard_dice"].sum().item() - # val_dice_loss += output["val_dice_loss"].sum().item() - # val_csa_loss += output["val_csa_loss"].sum().item() num_items += output["val_number"] mean_val_loss = (val_loss / num_items) mean_val_soft_dice = (val_soft_dice / num_items) mean_val_hard_dice = (val_hard_dice / num_items) - # mean_val_dice_loss = (val_dice_loss / num_items) - # mean_val_csa_loss = (val_csa_loss / num_items) wandb_logs = { "val_soft_dice": mean_val_soft_dice, "val_hard_dice": mean_val_hard_dice, "val_loss": mean_val_loss, - # "val_dice_loss": mean_val_dice_loss, - # "val_csa_loss": mean_val_csa_loss, } # save the best model based on validation dice score if mean_val_soft_dice > self.best_val_dice: @@ -402,7 +323,6 @@ def on_validation_epoch_end(self): self.best_val_epoch = self.current_epoch # save the best model based on validation CSA loss - # if mean_val_loss < self.best_val_csa: if mean_val_loss < self.best_val_loss: self.best_val_loss = mean_val_loss self.best_val_epoch = self.current_epoch @@ -412,8 +332,8 @@ def on_validation_epoch_end(self): f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" - f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" - # f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + # f"\nBest Average Soft Dice: {self.best_val_dice:.4f} at Epoch: {self.best_val_epoch}" + f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" f"\n----------------------------------------------------") @@ -438,7 +358,7 @@ def on_validation_epoch_end(self): # -------------------------------- def test_step(self, batch, batch_idx): - test_input, test_label = batch["image"], batch["label"] + test_input = batch["image"] # print(batch["label_meta_dict"]["filename_or_obj"][0]) # print(f"test_input.shape: {test_input.shape} \t test_label.shape: {test_label.shape}") batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, @@ -474,11 +394,11 @@ def test_step(self, batch, batch_idx): # save the prediction pred_saver(pred) - label_saver = SaveImage( - output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", - separate_folder=False, print_log=False, resample=True) - # save the label - label_saver(label) + # label_saver = SaveImage( + # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", + # separate_folder=False, print_log=False, resample=True) + # # save the label + # label_saver(label) # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics @@ -574,18 +494,7 @@ def main(args): optimizer_class = torch.optim.SGD # define models - if args.model in ["unet"]: - logger.info(f" Using ivadomed's UNet model! ") - # this is the ivadomed unet model - net = ModifiedUNet3D(in_channels=1, out_channels=1, init_filters=args.init_filters) - patch_size = "48x160x320" # "160x224x96" - save_exp_id =f"ivado_{args.model}_nf={args.init_filters}_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_AdapW_valCCrop_bs={args.batch_size}_{patch_size}" - if args.debug: - save_exp_id = f"DEBUG_{save_exp_id}" - - - elif args.model in ["unetr"]: + if args.model in ["unetr"]: # define image size to be fed to the model img_size = (160, 224, 96) @@ -615,11 +524,19 @@ def main(args): # define model net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision=args.enable_DS) - patch_size = "48x160x320" - save_exp_id =f"{args.model}_nf={args.init_filters}_DS={int(args.enable_DS)}" \ + patch_size = "64x192x320" + save_exp_id =f"{args.model}_{args.contrast}_{args.label_type}_nf={args.init_filters}" \ f"_opt={args.optimizer}_lr={args.learning_rate}" \ - f"_AdapW_CCrop" \ + f"_AdapW" \ f"_bs={args.batch_size}_{patch_size}" + # save_exp_id =f"{args.model}_{args.contrast}_{args.label_type}_nf={args.init_filters}" \ + # f"_opt={args.optimizer}_lr={args.learning_rate}" \ + # f"_DiceL" \ + # f"_bs={args.batch_size}_{patch_size}" + + if args.debug: + save_exp_id = f"DEBUG_{save_exp_id}" + # TODO: move this inside the for loop when using more folds timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format @@ -629,7 +546,11 @@ def main(args): logger.add(os.path.join(args.save_path, f"{save_exp_id}", "logs.txt"), rotation="10 MB", level="INFO") # define loss function + # loss_func = SoftDiceLoss(p=1, smooth=1.0) + # logger.info(f"Using SoftDiceLoss with p={loss_func.p}, smooth={loss_func.smooth}!") loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon}!") # define callbacks @@ -789,7 +710,7 @@ def main(args): parser = argparse.ArgumentParser(description='Script for training custom models for SCI Lesion Segmentation.') # Arguments for model, data, and training and saving - parser.add_argument('-m', '--model', choices=['unet', 'unetr', 'nnunet'], + parser.add_argument('-m', '--model', choices=['unetr', 'nnunet'], default='unet', type=str, help='Model type to be used') parser.add_argument('--enable_DS', default=False, action='store_true', help='Enable Deep Supervision') # dataset @@ -804,8 +725,6 @@ def main(args): # unet model parser.add_argument('-initf', '--init_filters', default=16, type=int, help="Number of Filters in Init Layer") - # parser.add_argument('-ps', '--patch_size', type=int, default=128, help='List containing subvolume size') - parser.add_argument('-dep', '--unet_depth', default=3, type=int, help="Depth of UNet model") # unetr model parser.add_argument('-fs', '--feature_size', default=16, type=int, help="Feature Size")