diff --git a/.gitignore b/.gitignore index 0d46c1d..d50afa6 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Extra files +/totalspineseg/models/nnUNet diff --git a/README.md b/README.md index c41b861..56a207a 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,9 @@ TotalSpineSeg uses a hybrid approach that integrates nnU-Net with an iterative a For comparison, we also trained a single model (`Dataset103`) that outputs individual label values for each vertebra and IVD in a single step. -![Figure 1](https://github.com/user-attachments/assets/7b82d6b8-d584-47ef-8504-fe06962bb82e) +![Figure 1](https://github.com/user-attachments/assets/9017fb8e-bed5-413f-a80f-b123a97f5735) -**Figure 1**: Illustration of the hybrid method for automatic segmentation of spinal structures. (A) MRI image used to train the Step 1 model. (B) The Step 1 model outputs nine classes. (C) Individual IVDs extracted from the output labels. (D) Odd IVDs extracted from the individual IVDs. (E) MRI image and odd IVDs used as inputs to train the Step 2 model, which outputs ten classes. (F) Final segmentation with individual labels for each vertebra and IVD. +**Figure 1**: Illustration of the hybrid method for automatic segmentation of spinal structures. (A) Input MRI image. (B) Step 1 model prediction. (C) Odd IVDs extraction from the Step1 prediction. (D) Step 2 model prediction. (E) Final segmentation with individual labels for each vertebra and IVD. ## Datasets @@ -62,44 +62,49 @@ When not available, sacrum segmentations were generated using the [totalsegmenta 1. Open a `bash` terminal in the directory where you want to work. -1. Create the installation directory: - ```bash - mkdir TotalSpineSeg - cd TotalSpineSeg - ``` +2. Create the installation directory: +```bash +mkdir TotalSpineSeg +cd TotalSpineSeg +``` -1. Create and activate a virtual environment (highly recommended): +3. Create and activate a virtual environment using one of the following options (highly recommended): + - venv ```bash python3 -m venv venv source venv/bin/activate ``` + - conda env + ``` + conda create -n myenv python=3.9 + conda activate myenv + ``` -1. Clone and install this repository: +4. Install this repository using one of the following options: + - Git clone (for developpers) + > **Note:** If you pull a new version from GitHub, make sure to rerun this command with the flag `--upgrade` ```bash git clone https://github.com/neuropoly/totalspineseg.git python3 -m pip install -e totalspineseg ``` - -1. For CUDA GPU support, install **PyTorch** following the instructions on their [website](https://pytorch.org/). Be sure to add the `--upgrade` flag to your installation command to replace any existing PyTorch installation. - Example: - ```bash - python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade - ``` - -1. Set the path to TotalSpineSeg and data folders in the virtual environment: - ```bash - mkdir data - export TOTALSPINESEG="$(realpath totalspineseg)" - export TOTALSPINESEG_DATA="$(realpath data)" - echo "export TOTALSPINESEG=\"$TOTALSPINESEG\"" >> venv/bin/activate - echo "export TOTALSPINESEG_DATA=\"$TOTALSPINESEG_DATA\"" >> venv/bin/activate + - PyPI installation (for inference only) + ``` + python3 -m pip install totalspineseg ``` -**Note:** If you pull a new version from GitHub, make sure to reinstall the package to apply the updates using the following command: +5. For CUDA GPU support, install **PyTorch** following the instructions on their [website](https://pytorch.org/). Be sure to add the `--upgrade` flag to your installation command to replace any existing PyTorch installation. + Example: ```bash -python3 -m pip install -e $TOTALSPINESEG --upgrade +python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade ``` +6. **OPTIONAL STEP:** Define a folder where weights will be stored: +> By default, weights will be stored in the package under `totalspineseg/models` + ```bash + mkdir data + export TOTALSPINESEG_DATA="$(realpath data)" + ``` + ## Training To train the TotalSpineSeg model, you will need the following hardware specifications: @@ -109,38 +114,47 @@ To train the TotalSpineSeg model, you will need the following hardware specifica Please ensure that your system meets these requirements before proceeding with the training process. -1. Make sure that the `bash` terminal is opened with the virtual environment (if used) activated (using `source /venv/bin/activate`). +1. Make sure that the `bash` terminal is opened with the virtual environment activated (see [Installation](#installation)). -1. Ensure training dependencies are installed: - ```bash - apt-get install git git-annex jq -y - ``` +2. Ensure training dependencies are installed: +```bash +apt-get install git git-annex jq -y +``` -1. Download the required datasets into `$TOTALSPINESEG_DATA/bids` (make sure you have access to the specified repositories): - ```bash - bash "$TOTALSPINESEG"/scripts/download_datasets.sh - ``` +3. Set the path to TotalSpineSeg and data folders in the virtual environment: +```bash +mkdir data +export TOTALSPINESEG="$(realpath totalspineseg)" +export TOTALSPINESEG_DATA="$(realpath data)" +echo "export TOTALSPINESEG=\"$TOTALSPINESEG\"" >> venv/bin/activate +echo "export TOTALSPINESEG_DATA=\"$TOTALSPINESEG_DATA\"" >> venv/bin/activate +``` -1. Temporary step (until all labels are pushed into the repositories) - Download labels into `$TOTALSPINESEG_DATA/bids`: - ```bash - curl -L -O https://github.com/neuropoly/totalspineseg/releases/download/labels/labels_iso_bids_0924.zip - unzip -qo labels_iso_bids_0924.zip -d "$TOTALSPINESEG_DATA" - rm labels_iso_bids_0924.zip - ``` +4. Download the required datasets into `$TOTALSPINESEG_DATA/bids` (make sure you have access to the specified repositories): +```bash +bash "$TOTALSPINESEG"/scripts/download_datasets.sh +``` -1. Prepare datasets in nnUNetv2 structure into `$TOTALSPINESEG_DATA/nnUnet`: - ```bash - bash "$TOTALSPINESEG"/scripts/prepare_datasets.sh [DATASET_ID] [-noaug] - ``` +5. Temporary step (until all labels are pushed into the repositories) - Download labels into `$TOTALSPINESEG_DATA/bids`: +```bash +curl -L -O https://github.com/neuropoly/totalspineseg/releases/download/labels/labels_iso_bids_0924.zip +unzip -qo labels_iso_bids_0924.zip -d "$TOTALSPINESEG_DATA" +rm labels_iso_bids_0924.zip +``` + +6. Prepare datasets in nnUNetv2 structure into `$TOTALSPINESEG_DATA/nnUnet`: +```bash +bash "$TOTALSPINESEG"/scripts/prepare_datasets.sh [DATASET_ID] [-noaug] +``` The script optionally accepts `DATASET_ID` as the first positional argument to specify the dataset to prepare. It can be either 101, 102, 103, or all. If `all` is specified, it will prepare all datasets (101, 102, 103). By default, it will prepare datasets 101 and 102. Additionally, you can use the `-noaug` parameter to prepare the datasets without data augmentations. -1. Train the model: - ```bash - bash "$TOTALSPINESEG"/scripts/train.sh [DATASET_ID [FOLD]] - ``` +7. Train the model: +```bash +bash "$TOTALSPINESEG"/scripts/train.sh [DATASET_ID [FOLD]] +``` The script optionally accepts `DATASET_ID` as the first positional argument to specify the dataset to train. It can be either 101, 102, 103, or all. If `all` is specified, it will train all datasets (101, 102, 103). By default, it will train datasets 101 and 102. @@ -148,23 +162,24 @@ Please ensure that your system meets these requirements before proceeding with t ## Inference -1. Make sure that the `bash` terminal is opened with the virtual environment (if used) activated (using `source /venv/bin/activate`). +1. Make sure that the `bash` terminal is opened with the virtual environment activated (see [Installation](#installation)). -1. Run the model on a folder containing the images in .nii.gz format, or on a single .nii.gz file: - ```bash - totalspineseg INPUT OUTPUT_FOLDER [--step1] [--iso] - ``` +2. Run the model on a folder containing the images in .nii.gz format, or on a single .nii.gz file: +> If you haven't trained the model, the script will automatically download the pre-trained models from the GitHub release. +```bash +totalspineseg INPUT OUTPUT_FOLDER [--step1] [--iso] +``` - This will process the images in INPUT or the single image and save the results in OUTPUT_FOLDER. If you haven't trained the model, the script will automatically download the pre-trained models from the GitHub release. + This will process the images in INPUT or the single image and save the results in OUTPUT_FOLDER. **Important Note:** By default, the output segmentations are resampled back to the input image space. If you prefer to obtain the outputs in the model's original 1mm isotropic resolution, especially useful for visualization purposes, we strongly recommend using the `--iso` argument. Additionally, you can use the `--step1` parameter to run only the step 1 model, which outputs a single label for all vertebrae, including the sacrum. For more options, you can use the `--help` parameter: - ```bash - totalspineseg --help - ``` +```bash +totalspineseg --help +``` **Output Data Structure:** diff --git a/pyproject.toml b/pyproject.toml index 7b2a300..ae9a099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,17 +1,13 @@ [project] name = "totalspineseg" -version = "20241005" +version = "20241129" requires-python = ">=3.9" description = "TotalSpineSeg is a tool for automatic instance segmentation and labeling of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images." readme = "README.md" authors = [ { name = "Yehuda Warszawer", email = "yehuda.warszawer@sheba.health.gov.il"}, - { name = "Nathan Molinier"}, - { name = "Jan Valosek"}, - { name = "Emanuel Shirbint"}, - { name = "Pierre-Louis Benveniste"}, + { name = "Nathan Molinier", email = "nathan.molinier@polymtl.ca"}, { name = "Anat Achiron"}, - { name = "Arman Eshaghi"}, { name = "Julien Cohen-Adad"}, ] classifiers = [ @@ -21,9 +17,6 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Recognition", "Topic :: Scientific/Engineering :: Medical Science Apps.", - "Topic :: Scientific/Engineering :: MRI Images.", - "Topic :: Scientific/Engineering :: Spinal Cord.", - "Topic :: Scientific/Engineering :: Spine.", ] keywords = [ 'deep learning', @@ -53,18 +46,19 @@ dependencies = [ # https://github.com/MIC-DKFZ/nnUNet/issues/2480 # --verify_dataset_integrity not working in nnunetv2==2.4.2 do we need to update this when fixed # https://github.com/MIC-DKFZ/nnUNet/issues/2144 - "nnunetv2==2.4.2", + "nnunetv2<=2.4.2", "psutil", ] [project.urls] homepage = "https://github.com/neuropoly/totalspineseg" repository = "https://github.com/neuropoly/totalspineseg" -Dataset101_TotalSpineSeg_step1 = "https://github.com/neuropoly/totalspineseg/releases/download/r20241005/Dataset101_TotalSpineSeg_step1_r20241005.zip" -Dataset102_TotalSpineSeg_step2 = "https://github.com/neuropoly/totalspineseg/releases/download/r20241005/Dataset102_TotalSpineSeg_step2_r20241005.zip" +Dataset101_TotalSpineSeg_step1 = "https://github.com/neuropoly/totalspineseg/releases/download/r20241115/Dataset101_TotalSpineSeg_step1_r20241115.zip" +Dataset102_TotalSpineSeg_step2 = "https://github.com/neuropoly/totalspineseg/releases/download/r20241115/Dataset102_TotalSpineSeg_step2_r20241115.zip" [project.scripts] totalspineseg = "totalspineseg.inference:main" +totalspineseg_init = "totalspineseg.init_inference:main" totalspineseg_cpdir = "totalspineseg.utils.cpdir:main" totalspineseg_fill_canal = "totalspineseg.utils.fill_canal:main" totalspineseg_augment = "totalspineseg.utils.augment:main" @@ -81,6 +75,7 @@ totalspineseg_extract_soft = "totalspineseg.utils.extract_soft:main" totalspineseg_extract_levels = "totalspineseg.utils.extract_levels:main" totalspineseg_extract_alternate = "totalspineseg.utils.extract_alternate:main" totalspineseg_install_weights = "totalspineseg.utils.install_weights:main" +totalspineseg_predict_nnunet = "totalspineseg.utils.predict_nnunet:main" [build-system] requires = ["pip>=23", "setuptools>=67"] @@ -90,4 +85,4 @@ build-backend = "setuptools.build_meta" include-package-data = true [tool.setuptools.package-data] -'totalspineseg' = ['resources/**.json'] \ No newline at end of file +'totalspineseg' = ['resources/**.json'] diff --git a/scripts/train.sh b/scripts/train.sh index 234d059..4c9f3da 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -63,8 +63,12 @@ export nnUNet_results="$TOTALSPINESEG_DATA"/nnUNet/results export nnUNet_exports="$TOTALSPINESEG_DATA"/nnUNet/exports nnUNetTrainer=${3:-nnUNetTrainer_DASegOrd0_NoMirroring} -nnUNetPlanner=${4:-nnUNetPlannerResEncL} -nnUNetPlans=${5:-nnUNetResEncUNetLPlans} +nnUNetPlanner=${4:-ExperimentPlanner} +# Note on nnUNetPlans_small configuration: +# To train with a small patch size, verify that the nnUNetPlans_small.json file +# in $nnUNet_preprocessed/Dataset10[1,2]_TotalSpineSeg_step[1,2] matches the version provided in the release. +# Make any necessary updates to this file before starting the training process. +nnUNetPlans=${5:-nnUNetPlans_small} configuration=3d_fullres data_identifier=nnUNetPlans_3d_fullres diff --git a/totalspineseg/__init__.py b/totalspineseg/__init__.py index b4c49f6..e371e3b 100644 --- a/totalspineseg/__init__.py +++ b/totalspineseg/__init__.py @@ -13,4 +13,9 @@ from .utils.reorient_canonical import reorient_canonical_mp from .utils.resample import resample, resample_mp from .utils.transform_seg2image import transform_seg2image, transform_seg2image_mp -from .utils.install_weights import install_weights \ No newline at end of file +from .utils.install_weights import install_weights +from .utils.predict_nnunet import predict_nnunet +from .utils.utils import ZIP_URLS, VERSION +from . import models + +__version__ = VERSION \ No newline at end of file diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 84dc2f9..4a08b18 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -1,9 +1,10 @@ -import os, argparse, warnings, subprocess, textwrap, torch, psutil, shutil +import os, argparse, warnings, textwrap, torch, psutil, shutil from fnmatch import fnmatch from pathlib import Path -from importlib.metadata import metadata +import importlib.resources from tqdm import tqdm from totalspineseg import * +from totalspineseg.init_inference import init_inference warnings.filterwarnings("ignore") @@ -61,10 +62,9 @@ def main(): help='Run only step 1 of the inference process.' ) parser.add_argument( - '--data-dir', '-d', type=Path, default=Path(os.environ.get('TOTALSPINESEG_DATA', '')), required='TOTALSPINESEG_DATA' not in os.environ, + '--data-dir', '-d', type=Path, default=None, help=' '.join(f''' - The path to store the nnUNet data, defaults to the TOTALSPINESEG_DATA environment variable if set. - If the TOTALSPINESEG_DATA environment variable is not set, the path must be provided. + The path to store the nnUNet data. '''.split()) ) parser.add_argument( @@ -95,54 +95,155 @@ def main(): suffix = args.suffix loc_suffix = args.loc_suffix step1_only = args.step1 - data_path = args.data_dir max_workers = args.max_workers max_workers_nnunet = min(args.max_workers_nnunet, max_workers) device = args.device quiet = args.quiet + # Init data_path + if not args.data_dir is None: + data_path = args.data_dir + elif 'TOTALSPINESEG_DATA' in os.environ: + data_path = Path(os.environ.get('TOTALSPINESEG_DATA', '')) + else: + data_path = importlib.resources.files(models) + + # Default release to use + default_release = list(ZIP_URLS.values())[0].split('/')[-2] + + # Install weights if not present + init_inference( + data_path=data_path, + dict_urls=ZIP_URLS, + quiet=quiet + ) + + # Run inference + inference( + input_path=input_path, + output_path=output_path, + data_path=data_path, + default_release=default_release, + output_iso=output_iso, + loc_path=loc_path, + suffix=suffix, + loc_suffix=loc_suffix, + step1_only=step1_only, + max_workers=max_workers, + max_workers_nnunet=max_workers_nnunet, + device=device, + quiet=quiet + ) + + +def inference( + input_path, + output_path, + data_path, + default_release, + output_iso=False, + loc_path=None, + suffix=[''], + loc_suffix='', + step1_only=False, + max_workers=os.cpu_count(), + max_workers_nnunet=int(max(min(os.cpu_count(), psutil.virtual_memory().total / 2**30 // 8), 1)), + device='cuda', + quiet=False + ): + ''' + Inference function + + Parameters + ---------- + input_path : pathlib.Path or string + The input folder path containing the niftii images. + output_path : pathlib.Path or string + The output folder path that will contain the predictions. + data_path : pathlib.Path or string + Folder path containing the network weights. + default_release : string + Default release used for inference. + output_iso : bool + If False, output predictions will be resampled to the original space. + loc_path : None or pathlib.Path/string + The localizer folder path containing the niftii predictions of the localizer. + suffix : string + Suffix to use for the input images + loc_suffix : string + Suffix to use for the localizer images + step1_only : bool + If True only the prediction of the first model will be computed. + max_workers : int + Max worker to run in parallel proccess, defaults to numer of available cores + max_workers_nnunet : int + Max worker to run in parallel proccess for nnUNet + device : 'cuda' or 'cpu' + Device to run the nnUNet model on + quiet : bool + If True, will reduce the amount of displayed information + + Returns + ------- + list of string + List of output folders. + ''' + # Convert paths to Path like objects + if isinstance(input_path, str): + input_path = Path(input_path) + else: + if not isinstance(input_path, Path): + raise ValueError('input_path should be a Path object from pathlib or a string') + + if isinstance(output_path, str): + output_path = Path(output_path) + else: + if not isinstance(output_path, Path): + raise ValueError('output_path should be a Path object from pathlib or a string') + + if isinstance(data_path, str): + data_path = Path(data_path) + else: + if not isinstance(data_path, Path): + raise ValueError('data_path should be a Path object from pathlib or a string') + # Check if the data folder exists if not data_path.exists(): - raise FileNotFoundError(' '.join(f''' - The totalspineseg data folder does not exist at {data_path}, - if it is not the correct path, please set the TOTALSPINESEG_DATA environment variable to the correct path, - or use the --data-dir argument to specify the correct path. - '''.split())) + raise FileNotFoundError(f"The totalspineseg data folder does not exist at {data_path}.") # Datasets data step1_dataset = 'Dataset101_TotalSpineSeg_step1' step2_dataset = 'Dataset102_TotalSpineSeg_step2' - # Read urls from 'pyproject.toml' - step1_zip_url = dict([_.split(', ') for _ in metadata('totalspineseg').get_all('Project-URL')])[step1_dataset] - step2_zip_url = dict([_.split(', ') for _ in metadata('totalspineseg').get_all('Project-URL')])[step2_dataset] - fold = 0 - # Set nnUNet paths - nnUNet_raw = data_path / 'nnUNet' / 'raw' - nnUNet_preprocessed = data_path / 'nnUNet' / 'preprocessed' + # Set nnUNet results path nnUNet_results = data_path / 'nnUNet' / 'results' - nnUNet_exports = data_path / 'nnUNet' / 'exports' - # If not both steps models are installed, use the release subfolder + # If not both steps models are installed, use the default release subfolder if not (nnUNet_results / step1_dataset).is_dir() or not (nnUNet_results / step2_dataset).is_dir(): - # TODO Think of better way to get the release - release = step1_zip_url.split('/')[-2] - nnUNet_results = nnUNet_results / release - - # Create the nnUNet directories if they do not exist - nnUNet_raw.mkdir(parents=True, exist_ok=True) - nnUNet_preprocessed.mkdir(parents=True, exist_ok=True) - nnUNet_results.mkdir(parents=True, exist_ok=True) - nnUNet_exports.mkdir(parents=True, exist_ok=True) - - # Set nnUNet environment variables - os.environ['nnUNet_def_n_proc'] = str(max_workers_nnunet) - os.environ['nnUNet_n_proc_DA'] = str(max_workers_nnunet) - os.environ['nnUNet_raw'] = str(nnUNet_raw) - os.environ['nnUNet_preprocessed'] = str(nnUNet_preprocessed) - os.environ['nnUNet_results'] = str(nnUNet_results) + nnUNet_results = nnUNet_results / default_release + # Check if weights are available + if not (nnUNet_results / step1_dataset).is_dir() or not (nnUNet_results / step2_dataset).is_dir(): + raise FileNotFoundError('Model weights are missing.') + + # Load device + if isinstance(device, str): + assert device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.' + if device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + else: + assert isinstance(device, torch.device) # Print the argument values if not quiet if not quiet: @@ -158,19 +259,9 @@ def main(): data_dir = "{data_path}" max_workers = {max_workers} max_workers_nnunet = {max_workers_nnunet} - device = "{device}" + device = "{device.type}" ''')) - # Installing the pretrained models if not already installed - for dataset, zip_url in [(step1_dataset, step1_zip_url), (step2_dataset, step2_zip_url)]: - install_weights( - nnunet_dataset=dataset, - zip_url=zip_url, - results_folder=nnUNet_results, - exports_folder=nnUNet_exports, - quiet=quiet - ) - if not quiet: print('\n' 'Making input dir with _0000 suffix:') if input_path.name.endswith('.nii.gz'): # If the input is a single file, copy it to the input_raw folder @@ -254,7 +345,7 @@ def main(): }, ) - if not quiet: print('\n' 'Converting 4D images to 3D:') + if not quiet: print('\n' 'Preprocessing images:') average4d_mp( output_path / 'input', output_path / 'input', @@ -265,7 +356,7 @@ def main(): quiet=quiet, ) - if not quiet: print('\n' 'Transforming images to canonical space:') + if not quiet: print('\n' 'Reorienting images to LPI(-):') reorient_canonical_mp( output_path / 'input', output_path / 'input', @@ -300,22 +391,21 @@ def main(): # Check if the final checkpoint exists, if not use the latest checkpoint checkpoint = 'checkpoint_final.pth' if (nnUNet_results / step1_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' / f'fold_{fold}' / 'checkpoint_final.pth').is_file() else 'checkpoint_latest.pth' + # Construct step 1 model folder + model_folder_step1 = nnUNet_results / step1_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' + if not quiet: print('\n' 'Running step 1 model:') - subprocess.run([ - 'nnUNetv2_predict', - '-d', step1_dataset, - '-i', str(output_path / 'input'), - '-o', str(output_path / 'step1_raw'), - '-f', str(fold), - '-c', configuration, - '-p', nnUNetPlans, - '-tr', nnUNetTrainer, - '-npp', str(max_workers_nnunet), - '-nps', str(max_workers_nnunet), - '-chk', checkpoint, - '-device', device, - '--save_probabilities', - ]) + predict_nnunet( + model_folder=model_folder_step1, + images_dir=output_path / 'input', + output_dir=output_path / 'step1_raw', + folds = str(fold), + save_probabilities = True, + checkpoint = checkpoint, + npp = max_workers_nnunet, + nps = max_workers_nnunet, + device = device + ) # Remove unnecessary files from output folder (output_path / 'step1_raw' / 'dataset.json').unlink(missing_ok=True) @@ -389,7 +479,7 @@ def main(): quiet=quiet, ) - if not quiet: print('\n' 'Filling spinal cancal label to include all non cord spinal canal:') + if not quiet: print('\n' 'Filling spinal canal label to include all non cord spinal canal:') # This will put the spinal canal label in all the voxels between the canal and the cord. fill_canal_mp( output_path / 'step1_output', @@ -566,21 +656,20 @@ def main(): # Check if the final checkpoint exists, if not use the latest checkpoint checkpoint = 'checkpoint_final.pth' if (nnUNet_results / step2_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' / f'fold_{fold}' / 'checkpoint_final.pth').is_file() else 'checkpoint_latest.pth' + # Construct step 2 model folder + model_folder_step2 = nnUNet_results / step2_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' + if not quiet: print('\n' 'Running step 2 model:') - subprocess.run([ - 'nnUNetv2_predict', - '-d', step2_dataset, - '-i', str(output_path / 'step2_input'), - '-o', str(output_path / 'step2_raw'), - '-f', str(fold), - '-c', configuration, - '-p', nnUNetPlans, - '-tr', nnUNetTrainer, - '-npp', str(max_workers_nnunet), - '-nps', str(max_workers_nnunet), - '-chk', checkpoint, - '-device', device - ]) + predict_nnunet( + model_folder=model_folder_step2, + images_dir=output_path / 'step2_input', + output_dir=output_path / 'step2_raw', + folds = str(fold), + checkpoint = checkpoint, + npp = max_workers_nnunet, + nps = max_workers_nnunet, + device = device + ) # Remove unnecessary files from output folder (output_path / 'step2_raw' / 'dataset.json').unlink(missing_ok=True) @@ -659,7 +748,7 @@ def main(): quiet=quiet, ) - if not quiet: print('\n' 'Filling spinal cancal label to include all non cord spinal canal:') + if not quiet: print('\n' 'Filling spinal canal label to include all non cord spinal canal:') # This will put the spinal canal label in all the voxels between the canal and the cord. fill_canal_mp( output_path / 'step2_output', @@ -767,9 +856,27 @@ def main(): max_workers=max_workers, quiet=quiet, ) + # Print all the output paths + if not quiet: print('\nResults of iterative labeling algorithm for step 1:') + if not quiet: print(f'{str(output_path)}/step1_output',) + + if not quiet: print('\nSpinal cord soft segmentations:') + if not quiet: print(f'{str(output_path)}/step1_cord',) + + if not quiet: print('\nSpinal canal soft segmentations:') + if not quiet: print(f'{str(output_path)}/step1_canal',) + + if not quiet: print('\nSingle voxel in canal centerline at each intervertebral disc level:') + if not quiet: print(f'{str(output_path)}/step1_levels',) + + if not quiet and not step1_only: print('\nSegmentation and labeling of the vertebrae, discs, spinal cord and spinal canal:') + if not quiet and not step1_only: print(f'{str(output_path)}/step2_output',) # Remove the input_raw folder shutil.rmtree(output_path / 'input_raw', ignore_errors=True) + + # Return list of output paths + return [str(output_path / folder) for folder in os.listdir(str(output_path))] if __name__ == '__main__': main() \ No newline at end of file diff --git a/totalspineseg/init_inference.py b/totalspineseg/init_inference.py new file mode 100644 index 0000000..79fd392 --- /dev/null +++ b/totalspineseg/init_inference.py @@ -0,0 +1,111 @@ +import argparse, textwrap, os +from pathlib import Path +import importlib.resources +from totalspineseg import models, install_weights +from totalspineseg.utils.utils import ZIP_URLS + +def main(): + parser = argparse.ArgumentParser( + description=textwrap.dedent(''' + This script downloads the pretrained models from the GitHub releases. + '''), + epilog=textwrap.dedent(''' + Examples: + totalspineseg_init + totalspineseg_init --quiet + '''), + formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + '--store-export', action="store_false", + help='Store exported zip file, default to true.' + ) + parser.add_argument( + '--data-dir', '-d', type=Path, default=None, + help=' '.join(f''' + The path to store the nnUNet data. + '''.split()) + ) + parser.add_argument( + '--quiet', '-q', action="store_true", + help='Do not display inputs and progress bar, defaults to false (display).' + ) + + # Parse the command-line arguments + args = parser.parse_args() + + # Get the command-line argument values + store_export = args.store_export + quiet = args.quiet + + # Init data_path + if not args.data_dir is None: + data_path = args.data_dir + elif 'TOTALSPINESEG_DATA' in os.environ: + data_path = Path(os.environ.get('TOTALSPINESEG_DATA', '')) + else: + data_path = importlib.resources.files(models) + + # Initialize inference + init_inference( + data_path=data_path, + dict_urls=ZIP_URLS, + store_export=store_export, + quiet=quiet + ) + + +def init_inference( + data_path, + dict_urls=ZIP_URLS, + store_export=True, + quiet=False + ): + ''' + Function used to download and install nnUNetV2 weights + + Parameters + ---------- + data_path : pathlib.Path or string + Folder path containing the network weights. + dict_urls : dictionary + Url dictionary containing all the weights that need to be downloaded. + quiet : bool + If True, will reduce the amount of displayed information + + Returns + ------- + list of string + List of output folders. + ''' + # Convert data_path to Path like object + if isinstance(data_path, str): + data_path = Path(data_path) + else: + if not isinstance(data_path, Path): + raise ValueError('data_path should be a Path object from pathlib or a string') + + # Set nnUNet paths + nnUNet_results = data_path / 'nnUNet' / 'results' + nnUNet_exports = data_path / 'nnUNet' / 'exports' + + # If not both steps models are installed, use the release subfolder + if not all([(nnUNet_results / dataset).is_dir() for dataset in dict_urls.keys()]): + # TODO Think of better way to get the release + weights_release = list(dict_urls.values())[0].split('/')[-2] + nnUNet_results = nnUNet_results / weights_release + + # Installing the pretrained models if not already installed + for dataset, zip_url in dict_urls.items(): + install_weights( + nnunet_dataset=dataset, + zip_url=zip_url, + results_folder=nnUNet_results, + exports_folder=nnUNet_exports, + store_export=store_export, + quiet=quiet + ) + + +if __name__=='__main__': + main() \ No newline at end of file diff --git a/totalspineseg/models/__init__.py b/totalspineseg/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/totalspineseg/utils/install_weights.py b/totalspineseg/utils/install_weights.py index 4ad2f4c..3ad54eb 100644 --- a/totalspineseg/utils/install_weights.py +++ b/totalspineseg/utils/install_weights.py @@ -1,4 +1,5 @@ -import os, argparse, subprocess, textwrap +import os, argparse, textwrap +import zipfile from pathlib import Path from urllib.request import urlretrieve from tqdm import tqdm @@ -36,9 +37,13 @@ def main(): '--exports-folder', type=Path, required=True, help='Exports folder where the zipped weights will be dowloaded (Required).' ) + parser.add_argument( + '--store-export', type=bool, default=True, + help='Store exported zip file, default to true.' + ) parser.add_argument( '--quiet', '-q', action="store_true", default=False, - help='Do not display inputs and progress bar, defaults to false (display).' + help='Do not display inputs and progress bar, default to false. (display)' ) # Parse the command-line arguments @@ -49,6 +54,7 @@ def main(): zip_url = args.zip_url results_folder = args.results_folder exports_folder = args.exports_folder + store_export = args.store_export quiet = args.quiet # Print the argument values if not quiet @@ -59,6 +65,7 @@ def main(): zip_url = "{zip_url}" results_folder = "{results_folder}" exports_folder = "{exports_folder}" + store_export = "{store_export}" quiet = {quiet} ''')) @@ -67,6 +74,7 @@ def main(): zip_url=zip_url, results_folder=results_folder, exports_folder=exports_folder, + store_export=store_export, quiet=quiet, ) @@ -75,6 +83,7 @@ def install_weights( zip_url, results_folder, exports_folder, + store_export=True, quiet=False, ): ''' @@ -110,8 +119,20 @@ def install_weights( # If the pretrained model is not installed, install it from zip if not quiet: print(f'Installing the pretrained model from {zip_file}...') # Install the pretrained model from the zip file - os.environ['nnUNet_results'] = str(results_folder) - subprocess.run(['nnUNetv2_install_pretrained_model_from_zip', str(zip_file)]) + install_model_from_zip_file(str(zip_file), extract_folder=str(results_folder)) + + # Remove export + if not store_export: + os.remove(str(zip_file)) + + +def install_model_from_zip_file(zip_file: str, extract_folder): + ''' + Based on https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/model_sharing/model_import.py + ''' + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(extract_folder) + if __name__ == '__main__': main() \ No newline at end of file diff --git a/totalspineseg/utils/predict_nnunet.py b/totalspineseg/utils/predict_nnunet.py new file mode 100644 index 0000000..5b1434d --- /dev/null +++ b/totalspineseg/utils/predict_nnunet.py @@ -0,0 +1,250 @@ +import argparse, textwrap +import os +from pathlib import Path +import torch + +# This is just to silence nnUNet warnings. These variables should have no purpose/effect. +# There are sadly no other workarounds at the moment, see: +# https://github.com/MIC-DKFZ/nnUNet/blob/227d68e77f00ec8792405bc1c62a88ddca714697/nnunetv2/paths.py#L21 +os.environ['nnUNet_raw'] = "./nnUNet_raw" +os.environ['nnUNet_preprocessed'] = "./nnUNet_preprocessed" +os.environ['nnUNet_results'] = "./nnUNet_results" + +from nnunetv2.utilities.file_path_utilities import get_output_folder +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + +def main(): + # Description and arguments + parser = argparse.ArgumentParser( + description=' '.join(f''' + This script runs nnUNetV2 inference. Based on https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/inference/predict_from_raw_data.py. + '''.split()), + epilog=textwrap.dedent(''' + Example: + predict_nnunet -i in_folder -o out_folder -d 101 -c 3d_fullres -p nnUNetPlans_small -tr nnUNetTrainer_DASegOrd0_NoMirroring -f 0 + '''), + formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + '--images-dir', '-i', type=Path, required=True, + help='The folder where input NIfTI images files are located (required).' + ) + parser.add_argument( + '--output-dir', '-o', type=Path, required=True, + help='The folder where nnUNet predictions will be stored (required).' + ) + parser.add_argument( + '--dataset', '-d', type=str, required=True, + help='nnUNet dataset number, example 567 for Dataset567_... folder under nnunet_results (required).' + ) + parser.add_argument( + '--configuration', '-c', type=str, required=True, + help='nnUNet configuration' + ) + parser.add_argument( + '--plans', '-p', type=str, default='nnUNetPlans', + help='nnUNet plans, default is "nnUNetPlans".' + ) + parser.add_argument( + '--trainer', '-tr', type=str, default='nnUNetTrainer', + help='nnUNet trainer, default is "nnUNetTrainer".' + ) + parser.add_argument( + '--folds', '-f', nargs='+', type=str, default=(0, 1, 2, 3, 4), + help='nnUNet folds, default is "(0, 1, 2, 3, 4)".' + ) + parser.add_argument( + '-step-size', type=float, default=0.5, + help='Step size for sliding window prediction, default is "0.5".' + ) + parser.add_argument( + '--disable-tta', action='store_true', default=False, + help='Set this flag to disable test time data augmentation, default is false.' + ) + parser.add_argument( + '--verbose', action='store_true', + help="Display extra information, defaults to false (display)." + ) + parser.add_argument( + '--save-probabilities', action='store_true', + help='Set this to export predicted class "probabilities", default is false' + ) + parser.add_argument( + '--continue-prediction', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files), default is false' + ) + parser.add_argument( + '-chk', type=str, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use, default is "checkpoint_final.pth".' + ) + parser.add_argument( + '-npp', type=int, default=3, + help='Number of processes used for preprocessing, default is "3".' + ) + parser.add_argument( + '-nps', type=int, default=3, + help='Number of processes used for segmentation export, default is "3".' + ) + parser.add_argument( + '-prev-stage-predictions', type=str, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models, default is None' + ) + parser.add_argument( + '-num-parts', type=int, default=1, + help='Number of separate nnUNetv2_predict call that you will be making, default is "1".' + ) + parser.add_argument( + '-part-id', type=int, required=False, default=0, + help='If multiple nnUNetv2_predict exist, default is "0".' + ) + parser.add_argument( + '-device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', + help='Use this to set the device the inference should run with, default is "cuda".' + ) + parser.add_argument( + '--disable-progress-bar', action='store_true', + help='Set this flag to disable progress bar, default is false.' + ) + + # Parse the command-line arguments + args = parser.parse_args() + + # Get the command-line argument values + images_dir = args.images_dir + output_dir = args.output_dir + dataset = args.dataset + configuration = args.configuration + plans = args.plans + trainer = args.trainer + folds = args.folds + step_size = args.step_size + disable_tta = args.disable_tta + save_probabilities = args.save_probabilities + continue_prediction = args.continue_prediction + checkpoint = args.chk + npp = args.npp + nps = args.nps + prev_stage_predictions = args.prev_stage_predictions + num_parts = args.num_parts + part_id = args.part_id + device = args.device + verbose = args.verbose + disable_progress_bar = args.disable_progress_bar + + # Get model folder + model_folder = get_output_folder(dataset, trainer, plans, configuration) + + # Print the argument values if not quiet + if verbose: + print(textwrap.dedent(f''' + Running {Path(__file__).stem} with the following params: + images_dir = "{images_dir}" + output_dir = "{output_dir}" + dataset = "{dataset}" + configuration = "{configuration}" + plans = "{plans}" + trainer = "{trainer}" + folds = "{folds}" + step_size = "{step_size}" + disable_tta = "{disable_tta}" + save_probabilities = "{save_probabilities}" + continue_prediction = "{continue_prediction}" + checkpoint = "{checkpoint}" + npp = "{npp}" + nps = "{nps}" + prev_stage_predictions = "{prev_stage_predictions}" + num_parts = "{num_parts}" + part_id = "{part_id}" + device = "{device}" + verbose = "{verbose}" + disable_progress_bar = "{disable_progress_bar}" + ''')) + + # Load device + if isinstance(device, str): + assert device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.' + if device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predict_nnunet( + model_folder = model_folder, + images_dir = images_dir, + output_dir = output_dir, + folds = folds, + step_size = step_size, + disable_tta = disable_tta, + save_probabilities = save_probabilities, + continue_prediction = continue_prediction, + checkpoint = checkpoint, + npp = npp, + nps = nps, + prev_stage_predictions = prev_stage_predictions, + num_parts = num_parts, + part_id = part_id, + device = device, + verbose = verbose, + disable_progress_bar = disable_progress_bar + ) + + +def predict_nnunet( + model_folder, + images_dir, + output_dir, + device, # torch device + folds = (0, 1, 2, 3, 4), + step_size = 0.5, + disable_tta = False, + save_probabilities = False, + continue_prediction = False, + checkpoint = 'checkpoint_final.pth', + npp = 3, + nps = 3, + prev_stage_predictions = None, + num_parts = 1, + part_id = 0, + verbose = False, + disable_progress_bar = False +): + # Check variables + folds = [i if i == 'all' else int(i) for i in folds] + assert part_id < num_parts + + # Create output folder if does not exists + output_dir.mkdir(parents=True, exist_ok=True) + + # Start nnUNet inference + predictor = nnUNetPredictor(tile_step_size=step_size, + use_gaussian=True, + use_mirroring=not disable_tta, + perform_everything_on_device=True, + device=device, + verbose=verbose, + verbose_preprocessing=verbose, + allow_tqdm=not disable_progress_bar) + + predictor.initialize_from_trained_model_folder( + model_folder, + folds, + checkpoint_name=checkpoint + ) + predictor.predict_from_files(str(images_dir), str(output_dir), save_probabilities=save_probabilities, + overwrite=not continue_prediction, + num_processes_preprocessing=npp, + num_processes_segmentation_export=nps, + folder_with_segs_from_prev_stage=prev_stage_predictions, + num_parts=num_parts, + part_id=part_id) + +if __name__=='__main__': + main() \ No newline at end of file diff --git a/totalspineseg/utils/utils.py b/totalspineseg/utils/utils.py new file mode 100644 index 0000000..0c2a05e --- /dev/null +++ b/totalspineseg/utils/utils.py @@ -0,0 +1,7 @@ +from importlib.metadata import metadata + +# Weights zip urls +ZIP_URLS = dict([meta.split(', ') for meta in metadata('totalspineseg').get_all('Project-URL') if meta.startswith('Dataset')]) + +# Version +VERSION = metadata('totalspineseg').get('version')