Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add monai-based scripts for dataset conversion, training, and inference #60

Merged
merged 106 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
2fa8821
add dataset conversion script for monai training
naga-karthik Jul 10, 2023
9395732
add file for loss functions
naga-karthik Jul 10, 2023
f109634
add transforms
naga-karthik Jul 10, 2023
d11a99c
add util functions
naga-karthik Jul 10, 2023
6590125
add working monai training script
naga-karthik Jul 10, 2023
cc244ae
change nearest interp to linear for labels in spacingD
naga-karthik Jul 10, 2023
7e8261a
fix bug in train/val/test metrics accumulation
naga-karthik Jul 10, 2023
d361b73
update plot_slices()
naga-karthik Jul 10, 2023
11f3ac5
fix cuda memory leak issues; remove unused variables
naga-karthik Jul 10, 2023
e88afbc
add PolyLR scheduler
naga-karthik Jul 11, 2023
b7ecd0a
swap order of spacingD and cropForegroundD
naga-karthik Jul 11, 2023
f73f22e
add empty patch filtering; change LR scheduler
naga-karthik Jul 11, 2023
5bb5315
add requirements
naga-karthik Jul 11, 2023
c7a0281
add option to load from joblib for ivadomed comparison
naga-karthik Jul 12, 2023
11961e7
add elastic and biasfield transforms
naga-karthik Jul 14, 2023
d230e59
change empty image patch filter to label patch
naga-karthik Jul 14, 2023
56dfff3
add initial verson of infeerence script
naga-karthik Jul 17, 2023
ad3d847
add ref to chkpt loading for inference
naga-karthik Jul 17, 2023
407ae04
remove duplicated test transforms
naga-karthik Jul 19, 2023
e67427f
add feature to plot more slices
naga-karthik Jul 19, 2023
72ac6fe
minor fixes and improvements
naga-karthik Jul 19, 2023
0540c1f
update code to do inference on various datasets
naga-karthik Jul 20, 2023
5d93f9a
fix diceLoss; add DiceCE loss
naga-karthik Jul 21, 2023
bd81aca
change to explicitly normalizing the logits
naga-karthik Jul 21, 2023
29825f9
add option for attentionunet
naga-karthik Jul 21, 2023
6b43c59
remove todo; update args for path-data
naga-karthik Jul 25, 2023
ac2572c
remove RobustCELoss; update DiceCELoss
naga-karthik Jul 25, 2023
8ddddeb
Notable changes:
naga-karthik Jul 25, 2023
36da5e1
remove attn-unet; update save_exp_id to contain more hyperparams
naga-karthik Jul 25, 2023
ec95391
remove individual axis flip in RandFlipd causing collation error
naga-karthik Jul 25, 2023
1e5610b
bring changes from train to val transforms
naga-karthik Jul 25, 2023
d3fadaf
lower prob of RandFlip transform
naga-karthik Aug 1, 2023
e551206
minor changes
naga-karthik Aug 7, 2023
31f0b59
add Modified3DUNet model from ivadomed
naga-karthik Aug 7, 2023
8fa0f49
add initial veersion of AdapWingLoss
naga-karthik Aug 7, 2023
4db99f7
add RandomGaussianSmooth transform (ie RandomBlur)
naga-karthik Aug 7, 2023
3f4aa39
minor changes
naga-karthik Aug 15, 2023
60fbd9f
add func to compute avg csa
naga-karthik Aug 15, 2023
3494106
remove unused comments; change opt to Adam, add ivadomed unet,
naga-karthik Aug 15, 2023
18e8c91
update to add csa loss during training/val
naga-karthik Aug 15, 2023
1023122
remove monai unet, add ivadomed unet
naga-karthik Aug 21, 2023
8f324c2
finalize csa loss-related changes
naga-karthik Aug 21, 2023
3aec0d5
minor fixes
naga-karthik Aug 21, 2023
a2677ce
refactor to add ensembling
naga-karthik Aug 21, 2023
801a155
add arg to run inference using unetr
naga-karthik Aug 26, 2023
22e5b18
add function to check for empty patches
naga-karthik Aug 26, 2023
27cf099
add TODO about adding RandSimulateLowResolution
naga-karthik Aug 26, 2023
182d27a
increase num_workers to speed up time/epoch
naga-karthik Aug 26, 2023
d8719b6
fix code to filter out empty patches
naga-karthik Aug 26, 2023
3f45cb8
minor update UNetR params; add simple profiling
naga-karthik Aug 26, 2023
1945b0d
add argument for specifying model to use for inference
naga-karthik Aug 29, 2023
e319ca6
add dynunet model for training
naga-karthik Aug 29, 2023
41e47b8
update training_step to for handling deepsupervison outputs in the lo…
naga-karthik Aug 29, 2023
d6582c3
add model checkpointing based on val csa and dice
naga-karthik Aug 29, 2023
4017b9e
move network helper functions from models.py
naga-karthik Aug 29, 2023
674c5b4
add script for creating MSD datalists for running inference on pathlo…
naga-karthik Aug 29, 2023
91ebb0e
add function to create model used in nnunet
naga-karthik Aug 29, 2023
2f511e2
add option to train using the model used in nnunet
naga-karthik Aug 29, 2023
2e2a061
update train/validation step to deal with deepsupervison outputs in l…
naga-karthik Aug 29, 2023
d1fc56f
add options to create datalists per contrast and specify hard/soft la…
naga-karthik Aug 29, 2023
8a83018
rearrange transforms to match ivadomed's; add RandSimulateLowResolution
naga-karthik Aug 31, 2023
c59edac
update train transforms like nnunet's
naga-karthik Sep 4, 2023
77d3bea
add working version of AdapWingLoss
naga-karthik Sep 7, 2023
dd0ece4
add training transforms as per ivadomed
naga-karthik Sep 11, 2023
0fc8b90
add val_transforms_with_center_crop()
naga-karthik Sep 11, 2023
2cff684
modify train/val steps for training with AdapWingLoss
naga-karthik Sep 11, 2023
f1a8343
modify dataloading to use val_transforms_with_center_crop
naga-karthik Sep 11, 2023
8777cd3
minor modifications to match ivadomed's training
naga-karthik Sep 11, 2023
284898f
add transforms used by nnunet; replace centerCrop
naga-karthik Sep 13, 2023
8c9fcd9
add feature to resume training from checkpoint
naga-karthik Sep 13, 2023
6b88362
fix warnings; remove comments; fix save names
naga-karthik Sep 13, 2023
3a76d31
add variants of validation transforms
naga-karthik Sep 18, 2023
0df19a6
re-use nnunet-like transforms
naga-karthik Sep 18, 2023
79b369f
update code to train nnunet-based model
naga-karthik Sep 18, 2023
5affa03
modify code to log everything to a txt file
naga-karthik Sep 18, 2023
e52c6c8
add code for inference with monai-based nnunet
naga-karthik Sep 18, 2023
36ae67e
save terminal outputs to log file; add args
naga-karthik Sep 18, 2023
9763ed5
add documentation on usage
naga-karthik Sep 18, 2023
c6e3e92
finalize train transforms
naga-karthik Sep 21, 2023
9bb37d3
rename to inference_transforms(); add DivisiblePadd
naga-karthik Sep 21, 2023
4ab0d4b
add requirements for cpu inference
naga-karthik Sep 21, 2023
82113d8
add init version of instructions
naga-karthik Sep 21, 2023
265c704
fix torch cpu version download
naga-karthik Sep 25, 2023
edaed3d
add script for inference on single image
naga-karthik Sep 25, 2023
043804c
add instructions for single-image inference
naga-karthik Sep 25, 2023
963e9d2
add inference transforms; remove transforms import
naga-karthik Sep 25, 2023
5330b19
add nibabel dep for monai; add scipy
naga-karthik Sep 25, 2023
e0d90b0
udpate args description for crop size
naga-karthik Sep 26, 2023
5a9a123
add sorting of subjects
naga-karthik Sep 26, 2023
57de777
remove dataset-based inference
naga-karthik Sep 27, 2023
890c465
update args description about cropping
naga-karthik Sep 28, 2023
650bb73
remove building_blocks
naga-karthik Sep 28, 2023
68f6bd3
move weights init class from building_blocks
naga-karthik Sep 28, 2023
2a18878
remove usage example in docstring
naga-karthik Sep 28, 2023
ee688d3
remove saving in separate folders; clean script
naga-karthik Oct 2, 2023
d5bea6b
create standalone script
naga-karthik Oct 2, 2023
f88d927
create light-weight, clean requirements txt
naga-karthik Oct 2, 2023
c4cbd61
change default crop size
naga-karthik Oct 8, 2023
4694169
fix bug with torch.clamp setting bg values to 0.5
naga-karthik Oct 8, 2023
2e35282
remove +cpu from torch version
naga-karthik Oct 9, 2023
5e3f97a
renamed as README, removed usage instructions
naga-karthik Oct 9, 2023
bdb3e98
cleaned by removing unused code
naga-karthik Oct 9, 2023
5cce1e8
add cupy for inference on gpu
naga-karthik Oct 9, 2023
2151b23
add args for contrast,label-type for per-contrast,hard/soft training
naga-karthik Oct 9, 2023
227ba30
fix same crop size of train/val
naga-karthik Oct 9, 2023
990c357
clean code
naga-karthik Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions monai/create_msd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import os
import json
from tqdm import tqdm
import numpy as np
import argparse
import joblib
from utils import FoldGenerator
from loguru import logger
from sklearn.model_selection import train_test_split

# TODO: split the data using ivadomed joblib file

root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped"

parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.')

parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility")
parser.add_argument('-ncvf', '--num-cv-folds', default=5, type=int,
help="[1-k] To create a k-fold dataset for cross validation, 0 for single file with all subjects")
parser.add_argument('-pd', '--path-data', default=root, type=str, help='Path to the data set directory')
parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.',
default=None, type=str)
parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved')
args = parser.parse_args()


root = args.path_data
seed = args.seed
num_cv_folds = args.num_cv_folds # for 100 subjects, performs a 60-20-20 split with num_cv_folds

# Get all subjects
# the participants.tsv file might not be up-to-date, hence rely on the existing folders
# subjects_df = pd.read_csv(os.path.join(root, 'participants.tsv'), sep='\t')
# subjects = subjects_df['participant_id'].values.tolist()
subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')]
logger.info(f"Total number of subjects in the root directory: {len(subjects)}")

if args.num_cv_folds != 0:
# create k-fold CV datasets as usual

# returns a nested list of length (num_cv_folds), each element (again, a list) consisting of
# train, val, test indices and the fold number
names_list = FoldGenerator(seed, num_cv_folds, len_data=len(subjects)).get_fold_names()

for fold in range(num_cv_folds):

train_ix, val_ix, test_ix, fold_num = names_list[fold]
training_subjects = [subjects[tr_ix] for tr_ix in train_ix]
validation_subjects = [subjects[v_ix] for v_ix in val_ix]
test_subjects = [subjects[te_ix] for te_ix in test_ix]

# keys to be defined in the dataset_0.json
params = {}
params["description"] = "sci-zurich naga"
params["labels"] = {
"0": "background",
"1": "sc-lesion"
}
params["license"] = "nk"
params["modality"] = {
"0": "MRI"
}
params["name"] = "sci-zurich"
params["numTest"] = len(test_subjects)
params["numTraining"] = len(training_subjects) + len(validation_subjects)
params["reference"] = "University of Zurich"
params["tensorImageSize"] = "3D"


train_val_subjects_dict = {
"training": training_subjects,
"validation": validation_subjects,
}
test_subjects_dict = {"test": test_subjects}

# run loop for training and validation subjects
temp_shapes_list = []
for name, subs_list in train_val_subjects_dict.items():

temp_list = []
for subject_no, subject in enumerate(tqdm(subs_list, desc='Loading Volumes')):

# Another for loop for going through sessions
temp_subject_path = os.path.join(root, subject)
num_sessions_per_subject = sum(os.path.isdir(os.path.join(temp_subject_path, pth)) for pth in os.listdir(temp_subject_path))

for ses_idx in range(1, num_sessions_per_subject+1):
temp_data = {}
# Get paths with session numbers
session = 'ses-0' + str(ses_idx)
subject_images_path = os.path.join(root, subject, session, 'anat')
subject_labels_path = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat')

subject_image_file = os.path.join(subject_images_path, '%s_%s_acq-sag_T2w.nii.gz' % (subject, session))
subject_label_file = os.path.join(subject_labels_path, '%s_%s_acq-sag_T2w_lesion-manual.nii.gz' % (subject, session))

# get shapes of each subject to calculate median later
# temp_shapes_list.append(np.shape(nib.load(subject_image_file).get_fdata()))

# # load GT mask
# gt_label = nib.load(subject_label_file).get_fdata()
# bbox_coords = get_bounding_boxes(mask=gt_label)

# store in a temp dictionary
temp_data["image"] = subject_image_file.replace(root+"/", '') # .strip(root)
temp_data["label"] = subject_label_file.replace(root+"/", '') # .strip(root)
# temp_data["box"] = bbox_coords

temp_list.append(temp_data)

params[name] = temp_list

# print(temp_shapes_list)
# calculate the median shapes along each axis
params["train_val_median_shape"] = np.median(temp_shapes_list, axis=0).tolist()

# run separate loop for testing
for name, subs_list in test_subjects_dict.items():
temp_list = []
for subject_no, subject in enumerate(tqdm(subs_list, desc='Loading Volumes')):

# Another for loop for going through sessions
temp_subject_path = os.path.join(root, subject)
num_sessions_per_subject = sum(os.path.isdir(os.path.join(temp_subject_path, pth)) for pth in os.listdir(temp_subject_path))

for ses_idx in range(1, num_sessions_per_subject+1):
temp_data = {}
# Get paths with session numbers
session = 'ses-0' + str(ses_idx)
subject_images_path = os.path.join(root, subject, session, 'anat')
subject_labels_path = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat')

subject_image_file = os.path.join(subject_images_path, '%s_%s_acq-sag_T2w.nii.gz' % (subject, session))
subject_label_file = os.path.join(subject_labels_path, '%s_%s_acq-sag_T2w_lesion-manual.nii.gz' % (subject, session))

# # load GT mask
# gt_label = nib.load(subject_label_file).get_fdata()
# bbox_coords = get_bounding_boxes(mask=gt_label)

temp_data["image"] = subject_image_file.replace(root+"/", '')
temp_data["label"] = subject_label_file.replace(root+"/", '')
# temp_data["box"] = bbox_coords

temp_list.append(temp_data)

params[name] = temp_list

final_json = json.dumps(params, indent=4, sort_keys=True)
jsonFile = open(root + "/" + f"dataset_fold-{fold_num}.json", "w")
jsonFile.write(final_json)
jsonFile.close()
else:

if args.path_joblib is not None:
# load information from the joblib to match train and test subjects
joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib')
splits = joblib.load("split_datasets_all_seed=15.joblib")
# get the subjects from the joblib file
train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']])))
val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']])))
test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']])))

else:
# create one json file with 60-20-20 train-val-test split
train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2
train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed)
# Use the training split to further split into training and validation splits
train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio),
random_state=args.seed, )

logger.info(f"Number of training subjects: {len(train_subjects)}")
logger.info(f"Number of validation subjects: {len(val_subjects)}")
logger.info(f"Number of testing subjects: {len(test_subjects)}")

# keys to be defined in the dataset_0.json
params = {}
params["description"] = "spine-generic-uncropped"
params["labels"] = {
"0": "background",
"1": "soft-sc-seg"
}
params["license"] = "nk"
params["modality"] = {
"0": "MRI"
}
params["name"] = "spine-generic"
params["numTest"] = len(test_subjects)
params["numTraining"] = len(train_subjects)
params["numValidation"] = len(val_subjects)
params["seed"] = args.seed
params["reference"] = "University of Zurich"
params["tensorImageSize"] = "3D"

train_subjects_dict = {"train": train_subjects}
val_subjects_dict = {"validation": val_subjects}
test_subjects_dict = {"test": test_subjects}
all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict]

# define the contrasts
contrasts_list = ['T1w', 'T2w', 'T2star', 'flip-1_mt-on_MTS', 'flip-2_mt-off_MTS', 'dwi']

for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"):

for name, subs_list in subjects_dict.items():

temp_list = []
for subject_no, subject in enumerate(subs_list):

temp_data_t1w = {}
temp_data_t2w = {}
temp_data_t2star = {}
temp_data_mton_mts = {}
temp_data_mtoff_mts = {}
temp_data_dwi = {}

# t1w
temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz")
temp_data_t1w["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T1w_softseg.nii.gz")
if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]):
temp_list.append(temp_data_t1w)

# t2w
temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz")
temp_data_t2w["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T2w_softseg.nii.gz")
if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]):
temp_list.append(temp_data_t2w)

# t2star
temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz")
temp_data_t2star["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_T2star_softseg.nii.gz")
if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]):
temp_list.append(temp_data_t2star)

# mton_mts
temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz")
temp_data_mton_mts["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_flip-1_mt-on_MTS_softseg.nii.gz")
if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]):
temp_list.append(temp_data_mton_mts)

# t1w_mts
temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz")
temp_data_mtoff_mts["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'anat', f"{subject}_flip-2_mt-off_MTS_softseg.nii.gz")
if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]):
temp_list.append(temp_data_mtoff_mts)

# dwi
temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz")
temp_data_dwi["label"] = os.path.join(root, "derivatives", "labels_softseg", subject, 'dwi', f"{subject}_rec-average_dwi_softseg.nii.gz")
if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]):
temp_list.append(temp_data_dwi)

params[name] = temp_list
logger.info(f"Number of images in {name} set: {len(temp_list)}")

final_json = json.dumps(params, indent=4, sort_keys=True)
jsonFile = open(args.path_out + "/" + f"dataset.json", "w")
jsonFile.write(final_json)
jsonFile.close()






27 changes: 27 additions & 0 deletions monai/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import torch.nn as nn


class SoftDiceLoss(nn.Module):
'''
soft-dice loss, useful in binary segmentation
taken from: https://github.com/CoinCheung/pytorch-loss/blob/master/soft_dice_loss.py
'''
def __init__(self, p=1, smooth=1):
super(SoftDiceLoss, self).__init__()
self.p = p
self.smooth = smooth

def forward(self, preds, labels):
'''
inputs:
preds: normalized probabilities (not logits) - tensor of shape (N, H, W, ...)
labels: soft labels [0,1] - tensor of shape(N, H, W, ...)
output:
loss: tensor of shape(1, )
'''
# probs = torch.sigmoid(logits)
numer = (preds * labels).sum()
denor = (preds.pow(self.p) + labels.pow(self.p)).sum()
loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth)
return loss
Loading