Skip to content

Commit

Permalink
Merge pull request #60 from sct-pipeline/nk/monai
Browse files Browse the repository at this point in the history
add monai-based scripts for dataset conversion, training, and inference
  • Loading branch information
naga-karthik authored Oct 9, 2023
2 parents 2b8e5b9 + 990c357 commit 112b450
Show file tree
Hide file tree
Showing 10 changed files with 2,056 additions and 0 deletions.
73 changes: 73 additions & 0 deletions monai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
## Instructions for running inference with the contrast-agnostic spinal cord segmentation model

The following steps are required for using the contrast-agnostic model.

### Setting up the environment and Installing dependencies

The following commands show how to set up the environment. Note that the documentation assumes that the user has `conda` installed on their system. Instructions on installing `conda` can be found [here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).

1. Create a conda environment with the following command:

```bash
conda create -n venv_monai python=3.9
```

2. Activate the environment with the following command:

```bash
conda activate venv_monai
```

3. The list of necessary packages can be found in `requirements_inference.txt`. Use the following command for installation:

```bash
pip install -r requirements_inference.txt
```

### Method 1: Running inference on a single image

The script for running inference is `run_inference_single_image.py`. Please run
```
python run_inference_single_image.py -h
```
to get the list of arguments and their descriptions.


### Method 2: Running inference on a dataset (Advanced)

NOTE: This section is experimental and for advanced users only. Please use Method 1 for running inference.

#### Creating a datalist

The inference script assumes the dataset to be in Medical Segmentation Decathlon-style `json` file format containing image-label pairs. The `create_inference_msd_datalist.py` script allows to create one for your dataset. Use the following command to create the datalist:

```bash
python create_inference_msd_datalist.py --dataset-name spine-generic --path-data <path-to-dataset> --path-out <path-to-output-folder> --contrast-suffix T1w
```

`--dataset-name` - Corresponds to name of the dataset. The datalist will be saved as `<dname>_dataset.json`
`--path-data` - Path to the BIDS dataset
`--path-out` - Path to the output folder. The datalist will be saved under `<path-out>/<dname>_dataset.json`
`--contrast-suffix` - The suffix of the contrast to be used for pairing images/labels

> **Note**
> This script is not meant to run off-the-shelf. Placeholders are provided to update the script with the .... TODO

#### Running inference

Use the following command:

```bash
python run_inference.py --path-json <path/to/json> --chkp-path <path/to/checkpoint> --path-out <path/to/output> --model <unet/nnunet> --crop_size <48x160x320> --device <gpu/cpu>
```

`--path-json` - Path to the datalist created in Step 2
`--chkp-path` - Path to the model checkpoint. This folder should contain the `best_model_loss.ckpt`
`--path-out` - Path to the output folder where the predictions will be saved
`--model` - Model to be used for inference. Currently, only `unet` and `nnunet` are supported
`--crop_size` - Crop size used for center cropping the image before running inference. Recommended to be set to a multiple of 32
`--device` - Device to be used for inference. Currently, only `gpu` and `cpu` are supported



226 changes: 226 additions & 0 deletions monai/create_msd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import os
import json
from tqdm import tqdm
import numpy as np
import argparse
import joblib
from utils import FoldGenerator
from loguru import logger
from sklearn.model_selection import train_test_split

# root = "/home/GRAMES.POLYMTL.CA/u114716/datasets/spine-generic_uncropped"

parser = argparse.ArgumentParser(description='Code for creating k-fold splits of the spine-generic dataset.')

parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the data set directory')
parser.add_argument('-pj', '--path-joblib', help='Path to joblib file from ivadomed containing the dataset splits.',
default=None, type=str)
parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved')
parser.add_argument("--contrast", default="t2w", type=str, help="Contrast to use for training",
choices=["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"])
parser.add_argument('--label-type', default='soft', type=str, help="Type of labels to use for training",
choices=['hard', 'soft'])
parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility")
args = parser.parse_args()


root = args.path_data
seed = args.seed
contrast = args.contrast
if args.label_type == 'soft':
logger.info("Using SOFT LABELS ...")
PATH_DERIVATIVES = os.path.join(root, "derivatives", "labels_softseg")
SUFFIX = "softseg"
else:
logger.info("Using HARD LABELS ...")
PATH_DERIVATIVES = os.path.join(root, "derivatives", "labels")
SUFFIX = "seg-manual"

# Get all subjects
# the participants.tsv file might not be up-to-date, hence rely on the existing folders
# subjects_df = pd.read_csv(os.path.join(root, 'participants.tsv'), sep='\t')
# subjects = subjects_df['participant_id'].values.tolist()
subjects = [subject for subject in os.listdir(root) if subject.startswith('sub-')]
logger.info(f"Total number of subjects in the root directory: {len(subjects)}")

if args.path_joblib is not None:
# load information from the joblib to match train and test subjects
# joblib_file = os.path.join(args.path_joblib, 'split_datasets_all_seed=15.joblib')
splits = joblib.load(args.path_joblib)
# get the subjects from the joblib file
train_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['train']])))
val_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['valid']])))
test_subjects = sorted(list(set([sub.split('_')[0] for sub in splits['test']])))

else:
# create one json file with 60-20-20 train-val-test split
train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2
train_subjects, test_subjects = train_test_split(subjects, test_size=test_ratio, random_state=args.seed)
# Use the training split to further split into training and validation splits
train_subjects, val_subjects = train_test_split(train_subjects, test_size=val_ratio / (train_ratio + val_ratio),
random_state=args.seed, )
# sort the subjects
train_subjects = sorted(train_subjects)
val_subjects = sorted(val_subjects)
test_subjects = sorted(test_subjects)

logger.info(f"Number of training subjects: {len(train_subjects)}")
logger.info(f"Number of validation subjects: {len(val_subjects)}")
logger.info(f"Number of testing subjects: {len(test_subjects)}")

# keys to be defined in the dataset_0.json
params = {}
params["description"] = "spine-generic-uncropped"
params["labels"] = {
"0": "background",
"1": "soft-sc-seg"
}
params["license"] = "nk"
params["modality"] = {
"0": "MRI"
}
params["name"] = "spine-generic"
params["numTest"] = len(test_subjects)
params["numTraining"] = len(train_subjects)
params["numValidation"] = len(val_subjects)
params["seed"] = args.seed
params["reference"] = "University of Zurich"
params["tensorImageSize"] = "3D"

train_subjects_dict = {"train": train_subjects}
val_subjects_dict = {"validation": val_subjects}
test_subjects_dict = {"test": test_subjects}
all_subjects_list = [train_subjects_dict, val_subjects_dict, test_subjects_dict]

# # define the contrasts
# contrasts_list = ['T1w', 'T2w', 'T2star', 'flip-1_mt-on_MTS', 'flip-2_mt-off_MTS', 'dwi']

for subjects_dict in tqdm(all_subjects_list, desc="Iterating through train/val/test splits"):

for name, subs_list in subjects_dict.items():

temp_list = []
for subject_no, subject in enumerate(subs_list):

if contrast == "all":
temp_data_t1w = {}
temp_data_t2w = {}
temp_data_t2star = {}
temp_data_mton_mts = {}
temp_data_mtoff_mts = {}
temp_data_dwi = {}

# t1w
temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz")
temp_data_t1w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T1w_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]):
temp_list.append(temp_data_t1w)

# t2w
temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz")
temp_data_t2w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2w_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]):
temp_list.append(temp_data_t2w)

# t2star
temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz")
temp_data_t2star["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2star_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]):
temp_list.append(temp_data_t2star)

# mton_mts
temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz")
temp_data_mton_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-1_mt-on_MTS_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]):
temp_list.append(temp_data_mton_mts)

# t1w_mts
temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz")
temp_data_mtoff_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-2_mt-off_MTS_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]):
temp_list.append(temp_data_mtoff_mts)

# dwi
temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz")
temp_data_dwi["label"] = os.path.join(PATH_DERIVATIVES, subject, 'dwi', f"{subject}_rec-average_dwi_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]):
temp_list.append(temp_data_dwi)


elif contrast == "t1w": # t1w
temp_data_t1w = {}
temp_data_t1w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T1w.nii.gz")
temp_data_t1w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T1w_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_t1w["label"]) and os.path.exists(temp_data_t1w["image"]):
temp_list.append(temp_data_t1w)
else:
logger.info(f"Subject {subject} does not have T1w image or label.")


elif contrast == "t2w": # t2w
temp_data_t2w = {}
temp_data_t2w["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2w.nii.gz")
temp_data_t2w["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2w_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_t2w["label"]) and os.path.exists(temp_data_t2w["image"]):
temp_list.append(temp_data_t2w)
else:
logger.info(f"Subject {subject} does not have T2w image or label.")


elif contrast == "t2star": # t2star
temp_data_t2star = {}
temp_data_t2star["image"] = os.path.join(root, subject, 'anat', f"{subject}_T2star.nii.gz")
temp_data_t2star["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_T2star_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_t2star["label"]) and os.path.exists(temp_data_t2star["image"]):
temp_list.append(temp_data_t2star)
else:
logger.info(f"Subject {subject} does not have T2star image or label.")


elif contrast == "mton": # mton_mts
temp_data_mton_mts = {}
temp_data_mton_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-1_mt-on_MTS.nii.gz")
temp_data_mton_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-1_mt-on_MTS_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_mton_mts["label"]) and os.path.exists(temp_data_mton_mts["image"]):
temp_list.append(temp_data_mton_mts)
else:
logger.info(f"Subject {subject} does not have MTOn image or label.")

elif contrast == "mtoff": # t1w_mts
temp_data_mtoff_mts = {}
temp_data_mtoff_mts["image"] = os.path.join(root, subject, 'anat', f"{subject}_flip-2_mt-off_MTS.nii.gz")
temp_data_mtoff_mts["label"] = os.path.join(PATH_DERIVATIVES, subject, 'anat', f"{subject}_flip-2_mt-off_MTS_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_mtoff_mts["label"]) and os.path.exists(temp_data_mtoff_mts["image"]):
temp_list.append(temp_data_mtoff_mts)
else:
logger.info(f"Subject {subject} does not have MTOff image or label.")

elif contrast == "dwi": # dwi
temp_data_dwi = {}
temp_data_dwi["image"] = os.path.join(root, subject, 'dwi', f"{subject}_rec-average_dwi.nii.gz")
temp_data_dwi["label"] = os.path.join(PATH_DERIVATIVES, subject, 'dwi', f"{subject}_rec-average_dwi_{SUFFIX}.nii.gz")
if os.path.exists(temp_data_dwi["label"]) and os.path.exists(temp_data_dwi["image"]):
temp_list.append(temp_data_dwi)
else:
logger.info(f"Subject {subject} does not have DWI image or label.")

else:
raise ValueError(f"Contrast {contrast} not recognized.")


params[name] = temp_list
logger.info(f"Number of images in {name} set: {len(temp_list)}")

final_json = json.dumps(params, indent=4, sort_keys=True)
if not os.path.exists(args.path_out):
os.makedirs(args.path_out, exist_ok=True)

jsonFile = open(args.path_out + "/" + f"dataset_{contrast}_{args.label_type}_seed{seed}.json", "w")
jsonFile.write(final_json)
jsonFile.close()






Loading

0 comments on commit 112b450

Please sign in to comment.