diff --git a/README.md b/README.md index 2515c69..4ecd58a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # TotalSpineSeg -TotalSpineSeg is a tool for automatic instance segmentation and labeling of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images. It follows the [TotalSegmentator classes](https://github.com/wasserth/TotalSegmentator/tree/v1.5.7#class-details) with additional classes for IVDs, spinal cord, and spinal canal (see [list of classes](#list-of-classes)). The model is based on [nnUNet](https://github.com/MIC-DKFZ/nnUNet) as the backbone for training and inference. +TotalSpineSeg is a tool for automatic instance segmentation of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images. It is robust to various MRI contrasts, acquisition orientations, and resolutions. The model used in TotalSpineSeg is based on [nnUNet](https://github.com/MIC-DKFZ/nnUNet) as the backbone for training and inference. If you use this model, please cite our work: > Warszawer Y, Molinier N, Valošek J, Shirbint E, Benveniste PL, Achiron A, Eshaghi A and Cohen-Adad J. _Fully Automatic Vertebrae and Spinal Cord Segmentation Using a Hybrid Approach Combining nnU-Net and Iterative Algorithm_. Proceedings of the 32th Annual Meeting of ISMRM. 2024 @@ -229,53 +229,53 @@ For a more detailed view of the output examples, you can check the [PDF version] | Label | Name | |:------|:-----| -| 18 | vertebrae_L5 | -| 19 | vertebrae_L4 | -| 20 | vertebrae_L3 | -| 21 | vertebrae_L2 | -| 22 | vertebrae_L1 | -| 23 | vertebrae_T12 | -| 24 | vertebrae_T11 | -| 25 | vertebrae_T10 | -| 26 | vertebrae_T9 | -| 27 | vertebrae_T8 | -| 28 | vertebrae_T7 | -| 29 | vertebrae_T6 | -| 30 | vertebrae_T5 | -| 31 | vertebrae_T4 | -| 32 | vertebrae_T3 | -| 33 | vertebrae_T2 | -| 34 | vertebrae_T1 | -| 35 | vertebrae_C7 | -| 36 | vertebrae_C6 | -| 37 | vertebrae_C5 | -| 38 | vertebrae_C4 | -| 39 | vertebrae_C3 | -| 40 | vertebrae_C2 | -| 41 | vertebrae_C1 | -| 92 | sacrum | -| 200 | spinal_cord | -| 201 | spinal_canal | -| 202 | disc_L5_S | -| 203 | disc_L4_L5 | -| 204 | disc_L3_L4 | -| 205 | disc_L2_L3 | -| 206 | disc_L1_L2 | -| 207 | disc_T12_L1 | -| 208 | disc_T11_T12 | -| 209 | disc_T10_T11 | -| 210 | disc_T9_T10 | -| 211 | disc_T8_T9 | -| 212 | disc_T7_T8 | -| 213 | disc_T6_T7 | -| 214 | disc_T5_T6 | -| 215 | disc_T4_T5 | -| 216 | disc_T3_T4 | -| 217 | disc_T2_T3 | -| 218 | disc_T1_T2 | -| 219 | disc_C7_T1 | -| 220 | disc_C6_C7 | -| 221 | disc_C5_C6 | -| 222 | disc_C4_C5 | -| 223 | disc_C3_C4 | -| 224 | disc_C2_C3 | +| 1 | spinal_cord | +| 2 | spinal_canal | +| 11 | vertebrae_C1 | +| 12 | vertebrae_C2 | +| 13 | vertebrae_C3 | +| 14 | vertebrae_C4 | +| 15 | vertebrae_C5 | +| 16 | vertebrae_C6 | +| 17 | vertebrae_C7 | +| 21 | vertebrae_T1 | +| 22 | vertebrae_T2 | +| 23 | vertebrae_T3 | +| 24 | vertebrae_T4 | +| 25 | vertebrae_T5 | +| 26 | vertebrae_T6 | +| 27 | vertebrae_T7 | +| 28 | vertebrae_T8 | +| 29 | vertebrae_T9 | +| 30 | vertebrae_T10 | +| 31 | vertebrae_T11 | +| 32 | vertebrae_T12 | +| 41 | vertebrae_L1 | +| 42 | vertebrae_L2 | +| 43 | vertebrae_L3 | +| 44 | vertebrae_L4 | +| 45 | vertebrae_L5 | +| 50 | sacrum | +| 63 | disc_C2_C3 | +| 64 | disc_C3_C4 | +| 65 | disc_C4_C5 | +| 66 | disc_C5_C6 | +| 67 | disc_C6_C7 | +| 71 | disc_C7_T1 | +| 72 | disc_T1_T2 | +| 73 | disc_T2_T3 | +| 74 | disc_T3_T4 | +| 75 | disc_T4_T5 | +| 76 | disc_T5_T6 | +| 77 | disc_T6_T7 | +| 78 | disc_T7_T8 | +| 79 | disc_T8_T9 | +| 80 | disc_T9_T10 | +| 81 | disc_T10_T11 | +| 82 | disc_T11_T12 | +| 91 | disc_T12_L1 | +| 92 | disc_L1_L2 | +| 93 | disc_L2_L3 | +| 94 | disc_L3_L4 | +| 95 | disc_L4_L5 | +| 100 | disc_L5_S | diff --git a/pyproject.toml b/pyproject.toml index 996f13a..9a8daee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ totalspineseg_average4d = "totalspineseg.utils.average4d:main" totalspineseg_crop_image2seg = "totalspineseg.utils.crop_image2seg:main" totalspineseg_extract_soft = "totalspineseg.utils.extract_soft:main" totalspineseg_extract_levels = "totalspineseg.utils.extract_levels:main" +totalspineseg_extract_alternate = "totalspineseg.utils.extract_alternate:main" totalspineseg_add_nnunet_trainer = "totalspineseg.utils.add_nnunet_trainer:main" [build-system] diff --git a/totalspineseg/__init__.py b/totalspineseg/__init__.py index 44cac5f..dc4417c 100644 --- a/totalspineseg/__init__.py +++ b/totalspineseg/__init__.py @@ -3,6 +3,7 @@ from .utils.cpdir import cpdir_mp from .utils.crop_image2seg import crop_image2seg, crop_image2seg_mp from .utils.extract_levels import extract_levels, extract_levels_mp +from .utils.extract_alternate import extract_alternate, extract_alternate_mp from .utils.extract_soft import extract_soft, extract_soft_mp from .utils.fill_canal import fill_canal, fill_canal_mp from .utils.iterative_label import iterative_label, iterative_label_mp diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index a753122..aceb487 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -125,9 +125,6 @@ def main(): nnUNet_exports.mkdir(parents=True, exist_ok=True) # Set the nnUNet variables - nnUNetTrainer = 'nnUNetTrainer_16000epochs' - nnUNetPlans = 'nnUNetPlans' - configuration = '3d_fullres_small' step1_dataset = 'Dataset101_TotalSpineSeg_step1' step2_dataset = 'Dataset102_TotalSpineSeg_step2' fold = 0 @@ -233,6 +230,27 @@ def main(): max_workers=max_workers, quiet=quiet, ) + preview_jpg_mp( + output_path / 'input', + output_path / 'preview', + segs_path=output_path / 'localizers', + output_suffix='_loc_tags', + override=True, + max_workers=max_workers, + quiet=quiet, + label_texts_right={ + 11: 'C1', 12: 'C2', 13: 'C3', 14: 'C4', 15: 'C5', 16: 'C6', 17: 'C7', + 21: 'T1', 22: 'T2', 23: 'T3', 24: 'T4', 25: 'T5', 26: 'T6', 27: 'T7', + 28: 'T8', 29: 'T9', 30: 'T10', 31: 'T11', 32: 'T12', + 41: 'L1', 42: 'L2', 43: 'L3', 44: 'L4', 45: 'L5', + }, + label_texts_left={ + 50: 'Sacrum', 63: 'C2C3', 64: 'C3C4', 65: 'C4C5', 66: 'C5C6', 67: 'C6C7', + 71: 'C7T1', 72: 'T1T2', 73: 'T2T3', 74: 'T3T4', 75: 'T4T5', 76: 'T5T6', 77: 'T6T7', + 78: 'T7T8', 79: 'T8T9', 80: 'T9T10', 81: 'T10T11', 82: 'T11T12', + 91: 'T12L1', 92: 'L1L2', 93: 'L2L3', 94: 'L3L4', 95: 'L4L5', 100: 'L5S' + }, + ) if not quiet: print('\n' 'Converting 4D images to 3D:') average4d_mp( @@ -275,6 +293,8 @@ def main(): quiet=quiet, ) + # Get the nnUNet parameters from the results folder + nnUNetTrainer, nnUNetPlans, configuration = next((nnUNet_results / step1_dataset).glob('*/fold_*')).parent.name.split('__') # Check if the final checkpoint exists, if not use the latest checkpoint checkpoint = 'checkpoint_final.pth' if (nnUNet_results / step1_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' / f'fold_{fold}' / 'checkpoint_final.pth').is_file() else 'checkpoint_latest.pth' @@ -323,10 +343,16 @@ def main(): iterative_label_mp( output_path / 'step1_output', output_path / 'step1_output', + selected_disc_landmarks=[2, 5, 3, 4], disc_labels=[1, 2, 3, 4, 5], - init_disc={2:224, 5:202, 3:219, 4:207}, - output_disc_step=-1, - map_input_dict={6:92, 7:201, 8:201, 9:200}, + disc_landmark_labels=[2, 3, 4, 5], + disc_landmark_output_labels=[63, 71, 91, 100], + canal_labels=[7, 8], + canal_output_label=2, + cord_labels=[9], + cord_output_label=1, + sacrum_labels=[6], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -336,11 +362,17 @@ def main(): output_path / 'step1_output', output_path / 'step1_output', locs_path=output_path / 'localizers', + selected_disc_landmarks=[2, 5], disc_labels=[1, 2, 3, 4, 5], - init_disc={2:224, 5:202}, - output_disc_step=-1, - loc_disc_labels=list(range(202, 225)), - map_input_dict={6:92, 7:201, 8:201, 9:200}, + disc_landmark_labels=[2, 3, 4, 5], + disc_landmark_output_labels=[63, 71, 91, 100], + loc_disc_labels=list(range(63, 101)), + canal_labels=[7, 8], + canal_output_label=2, + cord_labels=[9], + cord_output_label=1, + sacrum_labels=[6], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -351,8 +383,8 @@ def main(): fill_canal_mp( output_path / 'step1_output', output_path / 'step1_output', - canal_label=201, - cord_label=200, + canal_label=2, + cord_label=1, largest_canal=True, largest_cord=True, override=True, @@ -380,6 +412,21 @@ def main(): max_workers=max_workers, quiet=quiet, ) + preview_jpg_mp( + output_path / 'input', + output_path / 'preview', + segs_path=output_path / 'step1_output', + output_suffix='_step1_output_tags', + override=True, + max_workers=max_workers, + quiet=quiet, + label_texts_left={ + 63: 'C2C3', 64: 'C3C4', 65: 'C4C5', 66: 'C5C6', 67: 'C6C7', + 71: 'C7T1', 72: 'T1T2', 73: 'T2T3', 74: 'T3T4', 75: 'T4T5', 76: 'T5T6', 77: 'T6T7', + 78: 'T7T8', 79: 'T8T9', 80: 'T9T10', 81: 'T10T11', 82: 'T11T12', + 91: 'T12L1', 92: 'L1L2', 93: 'L2L3', 94: 'L3L4', 95: 'L4L5', 100: 'L5S' + }, + ) if not quiet: print('\n' 'Extracting spinal cord soft segmentation from step 1 model output:') extract_soft_mp( @@ -387,7 +434,7 @@ def main(): output_path / 'step1_output', output_path / 'step1_cord', label=9, - seg_labels=[200], + seg_labels=[1], dilate=1, override=True, max_workers=max_workers, @@ -400,7 +447,7 @@ def main(): output_path / 'step1_output', output_path / 'step1_canal', label=7, - seg_labels=[200, 201], + seg_labels=[1, 2], dilate=1, override=True, max_workers=max_workers, @@ -417,9 +464,8 @@ def main(): extract_levels_mp( output_path / 'step1_output', output_path / 'step1_levels', - canal_labels=[200, 201], - c2c3_label=224, - step=-1, + canal_labels=[1, 2], + disc_labels=list(range(63, 68)) + list(range(71, 83)) + list(range(91, 96)) + [100], override=True, max_workers=max_workers, quiet=quiet, @@ -459,18 +505,14 @@ def main(): quiet=quiet, ) - # Load label mappings from JSON file - with open(resources_path / 'labels_maps' / 'nnunet_step2_input.json', 'r', encoding='utf-8') as map_file: - map_dict = json.load(map_file) - if not quiet: print('\n' 'Mapping the IVDs labels from the step1 model output to the odd IVDs:') # This will also delete labels without odd IVDs - map_labels_mp( + extract_alternate_mp( output_path / 'step2_input', output_path / 'step2_input', - map_dict=map_dict, seg_suffix='_0001', output_seg_suffix='_0001', + labels=list(range(63, 101)), override=True, max_workers=max_workers, quiet=quiet, @@ -493,6 +535,8 @@ def main(): if not f.with_name(f.name.replace('_0000.nii.gz', '_0001.nii.gz')).exists(): f.unlink() + # Get the nnUNet parameters from the results folder + nnUNetTrainer, nnUNetPlans, configuration = next((nnUNet_results / step2_dataset).glob('*/fold_*')).parent.name.split('__') # Check if the final checkpoint exists, if not use the latest checkpoint checkpoint = 'checkpoint_final.pth' if (nnUNet_results / step2_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' / f'fold_{fold}' / 'checkpoint_final.pth').is_file() else 'checkpoint_latest.pth' @@ -543,17 +587,19 @@ def main(): iterative_label_mp( output_path / 'step2_output', output_path / 'step2_output', + selected_disc_landmarks=[4, 7, 5, 6], disc_labels=[1, 2, 3, 4, 5, 6, 7], + disc_landmark_labels=[4, 5, 6, 7], + disc_landmark_output_labels=[63, 71, 91, 100], vertebrae_labels=[9, 10, 11, 12, 13, 14], + vertebrae_landmark_output_labels=[13, 21, 41, 50], vertebrae_extra_labels=[8], - init_disc={4:224, 7:202, 5:219, 6:207}, - init_vertebrae={11:40, 14:17, 12:34, 13:23}, - step_diff_label=True, - step_diff_disc=True, - output_disc_step=-1, - output_vertebrae_step=-1, - map_output_dict={17:92}, - map_input_dict={14:92, 15:201, 16:201, 17:200}, + canal_labels=[15, 16], + canal_output_label=2, + cord_labels=[17], + cord_output_label=1, + sacrum_labels=[14], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -563,19 +609,20 @@ def main(): output_path / 'step2_output', output_path / 'step2_output', locs_path=output_path / 'localizers', + selected_disc_landmarks=[4, 7], disc_labels=[1, 2, 3, 4, 5, 6, 7], + disc_landmark_labels=[4, 5, 6, 7], + disc_landmark_output_labels=[63, 71, 91, 100], vertebrae_labels=[9, 10, 11, 12, 13, 14], + vertebrae_landmark_output_labels=[13, 21, 41, 50], vertebrae_extra_labels=[8], - init_disc={4:224, 7:202}, - init_vertebrae={11:40, 14:17}, - loc_disc_labels=list(range(202, 225)), - loc_vertebrae_labels=list(range(18, 42)) + [92], - step_diff_label=True, - step_diff_disc=True, - output_disc_step=-1, - output_vertebrae_step=-1, - map_output_dict={17:92}, - map_input_dict={14:92, 15:201, 16:201, 17:200}, + loc_disc_labels=list(range(63, 101)), + canal_labels=[15, 16], + canal_output_label=2, + cord_labels=[17], + cord_output_label=1, + sacrum_labels=[14], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -586,8 +633,8 @@ def main(): fill_canal_mp( output_path / 'step2_output', output_path / 'step2_output', - canal_label=201, - cord_label=200, + canal_label=2, + cord_label=1, largest_canal=True, largest_cord=True, override=True, @@ -615,6 +662,27 @@ def main(): max_workers=max_workers, quiet=quiet, ) + preview_jpg_mp( + output_path / 'input', + output_path / 'preview', + segs_path=output_path / 'step2_output', + output_suffix='_step2_output_tags', + override=True, + max_workers=max_workers, + quiet=quiet, + label_texts_right={ + 11: 'C1', 12: 'C2', 13: 'C3', 14: 'C4', 15: 'C5', 16: 'C6', 17: 'C7', + 21: 'T1', 22: 'T2', 23: 'T3', 24: 'T4', 25: 'T5', 26: 'T6', 27: 'T7', + 28: 'T8', 29: 'T9', 30: 'T10', 31: 'T11', 32: 'T12', + 41: 'L1', 42: 'L2', 43: 'L3', 44: 'L4', 45: 'L5', + }, + label_texts_left={ + 50: 'Sacrum', 63: 'C2C3', 64: 'C3C4', 65: 'C4C5', 66: 'C5C6', 67: 'C6C7', + 71: 'C7T1', 72: 'T1T2', 73: 'T2T3', 74: 'T3T4', 75: 'T4T5', 76: 'T5T6', 77: 'T6T7', + 78: 'T7T8', 79: 'T8T9', 80: 'T9T10', 81: 'T10T11', 82: 'T11T12', + 91: 'T12L1', 92: 'L1L2', 93: 'L2L3', 94: 'L3L4', 95: 'L4L5', 100: 'L5S' + }, + ) if __name__ == '__main__': main() \ No newline at end of file diff --git a/totalspineseg/utils/extract_alternate.py b/totalspineseg/utils/extract_alternate.py new file mode 100644 index 0000000..048833d --- /dev/null +++ b/totalspineseg/utils/extract_alternate.py @@ -0,0 +1,267 @@ +import sys, argparse, textwrap +from pathlib import Path +import numpy as np +import nibabel as nib +import multiprocessing as mp +from functools import partial +from tqdm.contrib.concurrent import process_map +import warnings + +warnings.filterwarnings("ignore") + +def main(): + + # Parse command line arguments + parser = argparse.ArgumentParser( + description=' '.join(f''' + Extract alternate labels from the segmentation. + '''.split()), + epilog=textwrap.dedent(''' + Examples: + extract_alternate -s labels -o levels --labels 60-100 -r + For BIDS: + extract_alternate -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_levels" -d "sub-" -u "anat" --labels 60-100 -r + '''), + formatter_class=argparse.RawTextHelpFormatter + ) + + parser.add_argument( + '--segs-dir', '-s', type=Path, required=True, + help='Folder containing input segmentations.' + ) + parser.add_argument( + '--output-segs-dir', '-o', type=Path, required=True, + help='Folder to save output segmentations.' + ) + parser.add_argument( + '--subject-dir', '-d', type=str, default=None, nargs='?', const='', + help=' '.join(f''' + Is every subject has its oen direcrory. + If this argument will be provided without value it will look for any directory in the segmentation directory. + If value also provided it will be used as a prefix to subject directory (for example "sub-"), defaults to False (no subjet directory). + '''.split()) + ) + parser.add_argument( + '--subject-subdir', '-u', type=str, default='', + help='Subfolder inside subject folder containing masks (for example "anat"), defaults to no subfolder.' + ) + parser.add_argument( + '--prefix', '-p', type=str, default='', + help='File prefix to work on.' + ) + parser.add_argument( + '--seg-suffix', type=str, default='', + help='Segmentation suffix, defaults to "".' + ) + parser.add_argument( + '--output-seg-suffix', type=str, default='', + help='Suffix for output segmentation, defaults to "".' + ) + parser.add_argument( + '--labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, + help='The labels to extract alternate elements from.' + ) + parser.add_argument( + '--prioratize-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='Specify labels that will be prioratized in the output, the first label in the list will be included in the output, defaults to [] (The first label in the list that is in the segmentation).' + ) + parser.add_argument( + '--override', '-r', action="store_true", default=False, + help='Override existing output files, defaults to false (Do not override).' + ) + parser.add_argument( + '--max-workers', '-w', type=int, default=mp.cpu_count(), + help='Max worker to run in parallel proccess, defaults to multiprocessing.cpu_count().' + ) + parser.add_argument( + '--quiet', '-q', action="store_true", default=False, + help='Do not display inputs and progress bar, defaults to false (display).' + ) + + # Parse the command-line arguments + args = parser.parse_args() + + # Get arguments + segs_path = args.segs_dir + output_segs_path = args.output_segs_dir + subject_dir = args.subject_dir + subject_subdir = args.subject_subdir + prefix = args.prefix + seg_suffix = args.seg_suffix + output_seg_suffix = args.output_seg_suffix + labels = [_ for __ in args.labels for _ in (__ if isinstance(__, list) else [__])] + prioratize_labels = [_ for __ in args.prioratize_labels for _ in (__ if isinstance(__, list) else [__])] + override = args.override + max_workers = args.max_workers + quiet = args.quiet + + # Print the argument values if not quiet + if not quiet: + print(textwrap.dedent(f''' + Running {Path(__file__).stem} with the following params: + segs_dir = "{segs_path}" + output_segs_dir = "{output_segs_path}" + subject_dir = "{subject_dir}" + subject_subdir = "{subject_subdir}" + prefix = "{prefix}" + seg_suffix = "{seg_suffix}" + output_seg_suffix = "{output_seg_suffix}" + labels = {labels} + prioratize_labels = {prioratize_labels} + override = {override} + max_workers = {max_workers} + quiet = {quiet} + ''')) + + extract_alternate_mp( + segs_path=segs_path, + output_segs_path=output_segs_path, + subject_dir=subject_dir, + subject_subdir=subject_subdir, + prefix=prefix, + seg_suffix=seg_suffix, + output_seg_suffix=output_seg_suffix, + labels=labels, + prioratize_labels=prioratize_labels, + override=override, + max_workers=max_workers, + quiet=quiet, + ) + +def extract_alternate_mp( + segs_path, + output_segs_path, + subject_dir=None, + subject_subdir='', + prefix='', + seg_suffix='', + output_seg_suffix='', + labels=[], + prioratize_labels=[], + override=False, + max_workers=mp.cpu_count(), + quiet=False, + ): + ''' + Wrapper function to handle multiprocessing. + ''' + segs_path = Path(segs_path) + output_segs_path = Path(output_segs_path) + + glob_pattern = "" + if subject_dir is not None: + glob_pattern += f"{subject_dir}*/" + if len(subject_subdir) > 0: + glob_pattern += f"{subject_subdir}/" + glob_pattern += f'{prefix}*{seg_suffix}.nii.gz' + + # Process the NIfTI image and segmentation files + seg_path_list = list(segs_path.glob(glob_pattern)) + output_seg_path_list = [output_segs_path / _.relative_to(segs_path).parent / _.name.replace(f'{seg_suffix}.nii.gz', f'{output_seg_suffix}.nii.gz') for _ in seg_path_list] + + process_map( + partial( + _extract_alternate, + labels=labels, + prioratize_labels=prioratize_labels, + override=override, + ), + seg_path_list, + output_seg_path_list, + max_workers=max_workers, + chunksize=1, + disable=quiet, + ) + +def _extract_alternate( + seg_path, + output_seg_path, + labels=[], + prioratize_labels=[], + override=False, + ): + ''' + Wrapper function to handle IO. + ''' + seg_path = Path(seg_path) + output_seg_path = Path(output_seg_path) + + # If the output image already exists and we are not overriding it, return + if not override and output_seg_path.exists(): + return + + # Load segmentation + seg = nib.load(seg_path) + + try: + output_seg = extract_alternate( + seg, + labels=labels, + prioratize_labels=prioratize_labels, + ) + except ValueError as e: + output_seg_path.is_file() and output_seg_path.unlink() + print(f'Error: {seg_path}, {e}') + return + + # Ensure correct segmentation dtype, affine and header + output_seg = nib.Nifti1Image( + np.asanyarray(output_seg.dataobj).round().astype(np.uint8), + output_seg.affine, output_seg.header + ) + output_seg.set_data_dtype(np.uint8) + output_seg.set_qform(output_seg.affine) + output_seg.set_sform(output_seg.affine) + + # Make sure output directory exists and save the segmentation + output_seg_path.parent.mkdir(parents=True, exist_ok=True) + nib.save(output_seg, output_seg_path) + +def extract_alternate( + seg, + labels=[], + prioratize_labels=[], + ): + ''' + Extract vertebrae levels from Spinal Canal and Discs. + + The function extracts the vertebrae levels from the input segmentation by finding the closest voxel in the canal centerline to the middle of each disc. + The superior voxels in the canal centerline are set to 1 and the middle voxels between C2-C3 and the superior voxels are set to 2. + + Parameters + ---------- + seg : nibabel.Nifti1Image + The input segmentation. + labels : list of int + The labels to extract alternate elements from. + prioratize_labels : list of int + Specify labels that will be prioratized in the output, the first label in the list will be included in the output. + + Returns + ------- + nibabel.Nifti1Image + The output segmentation with the vertebrae levels. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + output_seg_data = np.zeros_like(seg_data) + + # Get the labels in the segmentation + labels = np.array(labels)[np.isin(labels, seg_data)] + + # Get the labels to prioratize in the output that are in the segmentation and in the labels + prioratize_labels = np.array(prioratize_labels)[np.isin(prioratize_labels, labels)] + + selected_labels = labels[::2] + + if len(prioratize_labels) > 0 and prioratize_labels[0] not in selected_labels: + selected_labels = labels[1::2] + + output_seg_data[np.isin(seg_data, selected_labels)] = 1 + + output_seg = nib.Nifti1Image(output_seg_data, seg.affine, seg.header) + + return output_seg + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/totalspineseg/utils/extract_levels.py b/totalspineseg/utils/extract_levels.py index 264563b..2fef326 100644 --- a/totalspineseg/utils/extract_levels.py +++ b/totalspineseg/utils/extract_levels.py @@ -18,9 +18,9 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - extract_levels -s labels -o levels --canal-labels 200 201 --c2c3-label 224 --step -1 -r + extract_levels -s labels -o levels --canal-labels 1 2 --disc-labels 60-64 70-81 90-94 100 -r For BIDS: - extract_levels -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_levels" -d "sub-" -u "anat" --canal-labels 200 201 --c2c3-label 224 --step -1 -r + extract_levels -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_levels" -d "sub-" -u "anat" --canal-labels 1 2 --disc-labels 60-64 70-81 90-94 100 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -58,16 +58,16 @@ def main(): help='Suffix for output segmentation, defaults to "".' ) parser.add_argument( - '--canal-labels', type=int, nargs='+', required=True, + '--canal-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, help='The canal labels.' ) parser.add_argument( - '--c2c3-label', type=int, required=True, - help='The label for C2-C3 disc.' + '--disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, + help='The disc labels starting at C2C3 ordered from superior to inferior.' ) parser.add_argument( - '--step', type=int, default=1, - help='The step to take between discs labels in the input, defaults to 1.' + '--c1-label', type=int, default=0, + help='The label for C1 vertebra in the segmentation, if provided it will be used to determine if C1 is in the segmentation.' ) parser.add_argument( '--override', '-r', action="store_true", default=False, @@ -94,8 +94,8 @@ def main(): seg_suffix = args.seg_suffix output_seg_suffix = args.output_seg_suffix canal_labels = args.canal_labels - c2c3_label = args.c2c3_label - step = args.step + disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] + c1_label = args.c1_label override = args.override max_workers = args.max_workers quiet = args.quiet @@ -112,8 +112,8 @@ def main(): seg_suffix = "{seg_suffix}" output_seg_suffix = "{output_seg_suffix}" canal_labels = {canal_labels} - c2c3_label = {c2c3_label} - step = {step} + disc_labels = {disc_labels} + c1_label = {c1_label} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -128,8 +128,8 @@ def main(): seg_suffix=seg_suffix, output_seg_suffix=output_seg_suffix, canal_labels=canal_labels, - c2c3_label=c2c3_label, - step=step, + disc_labels=disc_labels, + c1_label=c1_label, override=override, max_workers=max_workers, quiet=quiet, @@ -144,8 +144,8 @@ def extract_levels_mp( seg_suffix='', output_seg_suffix='', canal_labels=[], - c2c3_label=3, - step=1, + disc_labels=[], + c1_label=0, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -171,8 +171,8 @@ def extract_levels_mp( partial( _extract_levels, canal_labels=canal_labels, - step=step, - c2c3_label=c2c3_label, + disc_labels=disc_labels, + c1_label=c1_label, override=override, ), seg_path_list, @@ -186,8 +186,8 @@ def _extract_levels( seg_path, output_seg_path, canal_labels=[], - c2c3_label=3, - step=1, + disc_labels=[], + c1_label=0, override=False, ): ''' @@ -207,8 +207,8 @@ def _extract_levels( output_seg = extract_levels( seg, canal_labels=canal_labels, - c2c3_label=c2c3_label, - step=step, + disc_labels=disc_labels, + c1_label=c1_label, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -231,8 +231,8 @@ def _extract_levels( def extract_levels( seg, canal_labels=[], - c2c3_label=3, - step=1, + disc_labels=[], + c1_label=0, ): ''' Extract vertebrae levels from Spinal Canal and Discs. @@ -246,10 +246,10 @@ def extract_levels( The input segmentation. canal_labels : list The canal labels. - c2c3_label : int - The label for C2-C3 disc. - step : int - The step to take between discs labels in the input. + disc_labels : list + The disc labels starting at C2C3 ordered from superior to inferior. + c1_label : int + The label for C1 vertebra in the segmentation, if provided it will be used to determine if C1 is in the segmentation. Returns ------- @@ -282,18 +282,20 @@ def extract_levels( # Get the indices of the canal centerline mask_canal_centerline_indices = np.array(np.nonzero(mask_canal_centerline)) - # Get the labels of the discs and the output labels - disc_labels = list(range(c2c3_label, c2c3_label + step * 23, step)) - out_labels = list(range(3, 26)) - - # Filter the discs that are in the segmentation - in_seg = np.isin(disc_labels, seg_data) - map_labels = [(d, o) for d, o, i in zip(disc_labels, out_labels, in_seg) if i] + # Get the labels of the discs in the segmentation + disc_labels_in_seg = np.array(disc_labels)[np.isin(disc_labels, seg_data)] # If no disc labels found in the segmentation raise an error - if len(map_labels) == 0: + if len(disc_labels_in_seg) == 0: raise ValueError(f"No disc labels found in the segmentation.") + # Get the labels of the discs and the output labels + first_disk_idx = disc_labels.index(disc_labels_in_seg[0]) + out_labels = list(range(3 + first_disk_idx, 3 + first_disk_idx + len(disc_labels_in_seg))) + + # Filter the discs that are in the segmentation + map_labels = zip(disc_labels_in_seg, out_labels) + # Loop over the discs from C2-C3 to L5-S1 and find the closest voxel in the canal centerline for disc_label, out_label in map_labels: # Create a mask of the disc @@ -320,8 +322,8 @@ def extract_levels( # Find the location of the superior voxels in the canal centerline canal_superior_index = np.unravel_index(np.argmax(mask_canal_centerline * indices[2]), seg_data.shape) - if canal_superior_index[2] - c2c3_index[2] >= 8 and output_seg_data.shape[2] - canal_superior_index[2] >= 2: - # If C2-C3 at least 8 voxels below the top of the canal and the top of the canal is at least 2 voxels from the top of the image + if (c1_label > 0 and c1_label in seg_data) or (c1_label == 0 and canal_superior_index[2] - c2c3_index[2] >= 8 and output_seg_data.shape[2] - canal_superior_index[2] >= 2): + # If C1 is in the segmentation or C2-C3 at least 8 voxels below the top of the canal and the top of the canal is at least 2 voxels from the top of the image # Set 1 to the superior voxels output_seg_data[canal_superior_index] = 1 diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 714a00e..361b85a 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,10 +20,12 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r - iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --init-vertebrae 11:40 14:17 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --loc-vertebrae-labels 18-41 92 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 63 71 91 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 63 71 91 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 13 21 41 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 2 5 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 63 71 91 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 --loc-disc-labels 63-100 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 4 7 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 63 71 91 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 13 21 41 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 --loc-disc-labels 63-100 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 63 71 91 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -39,7 +41,7 @@ def main(): parser.add_argument( '--locs-dir', '-l', type=Path, default=None, help=' '.join(f''' - Folder containing localizers' segmentations to help the labeling. Optional. + Folder containing localizers' segmentations to help the labeling if landmarks not found, Optional. The algorithm will transform the localizer to the segmentation space and use it to detect the matching vertebrae and disc if the init label not found. Mathcing will based on the magority of the voxels of the first vertebrae or disc in the localizer, that intersect with the input segmentation. '''.split()) @@ -72,84 +74,81 @@ def main(): '--loc-suffix', type=str, default='', help='Localizer suffix, defaults to "".' ) + parser.add_argument( + '--selected-disc-landmarks', type=int, nargs='+', default=[], + help='The selected disc labels to use as a landmark from the disc_landmark_labels.' + ) parser.add_argument( '--disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], help='The disc labels.' ) parser.add_argument( - '--init-disc', type=lambda x:map(int, x.split(':')), nargs='+', default=[], - help='Init labels list for disc ordered by priority (input_label:output_label !!without space!!). for example 4:224 5:219 6:202' + '--disc-landmark-labels', type=int, nargs=4, + help='All disc labels that can be used as a landmark: C2C3, C7T1, T12L1 and L5S1.' ) parser.add_argument( - '--output-disc-step', type=int, default=1, - help='The step to take between disc labels in the output, defaults to 1.' + '--disc-landmark-output-labels', type=int, nargs=4, + help='List of output labels for discs C2C3, C7T1, T12L1 and L5S1.' ) parser.add_argument( - '--loc-disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], - help='The disc labels in the localizer used for detecting first disc.' + '--disc-output-step', type=int, default=1, + help='The step to take between disc labels in the output, defaults to 1.' ) parser.add_argument( '--vertebrae-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], help='The vertebrae labels.' ) + parser.add_argument( + '--vertebrae-landmark-output-labels', type=int, nargs=4, + help='List of output labels for vertebrae C3, T1, L1, Sacrum.' + ) + parser.add_argument( + '--vertebrae-output-step', type=int, default=1, + help='The step to take between vertebrae labels in the output, defaults to 1.' + ) parser.add_argument( '--vertebrae-extra-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], help='Extra vertebrae labels to add to add to adjacent vertebrae labels.' ) parser.add_argument( - '--init-vertebrae', type=lambda x:map(int, x.split(':')), nargs='+', default=[], - help='Init labels list for vertebrae ordered by priority (input_label:output_label !!without space!!). for example 10:41 11:34 12:18' + '--region-max-sizes', type=int, nargs=4, default=[5, 12, 6, 1], + help='The maximum number of discs/vertebrae for each region (Cervical from C3, Thoracic, Lumbar, Sacrum), defaults to [5, 12, 6, 1].' ) parser.add_argument( - '--output-vertebrae-step', type=int, default=1, - help='The step to take between vertebrae labels in the output, defaults to 1.' + '--loc-disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The disc labels in the localizer used for detecting first disc.' ) parser.add_argument( - '--loc-vertebrae-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], - help='The disc labels in the localizer used for detecting first vertebrae.' + '--canal-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The canal labels in the segmentation.' ) parser.add_argument( - '--map-input', type=str, nargs='+', default=[], - help=' '.join(f''' - A dict mapping labels from input into the output segmentation. - The format should be input_label:output_label without any spaces. - For example, 14:92 16:201 17:200 to map the input sacrum label 14 to 92, canal label 16 to 201 and spinal cord label 17 to 200. - '''.split()) + '--canal-output-label', type=int, default=0, + help='Output label for the canal, defaults to 0 (Do not output).' ) parser.add_argument( - '--map-output', type=str, nargs='+', default=[], - help=' '.join(f''' - A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation. - The format should be input_label:output_label without any spaces. - For example, 17:92 to map the iteratively labeled vertebrae 17 to the sacrum label 92. - '''.split()) + '--cord-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The spinal cord labels in the segmentation.' ) parser.add_argument( - '--dilation-size', type=int, default=1, - help='Number of voxels to dilate before finding connected voxels to label, defaults to 1 (No dilation).' + '--cord-output-label', type=int, default=0, + help='Output label for the spinal cord, defaults to 0 (Do not output).' ) parser.add_argument( - '--step-diff-label', action="store_true", default=False, - help=' '.join(f''' - Make step only for different labels. When looping on the labels on the z axis, it will give a new label to the next label only - if it is different from the previous label. This is useful if there are labels for odd and even vertebrae, so the next label will - be for even vertebrae only if the previous label was odd. If it is still odd, it should give the same label. - '''.split()) + '--sacrum-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The sacrum labels in the segmentation.' ) parser.add_argument( - '--step-diff-disc', action="store_true", default=False, - help=' '.join(f''' - Make step only for different discs. When looping on the labels on the z axis, it will give a new label to the next label only - if there is a disc between them. This exclude the first and last vertebrae since it can be C1 or only contain the spinous process. - '''.split()) + '--sacrum-output-label', type=int, default=0, + help='Output label for the sacrum, defaults to 0 (Do not output).' ) parser.add_argument( - '--default-superior-disc', type=int, default=0, - help='Default superior disc label if no init label found, defaults to 0 (Raise error if init label not found).' + '--dilation-size', type=int, default=1, + help='Number of voxels to dilate before finding connected voxels to label, defaults to 1 (No dilation).' ) parser.add_argument( - '--default-superior-vertebrae', type=int, default=0, - help='Default superior vertebrae label if no init label found, defaults to 0 (Raise error if init label not found).' + '--default-superior-disc', type=int, default=0, + help='Default superior disc label if no init label found, defaults to 0 (Raise error if init label not found).' ) parser.add_argument( '--override', '-r', action="store_true", default=False, @@ -177,22 +176,25 @@ def main(): seg_suffix = args.seg_suffix output_seg_suffix = args.output_seg_suffix loc_suffix = args.loc_suffix + selected_disc_landmarks = args.selected_disc_landmarks disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] - init_disc = dict(args.init_disc) - output_disc_step = args.output_disc_step - loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] + disc_landmark_labels = args.disc_landmark_labels + disc_landmark_output_labels = args.disc_landmark_output_labels + disc_output_step = args.disc_output_step vertebrae_labels = [_ for __ in args.vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] + vertebrae_landmark_output_labels = args.vertebrae_landmark_output_labels + vertebrae_output_step = args.vertebrae_output_step vertebrae_extra_labels = [_ for __ in args.vertebrae_extra_labels for _ in (__ if isinstance(__, list) else [__])] - init_vertebrae = dict(args.init_vertebrae) - output_vertebrae_step = args.output_vertebrae_step - loc_vertebrae_labels = [_ for __ in args.loc_vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] - map_input_list = args.map_input - map_output_list = args.map_output + region_max_sizes = args.region_max_sizes + loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] + canal_labels = [_ for __ in args.canal_labels for _ in (__ if isinstance(__, list) else [__])] + canal_output_label = args.canal_output_label + cord_labels = [_ for __ in args.cord_labels for _ in (__ if isinstance(__, list) else [__])] + cord_output_label = args.cord_output_label + sacrum_labels = [_ for __ in args.sacrum_labels for _ in (__ if isinstance(__, list) else [__])] + sacrum_output_label = args.sacrum_output_label dilation_size = args.dilation_size - step_diff_label = args.step_diff_label - step_diff_disc = args.step_diff_disc default_superior_disc = args.default_superior_disc - default_superior_vertebrae = args.default_superior_vertebrae override = args.override max_workers = args.max_workers quiet = args.quiet @@ -210,38 +212,30 @@ def main(): seg_suffix = "{seg_suffix}" output_seg_suffix = "{output_seg_suffix}" loc_suffix = "{loc_suffix}" + selected_disc_landmarks = {selected_disc_landmarks} disc_labels = {disc_labels} - init_disc = {init_disc} - output_disc_step = {output_disc_step} - loc_disc_labels = {loc_disc_labels} + disc_landmark_labels = {disc_landmark_labels} + disc_landmark_output_labels = {disc_landmark_output_labels} + disc_output_step = {disc_output_step} vertebrae_labels = {vertebrae_labels} + vertebrae_landmark_output_labels = {vertebrae_landmark_output_labels} + vertebrae_output_step = {vertebrae_output_step} vertebrae_extra_labels = {vertebrae_extra_labels} - init_vertebrae = {init_vertebrae} - output_vertebrae_step = {output_vertebrae_step} - loc_vertebrae_labels = {loc_vertebrae_labels} - map_input = {map_input_list} - map_output = {map_output_list} + region_max_sizes = {region_max_sizes} + loc_disc_labels = {loc_disc_labels} + canal_labels = {canal_labels} + canal_output_label = {canal_output_label} + cord_labels = {cord_labels} + cord_output_label = {cord_output_label} + sacrum_labels = {sacrum_labels} + sacrum_output_label = {sacrum_output_label} dilation_size = {dilation_size} - step_diff_label = {step_diff_label} - step_diff_disc = {step_diff_disc} default_superior_disc = {default_superior_disc} - default_superior_vertebrae = {default_superior_vertebrae} override = {override} max_workers = {max_workers} quiet = {quiet} ''')) - # Load maps into a dict - try: - map_input_dict = {int(l_in): int(l_out) for l_in, l_out in map(lambda x:x.split(':'), map_input_list)} - except: - raise ValueError("Input param map_input is not in the right structure. Make sure it is in the right format, e.g., 1:2 3:5") - - try: - map_output_dict = {int(l_in): int(l_out) for l_in, l_out in map(lambda x:x.split(':'), map_output_list)} - except: - raise ValueError("Input param map_output is not in the right structure. Make sure it is in the right format, e.g., 1:2 3:5") - iterative_label_mp( segs_path=segs_path, output_segs_path=output_segs_path, @@ -252,22 +246,25 @@ def main(): seg_suffix=seg_suffix, output_seg_suffix=output_seg_suffix, loc_suffix=loc_suffix, + selected_disc_landmarks=selected_disc_landmarks, disc_labels=disc_labels, - init_disc=init_disc, - output_disc_step=output_disc_step, - loc_disc_labels=loc_disc_labels, + disc_landmark_labels=disc_landmark_labels, + disc_landmark_output_labels=disc_landmark_output_labels, + disc_output_step=disc_output_step, vertebrae_labels=vertebrae_labels, + vertebrae_landmark_output_labels=vertebrae_landmark_output_labels, + vertebrae_output_step=vertebrae_output_step, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, - output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, - map_input_dict=map_input_dict, - map_output_dict=map_output_dict, + region_max_sizes=region_max_sizes, + loc_disc_labels=loc_disc_labels, + canal_labels=canal_labels, + canal_output_label=canal_output_label, + cord_labels=cord_labels, + cord_output_label=cord_output_label, + sacrum_labels=sacrum_labels, + sacrum_output_label=sacrum_output_label, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, max_workers=max_workers, quiet=quiet, @@ -283,22 +280,25 @@ def iterative_label_mp( seg_suffix='', output_seg_suffix='', loc_suffix='', + selected_disc_landmarks=[], disc_labels=[], - init_disc={}, - output_disc_step=1, - loc_disc_labels=[], + disc_landmark_labels=[], + disc_landmark_output_labels=[], + disc_output_step=1, vertebrae_labels=[], + vertebrae_landmark_output_labels=[], + vertebrae_output_step=1, vertebrae_extra_labels=[], - init_vertebrae={}, - output_vertebrae_step=1, - loc_vertebrae_labels=[], - map_input_dict={}, - map_output_dict={}, + region_max_sizes=[5, 12, 6, 1], + loc_disc_labels=[], + canal_labels=[], + canal_output_label=0, + cord_labels=[], + cord_output_label=0, + sacrum_labels=[], + sacrum_output_label=0, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -325,22 +325,25 @@ def iterative_label_mp( process_map( partial( _iterative_label, + selected_disc_landmarks=selected_disc_landmarks, disc_labels=disc_labels, - output_disc_step=output_disc_step, - loc_disc_labels=loc_disc_labels, - init_disc=init_disc, + disc_landmark_labels=disc_landmark_labels, + disc_landmark_output_labels=disc_landmark_output_labels, + disc_output_step=disc_output_step, vertebrae_labels=vertebrae_labels, + vertebrae_landmark_output_labels=vertebrae_landmark_output_labels, + vertebrae_output_step=vertebrae_output_step, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, - output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, - map_input_dict=map_input_dict, - map_output_dict=map_output_dict, + region_max_sizes=region_max_sizes, + loc_disc_labels=loc_disc_labels, + canal_labels=canal_labels, + canal_output_label=canal_output_label, + cord_labels=cord_labels, + cord_output_label=cord_output_label, + sacrum_labels=sacrum_labels, + sacrum_output_label=sacrum_output_label, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, ), seg_path_list, @@ -355,22 +358,25 @@ def _iterative_label( seg_path, output_seg_path, loc_path=None, + selected_disc_landmarks=[], disc_labels=[], - init_disc={}, - output_disc_step=1, - loc_disc_labels=[], + disc_landmark_labels=[], + disc_landmark_output_labels=[], + disc_output_step=1, vertebrae_labels=[], + vertebrae_landmark_output_labels=[], + vertebrae_output_step=1, vertebrae_extra_labels=[], - init_vertebrae={}, - output_vertebrae_step=1, - loc_vertebrae_labels=[], - map_input_dict={}, - map_output_dict={}, + region_max_sizes=[5, 12, 6, 1], + loc_disc_labels=[], + canal_labels=[], + canal_output_label=0, + cord_labels=[], + cord_output_label=0, + sacrum_labels=[], + sacrum_output_label=0, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, ): ''' @@ -392,22 +398,25 @@ def _iterative_label( output_seg = iterative_label( seg, loc, + selected_disc_landmarks=selected_disc_landmarks, disc_labels=disc_labels, - init_disc=init_disc, - output_disc_step=output_disc_step, - loc_disc_labels=loc_disc_labels, + disc_landmark_labels=disc_landmark_labels, + disc_landmark_output_labels=disc_landmark_output_labels, + disc_output_step=disc_output_step, vertebrae_labels=vertebrae_labels, + vertebrae_landmark_output_labels=vertebrae_landmark_output_labels, + vertebrae_output_step=vertebrae_output_step, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, - output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, - map_input_dict=map_input_dict, - map_output_dict=map_output_dict, + region_max_sizes=region_max_sizes, + loc_disc_labels=loc_disc_labels, + canal_labels=canal_labels, + canal_output_label=canal_output_label, + cord_labels=cord_labels, + cord_output_label=cord_output_label, + sacrum_labels=sacrum_labels, + sacrum_output_label=sacrum_output_label, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, - default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, + disc_default_superior_output=default_superior_disc, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -430,22 +439,25 @@ def _iterative_label( def iterative_label( seg, loc=None, + selected_disc_landmarks=[], disc_labels=[], - init_disc={}, - output_disc_step=1, - loc_disc_labels=[], + disc_landmark_labels=[], + disc_landmark_output_labels=[], + disc_output_step=1, vertebrae_labels=[], + vertebrae_landmark_output_labels=[], + vertebrae_output_step=1, vertebrae_extra_labels=[], - init_vertebrae={}, - output_vertebrae_step=1, - loc_vertebrae_labels=[], - map_input_dict={}, - map_output_dict={}, + region_max_sizes=[5, 12, 6, 1], + loc_disc_labels=[], + canal_labels=[], + canal_output_label=0, + cord_labels=[], + cord_output_label=0, + sacrum_labels=[], + sacrum_output_label=0, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, - default_superior_disc=0, - default_superior_vertebrae=0, + disc_default_superior_output=0, ): ''' Label Vertebrae, IVDs, Spinal Cord and canal from init segmentation. @@ -454,8 +466,12 @@ def iterative_label( 2. Find connected voxels for each vertebrae label and label them into separate labels 3. Combine sequential vertebrae labels based on some conditions 4. Combine extra labels with adjacent vertebrae labels - 5. Map labels from the iteative algorithm output, to the final output (e.g., map the vertebrae label from the iteative algorithm output to the special sacrum label) - 6. Map input labels to the final output (e.g., map the input sacrum, canal and spinal cord labels to the output labels) + 5. Find the landmark disc labels and output labels + 6. Label the discs with the output labels + 7. Find the matching vertebrae labels to the discs landmarks + 8. Label the vertebrae with the output labels + 9. Map labels from the iteative algorithm output, to the final output (e.g., map the vertebrae label from the iteative algorithm output to the special sacrum label) + 10. Map input labels to the final output (e.g., map the input sacrum, canal and spinal cord labels to the output labels) Parameters ---------- @@ -463,262 +479,634 @@ def iterative_label( Segmentation image loc : nibabel.nifti1.Nifti1Image Localizer image to use for detecting first vertebrae and disc (optional) + selected_disc_landmarks : list + List of disc labels to use as a landmark from the disc_landmark_labels disc_labels : list - The disc labels - init_disc : dict - Init labels list for disc ordered by priority (input_label:output_label) - output_disc_step : int + The disc labels in the segmentation + disc_landmark_labels : list + All disc labels that can be used as a landmark: [C2C3, C7T1, T12L1, L5S1] + disc_landmark_output_labels : list + List of output labels for discs [C2C3, C7T1, T12L1, L5S1] + disc_output_step : int The step to take between disc labels in the output - loc_disc_labels : list - Localizer labels to use for detecting first disc vertebrae_labels : list - The vertebrae labels + The vertebrae labels in the segmentation + vertebrae_landmark_output_labels : list + List of output labels for vertebrae [C3, T1, L1, Sacrum] + vertebrae_output_step : int + The step to take between vertebrae labels in the output vertebrae_extra_labels : list Extra vertebrae labels to add to add to adjacent vertebrae labels - init_vertebrae : dict - Init labels list for vertebrae ordered by priority (input_label:output_label) - output_vertebrae_step : int - The step to take between vertebrae labels in the output - loc_vertebrae_labels : list - Localizer labels to use for detecting first vertebrae - map_input_dict : dict - A dict mapping labels from input into the output segmentation - map_output_dict : dict - A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation + region_max_sizes : list + The maximum number of discs/vertebrae for each region (Cervical from C3, Thoracic, Lumbar, Sacrum). + loc_disc_labels : list + Localizer labels to use for detecting first disc + canal_labels : list + Canal labels in the segmentation + canal_output_label : int + Output label for the canal + cord_labels : list + Spinal Cord labels in the segmentation + cord_output_label : int + Output label for the spinal cord + sacrum_labels : list + Sacrum labels in the segmentation + sacrum_output_label : int + Output label for the sacrum dilation_size : int Number of voxels to dilate before finding connected voxels to label - step_diff_label : bool - Make step only for different labels - step_diff_disc : bool - Make step only for different discs default_superior_disc : int Default superior disc label if no init label found - default_superior_vertebrae : int - Default superior vertebrae label if no init label found Returns ------- nibabel.nifti1.Nifti1Image Segmentation image with labeled vertebrae, IVDs, Spinal Cord and canal ''' + # Region default sizes for the discs and vertebrae (Cervical, Thoracic, Lumbar, Sacrum) + region_default_sizes=[5, 12, 5, 1] + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) output_seg_data = np.zeros_like(seg_data) - loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) + # Get the canal centerline indices to use for sorting the discs and vertebrae based on the prjection on the canal centerline + canal_centerline_indices = _get_canal_centerline_indices(seg_data, canal_labels + cord_labels) - # If localizer is provided, transform it to the segmentation space - if loc_data is not None: - loc_data = tio.Resample( - tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) - )( - tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) - ).data.numpy()[0, ...].astype(np.uint8) + # Get the mask of the voxels anterior to the canal, this helps in sorting the vertebrae considering only the vertebrae body + mask_aterior_to_canal = _get_mask_aterior_to_canal(seg_data, canal_labels + cord_labels) + + # Get sorted connected components superio-inferior (SI) for the disc labels + disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indices = _get_si_sorted_components( + seg, + disc_labels, + canal_centerline_indices, + mask_aterior_to_canal, + dilation_size, + combine_labels=True, + ) + + # Get sorted connected components superio-inferior (SI) for the vertebrae labels + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _get_si_sorted_components( + seg, + vertebrae_labels, + canal_centerline_indices, + mask_aterior_to_canal, + dilation_size, + ) + + # Combine sequential vertebrae labels if they have the same value in the original segmentation + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _merge_vertebrae_with_same_label( + seg, + vertebrae_labels, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal, + ) + + # Combine sequential vertebrae labels if there is no disc between them + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _merge_vertebrae_labels_with_no_disc_between( + seg, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indices, + disc_sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal, + ) + + # Combine extra labels with adjacent vertebrae labels + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _merge_extra_labels_with_adjacent_vertebrae( + seg, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indices, + vertebrae_extra_labels, + canal_centerline_indices, + mask_aterior_to_canal, + ) + + # Get the landmark disc labels and output labels - {label in sorted labels: output label} + # TODO Currently only the first 2 landmark from selected_disc_landmarks is used, to get all landmarks see TODO in the function + map_disc_sorted_labels_landmark2output = _get_landmark_output_labels( + seg, + loc, + disc_mask_labeled, + disc_sorted_labels, + selected_disc_landmarks, + disc_landmark_labels, + disc_landmark_output_labels, + loc_disc_labels, + disc_default_superior_output, + ) + + # Build a list containing all possible labels for the disc ordered superio-inferior + all_possible_disc_output_labels = [] + for l, s in zip(disc_landmark_output_labels, region_max_sizes): + for i in range(s): + all_possible_disc_output_labels.append(l + i * disc_output_step) + + # Make a list containing all possible labels for the disc ordered superio-inferior with the default region sizes + all_default_disc_output_labels = [] + for l, s in zip(disc_landmark_output_labels, region_default_sizes): + for i in range(s): + all_default_disc_output_labels.append(l + i * disc_output_step) + + # Make a dict mapping the sorted disc labels to the output labels + map_disc_sorted_labels_2output = {} + + # We loop over all the landmarks starting from the most superior + for l_disc in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: + + # If this is the most superior landmark, we have to adjust the start indices to start from the most superior label in the image + if len(map_disc_sorted_labels_2output) == 0: + # Get the index of the current landmark in the sorted disc labels + start_l = disc_sorted_labels.index(l_disc) + + # Get the index of the current landmark in the list of all default disc output labels + start_o_def = all_default_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l_disc]) + + # Adjust the start indices + start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) + + # Map the sorted disc labels to the output labels + for l, o in zip(disc_sorted_labels[start_l:], all_default_disc_output_labels[start_o_def:]): + map_disc_sorted_labels_2output[l] = o + + # Get the index of the current landmark in the sorted disc labels + start_l = disc_sorted_labels.index(l_disc) + + # Get the index of the current landmark in the list of all possible disc output labels + start_o = all_possible_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l_disc]) + + # Map the sorted disc labels to the output labels + # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark + for l, o in zip(disc_sorted_labels[start_l:], all_possible_disc_output_labels[start_o:]): + map_disc_sorted_labels_2output[l] = o + + # Label the discs with the output labels superio-inferior + for l, o in map_disc_sorted_labels_2output.items(): + output_seg_data[disc_mask_labeled == l] = o + + if vert_num_labels > 0: + # Build a list containing all possible labels for the vertebrae ordered superio-inferior + # We start with the C1 and C2 labels as the first landmark is the C3 vertebrae + all_possible_vertebrae_output_labels = [ + vertebrae_landmark_output_labels[0] - 2 * vertebrae_output_step, # C1 + vertebrae_landmark_output_labels[0] - vertebrae_output_step # C2 + ] + for l, s in zip(vertebrae_landmark_output_labels, region_max_sizes): + for i in range(s): + all_possible_vertebrae_output_labels.append(l + i * vertebrae_output_step) + + # Make a list containing all possible labels for the vertebrae ordered superio-inferior with the default region sizes + all_default_vertebrae_output_labels = [ + vertebrae_landmark_output_labels[0] - 2 * vertebrae_output_step, # C1 + vertebrae_landmark_output_labels[0] - vertebrae_output_step # C2 + ] + for l, s in zip(vertebrae_landmark_output_labels, region_default_sizes): + for i in range(s): + all_default_vertebrae_output_labels.append(l + i * vertebrae_output_step) + + # Sort the combined disc+vert labels by their z-index + sorted_labels = vert_sorted_labels + disc_sorted_labels + sorted_z_indices = vert_sorted_z_indices + disc_sorted_z_indices + is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indices, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indices, sorted_labels, is_vert))[::-1]) + + # Make a dict mapping disc to vertebrae labels + disc_output_labels_2vert = dict(zip(all_possible_disc_output_labels, all_possible_vertebrae_output_labels[2:])) + + # Make a dict mapping the sorted vertebrae labels to the output labels + map_vert_sorted_labels_2output = {} + + l_vert_output = 0 + # We loop over all the labels starting from the most superior, and we map the vertebrae labels to the output labels + for idx, curr_l, curr_is_vert in zip(range(len(sorted_labels)), sorted_labels, is_vert): + if not curr_is_vert: # This is a disc + # If the current disc is not in the map, continue + if curr_l not in map_disc_sorted_labels_2output: + continue + + # Get the output label for the disc and vertebrae + l_disc_output = map_disc_sorted_labels_2output[curr_l] + l_vert_output = disc_output_labels_2vert[l_disc_output] + + if idx > 0 and len(map_vert_sorted_labels_2output) == 0: # This is the first disc + # Get the index of the current vertebrae in the default vertebrae output labels list + i = all_default_vertebrae_output_labels.index(l_vert_output) + + # Get the labels of the vertebrae superior to the current disc + prev_vert_ls = [l for l, _is_v in zip(sorted_labels[idx - 1::-1], is_vert[idx - 1::-1]) if _is_v] + + # Map all the vertebrae superior to the current disc to the default vertebrae output labels + for l, o in zip(prev_vert_ls, all_default_vertebrae_output_labels[i - 1::-1]): + map_vert_sorted_labels_2output[l] = o + + elif l_vert_output > 0: # This is a vertebrae + map_vert_sorted_labels_2output[curr_l] = l_vert_output + + # Label the vertebrae with the output labels superio-inferior + for l, o in map_vert_sorted_labels_2output.items(): + output_seg_data[vert_mask_labeled == l] = o + + # Map Spinal Canal to the output label + if canal_labels is not None and len(canal_labels) > 0 and canal_output_label > 0: + output_seg_data[np.isin(seg_data, canal_labels)] = canal_output_label + + # Map Spinal Cord to the output label + if cord_labels is not None and len(cord_labels) > 0 and cord_output_label > 0: + output_seg_data[np.isin(seg_data, cord_labels)] = cord_output_label + + # Map Sacrum to the output label + if sacrum_labels is not None and len(sacrum_labels) > 0 and sacrum_output_label > 0: + output_seg_data[np.isin(seg_data, sacrum_labels)] = sacrum_output_label + + output_seg = nib.Nifti1Image(output_seg_data, seg.affine, seg.header) + + return output_seg + +def _get_canal_centerline_indices( + seg_data, + canal_labels=[], + ): + ''' + Get the indices of the canal centerline. + ''' + # Get array of indices for x, y, and z axes + indices = np.indices(seg_data.shape) + + # Create a mask of the canal + mask_canal = np.isin(seg_data, canal_labels) + + + # Create a mask the canal centerline by finding the middle voxels in x and y axes for each z index + mask_min_x_indices = np.min(indices[0], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1)) + mask_max_x_indices = np.max(indices[0], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1)) + mask_mid_x = indices[0] == ((mask_min_x_indices + mask_max_x_indices) // 2) + mask_min_y_indices = np.min(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1)) + mask_max_y_indices = np.max(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1)) + mask_mid_y = indices[1] == ((mask_min_y_indices + mask_max_y_indices) // 2) + mask_canal_centerline = mask_canal * mask_mid_x * mask_mid_y + + # Get the indices of the canal centerline + return np.array(np.nonzero(mask_canal_centerline)).T + +def _sort_labels_si( + mask_labeled, + labels, + canal_centerline_indices, + mask_aterior_to_canal=None, + ): + ''' + Sort the labels by their z-index (reversed to go from superior to inferior). + ''' + # Get the indices of the center of mass for each label + labels_indices = np.array(ndi.center_of_mass(np.isin(mask_labeled, labels), mask_labeled, labels)) + + # Get the distance of each label indices from the canal centerline + labels_distances_from_centerline = np.linalg.norm(labels_indices[:, None, :] - canal_centerline_indices[None, ...],axis=2) + + # Get the z-index of the closest canal centerline voxel for each label + labels_z_indices = canal_centerline_indices[np.argmin(labels_distances_from_centerline, axis=1), -1] + + # If mask_aterior_to_canal is provided, calculate the center of mass in this mask if the label is inside the mask + if mask_aterior_to_canal is not None: + # Save the existing labels z-index in a dict + labels_z_indices_dict = dict(zip(labels, labels_z_indices)) + + # Get the part that is anterior to the canal od mask_labeled + mask_labeled_aterior_to_canal = mask_aterior_to_canal * mask_labeled + + # Get the labels that contain voxels anterior to the canal + labels_masked = np.isin(labels, mask_labeled_aterior_to_canal) + + # Get the indices of the center of mass for each label + labels_masked_indices = np.array(ndi.center_of_mass(np.isin(mask_labeled_aterior_to_canal, labels_masked), mask_labeled_aterior_to_canal, labels_masked)) + + # Get the distance of each label indices for each voxel in the canal centerline + labels_masked_distances_from_centerline = np.linalg.norm(labels_masked_indices[:, None, :] - canal_centerline_indices[None, :],axis=2) + + # Get the z-index of the closest canal centerline voxel for each label + labels_masked_z_indices = canal_centerline_indices[np.argmin(labels_masked_distances_from_centerline, axis=1), -1] + + # Update the dict with the new z-index of the labels anterior to the canal + for l, z in zip(labels_masked, labels_masked_z_indices): + labels_z_indices_dict[l] = z + + # Update the labels_z_indices from the dict + labels_z_indices = [labels_z_indices_dict[l] for l in labels] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + return zip(*sorted(zip(labels_z_indices, labels))[::-1]) + +def _get_si_sorted_components( + seg, + labels, + canal_centerline_indices, + mask_aterior_to_canal=None, + dilation_size=1, + combine_labels=False, + ): + ''' + Get sorted connected components superio-inferior (SI) for the given labels in the segmentation. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) - # Arrays to store the z indexes of the discs sorted superior to inferior - disc_sorted_z_indexes = [] + # Skip if no labels are provided + if len(labels) == 0: + return None, 0, [], [] - # We run the same iterative algorithm for discs and vertebrae - for labels, extra_labels, step, init, default_sup, loc_labels, is_vert in ( - (disc_labels, [], output_disc_step, init_disc, default_superior_disc, loc_disc_labels, False), - (vertebrae_labels, vertebrae_extra_labels, output_vertebrae_step, init_vertebrae, default_superior_vertebrae, loc_vertebrae_labels, True)): + if combine_labels: + # For discs, combine all labels before label continue voxels since the discs not touching each other + _labels = [labels] + else: + _labels = [[_] for _ in labels] - # Skip if no labels are provided - if len(labels) == 0: - continue + # Init labeled segmentation + mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 - if is_vert: - _labels = [[_] for _ in labels] - else: - # For discs, combine all labels before label continue voxels since the discs not touching each other - _labels = [labels] + # For each label, find connected voxels and label them into separate labels + for l in _labels: + mask = np.isin(seg_data, l) - # Init labeled segmentation - mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 + # Dilate the mask to combine small disconnected regions + mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) - # For each label, find connected voxels and label them into separate labels - for l in _labels: - mask = np.isin(seg_data, l) + # Label the connected voxels in the dilated mask into separate labels + tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) - # Dilate the mask to combine small disconnected regions - mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) + # Undo dilation + tmp_mask_labeled *= mask - # Label the connected voxels in the dilated mask into separate labels - tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) + # Add current labels to the labeled segmentation + if tmp_num_labels > 0: + mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels + num_labels += tmp_num_labels - # Undo dilation - tmp_mask_labeled *= mask + # If no label found, raise error + if num_labels == 0: + raise ValueError(f"Some label must be in the segmentation (labels: {labels})") - # Add current labels to the labeled segmentation - if tmp_num_labels > 0: - mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels - num_labels += tmp_num_labels + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) - # If no label found, raise error - if num_labels == 0: - raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, range(1,num_labels+1), canal_centerline_indices, mask_aterior_to_canal + ) + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) +def _get_mask_aterior_to_canal( + seg_data, + canal_labels=[], + ): + ''' + Get the mask of the voxels anterior to the canal. + ''' + # Get array of indices for x, y, and z axes + indices = np.indices(seg_data.shape) - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + # Create a mask of the canal + mask_canal = np.isin(seg_data, canal_labels) - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) - - # In this part we loop over the labels superior to inferior and combine sequential vertebrae labels based on some conditions - if is_vert: - # Combine sequential vertebrae labels if they have the same value in the original segmentation - # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value - if step_diff_label and len(labels) - len(init) > 1: - new_sorted_labels = [] - - # Store the previous label and the original label of the previous label - prev_l, prev_orig_label = 0, 0 - - # Loop over the sorted labels - for l in sorted_labels: - # Get the original label of the current label - curr_orig_label = seg_data[mask_labeled == l].flat[0] - - # Combine the current label with the previous label if they have the same original label - if curr_orig_label == prev_orig_label: - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_orig_label = l, curr_orig_label - - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] - - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - # Combine sequential vertebrae labels if there is no disc between them - if step_diff_disc and len(disc_sorted_z_indexes) > 0: - new_sorted_labels = [] - - # Store the previous label and the z index of the previous label - prev_l, prev_z = 0, 0 - - for l, z in zip(sorted_labels, sorted_z_indexes): - # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process - if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_z = l, z - - sorted_labels = new_sorted_labels - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) + # Create a mask the canal centerline by finding the middle voxels in x and y axes for each z index + mask_min_y_indices = np.min(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1)) + mask_max_y_indices = np.max(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1)) + mask_mid_y_indices = (mask_min_y_indices + mask_max_y_indices) // 2 + + return indices[1] > mask_mid_y_indices + +def _merge_vertebrae_with_same_label( + seg, + labels, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal=None, + ): + ''' + Combine sequential vertebrae labels if they have the same value in the original segmentation. + This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value. + ''' + if num_labels == 0 or len(labels) <= 1: + return mask_labeled, num_labels, sorted_labels, sorted_z_indices + + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + new_sorted_labels = [] + + # Store the previous label and the original label of the previous label + prev_l, prev_orig_label = 0, 0 + + # Loop over the sorted labels + for l in sorted_labels: + # Get the original label of the current label + curr_orig_label = seg_data[mask_labeled == l].flat[0] + + # Combine the current label with the previous label if they have the same original label + if curr_orig_label == prev_orig_label: + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 else: - # Save the z indexes of the discs - disc_sorted_z_indexes = sorted_z_indexes + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_orig_label = l, curr_orig_label - # Find the most superior label in the segmentation - first_label = 0 - for k, v in init.items(): - if k in seg_data: - first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) - break + sorted_labels = new_sorted_labels + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, sorted_labels, canal_centerline_indices, mask_aterior_to_canal + ) - # If no init label found, set it from the localizer - if first_label == 0 and loc_data is not None: - # Make mask for the intersection of the localizer labels and the labels in the segmentation - mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) +def _merge_vertebrae_labels_with_no_disc_between( + seg, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indices, + disc_sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal, + ): + ''' + Combine sequential vertebrae labels if there is no disc between them. + ''' + if num_labels == 0 or len(disc_sorted_z_indices) == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indices + + new_sorted_labels = [] + + # Store the previous label and the z index of the previous label + prev_l, prev_z = 0, 0 + + for l, z in zip(sorted_labels, sorted_z_indices): + # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process + if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indices): + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_z = l, z + + sorted_labels = new_sorted_labels + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, sorted_labels, canal_centerline_indices, mask_aterior_to_canal + ) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) + +def _merge_extra_labels_with_adjacent_vertebrae( + seg, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indices, + extra_labels, + canal_centerline_indices, + mask_aterior_to_canal, + ): + ''' + Combine extra labels with adjacent vertebrae labels. + This is useful for combining remaining of general vertebrae labels that introduce for region based training but not used in the final segmentation. + ''' + if num_labels == 0 or len(extra_labels) == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indices + + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + mask_extra = np.isin(seg_data, extra_labels) + + # Loop over vertebral labels (from inferior because the transverse process make it steal from above) + for i in range(num_labels - 1, -1, -1): + # Mkae mask for the current vertebrae with filling the holes and dilating it + mask = _fill(mask_labeled == sorted_labels[i]) + mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) + + # Add the intersection of the mask with the extra labels to the current verebrae + mask_labeled[mask_extra * mask] = sorted_labels[i] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, sorted_labels, canal_centerline_indices, mask_aterior_to_canal + ) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) + +def _get_landmark_output_labels( + seg, + loc, + mask_labeled, + sorted_labels, + selected_landmarks, + landmark_labels, + landmark_output_labels, + loc_labels, + default_superior_output, + ): + ''' + Get dict mapping labels from sorted_labels to the output labels based on the landmarks in the segmentation or localizer. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) + + map_landmark_labels = dict(zip(landmark_labels, landmark_output_labels)) + + # If localizer is provided, transform it to the segmentation space + if loc_data is not None: + loc_data = tio.Resample( + tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) + )( + tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) + ).data.numpy()[0, ...].astype(np.uint8) + + # Init dict to store the output labels for the landmarks + map_landmark_outputs = {} + + # First we try to look for the landmarks in the segmentation + for l in selected_landmarks: + ############################################################################################################ + # TODO Remove this reake when we trust all the landmarks to get all landmarks instead of the first 2 + if len(map_landmark_outputs) > 0 and selected_landmarks.index(l) > 1: + break + ############################################################################################################ + if l in map_landmark_labels and l in seg_data: + map_landmark_outputs[np.argmax(np.bincount(mask_labeled[seg_data == l].flat))] = map_landmark_labels[l] + + # If no init label found, set it from the localizer + if len(map_landmark_outputs) == 0 and loc_data is not None: + # Make mask for the intersection of the localizer labels and the labels in the segmentation + mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) + mask_labeled_masked = mask * mask_labeled + loc_data_masked = mask * loc_data + + # First we try to look for the landmarks in the localizer + # TODO Edge case if map_output_dict used for discs, but it is not used in the current implementation + for output_label in np.array(landmark_output_labels)[np.isin(landmark_output_labels, loc_data_masked)].tolist(): + # Map the label with the most voxels in the localizer landmark to the output label + map_landmark_outputs[np.argmax(np.bincount(mask_labeled_masked[loc_data_masked == output_label].flat))] = output_label + + if len(map_landmark_outputs) == 0: # Get the first label from sorted_labels that is in the localizer specified labels - mask_labeled_masked = mask * mask_labeled first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) if first_sorted_labels_in_loc > 0: - # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it - loc_data_masked = mask * loc_data - target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) - # If target in map_output_dict reverse it from the reversed map - # TODO Edge case if multiple keys have the same value, not used in the current implementation - target = {v: k for k, v in map_output_dict.items()}.get(target, target) - first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) - - # If no init label found, set the default superior label - if first_label == 0 and default_sup > 0: - first_label = default_sup - - # If no init label found, print error - if first_label == 0: - raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - - # Combine extra labels with adjacent vertebrae labels - if len(extra_labels) > 0: - mask_extra = np.isin(seg_data, extra_labels) - - # Loop over vertebral labels (from inferior because the transverse process make it steal from above) - for i in range(num_labels - 1, -1, -1): - # Mkae mask for the current vertebrae with filling the holes and dilating it - mask = fill(mask_labeled == sorted_labels[i]) - mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) - - # Add the intersection of the mask with the extra labels to the current verebrae - mask_labeled[mask_extra * mask] = sorted_labels[i] - - # Set the output value for the current label - for i in range(num_labels): - output_seg_data[mask_labeled == sorted_labels[i]] = first_label + step * i - - # Use the map to map labels from the iteative algorithm output, to the final output - # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label - for orig, new in map_output_dict.items(): - if int(orig) in output_seg_data: - output_seg_data[output_seg_data == int(new)] = 0 - output_seg_data[output_seg_data == int(orig)] = int(new) - - # Use the map to map input labels to the final output - # This is useful to map the input sacrum, canal and spinal cord labels to the output labels - for orig, new in map_input_dict.items(): - if int(orig) in seg_data: - output_seg_data[output_seg_data == int(new)] = 0 - mask = seg_data == int(orig) - - # Map also all labels that are currently in the mask - # This is useful for example if we addedd from extra_labels to the sacrum and we want them to map together with the sacrum - mask_labes = [_ for _ in np.unique(output_seg_data[mask]) if _ != 0] - if len(mask_labes) > 0: - mask |= np.isin(output_seg_data, mask_labes) - output_seg_data[mask] = int(new) + # Get the output label for first_sorted_labels_in_loc, the label from the localizer that has the most voxels in it + map_landmark_outputs[first_sorted_labels_in_loc] = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) - output_seg = nib.Nifti1Image(output_seg_data, seg.affine, seg.header) + # If no init label found, set the default superior label + if len(map_landmark_outputs) == 0 and default_superior_output > 0: + map_landmark_outputs[sorted_labels[0]] = default_superior_output - return output_seg + # If no init label found, print error + if len(map_landmark_outputs) == 0: + if loc_data is not None: + raise ValueError( + f"At least one of the landmarks must be in the segmentation or localizer (landmarks: {selected_landmarks}. " + f"Check {loc_labels}), make sure the localizer is in the same space as the segmentation" + ) + raise ValueError(f"At least one of the landmarks must be in the segmentation or localizer (landmarks: {selected_landmarks})") + + return map_landmark_outputs -def fill(mask): +def _fill(mask): ''' Fill holes in a binary mask @@ -732,25 +1120,20 @@ def fill(mask): np.ndarray Binary mask with holes filled ''' + # Get array of indices for x, y, and z axes + indices = np.indices(mask.shape) - # Create an array of x indices with the same shape as the mask - x_indices = np.broadcast_to(np.arange(mask.shape[0])[..., np.newaxis, np.newaxis], mask.shape) - # Create an array of y indices with the same shape as the mask - y_indices = np.broadcast_to(np.arange(mask.shape[1])[..., np.newaxis], mask.shape) - # Create an array of z indices with the same shape as the mask - z_indices = np.broadcast_to(np.arange(mask.shape[2]), mask.shape) - - mask_min_x = np.min(np.where(mask, x_indices, np.inf), axis=0)[np.newaxis, ...] - mask_max_x = np.max(np.where(mask, x_indices, -np.inf), axis=0)[np.newaxis, ...] - mask_min_y = np.min(np.where(mask, y_indices, np.inf), axis=1)[:, np.newaxis, :] - mask_max_y = np.max(np.where(mask, y_indices, -np.inf), axis=1)[:, np.newaxis, :] - mask_min_z = np.min(np.where(mask, z_indices, np.inf), axis=2)[:, :, np.newaxis] - mask_max_z = np.max(np.where(mask, z_indices, -np.inf), axis=2)[:, :, np.newaxis] + mask_min_x = np.min(np.where(mask, indices[0], np.inf), axis=0)[np.newaxis, ...] + mask_max_x = np.max(np.where(mask, indices[0], -np.inf), axis=0)[np.newaxis, ...] + mask_min_y = np.min(np.where(mask, indices[1], np.inf), axis=1)[:, np.newaxis, :] + mask_max_y = np.max(np.where(mask, indices[1], -np.inf), axis=1)[:, np.newaxis, :] + mask_min_z = np.min(np.where(mask, indices[2], np.inf), axis=2)[:, :, np.newaxis] + mask_max_z = np.max(np.where(mask, indices[2], -np.inf), axis=2)[:, :, np.newaxis] return \ - ((mask_min_x <= x_indices) & (x_indices <= mask_max_x)) | \ - ((mask_min_y <= y_indices) & (y_indices <= mask_max_y)) | \ - ((mask_min_z <= z_indices) & (z_indices <= mask_max_z)) + ((mask_min_x <= indices[0]) & (indices[0] <= mask_max_x)) | \ + ((mask_min_y <= indices[1]) & (indices[1] <= mask_max_y)) | \ + ((mask_min_z <= indices[2]) & (indices[2] <= mask_max_z)) if __name__ == '__main__': main() \ No newline at end of file diff --git a/totalspineseg/utils/preview_jpg.py b/totalspineseg/utils/preview_jpg.py index cf82199..8f66e98 100644 --- a/totalspineseg/utils/preview_jpg.py +++ b/totalspineseg/utils/preview_jpg.py @@ -1,5 +1,5 @@ -import sys, argparse, textwrap -from PIL import Image +import sys, argparse, textwrap, json +from PIL import Image, ImageDraw, ImageFont import multiprocessing as mp from functools import partial from pathlib import Path @@ -78,13 +78,31 @@ def main(): '--sliceloc', '-l', type=float, default=0.5, help='Slice location within the specified orientation (0-1). Default is 0.5 (middle slice).' ) + parser.add_argument( + '--label-text-right', '-ltr', type=str, nargs='+', default=[], + help=' '.join(''' + JSON file or mapping from label integers to text labels to be placed on the right side. + The format should be input_label:text_label without any spaces. + For example, you can use a JSON file like right_labels.json containing {"1": "SC", "2": "Canal"}, + or provide mappings directly like 1:SC 2:Canal + '''.split()), + ) + parser.add_argument( + '--label-text-left', '-ltl', type=str, nargs='+', default=[], + help=' '.join(''' + JSON file or mapping from label integers to text labels to be placed on the left side. + The format should be input_label:text_label without any spaces. + For example, you can use a JSON file like left_labels.json containing {"1": "SC", "2": "Canal"}, + or provide mappings directly like 1:SC 2:Canal + '''.split()), + ) parser.add_argument( '--override', '-r', action="store_true", default=False, help='Override existing output files, defaults to false (Do not override).' ) parser.add_argument( '--max-workers', '-w', type=int, default=mp.cpu_count(), - help='Max worker to run in parallel proccess, defaults to multiprocessing.cpu_count().' + help='Max worker to run in parallel process, defaults to multiprocessing.cpu_count().' ) parser.add_argument( '--quiet', '-q', action="store_true", default=False, @@ -106,10 +124,18 @@ def main(): seg_suffix = args.seg_suffix orient = args.orient sliceloc = args.sliceloc + label_text_right_list = args.label_text_right + label_text_left_list = args.label_text_left override = args.override max_workers = args.max_workers quiet = args.quiet + # Load label_texts_right into a dict + label_texts_right = load_label_texts(label_text_right_list, 'label-text-right') + + # Load label_texts_left into a dict + label_texts_left = load_label_texts(label_text_left_list, 'label-text-left') + # Print the argument values if not quiet if not quiet: print(textwrap.dedent(f''' @@ -125,6 +151,8 @@ def main(): seg_suffix = "{seg_suffix}" orient = "{orient}" sliceloc = {sliceloc} + label_texts_right = {label_texts_right} + label_texts_left = {label_texts_left} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -142,11 +170,27 @@ def main(): seg_suffix=seg_suffix, orient=orient, sliceloc=sliceloc, + label_texts_right=label_texts_right, + label_texts_left=label_texts_left, override=override, max_workers=max_workers, quiet=quiet, ) +def load_label_texts(label_text_list, param_name): + if len(label_text_list) == 1 and label_text_list[0][-5:] == '.json': + # Load label mappings from JSON file + with open(label_text_list[0], 'r', encoding='utf-8') as map_file: + label_texts = json.load(map_file) + # Ensure keys are ints + label_texts = {int(k): v for k, v in label_texts.items()} + else: + try: + label_texts = {int(l_in): l_out for l_in, l_out in map(lambda x: x.split(':'), label_text_list)} + except: + raise ValueError(f"Input param --{param_name} is not in the right structure. Make sure it is in the right format, e.g., 1:SC 2:Canal") + return label_texts + def preview_jpg_mp( images_path, output_path, @@ -159,6 +203,8 @@ def preview_jpg_mp( seg_suffix='', orient='sag', sliceloc=0.5, + label_texts_right={}, + label_texts_left={}, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -189,6 +235,8 @@ def preview_jpg_mp( _preview_jpg, orient=orient, sliceloc=sliceloc, + label_texts_right=label_texts_right, + label_texts_left=label_texts_left, override=override, ), image_path_list, @@ -205,6 +253,8 @@ def _preview_jpg( seg_path=None, orient='sag', sliceloc=0.5, + label_texts_right={}, + label_texts_left={}, override=False, ): ''' @@ -240,8 +290,9 @@ def _preview_jpg( # Repeat the grayscale slice 3 times to create a color image slice_img = np.repeat(slice_img[:, :, np.newaxis], 3, axis=2).astype(np.uint8) - # Create a blank color image with the same dimensions as the input image - output_data = np.zeros_like(slice_img, dtype=np.uint8) + # Flip and rotate the image slice + slice_img = np.flipud(slice_img) + slice_img = np.rot90(slice_img, k=1) if seg_path and seg_path.is_file(): try: @@ -257,6 +308,10 @@ def _preview_jpg( slice_seg = seg_data.take(slice_index, axis=axis) + # Flip and rotate the segmentation slice + slice_seg = np.flipud(slice_seg) + slice_seg = np.rot90(slice_seg, k=1) + # Generate consistent random colors for each label unique_labels = np.unique(slice_seg).astype(int) colors = {} @@ -264,21 +319,135 @@ def _preview_jpg( rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(label * 10))) colors[label] = rs.randint(0, 255, 3) + # Create a blank color image with the same dimensions as the input image + output_data = np.zeros_like(slice_img, dtype=np.uint8) + # Apply the segmentation mask to the image and assign colors for label, color in colors.items(): if label != 0: # Ignore the background (label 0) mask = slice_seg == label output_data[mask] = color - output_data = np.where(output_data > 0, output_data, slice_img) - - # Rotate the image 90 degrees counter-clockwise and flip it vertically - output_data = np.flipud(output_data) - output_data = np.rot90(output_data, k=1) + output_data = np.where(output_data > 0, output_data, slice_img) + else: + output_data = slice_img + unique_labels = [] + colors = {} - # Create an Image object from the output Image object as a JPG file + # Create an Image object from the output data output_image = Image.fromarray(output_data, mode="RGB") + # Draw text labels if label_texts are provided + if (label_texts_right or label_texts_left) and seg_path and seg_path.is_file(): + draw = ImageDraw.Draw(output_image) + # Use a bold TrueType font for better sharpness and boldness + try: + font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=15) + except IOError: + font = ImageFont.load_default() + width, height = output_image.size + used_positions = [] + for label in unique_labels: + if label != 0: + text = None + side = None + if label in label_texts_right: + text = label_texts_right[label] + side = 'right' + elif label in label_texts_left: + text = label_texts_left[label] + side = 'left' + + if text and side: + mask = slice_seg == label + positions = np.argwhere(mask) + if positions.size > 0: + # Get the bounding box of the label + ys, xs = positions[:, 0], positions[:, 1] + x_min, x_max = xs.min(), xs.max() + y_min, y_max = ys.min(), ys.max() + # Start from just outside the label + if side == 'right': + x_new = x_max + 1 + else: + x_new = x_min - 1 + y_new = int((y_min + y_max) / 2) + + # Ensure starting positions are within image bounds + x_new = min(max(0, x_new), width - 1) + y_new = min(max(0, y_new), height - 1) + + # Search for a position outside the segmentation + found_position = False + if side == 'right': + for dx in range(0, width - x_new): + if slice_seg[y_new, min(x_new + dx, width - 1)] == 0: + x_new = x_new + dx + found_position = True + break + else: + for dx in range(0, x_new + 1): + if slice_seg[y_new, max(x_new - dx, 0)] == 0: + x_new = x_new - dx + found_position = True + break + if not found_position: + # Try moving down + for dy in range(1, height - y_new): + if slice_seg[min(y_new + dy, height - 1), x_new] == 0: + y_new = y_new + dy + found_position = True + break + if not found_position: + # Try moving up + for dy in range(1, y_new + 1): + if slice_seg[max(y_new - dy, 0), x_new] == 0: + y_new = y_new - dy + found_position = True + break + + if found_position: + text_color = tuple(colors[label].tolist()) + + # Avoid overlapping labels + if (x_new, y_new) not in used_positions: + # Get text size + try: + # For Pillow >= 8.0.0 + bbox = font.getbbox(text) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + except AttributeError: + try: + # For older versions + text_width, text_height = font.getsize(text) + except AttributeError: + # As a last resort, approximate text size + text_width, text_height = draw.textsize(text, font=font) + + # Adjust x_new for left side labels + if side == 'left': + x_new = x_new - text_width + + # Ensure text is within image bounds + x_new = min(max(0, x_new), width - text_width) + y_new = min(max(0, y_new), height - text_height) + + # Draw outline by drawing text multiple times around the perimeter + outline_color = (255, 255, 255) # White outline + outline_offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] + for dx, dy in outline_offsets: + draw.text((x_new + dx, y_new + dy), text, font=font, fill=outline_color) + + # Draw the text over the outline multiple times to make it thicker + for _ in range(3): + draw.text((x_new, y_new), text, fill=text_color, font=font) + + used_positions.append((x_new, y_new)) + else: + # No suitable position found, skip drawing the label + pass + # Make sure output directory exists and save the image output_path.parent.mkdir(parents=True, exist_ok=True) output_image.save(output_path)