-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from axondeepseg/ac/add_nnunet_scripts
Add nnUNetv2 scripts
- Loading branch information
Showing
4 changed files
with
394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Prepares a new dataset for nnUNetv2, focusing on BF segmentation. | ||
Features: | ||
- Training Set Compilation: Includes all subjects with annotations in the | ||
training set. nnUNetv2 will perform automatic cross-validation using these | ||
annotated subjects. | ||
- Testing Set Assignment: Allocates subjects without annotations to the | ||
testing set, facilitating model performance evaluation on unseen data. | ||
- Inspiration: The structure and methodology of this script is | ||
inspired by Armand Collin's work. The original script by Armand Collin can | ||
be found at: | ||
https://github.com/axondeepseg/model_seg_rabbit_axon-myelin_bf/blob/main/nnUNet_scripts/prepare_data.py | ||
""" | ||
|
||
|
||
__author__ = "Arthur Boschet" | ||
__license__ = "MIT" | ||
|
||
|
||
import argparse | ||
import csv | ||
import json | ||
import os | ||
import re | ||
from pathlib import Path | ||
from typing import Dict, List, Literal | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
|
||
def create_directories(base_dir: str, subdirs: List[str]): | ||
""" | ||
Creates subdirectories in a specified base directory. | ||
Parameters | ||
---------- | ||
base_dir : str | ||
The base directory where subdirectories will be created. | ||
subdirs : List[str] | ||
A list of subdirectory names to create within the base directory. | ||
""" | ||
for subdir in subdirs: | ||
os.makedirs(os.path.join(base_dir, subdir), exist_ok=True) | ||
|
||
|
||
def save_json(data: Dict, file_path: str): | ||
""" | ||
Saves a dictionary as a JSON file at the specified path. | ||
Parameters | ||
---------- | ||
data : Dict | ||
Dictionary to be saved as JSON. | ||
file_path : str | ||
File path where the JSON file will be saved. | ||
""" | ||
with open(file_path, "w") as f: | ||
json.dump(data, f, indent=2) | ||
|
||
|
||
def process_images( | ||
datapath: Path, | ||
out_folder: str, | ||
participants_to_sample_dict: Dict[str, List[str]], | ||
bids_to_nnunet_dict: Dict[str, int], | ||
dataset_name: str, | ||
is_test: bool = False, | ||
): | ||
""" | ||
Processes all image files in each subject's directory. | ||
Parameters | ||
---------- | ||
datapath : Path | ||
Path to the data directory. | ||
out_folder : str | ||
Output directory to save processed images. | ||
participants_to_sample_dict : Dict[str, List[str]] | ||
Dictionary mapping participant IDs to sample IDs. | ||
bids_to_nnunet_dict : Dict[str, int] | ||
Dictionary mapping subject names to case IDs. | ||
dataset_name : str | ||
Name of the dataset. | ||
is_test : bool, optional | ||
Boolean flag indicating if the images are for testing, by default False. | ||
""" | ||
folder_type = "imagesTs" if is_test else "imagesTr" | ||
image_suffix = "_0000" | ||
|
||
for subject in participants_to_sample_dict.keys(): | ||
for image in participants_to_sample_dict[subject]: | ||
case_id = bids_to_nnunet_dict[str((subject, image))] | ||
image_path = os.path.join( | ||
datapath, subject, "micr", f"{subject}_{image}_acq-roi_BF.png" | ||
) | ||
img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) | ||
fname = f"{dataset_name}_{case_id:03d}{image_suffix}.png" | ||
cv2.imwrite(os.path.join(out_folder, folder_type, fname), img) | ||
|
||
|
||
def process_labels( | ||
datapath: Path, | ||
out_folder: str, | ||
participants_to_sample_dict: Dict[str, List[str]], | ||
bids_to_nnunet_dict: Dict[str, int], | ||
dataset_name: str, | ||
label_type: Literal["axonmyelin", "myelin", "axon"] = "axonmyelin", | ||
): | ||
""" | ||
Processes label images from a list of subjects, matching each image with the label having the largest 'N' number. | ||
Parameters | ||
---------- | ||
datapath : Path | ||
Path to the data directory. | ||
out_folder : str | ||
Output directory to save processed label images. | ||
participants_to_sample_dict : Dict[str, List[str]] | ||
Dictionary mapping participant IDs to sample IDs. | ||
bids_to_nnunet_dict : Dict[str, int] | ||
Dictionary mapping subject names to case IDs. | ||
dataset_name : str | ||
Name of the dataset. | ||
label_type : Literal["axonmyelin", "myelin", "axon"], optional | ||
Type of label to use. Options are 'axonmyelin', 'myelin', or 'axon'. Defaults to 'axonmyelin'. | ||
""" | ||
label_type_to_divisor = {"axonmyelin": 127, "myelin": 255, "axon": 255} | ||
for subject in participants_to_sample_dict.keys(): | ||
for image in participants_to_sample_dict[subject]: | ||
case_id = bids_to_nnunet_dict[str((subject, image))] | ||
label_path = os.path.join( | ||
datapath, | ||
"derivatives", | ||
"labels", | ||
subject, | ||
"micr", | ||
f"{subject}_{image}_acq-roi_BF_seg-{label_type}-manual.png", | ||
) | ||
label = np.round( | ||
cv2.imread(str(label_path), cv2.IMREAD_GRAYSCALE) | ||
/ label_type_to_divisor[label_type] | ||
) | ||
fname = f"{dataset_name}_{case_id:03d}.png" | ||
cv2.imwrite(os.path.join(out_folder, "labelsTr", fname), label) | ||
|
||
|
||
def create_bids_to_nnunet_dict(file_path: Path) -> Dict[str, int]: | ||
""" | ||
Creates a dictionary mapping unique (sample_id, participant_id) tuples to case IDs. | ||
Parameters | ||
---------- | ||
file_path : Path | ||
Path to the file containing the list of subjects. | ||
Returns | ||
------- | ||
Dict[str, int] | ||
Dictionary mapping unique (sample_id, participant_id) tuples to case IDs. | ||
""" | ||
with open(file_path, "r") as file: | ||
reader = csv.reader(file, delimiter="\t") | ||
next(reader) # Skip the header row | ||
bids_to_nnunet_dict = {} | ||
num = 1 | ||
for row in reader: | ||
key = str((row[1], row[0])) # (participant_id, sample_id) | ||
bids_to_nnunet_dict[key] = num | ||
num += 1 | ||
return bids_to_nnunet_dict | ||
|
||
|
||
def main(args): | ||
""" | ||
Main function to process dataset for nnUNet. | ||
Parameters | ||
---------- | ||
args : argparse.Namespace | ||
Command line arguments containing DATAPATH and TARGETDIR. | ||
""" | ||
dataset_name = args.DATASETNAME | ||
description = args.DESCRIPTION | ||
datapath = Path(args.DATAPATH) | ||
target_dir = Path(args.TARGETDIR) | ||
train_test_split_path = Path(args.SPLITJSON) | ||
label_type = args.LABELTYPE | ||
dataset_id = str(args.DATASETID).zfill(3) | ||
|
||
out_folder = os.path.join( | ||
target_dir, "nnUNet_raw", f"Dataset{dataset_id}_{dataset_name}" | ||
) | ||
create_directories(out_folder, ["imagesTr", "labelsTr", "imagesTs"]) | ||
|
||
bids_to_nnunet_dict = create_bids_to_nnunet_dict( | ||
os.path.join(datapath, "samples.tsv") | ||
) | ||
|
||
with open(train_test_split_path, "r") as f: | ||
train_test_split_dict = json.load(f) | ||
|
||
train_participant_to_sample_dict = {} | ||
test_participant_to_sample_dict = {} | ||
|
||
for sample_id, participant_id in train_test_split_dict["train"].items(): | ||
if participant_id in train_participant_to_sample_dict: | ||
train_participant_to_sample_dict[participant_id].append(sample_id) | ||
else: | ||
train_participant_to_sample_dict[participant_id] = [sample_id] | ||
|
||
for sample_id, participant_id in train_test_split_dict["test"].items(): | ||
if participant_id in test_participant_to_sample_dict: | ||
test_participant_to_sample_dict[participant_id].append(sample_id) | ||
else: | ||
test_participant_to_sample_dict[participant_id] = [sample_id] | ||
|
||
dataset_info = { | ||
"name": dataset_name, | ||
"description": description, | ||
"labels": {"background": 0, "myelin": 1, "axon": 2} | ||
if label_type == "axonmyelin" | ||
else {"background": 0, "myelin": 1} | ||
if label_type == "myelin" | ||
else {"background": 0, "axon": 1}, | ||
"channel_names": {"0": "rescale_to_0_1"}, | ||
"numTraining": len( | ||
[ | ||
image | ||
for images in train_participant_to_sample_dict.values() | ||
for image in images | ||
] | ||
), | ||
"numTest": len( | ||
[ | ||
image | ||
for images in test_participant_to_sample_dict.values() | ||
for image in images | ||
] | ||
), | ||
"file_ending": ".png", | ||
} | ||
save_json(dataset_info, os.path.join(out_folder, "dataset.json")) | ||
|
||
process_images( | ||
datapath, | ||
out_folder, | ||
train_participant_to_sample_dict, | ||
bids_to_nnunet_dict, | ||
dataset_name, | ||
is_test=False, | ||
) | ||
process_labels( | ||
datapath, | ||
out_folder, | ||
train_participant_to_sample_dict, | ||
bids_to_nnunet_dict, | ||
dataset_name, | ||
label_type=label_type, | ||
) | ||
process_images( | ||
datapath, | ||
out_folder, | ||
test_participant_to_sample_dict, | ||
bids_to_nnunet_dict, | ||
dataset_name, | ||
is_test=True, | ||
) | ||
|
||
save_json( | ||
bids_to_nnunet_dict, os.path.join(target_dir, "subject_to_case_identifier.json") | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("DATAPATH", help="Path to the original dataset in BIDS format") | ||
parser.add_argument( | ||
"--TARGETDIR", | ||
default=".", | ||
help="Target directory for the new dataset, defaults to current directory", | ||
) | ||
parser.add_argument( | ||
"--DATASETNAME", | ||
default="BF_RAT", | ||
help="Name of the new dataset, defaults to BF", | ||
) | ||
parser.add_argument( | ||
"--DESCRIPTION", | ||
default="BF axon and myelin segmentation dataset for nnUNetv2", | ||
help="Description of the new dataset, defaults to BF segmentation dataset for nnUNetv2", | ||
) | ||
parser.add_argument( | ||
"--SPLITJSON", | ||
default="nn_unet_scripts/train_test_split.json", | ||
help="Path to the train_test_split.json file", | ||
) | ||
parser.add_argument( | ||
"--LABELTYPE", | ||
default="axonmyelin", | ||
help="Type of label to use. Options are 'axonmyelin', 'myelin', or 'axon'. Defaults to 'axonmyelin'", | ||
) | ||
parser.add_argument( | ||
"--DATASETID", | ||
default=3, | ||
type=int, | ||
help="ID of the dataset. This ID is formatted with 3 digits. For example, 1 becomes '001', 23 becomes '023', etc. Defaults to 3", | ||
) | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/bin/bash | ||
# This script sets up the nnUNet environment and runs the preprocessing and dataset integrity verification | ||
if [ "$#" -lt 2 ]; then | ||
echo "Usage: $0 PATH_TO_ORIGINAL_DATASET RESULTS_DIR [DATASET_ID] [LABEL_TYPE] [DATASET_NAME]" | ||
exit 1 | ||
fi | ||
|
||
config="2d" | ||
|
||
|
||
PATH_TO_ORIGINAL_DATASET=$1 | ||
RESULTS_DIR=$(realpath $2) | ||
dataset_id=${3:-3} | ||
label_type=${4:-"axonmyelin"} # 'axonmyelin', 'myelin', or 'axon'. Defaults to 'axonmyelin' | ||
dataset_name=${5:-"BF_RAT"} | ||
|
||
echo "-------------------------------------------------------" | ||
echo "Converting dataset to nnUNetv2 format" | ||
echo "-------------------------------------------------------" | ||
|
||
# Run the conversion script | ||
python convert_from_bids_to_nnunetv2_format.py $PATH_TO_ORIGINAL_DATASET --TARGETDIR $RESULTS_DIR --DATASETID $dataset_id --LABELTYPE $label_type --DATASETNAME $dataset_name --SPLITJSON train_test_split.json | ||
|
||
# Set up the necessary environment variables | ||
export nnUNet_raw="$RESULTS_DIR/nnUNet_raw" | ||
export nnUNet_preprocessed="$RESULTS_DIR/nnUNet_preprocessed" | ||
export nnUNet_results="$RESULTS_DIR/nnUNet_results" | ||
|
||
echo "-------------------------------------------------------" | ||
echo "Running preprocessing and verifying dataset integrity" | ||
echo "-------------------------------------------------------" | ||
|
||
nnUNetv2_plan_and_preprocess -d ${dataset_id} --verify_dataset_integrity -c ${config} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/bin/bash | ||
# | ||
# Training nnUNetv2 on multiple folds | ||
|
||
config=2d | ||
dataset_id=003 | ||
dataset_name=Dataset003_BF_RAT | ||
nnunet_trainer="nnUNetTrainer" | ||
device=4 | ||
|
||
# Select number of folds here | ||
folds=(0 1 2) | ||
|
||
for fold in ${folds[@]}; do | ||
echo "-------------------------------------------" | ||
echo "Training on Fold $fold" | ||
echo "-------------------------------------------" | ||
|
||
# training | ||
CUDA_VISIBLE_DEVICES=$(device) nnUNetv2_train ${dataset_id} ${config} ${fold} -tr ${nnunet_trainer} | ||
|
||
echo "" | ||
echo "-------------------------------------------" | ||
echo "Training completed, Testing on Fold $fold" | ||
echo "-------------------------------------------" | ||
|
||
# inference | ||
CUDA_VISIBLE_DEVICES=$(device) nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name}/imagesTs -tr ${nnunet_trainer} -o ${nnUNet_results}/${nnunet_trainer}__nnUNetPlans__${config}/fold_${fold}/test -d ${dataset_id} -f ${fold} -c ${config} | ||
|
||
echo "" | ||
echo "-------------------------------------------" | ||
echo " Inference completed on Fold $fold" | ||
echo "-------------------------------------------" | ||
|
||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"train": { | ||
"sample-uoftRat04": "sub-uoftRat04", | ||
"sample-uoftRat08": "sub-uoftRat08", | ||
"sample-uoftRat09": "sub-uoftRat09", | ||
"sample-uoftRat10": "sub-uoftRat10", | ||
"sample-uoftRat16": "sub-uoftRat16", | ||
"sample-uoftRat17": "sub-uoftRat17" | ||
}, | ||
"test": { | ||
"sample-uoftRat02": "sub-uoftRat02", | ||
"sample-uoftRat07": "sub-uoftRat07" | ||
} | ||
} |