From 253ce5a59da858f5281155aced8ecb29fbf4d8be Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Wed, 11 Sep 2024 19:50:24 +0300 Subject: [PATCH 01/26] Refactor iterative labeling logic init from discs only Reorganized the spine labeling logic by separating the functionality into distinct functions for better readability and maintainability. Removed `init_vertebrae`, `step_diff_label`, `step_diff_disc`, and related parameters to streamline the process. Added new parameters `output_c2c3` and `output_c2` for setting specific output labels. This refactor facilitates easier adjustments and enhancements in the future, improving code usability. Resolves issue #50 --- totalspineseg/inference.py | 11 +- totalspineseg/utils/iterative_label.py | 572 +++++++++++++------------ 2 files changed, 313 insertions(+), 270 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 86af7e5..c6ee540 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -547,11 +547,10 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_extra_labels=[8], init_disc={4:224, 7:202, 5:219, 6:207}, - init_vertebrae={11:40, 14:17, 12:34, 13:23}, - step_diff_label=True, - step_diff_disc=True, output_disc_step=-1, output_vertebrae_step=-1, + output_c2c3=224, + output_c2=40, map_output_dict={17:92}, map_input_dict={14:92, 15:201, 16:201, 17:200}, override=True, @@ -567,13 +566,11 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_extra_labels=[8], init_disc={4:224, 7:202}, - init_vertebrae={11:40, 14:17}, loc_disc_labels=list(range(202, 225)), - loc_vertebrae_labels=list(range(18, 42)) + [92], - step_diff_label=True, - step_diff_disc=True, output_disc_step=-1, output_vertebrae_step=-1, + output_c2c3=224, + output_c2=40, map_output_dict={17:92}, map_input_dict={14:92, 15:201, 16:201, 17:200}, override=True, diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 7f98196..1997088 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,10 +20,10 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r - iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --init-vertebrae 11:40 14:17 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --loc-vertebrae-labels 18-41 92 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r + iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -80,6 +80,10 @@ def main(): '--init-disc', type=lambda x:map(int, x.split(':')), nargs='+', default=[], help='Init labels list for disc ordered by priority (input_label:output_label !!without space!!). for example 4:224 5:219 6:202' ) + parser.add_argument( + '--output-c2c3', type=int, default=0, + help='The output label for C2C3, used to calculate the first vertebrae label, defaults to 0.' + ) parser.add_argument( '--output-disc-step', type=int, default=1, help='The step to take between disc labels in the output, defaults to 1.' @@ -97,17 +101,13 @@ def main(): help='Extra vertebrae labels to add to add to adjacent vertebrae labels.' ) parser.add_argument( - '--init-vertebrae', type=lambda x:map(int, x.split(':')), nargs='+', default=[], - help='Init labels list for vertebrae ordered by priority (input_label:output_label !!without space!!). for example 10:41 11:34 12:18' + '--output-c2', type=int, default=0, + help='The output label for C2, used to calculate the first vertebrae label, defaults to 0.' ) parser.add_argument( '--output-vertebrae-step', type=int, default=1, help='The step to take between vertebrae labels in the output, defaults to 1.' ) - parser.add_argument( - '--loc-vertebrae-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], - help='The disc labels in the localizer used for detecting first vertebrae.' - ) parser.add_argument( '--map-input', type=str, nargs='+', default=[], help=' '.join(f''' @@ -128,29 +128,10 @@ def main(): '--dilation-size', type=int, default=1, help='Number of voxels to dilate before finding connected voxels to label, defaults to 1 (No dilation).' ) - parser.add_argument( - '--step-diff-label', action="store_true", default=False, - help=' '.join(f''' - Make step only for different labels. When looping on the labels on the z axis, it will give a new label to the next label only - if it is different from the previous label. This is useful if there are labels for odd and even vertebrae, so the next label will - be for even vertebrae only if the previous label was odd. If it is still odd, it should give the same label. - '''.split()) - ) - parser.add_argument( - '--step-diff-disc', action="store_true", default=False, - help=' '.join(f''' - Make step only for different discs. When looping on the labels on the z axis, it will give a new label to the next label only - if there is a disc between them. This exclude the first and last vertebrae since it can be C1 or only contain the spinous process. - '''.split()) - ) parser.add_argument( '--default-superior-disc', type=int, default=0, help='Default superior disc label if no init label found, defaults to 0 (Raise error if init label not found).' ) - parser.add_argument( - '--default-superior-vertebrae', type=int, default=0, - help='Default superior vertebrae label if no init label found, defaults to 0 (Raise error if init label not found).' - ) parser.add_argument( '--override', '-r', action="store_true", default=False, help='Override existing output files, defaults to false (Do not override).' @@ -179,20 +160,17 @@ def main(): loc_suffix = args.loc_suffix disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] init_disc = dict(args.init_disc) + output_c2c3 = args.output_c2c3 output_disc_step = args.output_disc_step loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] vertebrae_labels = [_ for __ in args.vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] vertebrae_extra_labels = [_ for __ in args.vertebrae_extra_labels for _ in (__ if isinstance(__, list) else [__])] - init_vertebrae = dict(args.init_vertebrae) + output_c2 = args.output_c2 output_vertebrae_step = args.output_vertebrae_step - loc_vertebrae_labels = [_ for __ in args.loc_vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] map_input_list = args.map_input map_output_list = args.map_output dilation_size = args.dilation_size - step_diff_label = args.step_diff_label - step_diff_disc = args.step_diff_disc default_superior_disc = args.default_superior_disc - default_superior_vertebrae = args.default_superior_vertebrae override = args.override max_workers = args.max_workers quiet = args.quiet @@ -212,20 +190,17 @@ def main(): loc_suffix = "{loc_suffix}" disc_labels = {disc_labels} init_disc = {init_disc} + output_c2c3 = {output_c2c3} output_disc_step = {output_disc_step} loc_disc_labels = {loc_disc_labels} vertebrae_labels = {vertebrae_labels} vertebrae_extra_labels = {vertebrae_extra_labels} - init_vertebrae = {init_vertebrae} + output_c2 = {output_c2} output_vertebrae_step = {output_vertebrae_step} - loc_vertebrae_labels = {loc_vertebrae_labels} map_input = {map_input_list} map_output = {map_output_list} dilation_size = {dilation_size} - step_diff_label = {step_diff_label} - step_diff_disc = {step_diff_disc} default_superior_disc = {default_superior_disc} - default_superior_vertebrae = {default_superior_vertebrae} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -254,20 +229,17 @@ def main(): loc_suffix=loc_suffix, disc_labels=disc_labels, init_disc=init_disc, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, max_workers=max_workers, quiet=quiet, @@ -285,20 +257,17 @@ def iterative_label_mp( loc_suffix='', disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -326,21 +295,18 @@ def iterative_label_mp( partial( _iterative_label, disc_labels=disc_labels, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, init_disc=init_disc, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, ), seg_path_list, @@ -357,20 +323,17 @@ def _iterative_label( loc_path=None, disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, ): ''' @@ -394,20 +357,17 @@ def _iterative_label( loc, disc_labels=disc_labels, init_disc=init_disc, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -432,20 +392,17 @@ def iterative_label( loc=None, disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, ): ''' Label Vertebrae, IVDs, Spinal Cord and canal from init segmentation. @@ -467,6 +424,8 @@ def iterative_label( The disc labels init_disc : dict Init labels list for disc ordered by priority (input_label:output_label) + output_c2c3 : int + The output label for C2C3, used to calculate the first vertebrae label output_disc_step : int The step to take between disc labels in the output loc_disc_labels : list @@ -475,26 +434,18 @@ def iterative_label( The vertebrae labels vertebrae_extra_labels : list Extra vertebrae labels to add to add to adjacent vertebrae labels - init_vertebrae : dict - Init labels list for vertebrae ordered by priority (input_label:output_label) + output_c2 : int + The output label for C2, used to calculate the first vertebrae label output_vertebrae_step : int The step to take between vertebrae labels in the output - loc_vertebrae_labels : list - Localizer labels to use for detecting first vertebrae map_input_dict : dict A dict mapping labels from input into the output segmentation map_output_dict : dict A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation dilation_size : int Number of voxels to dilate before finding connected voxels to label - step_diff_label : bool - Make step only for different labels - step_diff_disc : bool - Make step only for different discs default_superior_disc : int Default superior disc label if no init label found - default_superior_vertebrae : int - Default superior vertebrae label if no init label found Returns ------- @@ -505,193 +456,61 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) - loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) - - # If localizer is provided, transform it to the segmentation space - if loc_data is not None: - loc_data = tio.Resample( - tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) - )( - tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) - ).data.numpy()[0, ...].astype(np.uint8) - - binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) - - # Arrays to store the z indexes of the discs sorted superior to inferior - disc_sorted_z_indexes = [] - - # We run the same iterative algorithm for discs and vertebrae - for labels, extra_labels, step, init, default_sup, loc_labels, is_vert in ( - (disc_labels, [], output_disc_step, init_disc, default_superior_disc, loc_disc_labels, False), - (vertebrae_labels, vertebrae_extra_labels, output_vertebrae_step, init_vertebrae, default_superior_vertebrae, loc_vertebrae_labels, True)): - - # Skip if no labels are provided - if len(labels) == 0: - continue - - if is_vert: - _labels = [[_] for _ in labels] - else: - # For discs, combine all labels before label continue voxels since the discs not touching each other - _labels = [labels] - - # Init labeled segmentation - mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 - - # For each label, find connected voxels and label them into separate labels - for l in _labels: - mask = np.isin(seg_data, l) - - # Dilate the mask to combine small disconnected regions - mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) - - # Label the connected voxels in the dilated mask into separate labels - tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) - - # Undo dilation - tmp_mask_labeled *= mask - - # Add current labels to the labeled segmentation - if tmp_num_labels > 0: - mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels - num_labels += tmp_num_labels - - # If no label found, raise error - if num_labels == 0: - raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + # Get sorted connected components superio-inferior (SI) for the disc and vertebrae labels + disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( + seg, + disc_labels, + dilation_size, + ) - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) + # Get sorted connected components superio-inferior (SI) for the vertebrae labels + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _get_si_sorted_components( + seg, + vertebrae_labels, + dilation_size, + combine_labels=True, + ) - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + # Combine sequential vertebrae labels based on some conditions + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_labels( + seg, + vertebrae_labels, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indexes, + disc_sorted_z_indexes, + vertebrae_extra_labels, + ) - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) - - # In this part we loop over the labels superior to inferior and combine sequential vertebrae labels based on some conditions - if is_vert: - # Combine sequential vertebrae labels if they have the same value in the original segmentation - # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value - if step_diff_label and len(labels) - len(init) > 1: - new_sorted_labels = [] - - # Store the previous label and the original label of the previous label - prev_l, prev_orig_label = 0, 0 - - # Loop over the sorted labels - for l in sorted_labels: - # Get the original label of the current label - curr_orig_label = seg_data[mask_labeled == l].flat[0] - - # Combine the current label with the previous label if they have the same original label - if curr_orig_label == prev_orig_label: - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_orig_label = l, curr_orig_label - - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] - - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - # Combine sequential vertebrae labels if there is no disc between them - if step_diff_disc and len(disc_sorted_z_indexes) > 0: - new_sorted_labels = [] - - # Store the previous label and the z index of the previous label - prev_l, prev_z = 0, 0 - - for l, z in zip(sorted_labels, sorted_z_indexes): - # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process - if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_z = l, z - - sorted_labels = new_sorted_labels - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - else: - # Save the z indexes of the discs - disc_sorted_z_indexes = sorted_z_indexes - - # Find the most superior label in the segmentation - first_label = 0 - for k, v in init.items(): - if k in seg_data: - first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) - break - - # If no init label found, set it from the localizer - if first_label == 0 and loc_data is not None: - # Make mask for the intersection of the localizer labels and the labels in the segmentation - mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) - - # Get the first label from sorted_labels that is in the localizer specified labels - mask_labeled_masked = mask * mask_labeled - first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) - - if first_sorted_labels_in_loc > 0: - # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it - loc_data_masked = mask * loc_data - target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) - # If target in map_output_dict reverse it from the reversed map - # TODO Edge case if multiple keys have the same value, not used in the current implementation - target = {v: k for k, v in map_output_dict.items()}.get(target, target) - first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) - - # If no init label found, set the default superior label - if first_label == 0 and default_sup > 0: - first_label = default_sup - - # If no init label found, print error - if first_label == 0: - raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - - # Combine extra labels with adjacent vertebrae labels - if len(extra_labels) > 0: - mask_extra = np.isin(seg_data, extra_labels) - - # Loop over vertebral labels (from inferior because the transverse process make it steal from above) - for i in range(num_labels - 1, -1, -1): - # Mkae mask for the current vertebrae with filling the holes and dilating it - mask = fill(mask_labeled == sorted_labels[i]) - mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) - - # Add the intersection of the mask with the extra labels to the current verebrae - mask_labeled[mask_extra * mask] = sorted_labels[i] - - # Set the output value for the current label - for i in range(num_labels): - output_seg_data[mask_labeled == sorted_labels[i]] = first_label + step * i + # Get the first disc label + disc_first_label = _get_first_label( + seg, + loc, + disc_mask_labeled, + disc_sorted_labels, + init_disc, + output_disc_step, + loc_disc_labels, + default_superior_disc, + map_output_dict, + ) + + # Sort the combined disc+vert labels by their z-index + sorted_labels = vert_sorted_labels + disc_sorted_labels + sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes + is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) + + # Get index of first vert in is_vert + # To cover cases with C1 and C2 we have to adjust the first vertebrae label by the position of the first disc + vert_first_label = output_c2 + (disc_first_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + for i in range(vert_num_labels): + output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = vert_first_label + output_vertebrae_step * i + for i in range(disc_num_labels): + output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = disc_first_label + output_disc_step * i # Use the map to map labels from the iteative algorithm output, to the final output # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label @@ -718,7 +537,234 @@ def iterative_label( return output_seg -def fill(mask): +def _get_si_sorted_components( + seg, + labels, + dilation_size=1, + combine_labels=False, + ): + ''' + Get sorted connected components superio-inferior (SI) for the given labels in the segmentation. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) + + # Skip if no labels are provided + if len(labels) == 0: + return None, 0, [], [] + + if combine_labels: + _labels = [[_] for _ in labels] + else: + # For discs, combine all labels before label continue voxels since the discs not touching each other + _labels = [labels] + + # Init labeled segmentation + mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 + + # For each label, find connected voxels and label them into separate labels + for l in _labels: + mask = np.isin(seg_data, l) + + # Dilate the mask to combine small disconnected regions + mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) + + # Label the connected voxels in the dilated mask into separate labels + tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) + + # Undo dilation + tmp_mask_labeled *= mask + + # Add current labels to the labeled segmentation + if tmp_num_labels > 0: + mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels + num_labels += tmp_num_labels + + # If no label found, raise error + if num_labels == 0: + raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _merge_vertebrae_labels( + seg, + labels, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indexes, + disc_sorted_z_indexes, + extra_labels, + ): + ''' + Combine sequential vertebrae labels based on some conditions. + ''' + if num_labels == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + # Combine sequential vertebrae labels if they have the same value in the original segmentation + # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value + if len(labels) > 1: + new_sorted_labels = [] + + # Store the previous label and the original label of the previous label + prev_l, prev_orig_label = 0, 0 + + # Loop over the sorted labels + for l in sorted_labels: + # Get the original label of the current label + curr_orig_label = seg_data[mask_labeled == l].flat[0] + + # Combine the current label with the previous label if they have the same original label + if curr_orig_label == prev_orig_label: + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_orig_label = l, curr_orig_label + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Combine sequential vertebrae labels if there is no disc between them + if len(disc_sorted_z_indexes) > 0: + new_sorted_labels = [] + + # Store the previous label and the z index of the previous label + prev_l, prev_z = 0, 0 + + for l, z in zip(sorted_labels, sorted_z_indexes): + # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process + if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_z = l, z + + sorted_labels = new_sorted_labels + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Combine extra labels with adjacent vertebrae labels + if len(extra_labels) > 0: + mask_extra = np.isin(seg_data, extra_labels) + + # Loop over vertebral labels (from inferior because the transverse process make it steal from above) + for i in range(num_labels - 1, -1, -1): + # Mkae mask for the current vertebrae with filling the holes and dilating it + mask = _fill(mask_labeled == sorted_labels[i]) + mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) + + # Add the intersection of the mask with the extra labels to the current verebrae + mask_labeled[mask_extra * mask] = sorted_labels[i] + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _get_first_label( + seg, + loc, + mask_labeled, + sorted_labels, + init, + step, + loc_labels, + default_superior, + map_output_dict, + ): + ''' + Get the first label for the iterative labeling algorithm. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) + + # If localizer is provided, transform it to the segmentation space + if loc_data is not None: + loc_data = tio.Resample( + tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) + )( + tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) + ).data.numpy()[0, ...].astype(np.uint8) + + # Find the most superior label in the segmentation + first_label = 0 + for k, v in init.items(): + if k in seg_data: + first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) + break + + # If no init label found, set it from the localizer + if first_label == 0 and loc_data is not None: + # Make mask for the intersection of the localizer labels and the labels in the segmentation + mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) + + # Get the first label from sorted_labels that is in the localizer specified labels + mask_labeled_masked = mask * mask_labeled + first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) + + if first_sorted_labels_in_loc > 0: + # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it + loc_data_masked = mask * loc_data + target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) + # If target in map_output_dict reverse it from the reversed map + # TODO Edge case if multiple keys have the same value, not used in the current implementation + target = {v: k for k, v in map_output_dict.items()}.get(target, target) + first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) + + # If no init label found, set the default superior label + if first_label == 0 and default_superior > 0: + first_label = default_superior + + # If no init label found, print error + if first_label == 0: + raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") + + return first_label + +def _fill(mask): ''' Fill holes in a binary mask From ae4e9edef9521c1991ec60313c02947b49a98030 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Thu, 12 Sep 2024 22:31:34 +0300 Subject: [PATCH 02/26] Improve label naming for clarity in iterative_label function Renamed variables and comments to improve clarity and understandability of the labeling process. Adjusted label computation logic to accurately reflect the sorting and labeling operations, especially for handling cases involving C1 and C2 vertebrae. This enhances readability and maintains consistency within the code base. No functional changes are expected. --- totalspineseg/utils/iterative_label.py | 36 ++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 1997088..e7de427 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -456,7 +456,7 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) - # Get sorted connected components superio-inferior (SI) for the disc and vertebrae labels + # Get sorted connected components superio-inferior (SI) for the disc labels disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( seg, disc_labels, @@ -484,7 +484,7 @@ def iterative_label( ) # Get the first disc label - disc_first_label = _get_first_label( + superior_disc_output_label = _get_superior_output_label( seg, loc, disc_mask_labeled, @@ -504,13 +504,17 @@ def iterative_label( # Sort the labels by their z-index (reversed to go from superior to inferior) sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) - # Get index of first vert in is_vert - # To cover cases with C1 and C2 we have to adjust the first vertebrae label by the position of the first disc - vert_first_label = output_c2 + (disc_first_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + # Get the superior output label for the vertebrae based on the superior disc label + # For C1 and C2 we have to adjust the first vertebrae label by the position of the first disc with substraction of (is_vert.index(False) - 1) * output_vertebrae_step + superior_vert_output_label = output_c2 + (superior_disc_output_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + + # Label the vertebrae with the output labels superio-inferior for i in range(vert_num_labels): - output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = vert_first_label + output_vertebrae_step * i + output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = superior_vert_output_label + output_vertebrae_step * i + + # Label the discs with the output labels superio-inferior for i in range(disc_num_labels): - output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = disc_first_label + output_disc_step * i + output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = superior_disc_output_label + output_disc_step * i # Use the map to map labels from the iteative algorithm output, to the final output # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label @@ -703,7 +707,7 @@ def _merge_vertebrae_labels( return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) -def _get_first_label( +def _get_superior_output_label( seg, loc, mask_labeled, @@ -730,14 +734,14 @@ def _get_first_label( ).data.numpy()[0, ...].astype(np.uint8) # Find the most superior label in the segmentation - first_label = 0 + superior_output_label = 0 for k, v in init.items(): if k in seg_data: - first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) + superior_output_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) break # If no init label found, set it from the localizer - if first_label == 0 and loc_data is not None: + if superior_output_label == 0 and loc_data is not None: # Make mask for the intersection of the localizer labels and the labels in the segmentation mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) @@ -752,17 +756,17 @@ def _get_first_label( # If target in map_output_dict reverse it from the reversed map # TODO Edge case if multiple keys have the same value, not used in the current implementation target = {v: k for k, v in map_output_dict.items()}.get(target, target) - first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) + superior_output_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) # If no init label found, set the default superior label - if first_label == 0 and default_superior > 0: - first_label = default_superior + if superior_output_label == 0 and default_superior > 0: + superior_output_label = default_superior # If no init label found, print error - if first_label == 0: + if superior_output_label == 0: raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - return first_label + return superior_output_label def _fill(mask): ''' From f1182be3a19d8741c266833ba5d39cb99552cb92 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Thu, 12 Sep 2024 22:46:24 +0300 Subject: [PATCH 03/26] Fix labeling logic in _get_superior_output_label() Refined the logic for determining the superior output label by replacing the use of the first element with the most frequent element. This ensures better accuracy in cases with multiple segmented labels and improves robustness in varying data scenarios. Addresses potential inaccuracies in segmentation outputs noted during recent evaluations. --- totalspineseg/utils/iterative_label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index e7de427..99295c4 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -737,7 +737,7 @@ def _get_superior_output_label( superior_output_label = 0 for k, v in init.items(): if k in seg_data: - superior_output_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) + superior_output_label = v - step * sorted_labels.index(np.argmax(np.bincount(mask_labeled[seg_data == k].flat))) break # If no init label found, set it from the localizer From c96c9f0033b6b1755ba2ce446772312d56ee7ade Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 01:18:47 +0300 Subject: [PATCH 04/26] Refactor disc and vertebrae labeling process Enhanced robustness of segmentation process by: - Integrating support for diverse MRI contrasts, orientations, and resolutions. - Removing redundant label definitions and consolidating critical ones. - Simplifying inferring by refining mapping from inputs to outputs. - Removing initial, redundant disc labels and updating key parameter names. - Streamlining logic for merging vertebrae and handling extra labels. - Adding detailed mapping of landmarks to output labels for discs and vertebrae. These changes improve the accuracy and flexibility of spinal segmentation. --- README.md | 100 ++--- totalspineseg/inference.py | 44 +- totalspineseg/utils/iterative_label.py | 577 ++++++++++++++++--------- 3 files changed, 448 insertions(+), 273 deletions(-) diff --git a/README.md b/README.md index 1e13977..401425c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # TotalSpineSeg -TotalSpineSeg is a tool for automatic instance segmentation and labeling of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images. It follows the [TotalSegmentator classes](https://github.com/wasserth/TotalSegmentator/tree/v1.5.7#class-details) with additional classes for IVDs, spinal cord, and spinal canal (see [list of classes](#list-of-classes)). The model is based on [nnUNet](https://github.com/MIC-DKFZ/nnUNet) as the backbone for training and inference. +TotalSpineSeg is a tool for automatic instance segmentation of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images. It is robust to various MRI contrasts, acquisition orientations, and resolutions. The model used in TotalSpineSeg is based on [nnUNet](https://github.com/MIC-DKFZ/nnUNet) as the backbone for training and inference. If you use this model, please cite our work: > Warszawer Y, Molinier N, Valošek J, Shirbint E, Benveniste PL, Achiron A, Eshaghi A and Cohen-Adad J. _Fully Automatic Vertebrae and Spinal Cord Segmentation Using a Hybrid Approach Combining nnU-Net and Iterative Algorithm_. Proceedings of the 32th Annual Meeting of ISMRM. 2024 @@ -222,53 +222,53 @@ For a more detailed view of the output examples, you can check the [PDF version] | Label | Name | |:------|:-----| -| 18 | vertebrae_L5 | -| 19 | vertebrae_L4 | -| 20 | vertebrae_L3 | -| 21 | vertebrae_L2 | -| 22 | vertebrae_L1 | -| 23 | vertebrae_T12 | -| 24 | vertebrae_T11 | -| 25 | vertebrae_T10 | -| 26 | vertebrae_T9 | +| 1 | spinal_cord | +| 2 | spinal_canal | +| 10 | vertebrae_C1 | +| 11 | vertebrae_C2 | +| 12 | vertebrae_C3 | +| 13 | vertebrae_C4 | +| 14 | vertebrae_C5 | +| 15 | vertebrae_C6 | +| 16 | vertebrae_C7 | +| 20 | vertebrae_T1 | +| 21 | vertebrae_T2 | +| 22 | vertebrae_T3 | +| 23 | vertebrae_T4 | +| 24 | vertebrae_T5 | +| 25 | vertebrae_T6 | +| 26 | vertebrae_T7 | | 27 | vertebrae_T8 | -| 28 | vertebrae_T7 | -| 29 | vertebrae_T6 | -| 30 | vertebrae_T5 | -| 31 | vertebrae_T4 | -| 32 | vertebrae_T3 | -| 33 | vertebrae_T2 | -| 34 | vertebrae_T1 | -| 35 | vertebrae_C7 | -| 36 | vertebrae_C6 | -| 37 | vertebrae_C5 | -| 38 | vertebrae_C4 | -| 39 | vertebrae_C3 | -| 40 | vertebrae_C2 | -| 41 | vertebrae_C1 | -| 92 | sacrum | -| 200 | spinal_cord | -| 201 | spinal_canal | -| 202 | disc_L5_S | -| 203 | disc_L4_L5 | -| 204 | disc_L3_L4 | -| 205 | disc_L2_L3 | -| 206 | disc_L1_L2 | -| 207 | disc_T12_L1 | -| 208 | disc_T11_T12 | -| 209 | disc_T10_T11 | -| 210 | disc_T9_T10 | -| 211 | disc_T8_T9 | -| 212 | disc_T7_T8 | -| 213 | disc_T6_T7 | -| 214 | disc_T5_T6 | -| 215 | disc_T4_T5 | -| 216 | disc_T3_T4 | -| 217 | disc_T2_T3 | -| 218 | disc_T1_T2 | -| 219 | disc_C7_T1 | -| 220 | disc_C6_C7 | -| 221 | disc_C5_C6 | -| 222 | disc_C4_C5 | -| 223 | disc_C3_C4 | -| 224 | disc_C2_C3 | +| 28 | vertebrae_T9 | +| 29 | vertebrae_T10 | +| 30 | vertebrae_T11 | +| 31 | vertebrae_T12 | +| 40 | vertebrae_L1 | +| 41 | vertebrae_L2 | +| 42 | vertebrae_L3 | +| 43 | vertebrae_L4 | +| 44 | vertebrae_L5 | +| 50 | sacrum | +| 60 | disc_C2_C3 | +| 61 | disc_C3_C4 | +| 62 | disc_C4_C5 | +| 63 | disc_C5_C6 | +| 64 | disc_C6_C7 | +| 70 | disc_C7_T1 | +| 71 | disc_T1_T2 | +| 72 | disc_T2_T3 | +| 73 | disc_T3_T4 | +| 74 | disc_T4_T5 | +| 75 | disc_T5_T6 | +| 76 | disc_T6_T7 | +| 77 | disc_T7_T8 | +| 78 | disc_T8_T9 | +| 79 | disc_T9_T10 | +| 80 | disc_T10_T11 | +| 81 | disc_T11_T12 | +| 90 | disc_T12_L1 | +| 91 | disc_L1_L2 | +| 92 | disc_L2_L3 | +| 93 | disc_L3_L4 | +| 94 | disc_L4_L5 | +| 100 | disc_L5_S | diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index c6ee540..655ded7 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -323,10 +323,11 @@ def main(): iterative_label_mp( output_path / 'step1_output', output_path / 'step1_output', + selected_disc_landmarks=[2, 5, 3, 4], disc_labels=[1, 2, 3, 4, 5], - init_disc={2:224, 5:202, 3:219, 4:207}, - output_disc_step=-1, - map_input_dict={6:92, 7:201, 8:201, 9:200}, + disc_landmark_labels=[2, 3, 4, 5], + disc_landmark_output_labels=[60, 70, 90, 100], + map_input_dict={6:50, 7:2, 8:2, 9:1}, override=True, max_workers=max_workers, quiet=quiet, @@ -336,11 +337,12 @@ def main(): output_path / 'step1_output', output_path / 'step1_output', locs_path=output_path / 'localizers', + selected_disc_landmarks=[2, 5], disc_labels=[1, 2, 3, 4, 5], - init_disc={2:224, 5:202}, - output_disc_step=-1, - loc_disc_labels=list(range(202, 225)), - map_input_dict={6:92, 7:201, 8:201, 9:200}, + disc_landmark_labels=[2, 3, 4, 5], + disc_landmark_output_labels=[60, 70, 90, 100], + loc_disc_labels=list(range(60, 101)), + map_input_dict={6:50, 7:2, 8:2, 9:1}, override=True, max_workers=max_workers, quiet=quiet, @@ -543,16 +545,15 @@ def main(): iterative_label_mp( output_path / 'step2_output', output_path / 'step2_output', + selected_disc_landmarks=[4, 7, 5, 6], disc_labels=[1, 2, 3, 4, 5, 6, 7], + disc_landmark_labels=[4, 5, 6, 7], + disc_landmark_output_labels=[60, 70, 90, 100], vertebrae_labels=[9, 10, 11, 12, 13, 14], + vertebrae_landmark_output_labels=[12, 20, 40, 50], vertebrae_extra_labels=[8], - init_disc={4:224, 7:202, 5:219, 6:207}, - output_disc_step=-1, - output_vertebrae_step=-1, - output_c2c3=224, - output_c2=40, - map_output_dict={17:92}, - map_input_dict={14:92, 15:201, 16:201, 17:200}, + map_output_dict={17:50}, + map_input_dict={14:50, 15:2, 16:2, 17:1}, override=True, max_workers=max_workers, quiet=quiet, @@ -562,17 +563,16 @@ def main(): output_path / 'step2_output', output_path / 'step2_output', locs_path=output_path / 'localizers', + selected_disc_landmarks=[4, 7], disc_labels=[1, 2, 3, 4, 5, 6, 7], + disc_landmark_labels=[4, 5, 6, 7], + disc_landmark_output_labels=[60, 70, 90, 100], vertebrae_labels=[9, 10, 11, 12, 13, 14], + vertebrae_landmark_output_labels=[12, 20, 40, 50], vertebrae_extra_labels=[8], - init_disc={4:224, 7:202}, - loc_disc_labels=list(range(202, 225)), - output_disc_step=-1, - output_vertebrae_step=-1, - output_c2c3=224, - output_c2=40, - map_output_dict={17:92}, - map_input_dict={14:92, 15:201, 16:201, 17:200}, + loc_disc_labels=list(range(60, 101)), + map_output_dict={17:50}, + map_input_dict={14:50, 15:2, 16:2, 17:1}, override=True, max_workers=max_workers, quiet=quiet, diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 99295c4..de7e9b9 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,10 +20,12 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r - iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --map-input 6:50 7:2 8:2 9:1 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --map-output 17:50 --map-input 14:50 15:2 16:2 17:1 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 2 5 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --map-input 6:50 7:2 8:2 9:1 --loc-disc-labels 60-100 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 4 7 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --map-output 17:50 --map-input 14:50 15:2 16:2 17:1 --loc-disc-labels 60-100 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --map-input 6:50 7:2 8:2 9:1 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -72,41 +74,49 @@ def main(): '--loc-suffix', type=str, default='', help='Localizer suffix, defaults to "".' ) + parser.add_argument( + '--selected-disc-landmarks', type=int, nargs='+', default=[], + help='The selected disc labels to use as a landmark from the disc_landmark_labels.' + ) parser.add_argument( '--disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], help='The disc labels.' ) parser.add_argument( - '--init-disc', type=lambda x:map(int, x.split(':')), nargs='+', default=[], - help='Init labels list for disc ordered by priority (input_label:output_label !!without space!!). for example 4:224 5:219 6:202' + '--disc-landmark-labels', type=int, nargs=4, + help='All disc labels that can be used as a landmark: C2C3, C7T1, T12L1 and L5S1.' ) parser.add_argument( - '--output-c2c3', type=int, default=0, - help='The output label for C2C3, used to calculate the first vertebrae label, defaults to 0.' + '--disc-landmark-output-labels', type=int, nargs=4, + help='List of output labels for discs C2C3, C7T1, T12L1 and L5S1.' ) parser.add_argument( - '--output-disc-step', type=int, default=1, + '--disc-output-step', type=int, default=1, help='The step to take between disc labels in the output, defaults to 1.' ) - parser.add_argument( - '--loc-disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], - help='The disc labels in the localizer used for detecting first disc.' - ) parser.add_argument( '--vertebrae-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], help='The vertebrae labels.' ) + parser.add_argument( + '--vertebrae-landmark-output-labels', type=int, nargs=4, + help='List of output labels for vertebrae C3, T1, L1, Sacrum.' + ) + parser.add_argument( + '--vertebrae-output-step', type=int, default=1, + help='The step to take between vertebrae labels in the output, defaults to 1.' + ) parser.add_argument( '--vertebrae-extra-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], help='Extra vertebrae labels to add to add to adjacent vertebrae labels.' ) parser.add_argument( - '--output-c2', type=int, default=0, - help='The output label for C2, used to calculate the first vertebrae label, defaults to 0.' + '--region-max-sizes', type=int, nargs=4, default=[5, 12, 6, 1], + help='The maximum number of discs/vertebrae for each region (Cervical from C3, Thoracic, Lumbar, Sacrum), defaults to [5, 12, 6, 1].' ) parser.add_argument( - '--output-vertebrae-step', type=int, default=1, - help='The step to take between vertebrae labels in the output, defaults to 1.' + '--loc-disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The disc labels in the localizer used for detecting first disc.' ) parser.add_argument( '--map-input', type=str, nargs='+', default=[], @@ -158,15 +168,17 @@ def main(): seg_suffix = args.seg_suffix output_seg_suffix = args.output_seg_suffix loc_suffix = args.loc_suffix + selected_disc_landmarks = args.selected_disc_landmarks disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] - init_disc = dict(args.init_disc) - output_c2c3 = args.output_c2c3 - output_disc_step = args.output_disc_step - loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] + disc_landmark_labels = args.disc_landmark_labels + disc_landmark_output_labels = args.disc_landmark_output_labels + disc_output_step = args.disc_output_step vertebrae_labels = [_ for __ in args.vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] + vertebrae_landmark_output_labels = args.vertebrae_landmark_output_labels + vertebrae_output_step = args.vertebrae_output_step vertebrae_extra_labels = [_ for __ in args.vertebrae_extra_labels for _ in (__ if isinstance(__, list) else [__])] - output_c2 = args.output_c2 - output_vertebrae_step = args.output_vertebrae_step + region_max_sizes = args.region_max_sizes + loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] map_input_list = args.map_input map_output_list = args.map_output dilation_size = args.dilation_size @@ -188,15 +200,17 @@ def main(): seg_suffix = "{seg_suffix}" output_seg_suffix = "{output_seg_suffix}" loc_suffix = "{loc_suffix}" + selected_disc_landmarks = {selected_disc_landmarks} disc_labels = {disc_labels} - init_disc = {init_disc} - output_c2c3 = {output_c2c3} - output_disc_step = {output_disc_step} - loc_disc_labels = {loc_disc_labels} + disc_landmark_labels = {disc_landmark_labels} + disc_landmark_output_labels = {disc_landmark_output_labels} + disc_output_step = {disc_output_step} vertebrae_labels = {vertebrae_labels} + vertebrae_landmark_output_labels = {vertebrae_landmark_output_labels} + vertebrae_output_step = {vertebrae_output_step} vertebrae_extra_labels = {vertebrae_extra_labels} - output_c2 = {output_c2} - output_vertebrae_step = {output_vertebrae_step} + region_max_sizes = {region_max_sizes} + loc_disc_labels = {loc_disc_labels} map_input = {map_input_list} map_output = {map_output_list} dilation_size = {dilation_size} @@ -227,15 +241,17 @@ def main(): seg_suffix=seg_suffix, output_seg_suffix=output_seg_suffix, loc_suffix=loc_suffix, + selected_disc_landmarks=selected_disc_landmarks, disc_labels=disc_labels, - init_disc=init_disc, - output_c2c3=output_c2c3, - output_disc_step=output_disc_step, - loc_disc_labels=loc_disc_labels, + disc_landmark_labels=disc_landmark_labels, + disc_landmark_output_labels=disc_landmark_output_labels, + disc_output_step=disc_output_step, vertebrae_labels=vertebrae_labels, + vertebrae_landmark_output_labels=vertebrae_landmark_output_labels, + vertebrae_output_step=vertebrae_output_step, vertebrae_extra_labels=vertebrae_extra_labels, - output_c2=output_c2, - output_vertebrae_step=output_vertebrae_step, + region_max_sizes=region_max_sizes, + loc_disc_labels=loc_disc_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, @@ -255,15 +271,17 @@ def iterative_label_mp( seg_suffix='', output_seg_suffix='', loc_suffix='', + selected_disc_landmarks=[], disc_labels=[], - init_disc={}, - output_c2c3=0, - output_disc_step=1, - loc_disc_labels=[], + disc_landmark_labels=[], + disc_landmark_output_labels=[], + disc_output_step=1, vertebrae_labels=[], + vertebrae_landmark_output_labels=[], + vertebrae_output_step=1, vertebrae_extra_labels=[], - output_c2=0, - output_vertebrae_step=1, + region_max_sizes=[5, 12, 6, 1], + loc_disc_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, @@ -294,15 +312,17 @@ def iterative_label_mp( process_map( partial( _iterative_label, + selected_disc_landmarks=selected_disc_landmarks, disc_labels=disc_labels, - output_c2c3=output_c2c3, - output_disc_step=output_disc_step, - loc_disc_labels=loc_disc_labels, - init_disc=init_disc, + disc_landmark_labels=disc_landmark_labels, + disc_landmark_output_labels=disc_landmark_output_labels, + disc_output_step=disc_output_step, vertebrae_labels=vertebrae_labels, + vertebrae_landmark_output_labels=vertebrae_landmark_output_labels, + vertebrae_output_step=vertebrae_output_step, vertebrae_extra_labels=vertebrae_extra_labels, - output_c2=output_c2, - output_vertebrae_step=output_vertebrae_step, + region_max_sizes=region_max_sizes, + loc_disc_labels=loc_disc_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, @@ -321,15 +341,17 @@ def _iterative_label( seg_path, output_seg_path, loc_path=None, + selected_disc_landmarks=[], disc_labels=[], - init_disc={}, - output_c2c3=0, - output_disc_step=1, - loc_disc_labels=[], + disc_landmark_labels=[], + disc_landmark_output_labels=[], + disc_output_step=1, vertebrae_labels=[], + vertebrae_landmark_output_labels=[], + vertebrae_output_step=1, vertebrae_extra_labels=[], - output_c2=0, - output_vertebrae_step=1, + region_max_sizes=[5, 12, 6, 1], + loc_disc_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, @@ -355,19 +377,21 @@ def _iterative_label( output_seg = iterative_label( seg, loc, + selected_disc_landmarks=selected_disc_landmarks, disc_labels=disc_labels, - init_disc=init_disc, - output_c2c3=output_c2c3, - output_disc_step=output_disc_step, - loc_disc_labels=loc_disc_labels, + disc_landmark_labels=disc_landmark_labels, + disc_landmark_output_labels=disc_landmark_output_labels, + disc_output_step=disc_output_step, vertebrae_labels=vertebrae_labels, + vertebrae_landmark_output_labels=vertebrae_landmark_output_labels, + vertebrae_output_step=vertebrae_output_step, vertebrae_extra_labels=vertebrae_extra_labels, - output_c2=output_c2, - output_vertebrae_step=output_vertebrae_step, + region_max_sizes=region_max_sizes, + loc_disc_labels=loc_disc_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - default_superior_disc=default_superior_disc, + disc_default_superior_output=default_superior_disc, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -390,19 +414,21 @@ def _iterative_label( def iterative_label( seg, loc=None, + selected_disc_landmarks=[], disc_labels=[], - init_disc={}, - output_c2c3=0, - output_disc_step=1, - loc_disc_labels=[], + disc_landmark_labels=[], + disc_landmark_output_labels=[], + disc_output_step=1, vertebrae_labels=[], + vertebrae_landmark_output_labels=[], + vertebrae_output_step=1, vertebrae_extra_labels=[], - output_c2=0, - output_vertebrae_step=1, + region_max_sizes=[5, 12, 6, 1], + loc_disc_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - default_superior_disc=0, + disc_default_superior_output=0, ): ''' Label Vertebrae, IVDs, Spinal Cord and canal from init segmentation. @@ -411,8 +437,12 @@ def iterative_label( 2. Find connected voxels for each vertebrae label and label them into separate labels 3. Combine sequential vertebrae labels based on some conditions 4. Combine extra labels with adjacent vertebrae labels - 5. Map labels from the iteative algorithm output, to the final output (e.g., map the vertebrae label from the iteative algorithm output to the special sacrum label) - 6. Map input labels to the final output (e.g., map the input sacrum, canal and spinal cord labels to the output labels) + 5. Find the landmark disc labels and output labels + 6. Label the discs with the output labels + 7. Find the matching vertebrae labels to the discs landmarks + 8. Label the vertebrae with the output labels + 9. Map labels from the iteative algorithm output, to the final output (e.g., map the vertebrae label from the iteative algorithm output to the special sacrum label) + 10. Map input labels to the final output (e.g., map the input sacrum, canal and spinal cord labels to the output labels) Parameters ---------- @@ -420,24 +450,28 @@ def iterative_label( Segmentation image loc : nibabel.nifti1.Nifti1Image Localizer image to use for detecting first vertebrae and disc (optional) + selected_disc_landmarks : list + List of disc labels to use as a landmark from the disc_landmark_labels disc_labels : list - The disc labels - init_disc : dict - Init labels list for disc ordered by priority (input_label:output_label) - output_c2c3 : int - The output label for C2C3, used to calculate the first vertebrae label - output_disc_step : int + The disc labels in the segmentation + disc_landmark_labels : list + All disc labels that can be used as a landmark: [C2C3, C7T1, T12L1, L5S1] + disc_landmark_output_labels : list + List of output labels for discs [C2C3, C7T1, T12L1, L5S1] + disc_output_step : int The step to take between disc labels in the output - loc_disc_labels : list - Localizer labels to use for detecting first disc vertebrae_labels : list - The vertebrae labels + The vertebrae labels in the segmentation + vertebrae_landmark_output_labels : list + List of output labels for vertebrae [C3, T1, L1, Sacrum] + vertebrae_output_step : int + The step to take between vertebrae labels in the output vertebrae_extra_labels : list Extra vertebrae labels to add to add to adjacent vertebrae labels - output_c2 : int - The output label for C2, used to calculate the first vertebrae label - output_vertebrae_step : int - The step to take between vertebrae labels in the output + region_max_sizes : list + The maximum number of discs/vertebrae for each region (Cervical from C3, Thoracic, Lumbar, Sacrum). + loc_disc_labels : list + Localizer labels to use for detecting first disc map_input_dict : dict A dict mapping labels from input into the output segmentation map_output_dict : dict @@ -461,6 +495,7 @@ def iterative_label( seg, disc_labels, dilation_size, + combine_labels=True, ) # Get sorted connected components superio-inferior (SI) for the vertebrae labels @@ -468,53 +503,140 @@ def iterative_label( seg, vertebrae_labels, dilation_size, - combine_labels=True, ) - # Combine sequential vertebrae labels based on some conditions - vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_labels( + # Combine sequential vertebrae labels if they have the same value in the original segmentation + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_with_same_label( seg, vertebrae_labels, vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes, + ) + + # Combine sequential vertebrae labels if there is no disc between them + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_labels_with_no_disc_between( + seg, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indexes, disc_sorted_z_indexes, + ) + + # Combine extra labels with adjacent vertebrae labels + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_extra_labels_with_adjacent_vertebrae( + seg, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indexes, vertebrae_extra_labels, ) - # Get the first disc label - superior_disc_output_label = _get_superior_output_label( + # Get the landmark disc labels and output labels - {label in sorted labels: output label} + # TODO Currently only the first 2 landmark from selected_disc_landmarks is used, to get all landmarks see TODO in the function + map_disc_sorted_labels_landmark2output = _get_landmark_output_labels( seg, loc, disc_mask_labeled, disc_sorted_labels, - init_disc, - output_disc_step, + selected_disc_landmarks, + disc_landmark_labels, + disc_landmark_output_labels, loc_disc_labels, - default_superior_disc, - map_output_dict, + disc_default_superior_output, ) - # Sort the combined disc+vert labels by their z-index - sorted_labels = vert_sorted_labels + disc_sorted_labels - sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes - is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) + # Build a list containing all possible labels for the disc ordered superio-inferior + all_possible_disc_output_labels = [] + for l, s in zip(disc_landmark_output_labels, region_max_sizes): + for i in range(s): + all_possible_disc_output_labels.append(l + i * disc_output_step) - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) + # Make a dict mapping the sorted disc labels to the output labels + map_disc_sorted_labels_2output = {} - # Get the superior output label for the vertebrae based on the superior disc label - # For C1 and C2 we have to adjust the first vertebrae label by the position of the first disc with substraction of (is_vert.index(False) - 1) * output_vertebrae_step - superior_vert_output_label = output_c2 + (superior_disc_output_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + # We loop over all the landmarks starting from the most superior + for l in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: + # Get the index of the current landmark in the sorted disc labels + start_l = disc_sorted_labels.index(l) - # Label the vertebrae with the output labels superio-inferior - for i in range(vert_num_labels): - output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = superior_vert_output_label + output_vertebrae_step * i + # Get the index of the current landmark in the list of all possible disc output labels + start_o = all_possible_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l]) + + # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image + if len(map_disc_sorted_labels_2output) == 0: + start_l, start_o = max(0, start_l - start_o), max(0, start_o - start_l) + + # Map the sorted disc labels to the output labels + # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark + for l, o in zip(disc_sorted_labels[start_l:], all_possible_disc_output_labels[start_o:]): + map_disc_sorted_labels_2output[l] = o # Label the discs with the output labels superio-inferior - for i in range(disc_num_labels): - output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = superior_disc_output_label + output_disc_step * i + for l, o in map_disc_sorted_labels_2output.items(): + output_seg_data[disc_mask_labeled == l] = o + + if vert_num_labels > 0: + # Build a list containing all possible labels for the vertebrae ordered superio-inferior + # We start with the C1 and C2 labels as the first landmark is the C3 vertebrae + all_possible_vertebrae_output_labels = [ + vertebrae_landmark_output_labels[0] - 2 * vertebrae_output_step, + vertebrae_landmark_output_labels[0] - vertebrae_output_step + ] + for l, s in zip(vertebrae_landmark_output_labels, region_max_sizes): + for i in range(s): + all_possible_vertebrae_output_labels.append(l + i * vertebrae_output_step) + + # Sort the combined disc+vert labels by their z-index + sorted_labels = vert_sorted_labels + disc_sorted_labels + sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes + is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) + + # Make a dict mapping disc to vertebrae labels + disc_landmark_output_labels_2vert = dict(zip(disc_landmark_output_labels, vertebrae_landmark_output_labels)) + + # Make a dict mapping the sorted vertebrae labels to the output labels + map_vert_sorted_labels_2output = {} + + # We loop over all the landmarks starting from the most superior + for l_disc in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: + # Get the index of the current landmark in the sorted disc labels + sorted_labels_l_disc_idx = list(zip(sorted_labels, is_vert)).index((l_disc, False)) + + # Continue if no more vertebrae labels + if True not in is_vert[sorted_labels_l_disc_idx:]: + continue + + # Get te vert label that is just next to the l_disc label + l = next(_l for _l, _ in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _) + + # Get the index of the current landmark in the sorted vertebrae labels + start_l = vert_sorted_labels.index(l) + + # Get the output disc label for the l_disc + l_disc_output = map_disc_sorted_labels_2output[l_disc] + + # Get the index of the current vert landmark in the list of all possible vertebrae output labels + start_o = all_possible_vertebrae_output_labels.index(disc_landmark_output_labels_2vert[l_disc_output]) + + # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image + if len(map_vert_sorted_labels_2output) == 0: + start_l, start_o = max(0, start_l - start_o), max(0, start_o - start_l) + + # Map the sorted vert labels to the output labels + # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark + for l, o in zip(vert_sorted_labels[start_l:], all_possible_vertebrae_output_labels[start_o:]): + map_vert_sorted_labels_2output[l] = o + + # Label the vertebrae with the output labels superio-inferior + for l, o in map_vert_sorted_labels_2output.items(): + output_seg_data[vert_mask_labeled == l] = o # Use the map to map labels from the iteative algorithm output, to the final output # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label @@ -559,10 +681,10 @@ def _get_si_sorted_components( return None, 0, [], [] if combine_labels: - _labels = [[_] for _ in labels] - else: # For discs, combine all labels before label continue voxels since the discs not touching each other _labels = [labels] + else: + _labels = [[_] for _ in labels] # Init labeled segmentation mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 @@ -604,127 +726,166 @@ def _get_si_sorted_components( return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) -def _merge_vertebrae_labels( +def _merge_vertebrae_with_same_label( seg, labels, mask_labeled, num_labels, sorted_labels, sorted_z_indexes, - disc_sorted_z_indexes, - extra_labels, ): ''' - Combine sequential vertebrae labels based on some conditions. + Combine sequential vertebrae labels if they have the same value in the original segmentation. + This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value. ''' - if num_labels == 0: + if num_labels == 0 or len(labels) <= 1: return mask_labeled, num_labels, sorted_labels, sorted_z_indexes seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) - # Combine sequential vertebrae labels if they have the same value in the original segmentation - # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value - if len(labels) > 1: - new_sorted_labels = [] - # Store the previous label and the original label of the previous label - prev_l, prev_orig_label = 0, 0 + new_sorted_labels = [] - # Loop over the sorted labels - for l in sorted_labels: - # Get the original label of the current label - curr_orig_label = seg_data[mask_labeled == l].flat[0] + # Store the previous label and the original label of the previous label + prev_l, prev_orig_label = 0, 0 - # Combine the current label with the previous label if they have the same original label - if curr_orig_label == prev_orig_label: - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 + # Loop over the sorted labels + for l in sorted_labels: + # Get the original label of the current label + curr_orig_label = seg_data[mask_labeled == l].flat[0] - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_orig_label = l, curr_orig_label + # Combine the current label with the previous label if they have the same original label + if curr_orig_label == prev_orig_label: + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_orig_label = l, curr_orig_label - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + sorted_labels = new_sorted_labels - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) - # Combine sequential vertebrae labels if there is no disc between them - if len(disc_sorted_z_indexes) > 0: - new_sorted_labels = [] + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] - # Store the previous label and the z index of the previous label - prev_l, prev_z = 0, 0 + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, sorted_labels))[::-1]) - for l, z in zip(sorted_labels, sorted_z_indexes): - # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process - if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_z = l, z +def _merge_vertebrae_labels_with_no_disc_between( + seg, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indexes, + disc_sorted_z_indexes, + ): + ''' + Combine sequential vertebrae labels if there is no disc between them. + ''' + if num_labels == 0 or len(disc_sorted_z_indexes) == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indexes - sorted_labels = new_sorted_labels + new_sorted_labels = [] - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) + # Store the previous label and the z index of the previous label + prev_l, prev_z = 0, 0 - # Combine extra labels with adjacent vertebrae labels - if len(extra_labels) > 0: - mask_extra = np.isin(seg_data, extra_labels) + for l, z in zip(sorted_labels, sorted_z_indexes): + # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process + if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_z = l, z - # Loop over vertebral labels (from inferior because the transverse process make it steal from above) - for i in range(num_labels - 1, -1, -1): - # Mkae mask for the current vertebrae with filling the holes and dilating it - mask = _fill(mask_labeled == sorted_labels[i]) - mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) + sorted_labels = new_sorted_labels - # Add the intersection of the mask with the extra labels to the current verebrae - mask_labeled[mask_extra * mask] = sorted_labels[i] + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) # Get the z index of the center of mass for each label canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, sorted_labels))[::-1]) return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) -def _get_superior_output_label( +def _merge_extra_labels_with_adjacent_vertebrae( + seg, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indexes, + extra_labels, + ): + ''' + Combine extra labels with adjacent vertebrae labels. + This is useful for combining remaining of general vertebrae labels that introduce for region based training but not used in the final segmentation. + ''' + if num_labels == 0 or len(extra_labels) == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + mask_extra = np.isin(seg_data, extra_labels) + + # Loop over vertebral labels (from inferior because the transverse process make it steal from above) + for i in range(num_labels - 1, -1, -1): + # Mkae mask for the current vertebrae with filling the holes and dilating it + mask = _fill(mask_labeled == sorted_labels[i]) + mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) + + # Add the intersection of the mask with the extra labels to the current verebrae + mask_labeled[mask_extra * mask] = sorted_labels[i] + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, sorted_labels))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _get_landmark_output_labels( seg, loc, mask_labeled, sorted_labels, - init, - step, + selected_landmarks, + landmark_labels, + landmark_output_labels, loc_labels, - default_superior, - map_output_dict, + default_superior_output, ): ''' - Get the first label for the iterative labeling algorithm. + Get dict mapping labels from sorted_labels to the output labels based on the landmarks in the segmentation or localizer. ''' seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) + map_landmark_labels = dict(zip(landmark_labels, landmark_output_labels)) + # If localizer is provided, transform it to the segmentation space if loc_data is not None: loc_data = tio.Resample( @@ -733,40 +894,54 @@ def _get_superior_output_label( tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) ).data.numpy()[0, ...].astype(np.uint8) - # Find the most superior label in the segmentation - superior_output_label = 0 - for k, v in init.items(): - if k in seg_data: - superior_output_label = v - step * sorted_labels.index(np.argmax(np.bincount(mask_labeled[seg_data == k].flat))) + # Init dict to store the output labels for the landmarks + map_landmark_outputs = {} + + # First we try to look for the landmarks in the segmentation + for l in selected_landmarks: + ############################################################################################################ + # TODO Remove this reake when we trust all the landmarks to get all landmarks instead of the first 2 + if len(map_landmark_outputs) > 0 and selected_landmarks.index(l) > 1: break + ############################################################################################################ + if l in map_landmark_labels and l in seg_data: + map_landmark_outputs[np.argmax(np.bincount(mask_labeled[seg_data == l].flat))] = map_landmark_labels[l] # If no init label found, set it from the localizer - if superior_output_label == 0 and loc_data is not None: + if len(map_landmark_outputs) == 0 and loc_data is not None: # Make mask for the intersection of the localizer labels and the labels in the segmentation mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) - - # Get the first label from sorted_labels that is in the localizer specified labels mask_labeled_masked = mask * mask_labeled - first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) + loc_data_masked = mask * loc_data + + # First we try to look for the landmarks in the localizer + # TODO Edge case if map_output_dict used for discs, but it is not used in the current implementation + for output_label in np.array(landmark_output_labels)[np.isin(landmark_output_labels, loc_data_masked)].tolist(): + # Map the label with the most voxels in the localizer landmark to the output label + map_landmark_outputs[np.argmax(np.bincount(mask_labeled_masked[loc_data_masked == output_label].flat))] = output_label - if first_sorted_labels_in_loc > 0: - # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it - loc_data_masked = mask * loc_data - target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) - # If target in map_output_dict reverse it from the reversed map - # TODO Edge case if multiple keys have the same value, not used in the current implementation - target = {v: k for k, v in map_output_dict.items()}.get(target, target) - superior_output_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) + if len(map_landmark_outputs) == 0: + # Get the first label from sorted_labels that is in the localizer specified labels + first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) + + if first_sorted_labels_in_loc > 0: + # Get the output label for first_sorted_labels_in_loc, the label from the localizer that has the most voxels in it + map_landmark_outputs[first_sorted_labels_in_loc] = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) # If no init label found, set the default superior label - if superior_output_label == 0 and default_superior > 0: - superior_output_label = default_superior + if len(map_landmark_outputs) == 0 and default_superior_output > 0: + map_landmark_outputs[sorted_labels[0]] = default_superior_output # If no init label found, print error - if superior_output_label == 0: - raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - - return superior_output_label + if len(map_landmark_outputs) == 0: + if loc_data is not None: + raise ValueError( + f"At least one of the landmarks must be in the segmentation or localizer (landmarks: {selected_landmarks}. " + f"Check {loc_labels}), make sure the localizer is in the same space as the segmentation" + ) + raise ValueError(f"At least one of the landmarks must be in the segmentation or localizer (landmarks: {selected_landmarks})") + + return map_landmark_outputs def _fill(mask): ''' From 2446f36bf1583b6311e7a4940cc4f23816753159 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 03:29:57 +0300 Subject: [PATCH 05/26] Replace labelling schema; introduce 'extract_alternate' - Changed canal and cord labels from 200 and 201 to 1 and 2 across multiple files for consistency and simplicity. - Updated 'extract_levels' to use a new approach for specifying disc labels instead of a step-based method. - Introduced 'extract_alternate' function for mapping labels, enhancing the flexibility and maintainability of label processing. --- pyproject.toml | 1 + totalspineseg/__init__.py | 1 + totalspineseg/inference.py | 24 +-- totalspineseg/utils/extract_alternate.py | 258 +++++++++++++++++++++++ totalspineseg/utils/extract_levels.py | 47 ++--- 5 files changed, 285 insertions(+), 46 deletions(-) create mode 100644 totalspineseg/utils/extract_alternate.py diff --git a/pyproject.toml b/pyproject.toml index 996f13a..9a8daee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ totalspineseg_average4d = "totalspineseg.utils.average4d:main" totalspineseg_crop_image2seg = "totalspineseg.utils.crop_image2seg:main" totalspineseg_extract_soft = "totalspineseg.utils.extract_soft:main" totalspineseg_extract_levels = "totalspineseg.utils.extract_levels:main" +totalspineseg_extract_alternate = "totalspineseg.utils.extract_alternate:main" totalspineseg_add_nnunet_trainer = "totalspineseg.utils.add_nnunet_trainer:main" [build-system] diff --git a/totalspineseg/__init__.py b/totalspineseg/__init__.py index 44cac5f..dc4417c 100644 --- a/totalspineseg/__init__.py +++ b/totalspineseg/__init__.py @@ -3,6 +3,7 @@ from .utils.cpdir import cpdir_mp from .utils.crop_image2seg import crop_image2seg, crop_image2seg_mp from .utils.extract_levels import extract_levels, extract_levels_mp +from .utils.extract_alternate import extract_alternate, extract_alternate_mp from .utils.extract_soft import extract_soft, extract_soft_mp from .utils.fill_canal import fill_canal, fill_canal_mp from .utils.iterative_label import iterative_label, iterative_label_mp diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 655ded7..8e7d5ab 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -353,8 +353,8 @@ def main(): fill_canal_mp( output_path / 'step1_output', output_path / 'step1_output', - canal_label=201, - cord_label=200, + canal_label=2, + cord_label=1, largest_canal=True, largest_cord=True, override=True, @@ -389,7 +389,7 @@ def main(): output_path / 'step1_output', output_path / 'step1_cord', label=9, - seg_labels=[200], + seg_labels=[1], dilate=1, override=True, max_workers=max_workers, @@ -402,7 +402,7 @@ def main(): output_path / 'step1_output', output_path / 'step1_canal', label=7, - seg_labels=[200, 201], + seg_labels=[1, 2], dilate=1, override=True, max_workers=max_workers, @@ -419,9 +419,8 @@ def main(): extract_levels_mp( output_path / 'step1_output', output_path / 'step1_levels', - canal_labels=[200, 201], - c2c3_label=224, - step=-1, + canal_labels=[1, 2], + disc_labels=list(range(60, 65)) + list(range(70, 82)) + list(range(90, 95)) + [100], override=True, max_workers=max_workers, quiet=quiet, @@ -461,16 +460,11 @@ def main(): quiet=quiet, ) - # Load label mappings from JSON file - with open(resources_path / 'labels_maps' / 'nnunet_step2_input.json', 'r', encoding='utf-8') as map_file: - map_dict = json.load(map_file) - if not quiet: print('\n' 'Mapping the IVDs labels from the step1 model output to the odd IVDs:') # This will also delete labels without odd IVDs - map_labels_mp( + extract_alternate_mp( output_path / 'step2_input', output_path / 'step2_input', - map_dict=map_dict, seg_suffix='_0001', output_seg_suffix='_0001', override=True, @@ -583,8 +577,8 @@ def main(): fill_canal_mp( output_path / 'step2_output', output_path / 'step2_output', - canal_label=201, - cord_label=200, + canal_label=2, + cord_label=1, largest_canal=True, largest_cord=True, override=True, diff --git a/totalspineseg/utils/extract_alternate.py b/totalspineseg/utils/extract_alternate.py new file mode 100644 index 0000000..63c2568 --- /dev/null +++ b/totalspineseg/utils/extract_alternate.py @@ -0,0 +1,258 @@ +import sys, argparse, textwrap +from pathlib import Path +import numpy as np +import nibabel as nib +import multiprocessing as mp +from functools import partial +from tqdm.contrib.concurrent import process_map +import warnings + +warnings.filterwarnings("ignore") + +def main(): + + # Parse command line arguments + parser = argparse.ArgumentParser( + description=' '.join(f''' + Extract alternate labels from the segmentation. + '''.split()), + epilog=textwrap.dedent(''' + Examples: + extract_alternate -s labels -o levels --labels 60-100 -r + For BIDS: + extract_alternate -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_levels" -d "sub-" -u "anat" --labels 60-100 -r + '''), + formatter_class=argparse.RawTextHelpFormatter + ) + + parser.add_argument( + '--segs-dir', '-s', type=Path, required=True, + help='Folder containing input segmentations.' + ) + parser.add_argument( + '--output-segs-dir', '-o', type=Path, required=True, + help='Folder to save output segmentations.' + ) + parser.add_argument( + '--subject-dir', '-d', type=str, default=None, nargs='?', const='', + help=' '.join(f''' + Is every subject has its oen direcrory. + If this argument will be provided without value it will look for any directory in the segmentation directory. + If value also provided it will be used as a prefix to subject directory (for example "sub-"), defaults to False (no subjet directory). + '''.split()) + ) + parser.add_argument( + '--subject-subdir', '-u', type=str, default='', + help='Subfolder inside subject folder containing masks (for example "anat"), defaults to no subfolder.' + ) + parser.add_argument( + '--prefix', '-p', type=str, default='', + help='File prefix to work on.' + ) + parser.add_argument( + '--seg-suffix', type=str, default='', + help='Segmentation suffix, defaults to "".' + ) + parser.add_argument( + '--output-seg-suffix', type=str, default='', + help='Suffix for output segmentation, defaults to "".' + ) + parser.add_argument( + '--labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, + help='The labels to extract alternate elements from.' + ) + parser.add_argument( + '--from-second', action="store_true", default=False, + help='Start from the second label, defaults to false (start from the first label).' + ) + parser.add_argument( + '--override', '-r', action="store_true", default=False, + help='Override existing output files, defaults to false (Do not override).' + ) + parser.add_argument( + '--max-workers', '-w', type=int, default=mp.cpu_count(), + help='Max worker to run in parallel proccess, defaults to multiprocessing.cpu_count().' + ) + parser.add_argument( + '--quiet', '-q', action="store_true", default=False, + help='Do not display inputs and progress bar, defaults to false (display).' + ) + + # Parse the command-line arguments + args = parser.parse_args() + + # Get arguments + segs_path = args.segs_dir + output_segs_path = args.output_segs_dir + subject_dir = args.subject_dir + subject_subdir = args.subject_subdir + prefix = args.prefix + seg_suffix = args.seg_suffix + output_seg_suffix = args.output_seg_suffix + labels = [_ for __ in args.labels for _ in (__ if isinstance(__, list) else [__])] + from_second = args.from_second + override = args.override + max_workers = args.max_workers + quiet = args.quiet + + # Print the argument values if not quiet + if not quiet: + print(textwrap.dedent(f''' + Running {Path(__file__).stem} with the following params: + segs_dir = "{segs_path}" + output_segs_dir = "{output_segs_path}" + subject_dir = "{subject_dir}" + subject_subdir = "{subject_subdir}" + prefix = "{prefix}" + seg_suffix = "{seg_suffix}" + output_seg_suffix = "{output_seg_suffix}" + labels = {labels} + from_second = {from_second} + override = {override} + max_workers = {max_workers} + quiet = {quiet} + ''')) + + extract_alternate_mp( + segs_path=segs_path, + output_segs_path=output_segs_path, + subject_dir=subject_dir, + subject_subdir=subject_subdir, + prefix=prefix, + seg_suffix=seg_suffix, + output_seg_suffix=output_seg_suffix, + labels=labels, + from_second=from_second, + override=override, + max_workers=max_workers, + quiet=quiet, + ) + +def extract_alternate_mp( + segs_path, + output_segs_path, + subject_dir=None, + subject_subdir='', + prefix='', + seg_suffix='', + output_seg_suffix='', + labels=[], + from_second=False, + override=False, + max_workers=mp.cpu_count(), + quiet=False, + ): + ''' + Wrapper function to handle multiprocessing. + ''' + segs_path = Path(segs_path) + output_segs_path = Path(output_segs_path) + + glob_pattern = "" + if subject_dir is not None: + glob_pattern += f"{subject_dir}*/" + if len(subject_subdir) > 0: + glob_pattern += f"{subject_subdir}/" + glob_pattern += f'{prefix}*{seg_suffix}.nii.gz' + + # Process the NIfTI image and segmentation files + seg_path_list = list(segs_path.glob(glob_pattern)) + output_seg_path_list = [output_segs_path / _.relative_to(segs_path).parent / _.name.replace(f'{seg_suffix}.nii.gz', f'{output_seg_suffix}.nii.gz') for _ in seg_path_list] + + process_map( + partial( + _extract_alternate, + labels=labels, + from_second=from_second, + override=override, + ), + seg_path_list, + output_seg_path_list, + max_workers=max_workers, + chunksize=1, + disable=quiet, + ) + +def _extract_alternate( + seg_path, + output_seg_path, + labels=[], + from_second=False, + override=False, + ): + ''' + Wrapper function to handle IO. + ''' + seg_path = Path(seg_path) + output_seg_path = Path(output_seg_path) + + # If the output image already exists and we are not overriding it, return + if not override and output_seg_path.exists(): + return + + # Load segmentation + seg = nib.load(seg_path) + + try: + output_seg = extract_alternate( + seg, + labels=labels, + from_second=from_second, + ) + except ValueError as e: + output_seg_path.is_file() and output_seg_path.unlink() + print(f'Error: {seg_path}, {e}') + return + + # Ensure correct segmentation dtype, affine and header + output_seg = nib.Nifti1Image( + np.asanyarray(output_seg.dataobj).round().astype(np.uint8), + output_seg.affine, output_seg.header + ) + output_seg.set_data_dtype(np.uint8) + output_seg.set_qform(output_seg.affine) + output_seg.set_sform(output_seg.affine) + + # Make sure output directory exists and save the segmentation + output_seg_path.parent.mkdir(parents=True, exist_ok=True) + nib.save(output_seg, output_seg_path) + +def extract_alternate( + seg, + labels=[], + from_second=False, + ): + ''' + Extract vertebrae levels from Spinal Canal and Discs. + + The function extracts the vertebrae levels from the input segmentation by finding the closest voxel in the canal centerline to the middle of each disc. + The superior voxels in the canal centerline are set to 1 and the middle voxels between C2-C3 and the superior voxels are set to 2. + + Parameters + ---------- + seg : nibabel.Nifti1Image + The input segmentation. + labels : list of int + The labels to extract alternate elements from. + from_second : bool + Start from the second label, defaults to False. + + Returns + ------- + nibabel.Nifti1Image + The output segmentation with the vertebrae levels. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + output_seg_data = np.zeros_like(seg_data) + + selected_labels = np.array(labels)[np.isin(labels, seg_data)][1 if from_second else 0::2] + + output_seg_data[np.isin(seg_data, selected_labels)] = 1 + + output_seg = nib.Nifti1Image(output_seg_data, seg.affine, seg.header) + + return output_seg + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/totalspineseg/utils/extract_levels.py b/totalspineseg/utils/extract_levels.py index 264563b..f3b3cd9 100644 --- a/totalspineseg/utils/extract_levels.py +++ b/totalspineseg/utils/extract_levels.py @@ -18,9 +18,9 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - extract_levels -s labels -o levels --canal-labels 200 201 --c2c3-label 224 --step -1 -r + extract_levels -s labels -o levels --canal-labels 1 2 --disc-labels 60-64 70-81 90-94 100 -r For BIDS: - extract_levels -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_levels" -d "sub-" -u "anat" --canal-labels 200 201 --c2c3-label 224 --step -1 -r + extract_levels -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_levels" -d "sub-" -u "anat" --canal-labels 1 2 --disc-labels 60-64 70-81 90-94 100 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -58,16 +58,12 @@ def main(): help='Suffix for output segmentation, defaults to "".' ) parser.add_argument( - '--canal-labels', type=int, nargs='+', required=True, + '--canal-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, help='The canal labels.' ) parser.add_argument( - '--c2c3-label', type=int, required=True, - help='The label for C2-C3 disc.' - ) - parser.add_argument( - '--step', type=int, default=1, - help='The step to take between discs labels in the input, defaults to 1.' + '--disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, + help='The disc labels starting at C2C3 ordered from superior to inferior.' ) parser.add_argument( '--override', '-r', action="store_true", default=False, @@ -94,8 +90,7 @@ def main(): seg_suffix = args.seg_suffix output_seg_suffix = args.output_seg_suffix canal_labels = args.canal_labels - c2c3_label = args.c2c3_label - step = args.step + disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] override = args.override max_workers = args.max_workers quiet = args.quiet @@ -112,8 +107,7 @@ def main(): seg_suffix = "{seg_suffix}" output_seg_suffix = "{output_seg_suffix}" canal_labels = {canal_labels} - c2c3_label = {c2c3_label} - step = {step} + disc_labels = {disc_labels} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -128,8 +122,7 @@ def main(): seg_suffix=seg_suffix, output_seg_suffix=output_seg_suffix, canal_labels=canal_labels, - c2c3_label=c2c3_label, - step=step, + disc_labels=disc_labels, override=override, max_workers=max_workers, quiet=quiet, @@ -144,8 +137,7 @@ def extract_levels_mp( seg_suffix='', output_seg_suffix='', canal_labels=[], - c2c3_label=3, - step=1, + disc_labels=[], override=False, max_workers=mp.cpu_count(), quiet=False, @@ -171,8 +163,7 @@ def extract_levels_mp( partial( _extract_levels, canal_labels=canal_labels, - step=step, - c2c3_label=c2c3_label, + disc_labels=disc_labels, override=override, ), seg_path_list, @@ -186,8 +177,7 @@ def _extract_levels( seg_path, output_seg_path, canal_labels=[], - c2c3_label=3, - step=1, + disc_labels=[], override=False, ): ''' @@ -207,8 +197,7 @@ def _extract_levels( output_seg = extract_levels( seg, canal_labels=canal_labels, - c2c3_label=c2c3_label, - step=step, + disc_labels=disc_labels ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -231,8 +220,7 @@ def _extract_levels( def extract_levels( seg, canal_labels=[], - c2c3_label=3, - step=1, + disc_labels=[], ): ''' Extract vertebrae levels from Spinal Canal and Discs. @@ -246,10 +234,8 @@ def extract_levels( The input segmentation. canal_labels : list The canal labels. - c2c3_label : int - The label for C2-C3 disc. - step : int - The step to take between discs labels in the input. + disc_labels : list + The disc labels starting at C2C3 ordered from superior to inferior. Returns ------- @@ -283,8 +269,7 @@ def extract_levels( mask_canal_centerline_indices = np.array(np.nonzero(mask_canal_centerline)) # Get the labels of the discs and the output labels - disc_labels = list(range(c2c3_label, c2c3_label + step * 23, step)) - out_labels = list(range(3, 26)) + out_labels = list(range(3, 3 + len(disc_labels))) # Filter the discs that are in the segmentation in_seg = np.isin(disc_labels, seg_data) From a0753a4ca3307b1e27cc8b8c9f3306b2dc6c96ae Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 03:50:00 +0300 Subject: [PATCH 06/26] Fix issue by adding range labels 60-101 for extract_alternate --- totalspineseg/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 8e7d5ab..0974d8c 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -467,6 +467,7 @@ def main(): output_path / 'step2_input', seg_suffix='_0001', output_seg_suffix='_0001', + labels=list(range(60, 101)), override=True, max_workers=max_workers, quiet=quiet, From 610f2b8f89ba175259735873f35c390d6b85e48a Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 13:15:04 +0300 Subject: [PATCH 07/26] Refactor label prioritization and mapping logic Replaced the 'from_second' argument with a more flexible and powerful 'prioritize_labels' argument, allowing specific labels to be prioritized in the output. Improved the logic for extracting alternates by explicitly ensuring prioritized labels' inclusion when necessary. Additionally, adjusted mappings in `iterative_label` function to correctly associate disc and vertebrae labels, enhancing output accuracy and clarity. Addresses issues with label handling and improves overall robustness. --- totalspineseg/utils/extract_alternate.py | 35 +++++++++++++++--------- totalspineseg/utils/iterative_label.py | 12 ++++---- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/totalspineseg/utils/extract_alternate.py b/totalspineseg/utils/extract_alternate.py index 63c2568..048833d 100644 --- a/totalspineseg/utils/extract_alternate.py +++ b/totalspineseg/utils/extract_alternate.py @@ -62,8 +62,8 @@ def main(): help='The labels to extract alternate elements from.' ) parser.add_argument( - '--from-second', action="store_true", default=False, - help='Start from the second label, defaults to false (start from the first label).' + '--prioratize-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='Specify labels that will be prioratized in the output, the first label in the list will be included in the output, defaults to [] (The first label in the list that is in the segmentation).' ) parser.add_argument( '--override', '-r', action="store_true", default=False, @@ -90,7 +90,7 @@ def main(): seg_suffix = args.seg_suffix output_seg_suffix = args.output_seg_suffix labels = [_ for __ in args.labels for _ in (__ if isinstance(__, list) else [__])] - from_second = args.from_second + prioratize_labels = [_ for __ in args.prioratize_labels for _ in (__ if isinstance(__, list) else [__])] override = args.override max_workers = args.max_workers quiet = args.quiet @@ -107,7 +107,7 @@ def main(): seg_suffix = "{seg_suffix}" output_seg_suffix = "{output_seg_suffix}" labels = {labels} - from_second = {from_second} + prioratize_labels = {prioratize_labels} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -122,7 +122,7 @@ def main(): seg_suffix=seg_suffix, output_seg_suffix=output_seg_suffix, labels=labels, - from_second=from_second, + prioratize_labels=prioratize_labels, override=override, max_workers=max_workers, quiet=quiet, @@ -137,7 +137,7 @@ def extract_alternate_mp( seg_suffix='', output_seg_suffix='', labels=[], - from_second=False, + prioratize_labels=[], override=False, max_workers=mp.cpu_count(), quiet=False, @@ -163,7 +163,7 @@ def extract_alternate_mp( partial( _extract_alternate, labels=labels, - from_second=from_second, + prioratize_labels=prioratize_labels, override=override, ), seg_path_list, @@ -177,7 +177,7 @@ def _extract_alternate( seg_path, output_seg_path, labels=[], - from_second=False, + prioratize_labels=[], override=False, ): ''' @@ -197,7 +197,7 @@ def _extract_alternate( output_seg = extract_alternate( seg, labels=labels, - from_second=from_second, + prioratize_labels=prioratize_labels, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -220,7 +220,7 @@ def _extract_alternate( def extract_alternate( seg, labels=[], - from_second=False, + prioratize_labels=[], ): ''' Extract vertebrae levels from Spinal Canal and Discs. @@ -234,8 +234,8 @@ def extract_alternate( The input segmentation. labels : list of int The labels to extract alternate elements from. - from_second : bool - Start from the second label, defaults to False. + prioratize_labels : list of int + Specify labels that will be prioratized in the output, the first label in the list will be included in the output. Returns ------- @@ -246,7 +246,16 @@ def extract_alternate( output_seg_data = np.zeros_like(seg_data) - selected_labels = np.array(labels)[np.isin(labels, seg_data)][1 if from_second else 0::2] + # Get the labels in the segmentation + labels = np.array(labels)[np.isin(labels, seg_data)] + + # Get the labels to prioratize in the output that are in the segmentation and in the labels + prioratize_labels = np.array(prioratize_labels)[np.isin(prioratize_labels, labels)] + + selected_labels = labels[::2] + + if len(prioratize_labels) > 0 and prioratize_labels[0] not in selected_labels: + selected_labels = labels[1::2] output_seg_data[np.isin(seg_data, selected_labels)] = 1 diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index de7e9b9..b7b2559 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -583,8 +583,8 @@ def iterative_label( # Build a list containing all possible labels for the vertebrae ordered superio-inferior # We start with the C1 and C2 labels as the first landmark is the C3 vertebrae all_possible_vertebrae_output_labels = [ - vertebrae_landmark_output_labels[0] - 2 * vertebrae_output_step, - vertebrae_landmark_output_labels[0] - vertebrae_output_step + vertebrae_landmark_output_labels[0] - 2 * vertebrae_output_step, # C1 + vertebrae_landmark_output_labels[0] - vertebrae_output_step # C2 ] for l, s in zip(vertebrae_landmark_output_labels, region_max_sizes): for i in range(s): @@ -599,7 +599,7 @@ def iterative_label( sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) # Make a dict mapping disc to vertebrae labels - disc_landmark_output_labels_2vert = dict(zip(disc_landmark_output_labels, vertebrae_landmark_output_labels)) + disc_output_labels_2vert = dict(zip(all_possible_disc_output_labels, all_possible_vertebrae_output_labels[2:])) # Make a dict mapping the sorted vertebrae labels to the output labels map_vert_sorted_labels_2output = {} @@ -613,8 +613,8 @@ def iterative_label( if True not in is_vert[sorted_labels_l_disc_idx:]: continue - # Get te vert label that is just next to the l_disc label - l = next(_l for _l, _ in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _) + # Get te vert label that is just next to the l_disc label inferiorly + l = next(_l for _l, _is_v in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _is_v) # Get the index of the current landmark in the sorted vertebrae labels start_l = vert_sorted_labels.index(l) @@ -623,7 +623,7 @@ def iterative_label( l_disc_output = map_disc_sorted_labels_2output[l_disc] # Get the index of the current vert landmark in the list of all possible vertebrae output labels - start_o = all_possible_vertebrae_output_labels.index(disc_landmark_output_labels_2vert[l_disc_output]) + start_o = all_possible_vertebrae_output_labels.index(disc_output_labels_2vert[l_disc_output]) # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image if len(map_vert_sorted_labels_2output) == 0: From 55aa76fa66c8a439051f644f34624b72a6cd40f8 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 14:20:24 +0300 Subject: [PATCH 08/26] Improve disc label handling in extract_levels function Revise the disc label extraction logic to ensure correct mapping of disc labels found in segmentation data. This update prevents errors when disc labels are missing and adjusts output labels based on actual disc labels present. Enhances robustness of the extraction process. --- totalspineseg/utils/extract_levels.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/totalspineseg/utils/extract_levels.py b/totalspineseg/utils/extract_levels.py index f3b3cd9..235bd4c 100644 --- a/totalspineseg/utils/extract_levels.py +++ b/totalspineseg/utils/extract_levels.py @@ -268,17 +268,20 @@ def extract_levels( # Get the indices of the canal centerline mask_canal_centerline_indices = np.array(np.nonzero(mask_canal_centerline)) - # Get the labels of the discs and the output labels - out_labels = list(range(3, 3 + len(disc_labels))) - - # Filter the discs that are in the segmentation - in_seg = np.isin(disc_labels, seg_data) - map_labels = [(d, o) for d, o, i in zip(disc_labels, out_labels, in_seg) if i] + # Get the labels of the discs in the segmentation + disc_labels_in_seg = np.array(disc_labels)[np.isin(disc_labels, seg_data)] # If no disc labels found in the segmentation raise an error - if len(map_labels) == 0: + if len(disc_labels_in_seg) == 0: raise ValueError(f"No disc labels found in the segmentation.") + # Get the labels of the discs and the output labels + first_disk_idx = disc_labels.index(disc_labels_in_seg[0]) + out_labels = list(range(3 + first_disk_idx, 3 + first_disk_idx + len(disc_labels_in_seg))) + + # Filter the discs that are in the segmentation + map_labels = dict(zip(disc_labels_in_seg, out_labels)) + # Loop over the discs from C2-C3 to L5-S1 and find the closest voxel in the canal centerline for disc_label, out_label in map_labels: # Create a mask of the disc From 1966341330fcda8b75dbeff0dd66ca0e4bd85d93 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 14:25:36 +0300 Subject: [PATCH 09/26] Fix disc label mapping error in extract_levels Corrected the disc label mapping in `extract_levels` to ensure proper iteration over map_labels. This fixes an issue where the previous dictionary implementation caused unexpected behavior by looping over keys rather than items, ensuring accurate processing of segmentation labels. --- totalspineseg/utils/extract_levels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/totalspineseg/utils/extract_levels.py b/totalspineseg/utils/extract_levels.py index 235bd4c..3262171 100644 --- a/totalspineseg/utils/extract_levels.py +++ b/totalspineseg/utils/extract_levels.py @@ -280,7 +280,7 @@ def extract_levels( out_labels = list(range(3 + first_disk_idx, 3 + first_disk_idx + len(disc_labels_in_seg))) # Filter the discs that are in the segmentation - map_labels = dict(zip(disc_labels_in_seg, out_labels)) + map_labels = zip(disc_labels_in_seg, out_labels) # Loop over the discs from C2-C3 to L5-S1 and find the closest voxel in the canal centerline for disc_label, out_label in map_labels: From 22ca4204ce30cef22d2762eb18ccbf6515cbe4d8 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:36:31 +0300 Subject: [PATCH 10/26] Add label text options for image preview Enhanced the `preview_jpg` utility to support placing label texts on segmented images. Added new CLI arguments to accept label mappings or JSON files for labels on the right and left sides of the image. Updated image processing to flip and rotate images appropriately and draw labels with outlines and specified colors. Improved overall clarity by adjusting label positions to avoid overlaps. --- totalspineseg/inference.py | 35 ++++++ totalspineseg/utils/preview_jpg.py | 191 +++++++++++++++++++++++++++-- 2 files changed, 215 insertions(+), 11 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 0974d8c..9c0624c 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -232,6 +232,19 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, + label_texts_right={ + 1: 'SC', 2: 'Canal', + 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', + 20: 'T1', 21: 'T2', 22: 'T3', 23: 'T4', 24: 'T5', 25: 'T6', 26: 'T7', + 27: 'T8', 28: 'T9', 29: 'T10', 30: 'T11', 31: 'T12', + 40: 'L1', 41: 'L2', 42: 'L3', 43: 'L4', 44: 'L5' + }, + label_texts_left={ + 50: 'Sacrum', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', + 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', + 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', + 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + }, ) if not quiet: print('\n' 'Converting 4D images to 3D:') @@ -381,6 +394,15 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, + label_texts_right={ + 1: 'SC', 2: 'Canal', + }, + label_texts_left={ + 50: 'Vertebrae', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', + 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', + 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', + 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + }, ) if not quiet: print('\n' 'Extracting spinal cord soft segmentation from step 1 model output:') @@ -606,6 +628,19 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, + label_texts_right={ + 1: 'SC', 2: 'Canal', + 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', + 20: 'T1', 21: 'T2', 22: 'T3', 23: 'T4', 24: 'T5', 25: 'T6', 26: 'T7', + 27: 'T8', 28: 'T9', 29: 'T10', 30: 'T11', 31: 'T12', + 40: 'L1', 41: 'L2', 42: 'L3', 43: 'L4', 44: 'L5' + }, + label_texts_left={ + 50: 'Sacrum', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', + 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', + 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', + 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + }, ) if __name__ == '__main__': diff --git a/totalspineseg/utils/preview_jpg.py b/totalspineseg/utils/preview_jpg.py index cf82199..8f66e98 100644 --- a/totalspineseg/utils/preview_jpg.py +++ b/totalspineseg/utils/preview_jpg.py @@ -1,5 +1,5 @@ -import sys, argparse, textwrap -from PIL import Image +import sys, argparse, textwrap, json +from PIL import Image, ImageDraw, ImageFont import multiprocessing as mp from functools import partial from pathlib import Path @@ -78,13 +78,31 @@ def main(): '--sliceloc', '-l', type=float, default=0.5, help='Slice location within the specified orientation (0-1). Default is 0.5 (middle slice).' ) + parser.add_argument( + '--label-text-right', '-ltr', type=str, nargs='+', default=[], + help=' '.join(''' + JSON file or mapping from label integers to text labels to be placed on the right side. + The format should be input_label:text_label without any spaces. + For example, you can use a JSON file like right_labels.json containing {"1": "SC", "2": "Canal"}, + or provide mappings directly like 1:SC 2:Canal + '''.split()), + ) + parser.add_argument( + '--label-text-left', '-ltl', type=str, nargs='+', default=[], + help=' '.join(''' + JSON file or mapping from label integers to text labels to be placed on the left side. + The format should be input_label:text_label without any spaces. + For example, you can use a JSON file like left_labels.json containing {"1": "SC", "2": "Canal"}, + or provide mappings directly like 1:SC 2:Canal + '''.split()), + ) parser.add_argument( '--override', '-r', action="store_true", default=False, help='Override existing output files, defaults to false (Do not override).' ) parser.add_argument( '--max-workers', '-w', type=int, default=mp.cpu_count(), - help='Max worker to run in parallel proccess, defaults to multiprocessing.cpu_count().' + help='Max worker to run in parallel process, defaults to multiprocessing.cpu_count().' ) parser.add_argument( '--quiet', '-q', action="store_true", default=False, @@ -106,10 +124,18 @@ def main(): seg_suffix = args.seg_suffix orient = args.orient sliceloc = args.sliceloc + label_text_right_list = args.label_text_right + label_text_left_list = args.label_text_left override = args.override max_workers = args.max_workers quiet = args.quiet + # Load label_texts_right into a dict + label_texts_right = load_label_texts(label_text_right_list, 'label-text-right') + + # Load label_texts_left into a dict + label_texts_left = load_label_texts(label_text_left_list, 'label-text-left') + # Print the argument values if not quiet if not quiet: print(textwrap.dedent(f''' @@ -125,6 +151,8 @@ def main(): seg_suffix = "{seg_suffix}" orient = "{orient}" sliceloc = {sliceloc} + label_texts_right = {label_texts_right} + label_texts_left = {label_texts_left} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -142,11 +170,27 @@ def main(): seg_suffix=seg_suffix, orient=orient, sliceloc=sliceloc, + label_texts_right=label_texts_right, + label_texts_left=label_texts_left, override=override, max_workers=max_workers, quiet=quiet, ) +def load_label_texts(label_text_list, param_name): + if len(label_text_list) == 1 and label_text_list[0][-5:] == '.json': + # Load label mappings from JSON file + with open(label_text_list[0], 'r', encoding='utf-8') as map_file: + label_texts = json.load(map_file) + # Ensure keys are ints + label_texts = {int(k): v for k, v in label_texts.items()} + else: + try: + label_texts = {int(l_in): l_out for l_in, l_out in map(lambda x: x.split(':'), label_text_list)} + except: + raise ValueError(f"Input param --{param_name} is not in the right structure. Make sure it is in the right format, e.g., 1:SC 2:Canal") + return label_texts + def preview_jpg_mp( images_path, output_path, @@ -159,6 +203,8 @@ def preview_jpg_mp( seg_suffix='', orient='sag', sliceloc=0.5, + label_texts_right={}, + label_texts_left={}, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -189,6 +235,8 @@ def preview_jpg_mp( _preview_jpg, orient=orient, sliceloc=sliceloc, + label_texts_right=label_texts_right, + label_texts_left=label_texts_left, override=override, ), image_path_list, @@ -205,6 +253,8 @@ def _preview_jpg( seg_path=None, orient='sag', sliceloc=0.5, + label_texts_right={}, + label_texts_left={}, override=False, ): ''' @@ -240,8 +290,9 @@ def _preview_jpg( # Repeat the grayscale slice 3 times to create a color image slice_img = np.repeat(slice_img[:, :, np.newaxis], 3, axis=2).astype(np.uint8) - # Create a blank color image with the same dimensions as the input image - output_data = np.zeros_like(slice_img, dtype=np.uint8) + # Flip and rotate the image slice + slice_img = np.flipud(slice_img) + slice_img = np.rot90(slice_img, k=1) if seg_path and seg_path.is_file(): try: @@ -257,6 +308,10 @@ def _preview_jpg( slice_seg = seg_data.take(slice_index, axis=axis) + # Flip and rotate the segmentation slice + slice_seg = np.flipud(slice_seg) + slice_seg = np.rot90(slice_seg, k=1) + # Generate consistent random colors for each label unique_labels = np.unique(slice_seg).astype(int) colors = {} @@ -264,21 +319,135 @@ def _preview_jpg( rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(label * 10))) colors[label] = rs.randint(0, 255, 3) + # Create a blank color image with the same dimensions as the input image + output_data = np.zeros_like(slice_img, dtype=np.uint8) + # Apply the segmentation mask to the image and assign colors for label, color in colors.items(): if label != 0: # Ignore the background (label 0) mask = slice_seg == label output_data[mask] = color - output_data = np.where(output_data > 0, output_data, slice_img) - - # Rotate the image 90 degrees counter-clockwise and flip it vertically - output_data = np.flipud(output_data) - output_data = np.rot90(output_data, k=1) + output_data = np.where(output_data > 0, output_data, slice_img) + else: + output_data = slice_img + unique_labels = [] + colors = {} - # Create an Image object from the output Image object as a JPG file + # Create an Image object from the output data output_image = Image.fromarray(output_data, mode="RGB") + # Draw text labels if label_texts are provided + if (label_texts_right or label_texts_left) and seg_path and seg_path.is_file(): + draw = ImageDraw.Draw(output_image) + # Use a bold TrueType font for better sharpness and boldness + try: + font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=15) + except IOError: + font = ImageFont.load_default() + width, height = output_image.size + used_positions = [] + for label in unique_labels: + if label != 0: + text = None + side = None + if label in label_texts_right: + text = label_texts_right[label] + side = 'right' + elif label in label_texts_left: + text = label_texts_left[label] + side = 'left' + + if text and side: + mask = slice_seg == label + positions = np.argwhere(mask) + if positions.size > 0: + # Get the bounding box of the label + ys, xs = positions[:, 0], positions[:, 1] + x_min, x_max = xs.min(), xs.max() + y_min, y_max = ys.min(), ys.max() + # Start from just outside the label + if side == 'right': + x_new = x_max + 1 + else: + x_new = x_min - 1 + y_new = int((y_min + y_max) / 2) + + # Ensure starting positions are within image bounds + x_new = min(max(0, x_new), width - 1) + y_new = min(max(0, y_new), height - 1) + + # Search for a position outside the segmentation + found_position = False + if side == 'right': + for dx in range(0, width - x_new): + if slice_seg[y_new, min(x_new + dx, width - 1)] == 0: + x_new = x_new + dx + found_position = True + break + else: + for dx in range(0, x_new + 1): + if slice_seg[y_new, max(x_new - dx, 0)] == 0: + x_new = x_new - dx + found_position = True + break + if not found_position: + # Try moving down + for dy in range(1, height - y_new): + if slice_seg[min(y_new + dy, height - 1), x_new] == 0: + y_new = y_new + dy + found_position = True + break + if not found_position: + # Try moving up + for dy in range(1, y_new + 1): + if slice_seg[max(y_new - dy, 0), x_new] == 0: + y_new = y_new - dy + found_position = True + break + + if found_position: + text_color = tuple(colors[label].tolist()) + + # Avoid overlapping labels + if (x_new, y_new) not in used_positions: + # Get text size + try: + # For Pillow >= 8.0.0 + bbox = font.getbbox(text) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + except AttributeError: + try: + # For older versions + text_width, text_height = font.getsize(text) + except AttributeError: + # As a last resort, approximate text size + text_width, text_height = draw.textsize(text, font=font) + + # Adjust x_new for left side labels + if side == 'left': + x_new = x_new - text_width + + # Ensure text is within image bounds + x_new = min(max(0, x_new), width - text_width) + y_new = min(max(0, y_new), height - text_height) + + # Draw outline by drawing text multiple times around the perimeter + outline_color = (255, 255, 255) # White outline + outline_offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] + for dx, dy in outline_offsets: + draw.text((x_new + dx, y_new + dy), text, font=font, fill=outline_color) + + # Draw the text over the outline multiple times to make it thicker + for _ in range(3): + draw.text((x_new, y_new), text, fill=text_color, font=font) + + used_positions.append((x_new, y_new)) + else: + # No suitable position found, skip drawing the label + pass + # Make sure output directory exists and save the image output_path.parent.mkdir(parents=True, exist_ok=True) output_image.save(output_path) From 3d8a559cda348e8d3a28df39c3884c3404ea73d6 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:47:12 +0300 Subject: [PATCH 11/26] Add preview image generation for segmentation steps 1 and 2 Enhanced the inference script to generate preview images for both step 1 and step 2 outputs. This addition allows for better visualization and verification of segmentation results. The previews include labeled images to help in quick assessment of different stages of the processing pipeline. --- totalspineseg/inference.py | 45 ++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 9c0624c..a5e2210 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -232,6 +232,15 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, + ) + preview_jpg_mp( + output_path / 'input', + output_path / 'preview', + segs_path=output_path / 'localizers', + output_suffix='_loc_tags', + override=True, + max_workers=max_workers, + quiet=quiet, label_texts_right={ 1: 'SC', 2: 'Canal', 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', @@ -394,15 +403,24 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, - label_texts_right={ - 1: 'SC', 2: 'Canal', - }, - label_texts_left={ - 50: 'Vertebrae', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', - 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', - 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', - 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' - }, + ) + preview_jpg_mp( + output_path / 'input', + output_path / 'preview', + segs_path=output_path / 'step1_output', + output_suffix='_step1_output_tags', + override=True, + max_workers=max_workers, + quiet=quiet, + label_texts_right={ + 1: 'SC', 2: 'Canal', + }, + label_texts_left={ + 50: 'Vertebrae', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', + 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', + 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', + 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + }, ) if not quiet: print('\n' 'Extracting spinal cord soft segmentation from step 1 model output:') @@ -628,6 +646,15 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, + ) + preview_jpg_mp( + output_path / 'input', + output_path / 'preview', + segs_path=output_path / 'step2_output', + output_suffix='_step2_output_tags', + override=True, + max_workers=max_workers, + quiet=quiet, label_texts_right={ 1: 'SC', 2: 'Canal', 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', From 64700dddcadef7e5469edd08f0026763dce504f9 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:56:52 +0300 Subject: [PATCH 12/26] Remove redundant spinal cord and canal labels Simplified label configuration by removing redundant 'SC' (Spinal Cord) and 'Canal' labels from `label_texts_right`. This reduces ambiguity and maintains a consistent labeling strategy within the segmentation process. Resolves potential confusion in the labeling scheme, ensuring clarity and consistency for future maintenance and understanding. --- totalspineseg/inference.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index a5e2210..747f531 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -242,7 +242,6 @@ def main(): max_workers=max_workers, quiet=quiet, label_texts_right={ - 1: 'SC', 2: 'Canal', 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', 20: 'T1', 21: 'T2', 22: 'T3', 23: 'T4', 24: 'T5', 25: 'T6', 26: 'T7', 27: 'T8', 28: 'T9', 29: 'T10', 30: 'T11', 31: 'T12', @@ -412,9 +411,6 @@ def main(): override=True, max_workers=max_workers, quiet=quiet, - label_texts_right={ - 1: 'SC', 2: 'Canal', - }, label_texts_left={ 50: 'Vertebrae', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', @@ -656,7 +652,6 @@ def main(): max_workers=max_workers, quiet=quiet, label_texts_right={ - 1: 'SC', 2: 'Canal', 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', 20: 'T1', 21: 'T2', 22: 'T3', 23: 'T4', 24: 'T5', 25: 'T6', 26: 'T7', 27: 'T8', 28: 'T9', 29: 'T10', 30: 'T11', 31: 'T12', From 23ff7d51cee9f4812bed6907ae65d69806c85f70 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:59:08 +0300 Subject: [PATCH 13/26] Remove unused 'Vertebrae' label from inference script The 'Vertebrae' label text was unnecessary and has been removed to streamline the label definition. This change helps avoid potential confusion and ensures only relevant label texts are included. No functional impact expected. --- totalspineseg/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 747f531..6fa8400 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -412,7 +412,7 @@ def main(): max_workers=max_workers, quiet=quiet, label_texts_left={ - 50: 'Vertebrae', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', + 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' From 3062368c0dd4508e6b361fe574e9279fc32828e0 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 19:04:50 +0300 Subject: [PATCH 14/26] Improve label mapping for spine segmentation Add default region sizes to handle disc and vertebrae labels more accurately. Adjust start indexes based on the default region sizes to ensure consistent mapping starting from the most superior label. This enhances label consistency and correctness in spine segmentation. --- totalspineseg/utils/iterative_label.py | 64 ++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 10 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index b7b2559..f6a0398 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -490,6 +490,9 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) + # Region default sizes for the discs and vertebrae (Cervical, Thoracic, Lumbar, Sacrum) + region_default_sizes=[5, 12, 5, 1], + # Get sorted connected components superio-inferior (SI) for the disc labels disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( seg, @@ -555,21 +558,39 @@ def iterative_label( for i in range(s): all_possible_disc_output_labels.append(l + i * disc_output_step) + # Make a list containing all possible labels for the vertebrae ordered superio-inferior with the default region sizes + all_default_disc_output_labels = [] + for l, s in zip(disc_landmark_output_labels, region_default_sizes): + for i in range(s): + all_default_disc_output_labels.append(l + i * disc_output_step) + # Make a dict mapping the sorted disc labels to the output labels map_disc_sorted_labels_2output = {} # We loop over all the landmarks starting from the most superior for l in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: + + # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image + if len(map_disc_sorted_labels_2output) == 0: + # Get the index of the current landmark in the sorted disc labels + start_l = disc_sorted_labels.index(l) + + # Get the index of the current landmark in the list of all default disc output labels + start_o_def = all_default_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l]) + + # Adjust the start indexes + start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) + + # Map the sorted disc labels to the output labels + for l, o in zip(disc_sorted_labels[start_l:], all_default_disc_output_labels[start_o:]): + map_disc_sorted_labels_2output[l] = o + # Get the index of the current landmark in the sorted disc labels start_l = disc_sorted_labels.index(l) # Get the index of the current landmark in the list of all possible disc output labels start_o = all_possible_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l]) - # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image - if len(map_disc_sorted_labels_2output) == 0: - start_l, start_o = max(0, start_l - start_o), max(0, start_o - start_l) - # Map the sorted disc labels to the output labels # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark for l, o in zip(disc_sorted_labels[start_l:], all_possible_disc_output_labels[start_o:]): @@ -590,6 +611,15 @@ def iterative_label( for i in range(s): all_possible_vertebrae_output_labels.append(l + i * vertebrae_output_step) + # Make a list containing all possible labels for the vertebrae ordered superio-inferior with the default region sizes + all_default_vertebrae_output_labels = [ + vertebrae_landmark_output_labels[0] - 2 * vertebrae_output_step, # C1 + vertebrae_landmark_output_labels[0] - vertebrae_output_step # C2 + ] + for l, s in zip(vertebrae_landmark_output_labels, region_default_sizes): + for i in range(s): + all_default_vertebrae_output_labels.append(l + i * vertebrae_output_step) + # Sort the combined disc+vert labels by their z-index sorted_labels = vert_sorted_labels + disc_sorted_labels sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes @@ -616,18 +646,32 @@ def iterative_label( # Get te vert label that is just next to the l_disc label inferiorly l = next(_l for _l, _is_v in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _is_v) - # Get the index of the current landmark in the sorted vertebrae labels - start_l = vert_sorted_labels.index(l) - # Get the output disc label for the l_disc l_disc_output = map_disc_sorted_labels_2output[l_disc] - # Get the index of the current vert landmark in the list of all possible vertebrae output labels - start_o = all_possible_vertebrae_output_labels.index(disc_output_labels_2vert[l_disc_output]) + # Get the output vert label for the l_disc + l_disc_vert_output = disc_output_labels_2vert[l_disc_output] # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image if len(map_vert_sorted_labels_2output) == 0: - start_l, start_o = max(0, start_l - start_o), max(0, start_o - start_l) + # Get the index of the current landmark in the sorted vertebrae labels + start_l = vert_sorted_labels.index(l) + + # Get the index of the current vert landmark in the list of all possible vertebrae output labels + start_o_def = all_default_vertebrae_output_labels.index(l_disc_vert_output) + + # Adjust the start indexes + start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) + + # Map the sorted vert labels to the output labels + for l, o in zip(vert_sorted_labels[start_l:], all_default_vertebrae_output_labels[start_o:]): + map_vert_sorted_labels_2output[l] = o + + # Get the index of the current landmark in the sorted vertebrae labels + start_l = vert_sorted_labels.index(l) + + # Get the index of the current vert landmark in the list of all possible vertebrae output labels + start_o = all_possible_vertebrae_output_labels.index(l_disc_vert_output) # Map the sorted vert labels to the output labels # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark From 149b4c00657f4cbc934ff784c3c348c38d7a838d Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 19:08:47 +0300 Subject: [PATCH 15/26] Fix typo in region_default_sizes definition Removed an extraneous comma from the `region_default_sizes` definition in `iterative_label.py`. This prevents potential issues with unintended tuple creation and ensures the sizes are correctly assigned for each spinal region. --- totalspineseg/utils/iterative_label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index f6a0398..4bf7b0e 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -491,7 +491,7 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) # Region default sizes for the discs and vertebrae (Cervical, Thoracic, Lumbar, Sacrum) - region_default_sizes=[5, 12, 5, 1], + region_default_sizes=[5, 12, 5, 1] # Get sorted connected components superio-inferior (SI) for the disc labels disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( From 56ac7c4b6955f45138f0d8a3c0c69e05ada4544f Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 19:10:59 +0300 Subject: [PATCH 16/26] Fix indexing error in label mapping Corrects the labeling index calculation by using `start_o_def` instead of `start_o` during the mapping of disc and vertebrae labels to output labels. This ensures accurate label alignment and prevents potential mismatches in the labeling process. --- totalspineseg/utils/iterative_label.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 4bf7b0e..53c8f2f 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -582,7 +582,7 @@ def iterative_label( start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) # Map the sorted disc labels to the output labels - for l, o in zip(disc_sorted_labels[start_l:], all_default_disc_output_labels[start_o:]): + for l, o in zip(disc_sorted_labels[start_l:], all_default_disc_output_labels[start_o_def:]): map_disc_sorted_labels_2output[l] = o # Get the index of the current landmark in the sorted disc labels @@ -664,7 +664,7 @@ def iterative_label( start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) # Map the sorted vert labels to the output labels - for l, o in zip(vert_sorted_labels[start_l:], all_default_vertebrae_output_labels[start_o:]): + for l, o in zip(vert_sorted_labels[start_l:], all_default_vertebrae_output_labels[start_o_def:]): map_vert_sorted_labels_2output[l] = o # Get the index of the current landmark in the sorted vertebrae labels From 3eb57a72375226a14af112e7052e67076e8a5a85 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 14 Sep 2024 22:13:08 +0300 Subject: [PATCH 17/26] Improve variable naming for clarity in iterative_label function Enhanced readability by renaming loop and temporary variables to more descriptive names throughout `iterative_label`. This makes the code easier to understand and maintain by clearly distinguishing between disc and vertebrae related variables. No functional changes were made. --- totalspineseg/utils/iterative_label.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 53c8f2f..7e09fdd 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -568,15 +568,15 @@ def iterative_label( map_disc_sorted_labels_2output = {} # We loop over all the landmarks starting from the most superior - for l in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: + for l_disc in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image if len(map_disc_sorted_labels_2output) == 0: # Get the index of the current landmark in the sorted disc labels - start_l = disc_sorted_labels.index(l) + start_l = disc_sorted_labels.index(l_disc) # Get the index of the current landmark in the list of all default disc output labels - start_o_def = all_default_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l]) + start_o_def = all_default_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l_disc]) # Adjust the start indexes start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) @@ -586,10 +586,10 @@ def iterative_label( map_disc_sorted_labels_2output[l] = o # Get the index of the current landmark in the sorted disc labels - start_l = disc_sorted_labels.index(l) + start_l = disc_sorted_labels.index(l_disc) # Get the index of the current landmark in the list of all possible disc output labels - start_o = all_possible_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l]) + start_o = all_possible_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l_disc]) # Map the sorted disc labels to the output labels # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark @@ -644,7 +644,7 @@ def iterative_label( continue # Get te vert label that is just next to the l_disc label inferiorly - l = next(_l for _l, _is_v in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _is_v) + l_vert = next(_l for _l, _is_v in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _is_v) # Get the output disc label for the l_disc l_disc_output = map_disc_sorted_labels_2output[l_disc] @@ -655,7 +655,7 @@ def iterative_label( # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image if len(map_vert_sorted_labels_2output) == 0: # Get the index of the current landmark in the sorted vertebrae labels - start_l = vert_sorted_labels.index(l) + start_l = vert_sorted_labels.index(l_vert) # Get the index of the current vert landmark in the list of all possible vertebrae output labels start_o_def = all_default_vertebrae_output_labels.index(l_disc_vert_output) @@ -668,7 +668,7 @@ def iterative_label( map_vert_sorted_labels_2output[l] = o # Get the index of the current landmark in the sorted vertebrae labels - start_l = vert_sorted_labels.index(l) + start_l = vert_sorted_labels.index(l_vert) # Get the index of the current vert landmark in the list of all possible vertebrae output labels start_o = all_possible_vertebrae_output_labels.index(l_disc_vert_output) From b4cb2adf245659a3373f43226eeef58b478a0c59 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sun, 15 Sep 2024 03:52:57 +0300 Subject: [PATCH 18/26] Simplify vertebrae to output labels mapping logic Refactored the label mapping logic to streamline the process of mapping vertebrae labels to output labels. The new approach eliminates redundant computations and reduces the complexity by focusing directly on indexing and mapping. This change improves code readability and maintainability, while ensuring accurate and consistent label assignments. No functional modifications introduced. --- totalspineseg/utils/iterative_label.py | 61 ++++++++------------------ 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 7e09fdd..4f6ce91 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -634,49 +634,24 @@ def iterative_label( # Make a dict mapping the sorted vertebrae labels to the output labels map_vert_sorted_labels_2output = {} - # We loop over all the landmarks starting from the most superior - for l_disc in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: - # Get the index of the current landmark in the sorted disc labels - sorted_labels_l_disc_idx = list(zip(sorted_labels, is_vert)).index((l_disc, False)) - - # Continue if no more vertebrae labels - if True not in is_vert[sorted_labels_l_disc_idx:]: - continue - - # Get te vert label that is just next to the l_disc label inferiorly - l_vert = next(_l for _l, _is_v in list(zip(sorted_labels, is_vert))[sorted_labels_l_disc_idx:] if _is_v) - - # Get the output disc label for the l_disc - l_disc_output = map_disc_sorted_labels_2output[l_disc] - - # Get the output vert label for the l_disc - l_disc_vert_output = disc_output_labels_2vert[l_disc_output] - - # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image - if len(map_vert_sorted_labels_2output) == 0: - # Get the index of the current landmark in the sorted vertebrae labels - start_l = vert_sorted_labels.index(l_vert) - - # Get the index of the current vert landmark in the list of all possible vertebrae output labels - start_o_def = all_default_vertebrae_output_labels.index(l_disc_vert_output) - - # Adjust the start indexes - start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) - - # Map the sorted vert labels to the output labels - for l, o in zip(vert_sorted_labels[start_l:], all_default_vertebrae_output_labels[start_o_def:]): - map_vert_sorted_labels_2output[l] = o - - # Get the index of the current landmark in the sorted vertebrae labels - start_l = vert_sorted_labels.index(l_vert) - - # Get the index of the current vert landmark in the list of all possible vertebrae output labels - start_o = all_possible_vertebrae_output_labels.index(l_disc_vert_output) - - # Map the sorted vert labels to the output labels - # This will ovveride the mapping from the previous landmarks for all labels inferior to the current landmark - for l, o in zip(vert_sorted_labels[start_l:], all_possible_vertebrae_output_labels[start_o:]): - map_vert_sorted_labels_2output[l] = o + l_vert_output = 0 + # We loop over all the labels starting from the most superior, and we map the vertebrae labels to the output labels + for idx, curr_l, curr_is_vert in zip(range(len(sorted_labels)), sorted_labels, is_vert): + if not curr_is_vert: # This is a disc + # Get the output label for the disc and vertebrae + l_disc_output = map_disc_sorted_labels_2output[curr_l] + l_vert_output = disc_output_labels_2vert[l_disc_output] + + if idx > 0 and len(map_vert_sorted_labels_2output) == 0: # This is the first disc + # Get the index of the current vertebrae in the default vertebrae output labels list + i = all_default_vertebrae_output_labels.index(l_vert_output) + + # Map all the vertebrae superior to the current disc to the default vertebrae output labels + for l, o in zip(sorted_labels[idx - 1::-1], all_default_vertebrae_output_labels[i - 1::-1]): + map_vert_sorted_labels_2output[l] = o + + elif l_vert_output > 0: # This is a vertebrae + map_vert_sorted_labels_2output[curr_l] = l_vert_output # Label the vertebrae with the output labels superio-inferior for l, o in map_vert_sorted_labels_2output.items(): From 7270f8325d1a016830135a4158f6fa1e45d6f0e5 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sun, 15 Sep 2024 15:37:17 +0300 Subject: [PATCH 19/26] Handle missing disc labels in vertebrae mapping Added a check to skip discs not present in the output map to prevent key errors. Also optimized the mapping of vertebrae superior to discs by isolating them before assigning default labels. Improves robustness of the labeling process. --- totalspineseg/utils/iterative_label.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 4f6ce91..e1dfbcf 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -638,6 +638,10 @@ def iterative_label( # We loop over all the labels starting from the most superior, and we map the vertebrae labels to the output labels for idx, curr_l, curr_is_vert in zip(range(len(sorted_labels)), sorted_labels, is_vert): if not curr_is_vert: # This is a disc + # If the current disc is not in the map, continue + if curr_l not in map_disc_sorted_labels_2output: + continue + # Get the output label for the disc and vertebrae l_disc_output = map_disc_sorted_labels_2output[curr_l] l_vert_output = disc_output_labels_2vert[l_disc_output] @@ -646,8 +650,11 @@ def iterative_label( # Get the index of the current vertebrae in the default vertebrae output labels list i = all_default_vertebrae_output_labels.index(l_vert_output) + # Get the labels of the vertebrae superior to the current disc + prev_vert_ls = [l for l, _is_v in zip(sorted_labels[idx - 1::-1], is_vert[idx - 1::-1]) if _is_v] + # Map all the vertebrae superior to the current disc to the default vertebrae output labels - for l, o in zip(sorted_labels[idx - 1::-1], all_default_vertebrae_output_labels[i - 1::-1]): + for l, o in zip(prev_vert_ls, all_default_vertebrae_output_labels[i - 1::-1]): map_vert_sorted_labels_2output[l] = o elif l_vert_output > 0: # This is a vertebrae From f17dc3b02b1f892a444c2756ef23b95cd75ab8ac Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:14:06 +0300 Subject: [PATCH 20/26] Refactor label sorting functions to improve code readability and robustness Replaced --map-input and --map-output options with more specific options for handling canal, cord, and sacrum labels. This enhances clarity and maintainability by directly addressing these anatomical structures. Introduced --canal-labels, --cord-labels, and --sacrum-labels arguments, along with corresponding output label options (--canal-output-label, --cord-output-label, --sacrum-output-label). Also refactored label sorting functions to improve code readability and robustness. These changes facilitate more accurate vertebrae and disc labeling in segmentation processes. Breaks backward compatibility with previous map options. Make sure to update scripts accordingly. --- totalspineseg/utils/iterative_label.py | 373 +++++++++++++++++-------- 1 file changed, 255 insertions(+), 118 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index e1dfbcf..9dc0422 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,12 +20,12 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --map-input 6:50 7:2 8:2 9:1 -r - iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --map-output 17:50 --map-input 14:50 15:2 16:2 17:1 -r - iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 2 5 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --map-input 6:50 7:2 8:2 9:1 --loc-disc-labels 60-100 -r - iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 4 7 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --map-output 17:50 --map-input 14:50 15:2 16:2 17:1 --loc-disc-labels 60-100 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 2 5 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 --loc-disc-labels 60-100 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 4 7 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 --loc-disc-labels 60-100 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --map-input 6:50 7:2 8:2 9:1 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -119,20 +119,28 @@ def main(): help='The disc labels in the localizer used for detecting first disc.' ) parser.add_argument( - '--map-input', type=str, nargs='+', default=[], - help=' '.join(f''' - A dict mapping labels from input into the output segmentation. - The format should be input_label:output_label without any spaces. - For example, 14:92 16:201 17:200 to map the input sacrum label 14 to 92, canal label 16 to 201 and spinal cord label 17 to 200. - '''.split()) + '--canal-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The canal labels in the segmentation.' ) parser.add_argument( - '--map-output', type=str, nargs='+', default=[], - help=' '.join(f''' - A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation. - The format should be input_label:output_label without any spaces. - For example, 17:92 to map the iteratively labeled vertebrae 17 to the sacrum label 92. - '''.split()) + '--canal-output-label', type=int, default=0, + help='Output label for the canal, defaults to 0 (Do not output).' + ) + parser.add_argument( + '--cord-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The spinal cord labels in the segmentation.' + ) + parser.add_argument( + '--cord-output-label', type=int, default=0, + help='Output label for the spinal cord, defaults to 0 (Do not output).' + ) + parser.add_argument( + '--sacrum-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], + help='The sacrum labels in the segmentation.' + ) + parser.add_argument( + '--sacrum-output-label', type=int, default=0, + help='Output label for the sacrum, defaults to 0 (Do not output).' ) parser.add_argument( '--dilation-size', type=int, default=1, @@ -179,8 +187,12 @@ def main(): vertebrae_extra_labels = [_ for __ in args.vertebrae_extra_labels for _ in (__ if isinstance(__, list) else [__])] region_max_sizes = args.region_max_sizes loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] - map_input_list = args.map_input - map_output_list = args.map_output + canal_labels = [_ for __ in args.canal_labels for _ in (__ if isinstance(__, list) else [__])] + canal_output_label = args.canal_output_label + cord_labels = [_ for __ in args.cord_labels for _ in (__ if isinstance(__, list) else [__])] + cord_output_label = args.cord_output_label + sacrum_labels = [_ for __ in args.sacrum_labels for _ in (__ if isinstance(__, list) else [__])] + sacrum_output_label = args.sacrum_output_label dilation_size = args.dilation_size default_superior_disc = args.default_superior_disc override = args.override @@ -211,8 +223,12 @@ def main(): vertebrae_extra_labels = {vertebrae_extra_labels} region_max_sizes = {region_max_sizes} loc_disc_labels = {loc_disc_labels} - map_input = {map_input_list} - map_output = {map_output_list} + canal_labels = {canal_labels} + canal_output_label = {canal_output_label} + cord_labels = {cord_labels} + cord_output_label = {cord_output_label} + sacrum_labels = {sacrum_labels} + sacrum_output_label = {sacrum_output_label} dilation_size = {dilation_size} default_superior_disc = {default_superior_disc} override = {override} @@ -220,17 +236,6 @@ def main(): quiet = {quiet} ''')) - # Load maps into a dict - try: - map_input_dict = {int(l_in): int(l_out) for l_in, l_out in map(lambda x:x.split(':'), map_input_list)} - except: - raise ValueError("Input param map_input is not in the right structure. Make sure it is in the right format, e.g., 1:2 3:5") - - try: - map_output_dict = {int(l_in): int(l_out) for l_in, l_out in map(lambda x:x.split(':'), map_output_list)} - except: - raise ValueError("Input param map_output is not in the right structure. Make sure it is in the right format, e.g., 1:2 3:5") - iterative_label_mp( segs_path=segs_path, output_segs_path=output_segs_path, @@ -252,8 +257,12 @@ def main(): vertebrae_extra_labels=vertebrae_extra_labels, region_max_sizes=region_max_sizes, loc_disc_labels=loc_disc_labels, - map_input_dict=map_input_dict, - map_output_dict=map_output_dict, + canal_labels=canal_labels, + canal_output_label=canal_output_label, + cord_labels=cord_labels, + cord_output_label=cord_output_label, + sacrum_labels=sacrum_labels, + sacrum_output_label=sacrum_output_label, dilation_size=dilation_size, default_superior_disc=default_superior_disc, override=override, @@ -282,8 +291,12 @@ def iterative_label_mp( vertebrae_extra_labels=[], region_max_sizes=[5, 12, 6, 1], loc_disc_labels=[], - map_input_dict={}, - map_output_dict={}, + canal_labels=[], + canal_output_label=0, + cord_labels=[], + cord_output_label=0, + sacrum_labels=[], + sacrum_output_label=0, dilation_size=1, default_superior_disc=0, override=False, @@ -323,8 +336,12 @@ def iterative_label_mp( vertebrae_extra_labels=vertebrae_extra_labels, region_max_sizes=region_max_sizes, loc_disc_labels=loc_disc_labels, - map_input_dict=map_input_dict, - map_output_dict=map_output_dict, + canal_labels=canal_labels, + canal_output_label=canal_output_label, + cord_labels=cord_labels, + cord_output_label=cord_output_label, + sacrum_labels=sacrum_labels, + sacrum_output_label=sacrum_output_label, dilation_size=dilation_size, default_superior_disc=default_superior_disc, override=override, @@ -352,8 +369,12 @@ def _iterative_label( vertebrae_extra_labels=[], region_max_sizes=[5, 12, 6, 1], loc_disc_labels=[], - map_input_dict={}, - map_output_dict={}, + canal_labels=[], + canal_output_label=0, + cord_labels=[], + cord_output_label=0, + sacrum_labels=[], + sacrum_output_label=0, dilation_size=1, default_superior_disc=0, override=False, @@ -388,8 +409,12 @@ def _iterative_label( vertebrae_extra_labels=vertebrae_extra_labels, region_max_sizes=region_max_sizes, loc_disc_labels=loc_disc_labels, - map_input_dict=map_input_dict, - map_output_dict=map_output_dict, + canal_labels=canal_labels, + canal_output_label=canal_output_label, + cord_labels=cord_labels, + cord_output_label=cord_output_label, + sacrum_labels=sacrum_labels, + sacrum_output_label=sacrum_output_label, dilation_size=dilation_size, disc_default_superior_output=default_superior_disc, ) @@ -425,8 +450,12 @@ def iterative_label( vertebrae_extra_labels=[], region_max_sizes=[5, 12, 6, 1], loc_disc_labels=[], - map_input_dict={}, - map_output_dict={}, + canal_labels=[], + canal_output_label=0, + cord_labels=[], + cord_output_label=0, + sacrum_labels=[], + sacrum_output_label=0, dilation_size=1, disc_default_superior_output=0, ): @@ -472,10 +501,18 @@ def iterative_label( The maximum number of discs/vertebrae for each region (Cervical from C3, Thoracic, Lumbar, Sacrum). loc_disc_labels : list Localizer labels to use for detecting first disc - map_input_dict : dict - A dict mapping labels from input into the output segmentation - map_output_dict : dict - A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation + canal_labels : list + Canal labels in the segmentation + canal_output_label : int + Output label for the canal + cord_labels : list + Spinal Cord labels in the segmentation + cord_output_label : int + Output label for the spinal cord + sacrum_labels : list + Sacrum labels in the segmentation + sacrum_output_label : int + Output label for the sacrum dilation_size : int Number of voxels to dilate before finding connected voxels to label default_superior_disc : int @@ -486,56 +523,72 @@ def iterative_label( nibabel.nifti1.Nifti1Image Segmentation image with labeled vertebrae, IVDs, Spinal Cord and canal ''' + # Region default sizes for the discs and vertebrae (Cervical, Thoracic, Lumbar, Sacrum) + region_default_sizes=[5, 12, 5, 1] + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) output_seg_data = np.zeros_like(seg_data) - # Region default sizes for the discs and vertebrae (Cervical, Thoracic, Lumbar, Sacrum) - region_default_sizes=[5, 12, 5, 1] + # Get the canal centerline indices to use for sorting the discs and vertebrae based on the prjection on the canal centerline + canal_centerline_indices = _get_canal_centerline_indices(seg_data, canal_labels + cord_labels) + + # Get the mask of the voxels anterior to the canal, this helps in sorting the vertebrae considering only the vertebrae body + mask_aterior_to_canal = _get_mask_aterior_to_canal(seg_data, canal_labels + cord_labels) # Get sorted connected components superio-inferior (SI) for the disc labels - disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( + disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indices = _get_si_sorted_components( seg, disc_labels, + canal_centerline_indices, + mask_aterior_to_canal, dilation_size, combine_labels=True, ) # Get sorted connected components superio-inferior (SI) for the vertebrae labels - vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _get_si_sorted_components( + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _get_si_sorted_components( seg, vertebrae_labels, + canal_centerline_indices, + mask_aterior_to_canal, dilation_size, ) # Combine sequential vertebrae labels if they have the same value in the original segmentation - vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_with_same_label( + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _merge_vertebrae_with_same_label( seg, vertebrae_labels, vert_mask_labeled, vert_num_labels, vert_sorted_labels, - vert_sorted_z_indexes, + vert_sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal, ) # Combine sequential vertebrae labels if there is no disc between them - vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_labels_with_no_disc_between( + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _merge_vertebrae_labels_with_no_disc_between( seg, vert_mask_labeled, vert_num_labels, vert_sorted_labels, - vert_sorted_z_indexes, - disc_sorted_z_indexes, + vert_sorted_z_indices, + disc_sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal, ) # Combine extra labels with adjacent vertebrae labels - vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_extra_labels_with_adjacent_vertebrae( + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indices = _merge_extra_labels_with_adjacent_vertebrae( seg, vert_mask_labeled, vert_num_labels, vert_sorted_labels, - vert_sorted_z_indexes, + vert_sorted_z_indices, vertebrae_extra_labels, + canal_centerline_indices, + mask_aterior_to_canal, ) # Get the landmark disc labels and output labels - {label in sorted labels: output label} @@ -558,7 +611,7 @@ def iterative_label( for i in range(s): all_possible_disc_output_labels.append(l + i * disc_output_step) - # Make a list containing all possible labels for the vertebrae ordered superio-inferior with the default region sizes + # Make a list containing all possible labels for the disc ordered superio-inferior with the default region sizes all_default_disc_output_labels = [] for l, s in zip(disc_landmark_output_labels, region_default_sizes): for i in range(s): @@ -570,7 +623,7 @@ def iterative_label( # We loop over all the landmarks starting from the most superior for l_disc in [_ for _ in disc_sorted_labels if _ in map_disc_sorted_labels_landmark2output]: - # If this is the most superior landmark, we have to adjust the start indexes to start from the most superior label in the image + # If this is the most superior landmark, we have to adjust the start indices to start from the most superior label in the image if len(map_disc_sorted_labels_2output) == 0: # Get the index of the current landmark in the sorted disc labels start_l = disc_sorted_labels.index(l_disc) @@ -578,7 +631,7 @@ def iterative_label( # Get the index of the current landmark in the list of all default disc output labels start_o_def = all_default_disc_output_labels.index(map_disc_sorted_labels_landmark2output[l_disc]) - # Adjust the start indexes + # Adjust the start indices start_l, start_o_def = max(0, start_l - start_o_def), max(0, start_o_def - start_l) # Map the sorted disc labels to the output labels @@ -622,11 +675,11 @@ def iterative_label( # Sort the combined disc+vert labels by their z-index sorted_labels = vert_sorted_labels + disc_sorted_labels - sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes + sorted_z_indices = vert_sorted_z_indices + disc_sorted_z_indices is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) + sorted_z_indices, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indices, sorted_labels, is_vert))[::-1]) # Make a dict mapping disc to vertebrae labels disc_output_labels_2vert = dict(zip(all_possible_disc_output_labels, all_possible_vertebrae_output_labels[2:])) @@ -664,34 +717,101 @@ def iterative_label( for l, o in map_vert_sorted_labels_2output.items(): output_seg_data[vert_mask_labeled == l] = o - # Use the map to map labels from the iteative algorithm output, to the final output - # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label - for orig, new in map_output_dict.items(): - if int(orig) in output_seg_data: - output_seg_data[output_seg_data == int(new)] = 0 - output_seg_data[output_seg_data == int(orig)] = int(new) - - # Use the map to map input labels to the final output - # This is useful to map the input sacrum, canal and spinal cord labels to the output labels - for orig, new in map_input_dict.items(): - if int(orig) in seg_data: - output_seg_data[output_seg_data == int(new)] = 0 - mask = seg_data == int(orig) - - # Map also all labels that are currently in the mask - # This is useful for example if we addedd from extra_labels to the sacrum and we want them to map together with the sacrum - mask_labes = [_ for _ in np.unique(output_seg_data[mask]) if _ != 0] - if len(mask_labes) > 0: - mask |= np.isin(output_seg_data, mask_labes) - output_seg_data[mask] = int(new) + # Map Spinal Canal to the output label + if canal_labels is not None and len(canal_labels) > 0 and canal_output_label > 0: + output_seg_data[np.isin(seg_data, canal_labels)] = canal_output_label + + # Map Spinal Cord to the output label + if cord_labels is not None and len(cord_labels) > 0 and cord_output_label > 0: + output_seg_data[np.isin(seg_data, cord_labels)] = cord_output_label + + # Map Sacrum to the output label + if sacrum_labels is not None and len(sacrum_labels) > 0 and sacrum_output_label > 0: + output_seg_data[np.isin(seg_data, sacrum_labels)] = sacrum_output_label output_seg = nib.Nifti1Image(output_seg_data, seg.affine, seg.header) return output_seg +def _get_canal_centerline_indices( + seg_data, + canal_labels=[], + ): + ''' + Get the indices of the canal centerline. + ''' + # Get array of indices for x, y, and z axes + indices = np.indices(seg_data.shape) + + # Create a mask of the canal + mask_canal = np.isin(seg_data, canal_labels) + + + # Create a mask the canal centerline by finding the middle voxels in x and y axes for each z index + mask_min_x_indices = np.min(indices[0], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1)) + mask_max_x_indices = np.max(indices[0], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1)) + mask_mid_x = indices[0] == ((mask_min_x_indices + mask_max_x_indices) // 2) + mask_min_y_indices = np.min(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1)) + mask_max_y_indices = np.max(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1)) + mask_mid_y = indices[1] == ((mask_min_y_indices + mask_max_y_indices) // 2) + mask_canal_centerline = mask_canal * mask_mid_x * mask_mid_y + + # Get the indices of the canal centerline + return np.array(np.nonzero(mask_canal_centerline)).T + +def _sort_labels_si( + mask_labeled, + labels, + canal_centerline_indices, + mask_aterior_to_canal=None, + ): + ''' + Sort the labels by their z-index (reversed to go from superior to inferior). + ''' + # Get the indices of the center of mass for each label + labels_indices = np.array(ndi.center_of_mass(np.isin(mask_labeled, labels), mask_labeled, labels)) + + # Get the distance of each label indices from the canal centerline + labels_distances_from_centerline = np.linalg.norm(labels_indices[:, None, :] - canal_centerline_indices[None, ...],axis=2) + + # Get the z-index of the closest canal centerline voxel for each label + labels_z_indices = canal_centerline_indices[np.argmin(labels_distances_from_centerline, axis=1), -1] + + # If mask_aterior_to_canal is provided, calculate the center of mass in this mask if the label is inside the mask + if mask_aterior_to_canal is not None: + # Save the existing labels z-index in a dict + labels_z_indices_dict = dict(zip(labels, labels_z_indices)) + + # Get the part that is anterior to the canal od mask_labeled + mask_labeled_aterior_to_canal = mask_aterior_to_canal * mask_labeled + + # Get the labels that contain voxels anterior to the canal + labels_masked = np.isin(labels, mask_labeled_aterior_to_canal) + + # Get the indices of the center of mass for each label + labels_masked_indices = np.array(ndi.center_of_mass(np.isin(mask_labeled_aterior_to_canal, labels_masked), mask_labeled_aterior_to_canal, labels_masked)) + + # Get the distance of each label indices for each voxel in the canal centerline + labels_masked_distances_from_centerline = np.linalg.norm(labels_masked_indices[:, None, :] - canal_centerline_indices[None, :],axis=2) + + # Get the z-index of the closest canal centerline voxel for each label + labels_masked_z_indices = canal_centerline_indices[np.argmin(labels_masked_distances_from_centerline, axis=1), -1] + + # Update the dict with the new z-index of the labels anterior to the canal + for l, z in zip(labels_masked, labels_masked_z_indices): + labels_z_indices_dict[l] = z + + # Update the labels_z_indices from the dict + labels_z_indices = [labels_z_indices_dict[l] for l in labels] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + return zip(*sorted(zip(labels_z_indices, labels))[::-1]) + def _get_si_sorted_components( seg, labels, + canal_centerline_indices, + mask_aterior_to_canal=None, dilation_size=1, combine_labels=False, ): @@ -743,14 +863,31 @@ def _get_si_sorted_components( elif mask_labeled.max() < np.iinfo(np.uint16).max: mask_labeled = mask_labeled.astype(np.uint16) - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, range(1,num_labels+1), canal_centerline_indices, mask_aterior_to_canal + ) + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) + +def _get_mask_aterior_to_canal( + seg_data, + canal_labels=[], + ): + ''' + Get the mask of the voxels anterior to the canal. + ''' + # Get array of indices for x, y, and z axes + indices = np.indices(seg_data.shape) - return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + # Create a mask of the canal + mask_canal = np.isin(seg_data, canal_labels) + + # Create a mask the canal centerline by finding the middle voxels in x and y axes for each z index + mask_min_y_indices = np.min(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1)) + mask_max_y_indices = np.max(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1)) + mask_mid_y_indices = (mask_min_y_indices + mask_max_y_indices) // 2 + + return indices[1] > mask_mid_y_indices def _merge_vertebrae_with_same_label( seg, @@ -758,14 +895,16 @@ def _merge_vertebrae_with_same_label( mask_labeled, num_labels, sorted_labels, - sorted_z_indexes, + sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal=None, ): ''' Combine sequential vertebrae labels if they have the same value in the original segmentation. This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value. ''' if num_labels == 0 or len(labels) <= 1: - return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + return mask_labeled, num_labels, sorted_labels, sorted_z_indices seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) @@ -798,37 +937,37 @@ def _merge_vertebrae_with_same_label( elif mask_labeled.max() < np.iinfo(np.uint16).max: mask_labeled = mask_labeled.astype(np.uint16) - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, sorted_labels))[::-1]) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, sorted_labels, canal_centerline_indices, mask_aterior_to_canal + ) - return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) def _merge_vertebrae_labels_with_no_disc_between( seg, mask_labeled, num_labels, sorted_labels, - sorted_z_indexes, - disc_sorted_z_indexes, + sorted_z_indices, + disc_sorted_z_indices, + canal_centerline_indices, + mask_aterior_to_canal, ): ''' Combine sequential vertebrae labels if there is no disc between them. ''' - if num_labels == 0 or len(disc_sorted_z_indexes) == 0: - return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + if num_labels == 0 or len(disc_sorted_z_indices) == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indices new_sorted_labels = [] # Store the previous label and the z index of the previous label prev_l, prev_z = 0, 0 - for l, z in zip(sorted_labels, sorted_z_indexes): + for l, z in zip(sorted_labels, sorted_z_indices): # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process - if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): + if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indices): # Combine the current label with the previous label mask_labeled[mask_labeled == l] = prev_l num_labels -= 1 @@ -846,29 +985,29 @@ def _merge_vertebrae_labels_with_no_disc_between( elif mask_labeled.max() < np.iinfo(np.uint16).max: mask_labeled = mask_labeled.astype(np.uint16) - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, sorted_labels))[::-1]) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, sorted_labels, canal_centerline_indices, mask_aterior_to_canal + ) - return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) def _merge_extra_labels_with_adjacent_vertebrae( seg, mask_labeled, num_labels, sorted_labels, - sorted_z_indexes, + sorted_z_indices, extra_labels, + canal_centerline_indices, + mask_aterior_to_canal, ): ''' Combine extra labels with adjacent vertebrae labels. This is useful for combining remaining of general vertebrae labels that introduce for region based training but not used in the final segmentation. ''' if num_labels == 0 or len(extra_labels) == 0: - return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + return mask_labeled, num_labels, sorted_labels, sorted_z_indices seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) @@ -883,14 +1022,12 @@ def _merge_extra_labels_with_adjacent_vertebrae( # Add the intersection of the mask with the extra labels to the current verebrae mask_labeled[mask_extra * mask] = sorted_labels[i] - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, sorted_labels))[::-1]) + sorted_z_indices, sorted_labels = _sort_labels_si( + mask_labeled, sorted_labels, canal_centerline_indices, mask_aterior_to_canal + ) - return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indices) def _get_landmark_output_labels( seg, From 9203640cf5bd8d95efc4c91bcd903d28d344e3d0 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:28:22 +0300 Subject: [PATCH 21/26] Dynamically retrieve nnUNet parameters from results Updated the script to dynamically retrieve `nnUNetTrainer`, `nnUNetPlans`, and `configuration` parameters from the results folder instead of hardcoding them. This enhances flexibility and adaptability to different environments and configurations, ensuring the code works seamlessly across various setups. --- totalspineseg/inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 6fa8400..de25935 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -125,9 +125,6 @@ def main(): nnUNet_exports.mkdir(parents=True, exist_ok=True) # Set the nnUNet variables - nnUNetTrainer = 'nnUNetTrainer_16000epochs' - nnUNetPlans = 'nnUNetPlans' - configuration = '3d_fullres_small' step1_dataset = 'Dataset101_TotalSpineSeg_step1' step2_dataset = 'Dataset102_TotalSpineSeg_step2' fold = 0 @@ -296,6 +293,8 @@ def main(): quiet=quiet, ) + # Get the nnUNet parameters from the results folder + nnUNetTrainer, nnUNetPlans, configuration = next((nnUNet_results / step1_dataset).glob('*/fold_*')).parent.name.split('__') # Check if the final checkpoint exists, if not use the latest checkpoint checkpoint = 'checkpoint_final.pth' if (nnUNet_results / step1_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' / f'fold_{fold}' / 'checkpoint_final.pth').is_file() else 'checkpoint_latest.pth' @@ -526,6 +525,8 @@ def main(): if not f.with_name(f.name.replace('_0000.nii.gz', '_0001.nii.gz')).exists(): f.unlink() + # Get the nnUNet parameters from the results folder + nnUNetTrainer, nnUNetPlans, configuration = next((nnUNet_results / step2_dataset).glob('*/fold_*')).parent.name.split('__') # Check if the final checkpoint exists, if not use the latest checkpoint checkpoint = 'checkpoint_final.pth' if (nnUNet_results / step2_dataset / f'{nnUNetTrainer}__{nnUNetPlans}__{configuration}' / f'fold_{fold}' / 'checkpoint_final.pth').is_file() else 'checkpoint_latest.pth' From c6faf6ad30236034fc7804dc5be2d2651e45eca3 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:41:19 +0300 Subject: [PATCH 22/26] Refactor label mapping to match labeling algo Replaced `map_input_dict` and `map_output_dict` with more descriptive structures: `canal_labels`, `canal_output_label`, `cord_labels`, `cord_output_label`, `sacrum_labels`, and `sacrum_output_label`. --- totalspineseg/inference.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index de25935..64f4f7a 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -347,7 +347,12 @@ def main(): disc_labels=[1, 2, 3, 4, 5], disc_landmark_labels=[2, 3, 4, 5], disc_landmark_output_labels=[60, 70, 90, 100], - map_input_dict={6:50, 7:2, 8:2, 9:1}, + canal_labels=[7, 8], + canal_output_label=2, + cord_labels=[9], + cord_output_label=1, + sacrum_labels=[6], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -362,7 +367,12 @@ def main(): disc_landmark_labels=[2, 3, 4, 5], disc_landmark_output_labels=[60, 70, 90, 100], loc_disc_labels=list(range(60, 101)), - map_input_dict={6:50, 7:2, 8:2, 9:1}, + canal_labels=[7, 8], + canal_output_label=2, + cord_labels=[9], + cord_output_label=1, + sacrum_labels=[6], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -584,8 +594,12 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_landmark_output_labels=[12, 20, 40, 50], vertebrae_extra_labels=[8], - map_output_dict={17:50}, - map_input_dict={14:50, 15:2, 16:2, 17:1}, + canal_labels=[15, 16], + canal_output_label=2, + cord_labels=[17], + cord_output_label=1, + sacrum_labels=[14], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, @@ -603,8 +617,12 @@ def main(): vertebrae_landmark_output_labels=[12, 20, 40, 50], vertebrae_extra_labels=[8], loc_disc_labels=list(range(60, 101)), - map_output_dict={17:50}, - map_input_dict={14:50, 15:2, 16:2, 17:1}, + canal_labels=[15, 16], + canal_output_label=2, + cord_labels=[17], + cord_output_label=1, + sacrum_labels=[14], + sacrum_output_label=50, override=True, max_workers=max_workers, quiet=quiet, From 4b3440bc07cbc9a810b4a4111579fd6078e1d588 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 20:01:22 +0300 Subject: [PATCH 23/26] Refine help messaging for localizer segmentation Updated help messages to improve clarity on the usage of localizer segmentations in detecting vertebrae and discs. Specifically, corrected explanations related to alignment requirements and voxel-based matching, as well as refined descriptions of the folder's content. These changes enhance user understanding of the localizer-based detection process, reducing potential misconfigurations and usage errors. --- totalspineseg/inference.py | 4 ++-- totalspineseg/utils/iterative_label.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 64f4f7a..79420c5 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -40,8 +40,8 @@ def main(): help=' '.join(f''' Folder containing localizers segmentations or a single .nii.gz localizer segmentation to use for detecting first vertebrae and disc if C1 and C2-C3 disc or the Sacrum and L5-S disc not found in the image, Optional. This is the output of the model applied on localizer images. It can be the output of step 2, or step 1 if you only want to run step 1 (step1 flag). - The algorithm will transform the localizer to the segmentation space and use it to detect the matching vertebrae and discs. - Mathcing will based on the magority of the voxels of the first vertebrae or disc in the localizer, that intersect with image. + The algorithm will use the localizers' segmentations to detect the matching vertebrae and discs. The localizer and the image must be aligned. + Matching will based on the majority of the voxels of the first vertebra or disc in the localizer, that intersect with image. The file names should be in match with the image file names, or you can use the --suffix and --loc-suffix to match the files. '''.split()) ) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 9dc0422..0f2076c 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -41,7 +41,7 @@ def main(): parser.add_argument( '--locs-dir', '-l', type=Path, default=None, help=' '.join(f''' - Folder containing localizers segmentations to use for detecting first vertebrae and disc if init label not found, Optional. + Folder containing localizers' segmentations to help the labeling if landmarks not found, Optional. The algorithm will transform the localizer to the segmentation space and use it to detect the matching vertebrae and disc if the init label not found. Mathcing will based on the magority of the voxels of the first vertebrae or disc in the localizer, that intersect with the input segmentation. '''.split()) From 4eab269ac2ce2276f664e1296427f54b0c2a3c2e Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 20:19:16 +0300 Subject: [PATCH 24/26] Add support for C1 vertebra label in extract_levels Enhanced the spine segmentation utility to optionally include a label for the C1 vertebra. This allows the application to accurately identify and process C1 if provided, improving segmentation versatility and accuracy. Adjusted the logic to handle cases where C1 is present in the data. This new feature is controlled via an additional command-line argument. --- totalspineseg/utils/extract_levels.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/totalspineseg/utils/extract_levels.py b/totalspineseg/utils/extract_levels.py index 3262171..2fef326 100644 --- a/totalspineseg/utils/extract_levels.py +++ b/totalspineseg/utils/extract_levels.py @@ -65,6 +65,10 @@ def main(): '--disc-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', required=True, help='The disc labels starting at C2C3 ordered from superior to inferior.' ) + parser.add_argument( + '--c1-label', type=int, default=0, + help='The label for C1 vertebra in the segmentation, if provided it will be used to determine if C1 is in the segmentation.' + ) parser.add_argument( '--override', '-r', action="store_true", default=False, help='Override existing output files, defaults to false (Do not override).' @@ -91,6 +95,7 @@ def main(): output_seg_suffix = args.output_seg_suffix canal_labels = args.canal_labels disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] + c1_label = args.c1_label override = args.override max_workers = args.max_workers quiet = args.quiet @@ -108,6 +113,7 @@ def main(): output_seg_suffix = "{output_seg_suffix}" canal_labels = {canal_labels} disc_labels = {disc_labels} + c1_label = {c1_label} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -123,6 +129,7 @@ def main(): output_seg_suffix=output_seg_suffix, canal_labels=canal_labels, disc_labels=disc_labels, + c1_label=c1_label, override=override, max_workers=max_workers, quiet=quiet, @@ -138,6 +145,7 @@ def extract_levels_mp( output_seg_suffix='', canal_labels=[], disc_labels=[], + c1_label=0, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -164,6 +172,7 @@ def extract_levels_mp( _extract_levels, canal_labels=canal_labels, disc_labels=disc_labels, + c1_label=c1_label, override=override, ), seg_path_list, @@ -178,6 +187,7 @@ def _extract_levels( output_seg_path, canal_labels=[], disc_labels=[], + c1_label=0, override=False, ): ''' @@ -197,7 +207,8 @@ def _extract_levels( output_seg = extract_levels( seg, canal_labels=canal_labels, - disc_labels=disc_labels + disc_labels=disc_labels, + c1_label=c1_label, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -221,6 +232,7 @@ def extract_levels( seg, canal_labels=[], disc_labels=[], + c1_label=0, ): ''' Extract vertebrae levels from Spinal Canal and Discs. @@ -236,6 +248,8 @@ def extract_levels( The canal labels. disc_labels : list The disc labels starting at C2C3 ordered from superior to inferior. + c1_label : int + The label for C1 vertebra in the segmentation, if provided it will be used to determine if C1 is in the segmentation. Returns ------- @@ -308,8 +322,8 @@ def extract_levels( # Find the location of the superior voxels in the canal centerline canal_superior_index = np.unravel_index(np.argmax(mask_canal_centerline * indices[2]), seg_data.shape) - if canal_superior_index[2] - c2c3_index[2] >= 8 and output_seg_data.shape[2] - canal_superior_index[2] >= 2: - # If C2-C3 at least 8 voxels below the top of the canal and the top of the canal is at least 2 voxels from the top of the image + if (c1_label > 0 and c1_label in seg_data) or (c1_label == 0 and canal_superior_index[2] - c2c3_index[2] >= 8 and output_seg_data.shape[2] - canal_superior_index[2] >= 2): + # If C1 is in the segmentation or C2-C3 at least 8 voxels below the top of the canal and the top of the canal is at least 2 voxels from the top of the image # Set 1 to the superior voxels output_seg_data[canal_superior_index] = 1 From 7b890b2230348aa4748c66e49b5fc421f8a84d5e Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 21:21:44 +0300 Subject: [PATCH 25/26] Revised spine landmark labels for consistency Updated the labeling scheme for vertebrae and discs in README and inference scripts. Labels now follow a sequential and standardized format, improving clarity and consistency across the codebase and documentation. Adjusted ranges and references to reflect the new label assignments accurately. --- README.md | 92 +++++++++++++------------- totalspineseg/inference.py | 60 ++++++++--------- totalspineseg/utils/iterative_label.py | 10 +-- 3 files changed, 81 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 401425c..c491ade 100644 --- a/README.md +++ b/README.md @@ -224,51 +224,51 @@ For a more detailed view of the output examples, you can check the [PDF version] |:------|:-----| | 1 | spinal_cord | | 2 | spinal_canal | -| 10 | vertebrae_C1 | -| 11 | vertebrae_C2 | -| 12 | vertebrae_C3 | -| 13 | vertebrae_C4 | -| 14 | vertebrae_C5 | -| 15 | vertebrae_C6 | -| 16 | vertebrae_C7 | -| 20 | vertebrae_T1 | -| 21 | vertebrae_T2 | -| 22 | vertebrae_T3 | -| 23 | vertebrae_T4 | -| 24 | vertebrae_T5 | -| 25 | vertebrae_T6 | -| 26 | vertebrae_T7 | -| 27 | vertebrae_T8 | -| 28 | vertebrae_T9 | -| 29 | vertebrae_T10 | -| 30 | vertebrae_T11 | -| 31 | vertebrae_T12 | -| 40 | vertebrae_L1 | -| 41 | vertebrae_L2 | -| 42 | vertebrae_L3 | -| 43 | vertebrae_L4 | -| 44 | vertebrae_L5 | +| 11 | vertebrae_C1 | +| 12 | vertebrae_C2 | +| 13 | vertebrae_C3 | +| 14 | vertebrae_C4 | +| 15 | vertebrae_C5 | +| 16 | vertebrae_C6 | +| 17 | vertebrae_C7 | +| 21 | vertebrae_T1 | +| 22 | vertebrae_T2 | +| 23 | vertebrae_T3 | +| 24 | vertebrae_T4 | +| 25 | vertebrae_T5 | +| 26 | vertebrae_T6 | +| 27 | vertebrae_T7 | +| 28 | vertebrae_T8 | +| 29 | vertebrae_T9 | +| 30 | vertebrae_T10 | +| 31 | vertebrae_T11 | +| 32 | vertebrae_T12 | +| 41 | vertebrae_L1 | +| 42 | vertebrae_L2 | +| 43 | vertebrae_L3 | +| 44 | vertebrae_L4 | +| 45 | vertebrae_L5 | | 50 | sacrum | -| 60 | disc_C2_C3 | -| 61 | disc_C3_C4 | -| 62 | disc_C4_C5 | -| 63 | disc_C5_C6 | -| 64 | disc_C6_C7 | -| 70 | disc_C7_T1 | -| 71 | disc_T1_T2 | -| 72 | disc_T2_T3 | -| 73 | disc_T3_T4 | -| 74 | disc_T4_T5 | -| 75 | disc_T5_T6 | -| 76 | disc_T6_T7 | -| 77 | disc_T7_T8 | -| 78 | disc_T8_T9 | -| 79 | disc_T9_T10 | -| 80 | disc_T10_T11 | -| 81 | disc_T11_T12 | -| 90 | disc_T12_L1 | -| 91 | disc_L1_L2 | -| 92 | disc_L2_L3 | -| 93 | disc_L3_L4 | -| 94 | disc_L4_L5 | +| 63 | disc_C2_C3 | +| 64 | disc_C3_C4 | +| 65 | disc_C4_C5 | +| 66 | disc_C5_C6 | +| 67 | disc_C6_C7 | +| 71 | disc_C7_T1 | +| 72 | disc_T1_T2 | +| 73 | disc_T2_T3 | +| 74 | disc_T3_T4 | +| 75 | disc_T4_T5 | +| 76 | disc_T5_T6 | +| 77 | disc_T6_T7 | +| 78 | disc_T7_T8 | +| 79 | disc_T8_T9 | +| 80 | disc_T9_T10 | +| 81 | disc_T10_T11 | +| 82 | disc_T11_T12 | +| 91 | disc_T12_L1 | +| 92 | disc_L1_L2 | +| 93 | disc_L2_L3 | +| 94 | disc_L3_L4 | +| 95 | disc_L4_L5 | | 100 | disc_L5_S | diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 79420c5..aceb487 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -239,16 +239,16 @@ def main(): max_workers=max_workers, quiet=quiet, label_texts_right={ - 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', - 20: 'T1', 21: 'T2', 22: 'T3', 23: 'T4', 24: 'T5', 25: 'T6', 26: 'T7', - 27: 'T8', 28: 'T9', 29: 'T10', 30: 'T11', 31: 'T12', - 40: 'L1', 41: 'L2', 42: 'L3', 43: 'L4', 44: 'L5' + 11: 'C1', 12: 'C2', 13: 'C3', 14: 'C4', 15: 'C5', 16: 'C6', 17: 'C7', + 21: 'T1', 22: 'T2', 23: 'T3', 24: 'T4', 25: 'T5', 26: 'T6', 27: 'T7', + 28: 'T8', 29: 'T9', 30: 'T10', 31: 'T11', 32: 'T12', + 41: 'L1', 42: 'L2', 43: 'L3', 44: 'L4', 45: 'L5', }, label_texts_left={ - 50: 'Sacrum', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', - 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', - 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', - 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + 50: 'Sacrum', 63: 'C2C3', 64: 'C3C4', 65: 'C4C5', 66: 'C5C6', 67: 'C6C7', + 71: 'C7T1', 72: 'T1T2', 73: 'T2T3', 74: 'T3T4', 75: 'T4T5', 76: 'T5T6', 77: 'T6T7', + 78: 'T7T8', 79: 'T8T9', 80: 'T9T10', 81: 'T10T11', 82: 'T11T12', + 91: 'T12L1', 92: 'L1L2', 93: 'L2L3', 94: 'L3L4', 95: 'L4L5', 100: 'L5S' }, ) @@ -346,7 +346,7 @@ def main(): selected_disc_landmarks=[2, 5, 3, 4], disc_labels=[1, 2, 3, 4, 5], disc_landmark_labels=[2, 3, 4, 5], - disc_landmark_output_labels=[60, 70, 90, 100], + disc_landmark_output_labels=[63, 71, 91, 100], canal_labels=[7, 8], canal_output_label=2, cord_labels=[9], @@ -365,8 +365,8 @@ def main(): selected_disc_landmarks=[2, 5], disc_labels=[1, 2, 3, 4, 5], disc_landmark_labels=[2, 3, 4, 5], - disc_landmark_output_labels=[60, 70, 90, 100], - loc_disc_labels=list(range(60, 101)), + disc_landmark_output_labels=[63, 71, 91, 100], + loc_disc_labels=list(range(63, 101)), canal_labels=[7, 8], canal_output_label=2, cord_labels=[9], @@ -421,10 +421,10 @@ def main(): max_workers=max_workers, quiet=quiet, label_texts_left={ - 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', - 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', - 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', - 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + 63: 'C2C3', 64: 'C3C4', 65: 'C4C5', 66: 'C5C6', 67: 'C6C7', + 71: 'C7T1', 72: 'T1T2', 73: 'T2T3', 74: 'T3T4', 75: 'T4T5', 76: 'T5T6', 77: 'T6T7', + 78: 'T7T8', 79: 'T8T9', 80: 'T9T10', 81: 'T10T11', 82: 'T11T12', + 91: 'T12L1', 92: 'L1L2', 93: 'L2L3', 94: 'L3L4', 95: 'L4L5', 100: 'L5S' }, ) @@ -465,7 +465,7 @@ def main(): output_path / 'step1_output', output_path / 'step1_levels', canal_labels=[1, 2], - disc_labels=list(range(60, 65)) + list(range(70, 82)) + list(range(90, 95)) + [100], + disc_labels=list(range(63, 68)) + list(range(71, 83)) + list(range(91, 96)) + [100], override=True, max_workers=max_workers, quiet=quiet, @@ -512,7 +512,7 @@ def main(): output_path / 'step2_input', seg_suffix='_0001', output_seg_suffix='_0001', - labels=list(range(60, 101)), + labels=list(range(63, 101)), override=True, max_workers=max_workers, quiet=quiet, @@ -590,9 +590,9 @@ def main(): selected_disc_landmarks=[4, 7, 5, 6], disc_labels=[1, 2, 3, 4, 5, 6, 7], disc_landmark_labels=[4, 5, 6, 7], - disc_landmark_output_labels=[60, 70, 90, 100], + disc_landmark_output_labels=[63, 71, 91, 100], vertebrae_labels=[9, 10, 11, 12, 13, 14], - vertebrae_landmark_output_labels=[12, 20, 40, 50], + vertebrae_landmark_output_labels=[13, 21, 41, 50], vertebrae_extra_labels=[8], canal_labels=[15, 16], canal_output_label=2, @@ -612,11 +612,11 @@ def main(): selected_disc_landmarks=[4, 7], disc_labels=[1, 2, 3, 4, 5, 6, 7], disc_landmark_labels=[4, 5, 6, 7], - disc_landmark_output_labels=[60, 70, 90, 100], + disc_landmark_output_labels=[63, 71, 91, 100], vertebrae_labels=[9, 10, 11, 12, 13, 14], - vertebrae_landmark_output_labels=[12, 20, 40, 50], + vertebrae_landmark_output_labels=[13, 21, 41, 50], vertebrae_extra_labels=[8], - loc_disc_labels=list(range(60, 101)), + loc_disc_labels=list(range(63, 101)), canal_labels=[15, 16], canal_output_label=2, cord_labels=[17], @@ -671,16 +671,16 @@ def main(): max_workers=max_workers, quiet=quiet, label_texts_right={ - 10: 'C1', 11: 'C2', 12: 'C3', 13: 'C4', 14: 'C5', 15: 'C6', 16: 'C7', - 20: 'T1', 21: 'T2', 22: 'T3', 23: 'T4', 24: 'T5', 25: 'T6', 26: 'T7', - 27: 'T8', 28: 'T9', 29: 'T10', 30: 'T11', 31: 'T12', - 40: 'L1', 41: 'L2', 42: 'L3', 43: 'L4', 44: 'L5' + 11: 'C1', 12: 'C2', 13: 'C3', 14: 'C4', 15: 'C5', 16: 'C6', 17: 'C7', + 21: 'T1', 22: 'T2', 23: 'T3', 24: 'T4', 25: 'T5', 26: 'T6', 27: 'T7', + 28: 'T8', 29: 'T9', 30: 'T10', 31: 'T11', 32: 'T12', + 41: 'L1', 42: 'L2', 43: 'L3', 44: 'L4', 45: 'L5', }, label_texts_left={ - 50: 'Sacrum', 60: 'C2C3', 61: 'C3C4', 62: 'C4C5', 63: 'C5C6', 64: 'C6C7', 70: 'C7T1', - 71: 'T1T2', 72: 'T2T3', 73: 'T3T4', 74: 'T4T5', 75: 'T5T6', 76: 'T6T7', 77: 'T7T8', - 78: 'T8T9', 79: 'T9T10', 80: 'T10T11', 81: 'T11T12', 90: 'T12L1', - 91: 'L1L2', 92: 'L2L3', 93: 'L3L4', 94: 'L4L5', 100: 'L5S' + 50: 'Sacrum', 63: 'C2C3', 64: 'C3C4', 65: 'C4C5', 66: 'C5C6', 67: 'C6C7', + 71: 'C7T1', 72: 'T1T2', 73: 'T2T3', 74: 'T3T4', 75: 'T4T5', 76: 'T5T6', 77: 'T6T7', + 78: 'T7T8', 79: 'T8T9', 80: 'T9T10', 81: 'T10T11', 82: 'T11T12', + 91: 'T12L1', 92: 'L1L2', 93: 'L2L3', 94: 'L3L4', 95: 'L4L5', 100: 'L5S' }, ) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 0f2076c..7e33354 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,12 +20,12 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r - iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 -r - iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 2 5 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 --loc-disc-labels 60-100 -r - iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 4 7 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 60 70 90 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 12 20 40 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 --loc-disc-labels 60-100 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 63 71 91 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r + iterative_label -s labels_init -o labels --selected-disc-landmarks 2 5 3 4 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 63 71 91 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 13 21 41 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 2 5 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 63 71 91 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 --loc-disc-labels 63-100 -r + iterative_label -s labels_init -o labels -l localizers --selected-disc-landmarks 4 7 --disc-labels 1-7 --disc-landmark-labels 4 5 6 7 --disc-landmark-output-labels 63 71 91 100 --vertebrae-labels 9-14 --vertebrae-landmark-output-labels 13 21 41 50 --vertebrae-extra-labels 8 --canal-labels 15 16 --canal-output-label 2 --cord-labels 17 --cord-output-label 1 --sacrum-labels 14 --sacrum-output-label 50 --loc-disc-labels 63-100 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 60 70 90 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --selected-disc-landmarks 2 5 3 4 --disc-labels 1-5 --disc-landmark-labels 2 3 4 5 --disc-landmark-output-labels 63 71 91 100 --canal-labels 7 8 --canal-output-label 2 --cord-labels 9 --cord-output-label 1 -r '''), formatter_class=argparse.RawTextHelpFormatter ) From dbad232ee17bb419a20a2c570c530d7e30273a0d Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Sat, 21 Sep 2024 21:32:49 +0300 Subject: [PATCH 26/26] Simplify axis index array creation Refactored index array creation for x, y, z axes using `np.indices` to streamline code and reduce redundancy. This enhances readability and maintenance, reducing the risk of errors associated with broadcasting and manual index array creation. No functional changes to the logic were introduced. --- totalspineseg/utils/iterative_label.py | 27 +++++++++++--------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 7e33354..361b85a 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -1120,25 +1120,20 @@ def _fill(mask): np.ndarray Binary mask with holes filled ''' + # Get array of indices for x, y, and z axes + indices = np.indices(mask.shape) - # Create an array of x indices with the same shape as the mask - x_indices = np.broadcast_to(np.arange(mask.shape[0])[..., np.newaxis, np.newaxis], mask.shape) - # Create an array of y indices with the same shape as the mask - y_indices = np.broadcast_to(np.arange(mask.shape[1])[..., np.newaxis], mask.shape) - # Create an array of z indices with the same shape as the mask - z_indices = np.broadcast_to(np.arange(mask.shape[2]), mask.shape) - - mask_min_x = np.min(np.where(mask, x_indices, np.inf), axis=0)[np.newaxis, ...] - mask_max_x = np.max(np.where(mask, x_indices, -np.inf), axis=0)[np.newaxis, ...] - mask_min_y = np.min(np.where(mask, y_indices, np.inf), axis=1)[:, np.newaxis, :] - mask_max_y = np.max(np.where(mask, y_indices, -np.inf), axis=1)[:, np.newaxis, :] - mask_min_z = np.min(np.where(mask, z_indices, np.inf), axis=2)[:, :, np.newaxis] - mask_max_z = np.max(np.where(mask, z_indices, -np.inf), axis=2)[:, :, np.newaxis] + mask_min_x = np.min(np.where(mask, indices[0], np.inf), axis=0)[np.newaxis, ...] + mask_max_x = np.max(np.where(mask, indices[0], -np.inf), axis=0)[np.newaxis, ...] + mask_min_y = np.min(np.where(mask, indices[1], np.inf), axis=1)[:, np.newaxis, :] + mask_max_y = np.max(np.where(mask, indices[1], -np.inf), axis=1)[:, np.newaxis, :] + mask_min_z = np.min(np.where(mask, indices[2], np.inf), axis=2)[:, :, np.newaxis] + mask_max_z = np.max(np.where(mask, indices[2], -np.inf), axis=2)[:, :, np.newaxis] return \ - ((mask_min_x <= x_indices) & (x_indices <= mask_max_x)) | \ - ((mask_min_y <= y_indices) & (y_indices <= mask_max_y)) | \ - ((mask_min_z <= z_indices) & (z_indices <= mask_max_z)) + ((mask_min_x <= indices[0]) & (indices[0] <= mask_max_x)) | \ + ((mask_min_y <= indices[1]) & (indices[1] <= mask_max_y)) | \ + ((mask_min_z <= indices[2]) & (indices[2] <= mask_max_z)) if __name__ == '__main__': main() \ No newline at end of file