-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training of an STU-Net model for ms lesion segmentation #29
Comments
Here are the steps taken to train an STU-Net (the documentation of the repo is not up to date). In the project folder:
The trainer needs to be updated with the following code: class STUNetTrainer(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 1000
self.initial_lr = 1e-2
@staticmethod
def build_network_architecture(plans_manager,
dataset_json,
configuration_manager,
num_input_channels,
enable_deep_supervision: bool = True) -> nn.Module:
label_manager = plans_manager.get_label_manager(dataset_json)
num_classes=label_manager.num_segmentation_heads
kernel_sizes = [[3,3,3]] * 6
strides=configuration_manager.pool_op_kernel_sizes[1:]
if len(strides)>5:
strides = strides[:5]
while len(strides)<5:
strides.append([1,1,1])
return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]],
pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
def initialize(self):
if not self.was_initialized:
self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
self.dataset_json)
self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
self.configuration_manager,
self.num_input_channels,
enable_deep_supervision=True).to(self.device)
# compile network for free speedup
if self._do_i_compile():
self.print_to_log_file('Compiling network...')
self.network = torch.compile(self.network)
self.optimizer, self.lr_scheduler = self.configure_optimizers()
# if ddp, wrap in DDP wrapper
if self.is_ddp:
self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
self.network = DDP(self.network, device_ids=[self.local_rank])
self.loss = self._build_loss()
self.was_initialized = True
else:
raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
"That should not happen.") Than I export the nnUNet folders (where the dataset has been preprocessed): export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_preprocessed" Then I launched the training (on koios): CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 201 3d_fullres 1 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/base_ep4k.model -tr STUNetTrainer_base_ft Referencing this issue which helped: uni-medical/STU-Net#34 |
Before running inference, the file The inference on the test set was done using: CUDA_VISIBLE_DEVICES=0 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set -d 201 -c 3d_fullres -f 1 -chk checkpoint_best.pth -tr STUNetTrainer_base_ft The results were computed (in the python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/labelsTs -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set And the plots were obtained doing: python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test |
We evaluated this model on the following metrics: Dice, F1 score, PPV and sensitivity To evaluate the predictions : python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/labelsTs -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set_2 To plot the results: python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set_2/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test Here are the results: |
To try and improve performance, I trained a second model using the larger pretrained model from STUNet: To do so, I reproduced the same steps as above to install and set up everything. It is stored in I did the following steps:
export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_raw"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_preprocessed"
CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 201 3d_fullres 1 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/large_ep4k.model -tr STUNetTrainer_large_ft One epoch takes 250 seconds so the training is going to be a bit long. |
Because I saw that the performance of the nnUNet were improved by training on re-oriented images, I did the same for the base STUNet model. For that, I used the 301 nnUNet dataset. We exported the paths from nnUnet re-oriented data and stunet results output: export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_preprocessed"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results" I ran training using: CUDA_VISIBLE_DEVICES=1 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 301 3d_fullres 0 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/base_ep4k.model -tr STUNetTrainer_base_ft Just looking at the training curve we can see the improvement in terms of final Dice reached and in terms of training stability: Predictions with checkpoint_bestHere is the command I run the predictions on the test set CUDA_VISIBLE_DEVICES=1 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set -d 301 -c 3d_fullres -f 0 -chk checkpoint_best.pth -tr STUNetTrainer_base_ft To evaluate the predictions (with venv_nnunet) I ran: python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/labelsTs -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set To plot the performances: python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test Output: Dice score per contrast (mean ± std)
PSIR (n=60): 0.3698 ± 0.2441
STIR (n=11): 0.4181 ± 0.2355
T2star (n=83): 0.5220 ± 0.1952
T2w (n=358): 0.4867 ± 0.1790
UNIT1 (n=57): 0.6030 ± 0.1588
Other metricsPPV score per contrast (mean ± std)
PSIR (n=60): 0.6265 ± 0.3512
STIR (n=11): 0.6615 ± 0.3925
T2star (n=83): 0.8395 ± 0.2638
T2w (n=358): 0.7592 ± 0.2898
UNIT1 (n=57): 0.8412 ± 0.2318
F1 score per contrast (mean ± std)
PSIR (n=60): 0.5346 ± 0.3251
STIR (n=11): 0.6459 ± 0.3729
T2star (n=83): 0.7936 ± 0.2358
T2w (n=358): 0.7483 ± 0.2568
UNIT1 (n=57): 0.8065 ± 0.2006
Sensitivity score per contrast (mean ± std)
PSIR (n=60): 0.5568 ± 0.3609
STIR (n=11): 0.7018 ± 0.3605
T2star (n=83): 0.8044 ± 0.2635
T2w (n=358): 0.8240 ± 0.2730
UNIT1 (n=57): 0.8330 ± 0.2337 Predictions with checkpoint_finalI also ran the predictions with the CUDA_VISIBLE_DEVICES=1 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final -d 301 -c 3d_fullres -f 0 -chk checkpoint_final.pth -tr STUNetTrainer_base_ft To evaluate predictions python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/labelsTs -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final To plot the performances: python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test Output: Dice score per contrast (mean ± std)
PSIR (n=60): 0.3609 ± 0.2463
STIR (n=11): 0.4212 ± 0.2330
T2star (n=83): 0.5117 ± 0.1980
T2w (n=358): 0.4825 ± 0.1797
UNIT1 (n=57): 0.6013 ± 0.1642 Other metrics:PPV score per contrast (mean ± std)
PSIR (n=60): 0.6050 ± 0.3617
STIR (n=11): 0.6718 ± 0.3886
T2star (n=83): 0.8469 ± 0.2679
T2w (n=358): 0.7725 ± 0.2865
UNIT1 (n=57): 0.8593 ± 0.2213
F1 score per contrast (mean ± std)
PSIR (n=60): 0.5193 ± 0.3222
STIR (n=11): 0.6225 ± 0.3459
T2star (n=83): 0.7789 ± 0.2492
T2w (n=358): 0.7554 ± 0.2541
UNIT1 (n=57): 0.8058 ± 0.1996
Sensitivity score per contrast (mean ± std)
PSIR (n=60): 0.5533 ± 0.3646
STIR (n=11): 0.6499 ± 0.3456
T2star (n=83): 0.7734 ± 0.2751
T2w (n=358): 0.8237 ± 0.2701
UNIT1 (n=57): 0.8199 ± 0.2417
→ In my case, checkpoint_best is performing the best |
For the large STUNet model on reoriented images: Commands usedexport nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_preprocessed"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results" Inference CUDA_VISIBLE_DEVICES=1 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set -d 301 -c 3d_fullres -f 0 -chk checkpoint_best.pth -tr STUNetTrainer_large_ft I had an error like this : python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/labelsTs -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test Note Inference is VERY LONG ! 2min50 on GPU ! Results: Dice score per contrast (mean ± std)
PSIR (n=60): 0.3573 ± 0.2525
STIR (n=11): 0.4134 ± 0.2358
T2star (n=83): 0.5556 ± 0.1885
T2w (n=358): 0.4916 ± 0.1973
UNIT1 (n=57): 0.6115 ± 0.1715 More resultsPPV score per contrast (mean ± std)
PSIR (n=60): 0.6142 ± 0.3703
STIR (n=11): 0.7152 ± 0.4100
T2star (n=83): 0.8238 ± 0.2744
T2w (n=358): 0.7939 ± 0.2863
UNIT1 (n=57): 0.8795 ± 0.2035
F1 score per contrast (mean ± std)
PSIR (n=60): 0.4880 ± 0.3315
STIR (n=11): 0.5916 ± 0.3737
T2star (n=83): 0.7851 ± 0.2562
T2w (n=358): 0.7499 ± 0.2657
UNIT1 (n=57): 0.8183 ± 0.2008
Sensitivity score per contrast (mean ± std)
PSIR (n=60): 0.4884 ± 0.3671
STIR (n=11): 0.5844 ± 0.4024
T2star (n=83): 0.7972 ± 0.2830
T2w (n=358): 0.7927 ± 0.2939
UNIT1 (n=57): 0.8153 ± 0.2461 |
In this issue, I explore the work done to segment MS lesion in the spinal cord using the STU-Net.
I used the code from this repo: https://github.com/uni-medical/STU-Net
The dataset used is the nnUnet preprocessed data store in :
~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/
The text was updated successfully, but these errors were encountered: