From 253ce5a59da858f5281155aced8ecb29fbf4d8be Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Wed, 11 Sep 2024 19:50:24 +0300 Subject: [PATCH 1/3] Refactor iterative labeling logic init from discs only Reorganized the spine labeling logic by separating the functionality into distinct functions for better readability and maintainability. Removed `init_vertebrae`, `step_diff_label`, `step_diff_disc`, and related parameters to streamline the process. Added new parameters `output_c2c3` and `output_c2` for setting specific output labels. This refactor facilitates easier adjustments and enhancements in the future, improving code usability. Resolves issue #50 --- totalspineseg/inference.py | 11 +- totalspineseg/utils/iterative_label.py | 572 +++++++++++++------------ 2 files changed, 313 insertions(+), 270 deletions(-) diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 86af7e5..c6ee540 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -547,11 +547,10 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_extra_labels=[8], init_disc={4:224, 7:202, 5:219, 6:207}, - init_vertebrae={11:40, 14:17, 12:34, 13:23}, - step_diff_label=True, - step_diff_disc=True, output_disc_step=-1, output_vertebrae_step=-1, + output_c2c3=224, + output_c2=40, map_output_dict={17:92}, map_input_dict={14:92, 15:201, 16:201, 17:200}, override=True, @@ -567,13 +566,11 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_extra_labels=[8], init_disc={4:224, 7:202}, - init_vertebrae={11:40, 14:17}, loc_disc_labels=list(range(202, 225)), - loc_vertebrae_labels=list(range(18, 42)) + [92], - step_diff_label=True, - step_diff_disc=True, output_disc_step=-1, output_vertebrae_step=-1, + output_c2c3=224, + output_c2=40, map_output_dict={17:92}, map_input_dict={14:92, 15:201, 16:201, 17:200}, override=True, diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 7f98196..1997088 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,10 +20,10 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r - iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --init-vertebrae 11:40 14:17 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --loc-vertebrae-labels 18-41 92 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r + iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -80,6 +80,10 @@ def main(): '--init-disc', type=lambda x:map(int, x.split(':')), nargs='+', default=[], help='Init labels list for disc ordered by priority (input_label:output_label !!without space!!). for example 4:224 5:219 6:202' ) + parser.add_argument( + '--output-c2c3', type=int, default=0, + help='The output label for C2C3, used to calculate the first vertebrae label, defaults to 0.' + ) parser.add_argument( '--output-disc-step', type=int, default=1, help='The step to take between disc labels in the output, defaults to 1.' @@ -97,17 +101,13 @@ def main(): help='Extra vertebrae labels to add to add to adjacent vertebrae labels.' ) parser.add_argument( - '--init-vertebrae', type=lambda x:map(int, x.split(':')), nargs='+', default=[], - help='Init labels list for vertebrae ordered by priority (input_label:output_label !!without space!!). for example 10:41 11:34 12:18' + '--output-c2', type=int, default=0, + help='The output label for C2, used to calculate the first vertebrae label, defaults to 0.' ) parser.add_argument( '--output-vertebrae-step', type=int, default=1, help='The step to take between vertebrae labels in the output, defaults to 1.' ) - parser.add_argument( - '--loc-vertebrae-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], - help='The disc labels in the localizer used for detecting first vertebrae.' - ) parser.add_argument( '--map-input', type=str, nargs='+', default=[], help=' '.join(f''' @@ -128,29 +128,10 @@ def main(): '--dilation-size', type=int, default=1, help='Number of voxels to dilate before finding connected voxels to label, defaults to 1 (No dilation).' ) - parser.add_argument( - '--step-diff-label', action="store_true", default=False, - help=' '.join(f''' - Make step only for different labels. When looping on the labels on the z axis, it will give a new label to the next label only - if it is different from the previous label. This is useful if there are labels for odd and even vertebrae, so the next label will - be for even vertebrae only if the previous label was odd. If it is still odd, it should give the same label. - '''.split()) - ) - parser.add_argument( - '--step-diff-disc', action="store_true", default=False, - help=' '.join(f''' - Make step only for different discs. When looping on the labels on the z axis, it will give a new label to the next label only - if there is a disc between them. This exclude the first and last vertebrae since it can be C1 or only contain the spinous process. - '''.split()) - ) parser.add_argument( '--default-superior-disc', type=int, default=0, help='Default superior disc label if no init label found, defaults to 0 (Raise error if init label not found).' ) - parser.add_argument( - '--default-superior-vertebrae', type=int, default=0, - help='Default superior vertebrae label if no init label found, defaults to 0 (Raise error if init label not found).' - ) parser.add_argument( '--override', '-r', action="store_true", default=False, help='Override existing output files, defaults to false (Do not override).' @@ -179,20 +160,17 @@ def main(): loc_suffix = args.loc_suffix disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] init_disc = dict(args.init_disc) + output_c2c3 = args.output_c2c3 output_disc_step = args.output_disc_step loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] vertebrae_labels = [_ for __ in args.vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] vertebrae_extra_labels = [_ for __ in args.vertebrae_extra_labels for _ in (__ if isinstance(__, list) else [__])] - init_vertebrae = dict(args.init_vertebrae) + output_c2 = args.output_c2 output_vertebrae_step = args.output_vertebrae_step - loc_vertebrae_labels = [_ for __ in args.loc_vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] map_input_list = args.map_input map_output_list = args.map_output dilation_size = args.dilation_size - step_diff_label = args.step_diff_label - step_diff_disc = args.step_diff_disc default_superior_disc = args.default_superior_disc - default_superior_vertebrae = args.default_superior_vertebrae override = args.override max_workers = args.max_workers quiet = args.quiet @@ -212,20 +190,17 @@ def main(): loc_suffix = "{loc_suffix}" disc_labels = {disc_labels} init_disc = {init_disc} + output_c2c3 = {output_c2c3} output_disc_step = {output_disc_step} loc_disc_labels = {loc_disc_labels} vertebrae_labels = {vertebrae_labels} vertebrae_extra_labels = {vertebrae_extra_labels} - init_vertebrae = {init_vertebrae} + output_c2 = {output_c2} output_vertebrae_step = {output_vertebrae_step} - loc_vertebrae_labels = {loc_vertebrae_labels} map_input = {map_input_list} map_output = {map_output_list} dilation_size = {dilation_size} - step_diff_label = {step_diff_label} - step_diff_disc = {step_diff_disc} default_superior_disc = {default_superior_disc} - default_superior_vertebrae = {default_superior_vertebrae} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -254,20 +229,17 @@ def main(): loc_suffix=loc_suffix, disc_labels=disc_labels, init_disc=init_disc, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, max_workers=max_workers, quiet=quiet, @@ -285,20 +257,17 @@ def iterative_label_mp( loc_suffix='', disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -326,21 +295,18 @@ def iterative_label_mp( partial( _iterative_label, disc_labels=disc_labels, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, init_disc=init_disc, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, ), seg_path_list, @@ -357,20 +323,17 @@ def _iterative_label( loc_path=None, disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, ): ''' @@ -394,20 +357,17 @@ def _iterative_label( loc, disc_labels=disc_labels, init_disc=init_disc, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -432,20 +392,17 @@ def iterative_label( loc=None, disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, ): ''' Label Vertebrae, IVDs, Spinal Cord and canal from init segmentation. @@ -467,6 +424,8 @@ def iterative_label( The disc labels init_disc : dict Init labels list for disc ordered by priority (input_label:output_label) + output_c2c3 : int + The output label for C2C3, used to calculate the first vertebrae label output_disc_step : int The step to take between disc labels in the output loc_disc_labels : list @@ -475,26 +434,18 @@ def iterative_label( The vertebrae labels vertebrae_extra_labels : list Extra vertebrae labels to add to add to adjacent vertebrae labels - init_vertebrae : dict - Init labels list for vertebrae ordered by priority (input_label:output_label) + output_c2 : int + The output label for C2, used to calculate the first vertebrae label output_vertebrae_step : int The step to take between vertebrae labels in the output - loc_vertebrae_labels : list - Localizer labels to use for detecting first vertebrae map_input_dict : dict A dict mapping labels from input into the output segmentation map_output_dict : dict A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation dilation_size : int Number of voxels to dilate before finding connected voxels to label - step_diff_label : bool - Make step only for different labels - step_diff_disc : bool - Make step only for different discs default_superior_disc : int Default superior disc label if no init label found - default_superior_vertebrae : int - Default superior vertebrae label if no init label found Returns ------- @@ -505,193 +456,61 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) - loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) - - # If localizer is provided, transform it to the segmentation space - if loc_data is not None: - loc_data = tio.Resample( - tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) - )( - tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) - ).data.numpy()[0, ...].astype(np.uint8) - - binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) - - # Arrays to store the z indexes of the discs sorted superior to inferior - disc_sorted_z_indexes = [] - - # We run the same iterative algorithm for discs and vertebrae - for labels, extra_labels, step, init, default_sup, loc_labels, is_vert in ( - (disc_labels, [], output_disc_step, init_disc, default_superior_disc, loc_disc_labels, False), - (vertebrae_labels, vertebrae_extra_labels, output_vertebrae_step, init_vertebrae, default_superior_vertebrae, loc_vertebrae_labels, True)): - - # Skip if no labels are provided - if len(labels) == 0: - continue - - if is_vert: - _labels = [[_] for _ in labels] - else: - # For discs, combine all labels before label continue voxels since the discs not touching each other - _labels = [labels] - - # Init labeled segmentation - mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 - - # For each label, find connected voxels and label them into separate labels - for l in _labels: - mask = np.isin(seg_data, l) - - # Dilate the mask to combine small disconnected regions - mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) - - # Label the connected voxels in the dilated mask into separate labels - tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) - - # Undo dilation - tmp_mask_labeled *= mask - - # Add current labels to the labeled segmentation - if tmp_num_labels > 0: - mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels - num_labels += tmp_num_labels - - # If no label found, raise error - if num_labels == 0: - raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + # Get sorted connected components superio-inferior (SI) for the disc and vertebrae labels + disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( + seg, + disc_labels, + dilation_size, + ) - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) + # Get sorted connected components superio-inferior (SI) for the vertebrae labels + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _get_si_sorted_components( + seg, + vertebrae_labels, + dilation_size, + combine_labels=True, + ) - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + # Combine sequential vertebrae labels based on some conditions + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_labels( + seg, + vertebrae_labels, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indexes, + disc_sorted_z_indexes, + vertebrae_extra_labels, + ) - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) - - # In this part we loop over the labels superior to inferior and combine sequential vertebrae labels based on some conditions - if is_vert: - # Combine sequential vertebrae labels if they have the same value in the original segmentation - # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value - if step_diff_label and len(labels) - len(init) > 1: - new_sorted_labels = [] - - # Store the previous label and the original label of the previous label - prev_l, prev_orig_label = 0, 0 - - # Loop over the sorted labels - for l in sorted_labels: - # Get the original label of the current label - curr_orig_label = seg_data[mask_labeled == l].flat[0] - - # Combine the current label with the previous label if they have the same original label - if curr_orig_label == prev_orig_label: - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_orig_label = l, curr_orig_label - - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] - - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - # Combine sequential vertebrae labels if there is no disc between them - if step_diff_disc and len(disc_sorted_z_indexes) > 0: - new_sorted_labels = [] - - # Store the previous label and the z index of the previous label - prev_l, prev_z = 0, 0 - - for l, z in zip(sorted_labels, sorted_z_indexes): - # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process - if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_z = l, z - - sorted_labels = new_sorted_labels - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - else: - # Save the z indexes of the discs - disc_sorted_z_indexes = sorted_z_indexes - - # Find the most superior label in the segmentation - first_label = 0 - for k, v in init.items(): - if k in seg_data: - first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) - break - - # If no init label found, set it from the localizer - if first_label == 0 and loc_data is not None: - # Make mask for the intersection of the localizer labels and the labels in the segmentation - mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) - - # Get the first label from sorted_labels that is in the localizer specified labels - mask_labeled_masked = mask * mask_labeled - first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) - - if first_sorted_labels_in_loc > 0: - # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it - loc_data_masked = mask * loc_data - target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) - # If target in map_output_dict reverse it from the reversed map - # TODO Edge case if multiple keys have the same value, not used in the current implementation - target = {v: k for k, v in map_output_dict.items()}.get(target, target) - first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) - - # If no init label found, set the default superior label - if first_label == 0 and default_sup > 0: - first_label = default_sup - - # If no init label found, print error - if first_label == 0: - raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - - # Combine extra labels with adjacent vertebrae labels - if len(extra_labels) > 0: - mask_extra = np.isin(seg_data, extra_labels) - - # Loop over vertebral labels (from inferior because the transverse process make it steal from above) - for i in range(num_labels - 1, -1, -1): - # Mkae mask for the current vertebrae with filling the holes and dilating it - mask = fill(mask_labeled == sorted_labels[i]) - mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) - - # Add the intersection of the mask with the extra labels to the current verebrae - mask_labeled[mask_extra * mask] = sorted_labels[i] - - # Set the output value for the current label - for i in range(num_labels): - output_seg_data[mask_labeled == sorted_labels[i]] = first_label + step * i + # Get the first disc label + disc_first_label = _get_first_label( + seg, + loc, + disc_mask_labeled, + disc_sorted_labels, + init_disc, + output_disc_step, + loc_disc_labels, + default_superior_disc, + map_output_dict, + ) + + # Sort the combined disc+vert labels by their z-index + sorted_labels = vert_sorted_labels + disc_sorted_labels + sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes + is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) + + # Get index of first vert in is_vert + # To cover cases with C1 and C2 we have to adjust the first vertebrae label by the position of the first disc + vert_first_label = output_c2 + (disc_first_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + for i in range(vert_num_labels): + output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = vert_first_label + output_vertebrae_step * i + for i in range(disc_num_labels): + output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = disc_first_label + output_disc_step * i # Use the map to map labels from the iteative algorithm output, to the final output # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label @@ -718,7 +537,234 @@ def iterative_label( return output_seg -def fill(mask): +def _get_si_sorted_components( + seg, + labels, + dilation_size=1, + combine_labels=False, + ): + ''' + Get sorted connected components superio-inferior (SI) for the given labels in the segmentation. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) + + # Skip if no labels are provided + if len(labels) == 0: + return None, 0, [], [] + + if combine_labels: + _labels = [[_] for _ in labels] + else: + # For discs, combine all labels before label continue voxels since the discs not touching each other + _labels = [labels] + + # Init labeled segmentation + mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 + + # For each label, find connected voxels and label them into separate labels + for l in _labels: + mask = np.isin(seg_data, l) + + # Dilate the mask to combine small disconnected regions + mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) + + # Label the connected voxels in the dilated mask into separate labels + tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) + + # Undo dilation + tmp_mask_labeled *= mask + + # Add current labels to the labeled segmentation + if tmp_num_labels > 0: + mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels + num_labels += tmp_num_labels + + # If no label found, raise error + if num_labels == 0: + raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _merge_vertebrae_labels( + seg, + labels, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indexes, + disc_sorted_z_indexes, + extra_labels, + ): + ''' + Combine sequential vertebrae labels based on some conditions. + ''' + if num_labels == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + # Combine sequential vertebrae labels if they have the same value in the original segmentation + # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value + if len(labels) > 1: + new_sorted_labels = [] + + # Store the previous label and the original label of the previous label + prev_l, prev_orig_label = 0, 0 + + # Loop over the sorted labels + for l in sorted_labels: + # Get the original label of the current label + curr_orig_label = seg_data[mask_labeled == l].flat[0] + + # Combine the current label with the previous label if they have the same original label + if curr_orig_label == prev_orig_label: + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_orig_label = l, curr_orig_label + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Combine sequential vertebrae labels if there is no disc between them + if len(disc_sorted_z_indexes) > 0: + new_sorted_labels = [] + + # Store the previous label and the z index of the previous label + prev_l, prev_z = 0, 0 + + for l, z in zip(sorted_labels, sorted_z_indexes): + # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process + if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_z = l, z + + sorted_labels = new_sorted_labels + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Combine extra labels with adjacent vertebrae labels + if len(extra_labels) > 0: + mask_extra = np.isin(seg_data, extra_labels) + + # Loop over vertebral labels (from inferior because the transverse process make it steal from above) + for i in range(num_labels - 1, -1, -1): + # Mkae mask for the current vertebrae with filling the holes and dilating it + mask = _fill(mask_labeled == sorted_labels[i]) + mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) + + # Add the intersection of the mask with the extra labels to the current verebrae + mask_labeled[mask_extra * mask] = sorted_labels[i] + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _get_first_label( + seg, + loc, + mask_labeled, + sorted_labels, + init, + step, + loc_labels, + default_superior, + map_output_dict, + ): + ''' + Get the first label for the iterative labeling algorithm. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) + + # If localizer is provided, transform it to the segmentation space + if loc_data is not None: + loc_data = tio.Resample( + tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) + )( + tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) + ).data.numpy()[0, ...].astype(np.uint8) + + # Find the most superior label in the segmentation + first_label = 0 + for k, v in init.items(): + if k in seg_data: + first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) + break + + # If no init label found, set it from the localizer + if first_label == 0 and loc_data is not None: + # Make mask for the intersection of the localizer labels and the labels in the segmentation + mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) + + # Get the first label from sorted_labels that is in the localizer specified labels + mask_labeled_masked = mask * mask_labeled + first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) + + if first_sorted_labels_in_loc > 0: + # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it + loc_data_masked = mask * loc_data + target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) + # If target in map_output_dict reverse it from the reversed map + # TODO Edge case if multiple keys have the same value, not used in the current implementation + target = {v: k for k, v in map_output_dict.items()}.get(target, target) + first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) + + # If no init label found, set the default superior label + if first_label == 0 and default_superior > 0: + first_label = default_superior + + # If no init label found, print error + if first_label == 0: + raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") + + return first_label + +def _fill(mask): ''' Fill holes in a binary mask From ae4e9edef9521c1991ec60313c02947b49a98030 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Thu, 12 Sep 2024 22:31:34 +0300 Subject: [PATCH 2/3] Improve label naming for clarity in iterative_label function Renamed variables and comments to improve clarity and understandability of the labeling process. Adjusted label computation logic to accurately reflect the sorting and labeling operations, especially for handling cases involving C1 and C2 vertebrae. This enhances readability and maintains consistency within the code base. No functional changes are expected. --- totalspineseg/utils/iterative_label.py | 36 ++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 1997088..e7de427 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -456,7 +456,7 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) - # Get sorted connected components superio-inferior (SI) for the disc and vertebrae labels + # Get sorted connected components superio-inferior (SI) for the disc labels disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( seg, disc_labels, @@ -484,7 +484,7 @@ def iterative_label( ) # Get the first disc label - disc_first_label = _get_first_label( + superior_disc_output_label = _get_superior_output_label( seg, loc, disc_mask_labeled, @@ -504,13 +504,17 @@ def iterative_label( # Sort the labels by their z-index (reversed to go from superior to inferior) sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) - # Get index of first vert in is_vert - # To cover cases with C1 and C2 we have to adjust the first vertebrae label by the position of the first disc - vert_first_label = output_c2 + (disc_first_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + # Get the superior output label for the vertebrae based on the superior disc label + # For C1 and C2 we have to adjust the first vertebrae label by the position of the first disc with substraction of (is_vert.index(False) - 1) * output_vertebrae_step + superior_vert_output_label = output_c2 + (superior_disc_output_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step + + # Label the vertebrae with the output labels superio-inferior for i in range(vert_num_labels): - output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = vert_first_label + output_vertebrae_step * i + output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = superior_vert_output_label + output_vertebrae_step * i + + # Label the discs with the output labels superio-inferior for i in range(disc_num_labels): - output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = disc_first_label + output_disc_step * i + output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = superior_disc_output_label + output_disc_step * i # Use the map to map labels from the iteative algorithm output, to the final output # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label @@ -703,7 +707,7 @@ def _merge_vertebrae_labels( return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) -def _get_first_label( +def _get_superior_output_label( seg, loc, mask_labeled, @@ -730,14 +734,14 @@ def _get_first_label( ).data.numpy()[0, ...].astype(np.uint8) # Find the most superior label in the segmentation - first_label = 0 + superior_output_label = 0 for k, v in init.items(): if k in seg_data: - first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) + superior_output_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) break # If no init label found, set it from the localizer - if first_label == 0 and loc_data is not None: + if superior_output_label == 0 and loc_data is not None: # Make mask for the intersection of the localizer labels and the labels in the segmentation mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) @@ -752,17 +756,17 @@ def _get_first_label( # If target in map_output_dict reverse it from the reversed map # TODO Edge case if multiple keys have the same value, not used in the current implementation target = {v: k for k, v in map_output_dict.items()}.get(target, target) - first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) + superior_output_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) # If no init label found, set the default superior label - if first_label == 0 and default_superior > 0: - first_label = default_superior + if superior_output_label == 0 and default_superior > 0: + superior_output_label = default_superior # If no init label found, print error - if first_label == 0: + if superior_output_label == 0: raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - return first_label + return superior_output_label def _fill(mask): ''' From f1182be3a19d8741c266833ba5d39cb99552cb92 Mon Sep 17 00:00:00 2001 From: Yehuda Warszawer <36595323+yw7@users.noreply.github.com> Date: Thu, 12 Sep 2024 22:46:24 +0300 Subject: [PATCH 3/3] Fix labeling logic in _get_superior_output_label() Refined the logic for determining the superior output label by replacing the use of the first element with the most frequent element. This ensures better accuracy in cases with multiple segmented labels and improves robustness in varying data scenarios. Addresses potential inaccuracies in segmentation outputs noted during recent evaluations. --- totalspineseg/utils/iterative_label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index e7de427..99295c4 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -737,7 +737,7 @@ def _get_superior_output_label( superior_output_label = 0 for k, v in init.items(): if k in seg_data: - superior_output_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) + superior_output_label = v - step * sorted_labels.index(np.argmax(np.bincount(mask_labeled[seg_data == k].flat))) break # If no init label found, set it from the localizer