Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training of an STU-Net model for ms lesion segmentation #29

Open
plbenveniste opened this issue Aug 13, 2024 · 8 comments
Open

Training of an STU-Net model for ms lesion segmentation #29

plbenveniste opened this issue Aug 13, 2024 · 8 comments
Assignees

Comments

@plbenveniste
Copy link
Collaborator

In this issue, I explore the work done to segment MS lesion in the spinal cord using the STU-Net.

I used the code from this repo: https://github.com/uni-medical/STU-Net

The dataset used is the nnUnet preprocessed data store in : ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/

@plbenveniste plbenveniste self-assigned this Aug 13, 2024
@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Aug 13, 2024

Here are the steps taken to train an STU-Net (the documentation of the repo is not up to date).

In the project folder:

  • Download the STU-Net repo: git clone https://github.com/uni-medical/STU-Net/
  • Download the nnUNet repo: git clone https://github.com/MIC-DKFZ/nnUNet/
  • Create a virtual env: conda create -n venv_stunet2 python=3.9 and activate it
  • Download torch for nnunet: conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
  • Copy STU finetuning script to nnUNet: cp STU-Net/nnUNet-2.2/nnunetv2/run/run_finetuning_stunet.py nnUNet/nnunetv2/run/
  • Copy the trainer: cp STU-Net/nnUNet-2.2/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py nnUNet/nnunetv2/training/nnUNetTrainer/
  • Install nnUNet: in nnUNet pip install -e .
  • Download the pretrained model of your choice: in my case base_ep4k.model

The trainer needs to be updated with the following code:

class STUNetTrainer(nnUNetTrainer):
    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
                 device: torch.device = torch.device('cuda')):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.num_epochs = 1000
        self.initial_lr = 1e-2

    @staticmethod
    def build_network_architecture(plans_manager,
                                   dataset_json,
                                   configuration_manager,
                                   num_input_channels,
                                   enable_deep_supervision: bool = True) -> nn.Module: 
        label_manager = plans_manager.get_label_manager(dataset_json)
        num_classes=label_manager.num_segmentation_heads
        kernel_sizes = [[3,3,3]] * 6
        strides=configuration_manager.pool_op_kernel_sizes[1:]
        if len(strides)>5:
            strides = strides[:5]
        while len(strides)<5:
            strides.append([1,1,1])
        return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]], 
                      pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)

    def initialize(self):
        if not self.was_initialized:
            self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
                                                                   self.dataset_json)

            self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
                                                           self.configuration_manager,
                                                           self.num_input_channels,
                                                           enable_deep_supervision=True).to(self.device)
            # compile network for free speedup
            if self._do_i_compile():
                self.print_to_log_file('Compiling network...')
                self.network = torch.compile(self.network)

            self.optimizer, self.lr_scheduler = self.configure_optimizers()
            # if ddp, wrap in DDP wrapper
            if self.is_ddp:
                self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
                self.network = DDP(self.network, device_ids=[self.local_rank])

            self.loss = self._build_loss()
            self.was_initialized = True
        else:
            raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
                               "That should not happen.")

Than I export the nnUNet folders (where the dataset has been preprocessed):

export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_preprocessed"

Then I launched the training (on koios):

CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 201 3d_fullres 1 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/base_ep4k.model  -tr STUNetTrainer_base_ft

Referencing this issue which helped: uni-medical/STU-Net#34

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Aug 14, 2024

Before running inference, the file predict_from_raw_data.py needed to be copied from STU-Net/nnUNet-2.2/nnunetv2/inference/ to the nnUNet repo.

The inference on the test set was done using:

CUDA_VISIBLE_DEVICES=0 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set -d 201 -c 3d_fullres -f 1 -chk checkpoint_best.pth -tr STUNetTrainer_base_ft

The results were computed (in the venv_nnunet environment) with:

python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/labelsTs  -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set

And the plots were obtained doing:

python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Aug 15, 2024

Here are the results computed with the training of the STU-Net with the base model:
dice_scores_contrast
dice_scores_orientation
dice_scores_site

The results are pretty similar to that of the nnUNet model trained.

TODO:

  • train a larger STU-Net to see how the performance evolve

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Sep 25, 2024

We evaluated this model on the following metrics: Dice, F1 score, PPV and sensitivity

To evaluate the predictions :

python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/labelsTs  -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set_2

To plot the results:

python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set_2/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test

Here are the results:

dice_scores_contrast
f1_scores_contrast
ppv_scores_contrast
sensitivity_scores_contrast

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Oct 9, 2024

To try and improve performance, I trained a second model using the larger pretrained model from STUNet: large_ep4k.model

To do so, I reproduced the same steps as above to install and set up everything. It is stored in /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment.

I did the following steps:

  • copied the 201 raw and preprocessed dataset from nnUNet
  • then I exported the paths:
export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_raw"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_preprocessed"
  • finally I ran training:
CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 201 3d_fullres 1 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/large_ep4k.model  -tr STUNetTrainer_large_ft

One epoch takes 250 seconds so the training is going to be a bit long.

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Oct 12, 2024

Here the output of the training of the model:

image

In order to fully train the model, I modified the trainer to go until 2000 epochs. The command to run training is the same, the number of epoch was simply changed in the file nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py at line 131.

It is currently training

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Oct 17, 2024

Because I saw that the performance of the nnUNet were improved by training on re-oriented images, I did the same for the base STUNet model. For that, I used the 301 nnUNet dataset.

We exported the paths from nnUnet re-oriented data and stunet results output:

export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_preprocessed"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results"

I ran training using:

CUDA_VISIBLE_DEVICES=1 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 301 3d_fullres 0 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/base_ep4k.model  -tr STUNetTrainer_base_ft

Just looking at the training curve we can see the improvement in terms of final Dice reached and in terms of training stability:

ezgif com-animated-gif-maker (1)

Predictions with checkpoint_best

Here is the command I run the predictions on the test set

CUDA_VISIBLE_DEVICES=1 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set -d 301 -c 3d_fullres -f 0 -chk checkpoint_best.pth -tr STUNetTrainer_base_ft

To evaluate the predictions (with venv_nnunet) I ran:

python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/labelsTs  -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set

To plot the performances:

python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test

Output:

Dice score per contrast (mean ± std)
PSIR (n=60): 0.3698 ± 0.2441
STIR (n=11): 0.4181 ± 0.2355
T2star (n=83): 0.5220 ± 0.1952
T2w (n=358): 0.4867 ± 0.1790
UNIT1 (n=57): 0.6030 ± 0.1588

dice_scores_contrast

Other metrics
PPV score per contrast (mean ± std)
PSIR (n=60): 0.6265 ± 0.3512
STIR (n=11): 0.6615 ± 0.3925
T2star (n=83): 0.8395 ± 0.2638
T2w (n=358): 0.7592 ± 0.2898
UNIT1 (n=57): 0.8412 ± 0.2318

F1 score per contrast (mean ± std)
PSIR (n=60): 0.5346 ± 0.3251
STIR (n=11): 0.6459 ± 0.3729
T2star (n=83): 0.7936 ± 0.2358
T2w (n=358): 0.7483 ± 0.2568
UNIT1 (n=57): 0.8065 ± 0.2006

Sensitivity score per contrast (mean ± std)
PSIR (n=60): 0.5568 ± 0.3609
STIR (n=11): 0.7018 ± 0.3605
T2star (n=83): 0.8044 ± 0.2635
T2w (n=358): 0.8240 ± 0.2730
UNIT1 (n=57): 0.8330 ± 0.2337

sensitivity_scores_contrast
f1_scores_contrast
ppv_scores_contrast

Predictions with checkpoint_final

I also ran the predictions with the checkpoint_final model to see which was the best:

CUDA_VISIBLE_DEVICES=1 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final -d 301 -c 3d_fullres -f 0 -chk checkpoint_final.pth -tr STUNetTrainer_base_ft

To evaluate predictions

python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/labelsTs  -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final

To plot the performances:

python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_0/test_set_checkpt_final/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test

Output:

Dice score per contrast (mean ± std)
PSIR (n=60): 0.3609 ± 0.2463
STIR (n=11): 0.4212 ± 0.2330
T2star (n=83): 0.5117 ± 0.1980
T2w (n=358): 0.4825 ± 0.1797
UNIT1 (n=57): 0.6013 ± 0.1642

dice_scores_contrast

Other metrics:
PPV score per contrast (mean ± std)
PSIR (n=60): 0.6050 ± 0.3617
STIR (n=11): 0.6718 ± 0.3886
T2star (n=83): 0.8469 ± 0.2679
T2w (n=358): 0.7725 ± 0.2865
UNIT1 (n=57): 0.8593 ± 0.2213

F1 score per contrast (mean ± std)
PSIR (n=60): 0.5193 ± 0.3222
STIR (n=11): 0.6225 ± 0.3459
T2star (n=83): 0.7789 ± 0.2492
T2w (n=358): 0.7554 ± 0.2541
UNIT1 (n=57): 0.8058 ± 0.1996

Sensitivity score per contrast (mean ± std)
PSIR (n=60): 0.5533 ± 0.3646
STIR (n=11): 0.6499 ± 0.3456
T2star (n=83): 0.7734 ± 0.2751
T2w (n=358): 0.8237 ± 0.2701
UNIT1 (n=57): 0.8199 ± 0.2417

ppv_scores_contrast
f1_scores_contrast
sensitivity_scores_contrast

→ In my case, checkpoint_best is performing the best

@plbenveniste
Copy link
Collaborator Author

plbenveniste commented Nov 6, 2024

For the large STUNet model on reoriented images:

Commands used
export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_preprocessed"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results"

Inference

CUDA_VISIBLE_DEVICES=1 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set  -d 301 -c 3d_fullres -f 0 -chk checkpoint_best.pth -tr STUNetTrainer_large_ft

I had an error like this : TypeError: build_network_architecture() got multiple values for argument 'enable_deep_supervision' but it was caused by the fact that I forgotted to update the file predict_from_raw_data.py from the stunet repo.

python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/labelsTs  -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset301_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set/
python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_large_experiment/nnUNet_results/Dataset301_msLesionAgnostic/STUNetTrainer_large_ft__nnUNetPlans__3d_fullres/fold_0/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test

Note

Inference is VERY LONG ! 2min50 on GPU !

Results:

Dice score per contrast (mean ± std)
PSIR (n=60): 0.3573 ± 0.2525
STIR (n=11): 0.4134 ± 0.2358
T2star (n=83): 0.5556 ± 0.1885
T2w (n=358): 0.4916 ± 0.1973
UNIT1 (n=57): 0.6115 ± 0.1715
More results
PPV score per contrast (mean ± std)
PSIR (n=60): 0.6142 ± 0.3703
STIR (n=11): 0.7152 ± 0.4100
T2star (n=83): 0.8238 ± 0.2744
T2w (n=358): 0.7939 ± 0.2863
UNIT1 (n=57): 0.8795 ± 0.2035

F1 score per contrast (mean ± std)
PSIR (n=60): 0.4880 ± 0.3315
STIR (n=11): 0.5916 ± 0.3737
T2star (n=83): 0.7851 ± 0.2562
T2w (n=358): 0.7499 ± 0.2657
UNIT1 (n=57): 0.8183 ± 0.2008

Sensitivity score per contrast (mean ± std)
PSIR (n=60): 0.4884 ± 0.3671
STIR (n=11): 0.5844 ± 0.4024
T2star (n=83): 0.7972 ± 0.2830
T2w (n=358): 0.7927 ± 0.2939
UNIT1 (n=57): 0.8153 ± 0.2461

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant