diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index 86af7e5..c6ee540 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -547,11 +547,10 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_extra_labels=[8], init_disc={4:224, 7:202, 5:219, 6:207}, - init_vertebrae={11:40, 14:17, 12:34, 13:23}, - step_diff_label=True, - step_diff_disc=True, output_disc_step=-1, output_vertebrae_step=-1, + output_c2c3=224, + output_c2=40, map_output_dict={17:92}, map_input_dict={14:92, 15:201, 16:201, 17:200}, override=True, @@ -567,13 +566,11 @@ def main(): vertebrae_labels=[9, 10, 11, 12, 13, 14], vertebrae_extra_labels=[8], init_disc={4:224, 7:202}, - init_vertebrae={11:40, 14:17}, loc_disc_labels=list(range(202, 225)), - loc_vertebrae_labels=list(range(18, 42)) + [92], - step_diff_label=True, - step_diff_disc=True, output_disc_step=-1, output_vertebrae_step=-1, + output_c2c3=224, + output_c2=40, map_output_dict={17:92}, map_input_dict={14:92, 15:201, 16:201, 17:200}, override=True, diff --git a/totalspineseg/utils/iterative_label.py b/totalspineseg/utils/iterative_label.py index 7f98196..99295c4 100644 --- a/totalspineseg/utils/iterative_label.py +++ b/totalspineseg/utils/iterative_label.py @@ -20,10 +20,10 @@ def main(): '''.split()), epilog=textwrap.dedent(''' Examples: - iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r - iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --init-vertebrae 11:40 14:17 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --loc-vertebrae-labels 18-41 92 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s labels_init -o labels --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r + iterative_label -s labels_init -o labels -l localizers --disc-labels 1-7 --vertebrae-labels 9-14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 --output-disc-step -1 --output-vertebrae-step -1 --loc-disc-labels 202-224 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r For BIDS: - iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --init-vertebrae 11:40 14:17 12:34 13:23 --step-diff-label --step-diff-disc --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 -r + iterative_label -s derivatives/labels -o derivatives/labels --seg-suffix "_seg" --output-seg-suffix "_seg_seq" -d "sub-" -u "anat" --disc-labels 1 2 3 4 5 6 7 --vertebrae-labels 9 10 11 12 13 14 --vertebrae-extra-labels 8 --init-disc 4:224 7:202 5:219 6:207 --output-disc-step -1 --output-vertebrae-step -1 --map-output 17:92 --map-input 14:92 16:201 17:200 --output-c2c3 224 --output-c2 40 -r '''), formatter_class=argparse.RawTextHelpFormatter ) @@ -80,6 +80,10 @@ def main(): '--init-disc', type=lambda x:map(int, x.split(':')), nargs='+', default=[], help='Init labels list for disc ordered by priority (input_label:output_label !!without space!!). for example 4:224 5:219 6:202' ) + parser.add_argument( + '--output-c2c3', type=int, default=0, + help='The output label for C2C3, used to calculate the first vertebrae label, defaults to 0.' + ) parser.add_argument( '--output-disc-step', type=int, default=1, help='The step to take between disc labels in the output, defaults to 1.' @@ -97,17 +101,13 @@ def main(): help='Extra vertebrae labels to add to add to adjacent vertebrae labels.' ) parser.add_argument( - '--init-vertebrae', type=lambda x:map(int, x.split(':')), nargs='+', default=[], - help='Init labels list for vertebrae ordered by priority (input_label:output_label !!without space!!). for example 10:41 11:34 12:18' + '--output-c2', type=int, default=0, + help='The output label for C2, used to calculate the first vertebrae label, defaults to 0.' ) parser.add_argument( '--output-vertebrae-step', type=int, default=1, help='The step to take between vertebrae labels in the output, defaults to 1.' ) - parser.add_argument( - '--loc-vertebrae-labels', type=lambda x:list(range(int(x.split('-')[0]), int(x.split('-')[-1]) + 1)), nargs='+', default=[], - help='The disc labels in the localizer used for detecting first vertebrae.' - ) parser.add_argument( '--map-input', type=str, nargs='+', default=[], help=' '.join(f''' @@ -128,29 +128,10 @@ def main(): '--dilation-size', type=int, default=1, help='Number of voxels to dilate before finding connected voxels to label, defaults to 1 (No dilation).' ) - parser.add_argument( - '--step-diff-label', action="store_true", default=False, - help=' '.join(f''' - Make step only for different labels. When looping on the labels on the z axis, it will give a new label to the next label only - if it is different from the previous label. This is useful if there are labels for odd and even vertebrae, so the next label will - be for even vertebrae only if the previous label was odd. If it is still odd, it should give the same label. - '''.split()) - ) - parser.add_argument( - '--step-diff-disc', action="store_true", default=False, - help=' '.join(f''' - Make step only for different discs. When looping on the labels on the z axis, it will give a new label to the next label only - if there is a disc between them. This exclude the first and last vertebrae since it can be C1 or only contain the spinous process. - '''.split()) - ) parser.add_argument( '--default-superior-disc', type=int, default=0, help='Default superior disc label if no init label found, defaults to 0 (Raise error if init label not found).' ) - parser.add_argument( - '--default-superior-vertebrae', type=int, default=0, - help='Default superior vertebrae label if no init label found, defaults to 0 (Raise error if init label not found).' - ) parser.add_argument( '--override', '-r', action="store_true", default=False, help='Override existing output files, defaults to false (Do not override).' @@ -179,20 +160,17 @@ def main(): loc_suffix = args.loc_suffix disc_labels = [_ for __ in args.disc_labels for _ in (__ if isinstance(__, list) else [__])] init_disc = dict(args.init_disc) + output_c2c3 = args.output_c2c3 output_disc_step = args.output_disc_step loc_disc_labels = [_ for __ in args.loc_disc_labels for _ in (__ if isinstance(__, list) else [__])] vertebrae_labels = [_ for __ in args.vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] vertebrae_extra_labels = [_ for __ in args.vertebrae_extra_labels for _ in (__ if isinstance(__, list) else [__])] - init_vertebrae = dict(args.init_vertebrae) + output_c2 = args.output_c2 output_vertebrae_step = args.output_vertebrae_step - loc_vertebrae_labels = [_ for __ in args.loc_vertebrae_labels for _ in (__ if isinstance(__, list) else [__])] map_input_list = args.map_input map_output_list = args.map_output dilation_size = args.dilation_size - step_diff_label = args.step_diff_label - step_diff_disc = args.step_diff_disc default_superior_disc = args.default_superior_disc - default_superior_vertebrae = args.default_superior_vertebrae override = args.override max_workers = args.max_workers quiet = args.quiet @@ -212,20 +190,17 @@ def main(): loc_suffix = "{loc_suffix}" disc_labels = {disc_labels} init_disc = {init_disc} + output_c2c3 = {output_c2c3} output_disc_step = {output_disc_step} loc_disc_labels = {loc_disc_labels} vertebrae_labels = {vertebrae_labels} vertebrae_extra_labels = {vertebrae_extra_labels} - init_vertebrae = {init_vertebrae} + output_c2 = {output_c2} output_vertebrae_step = {output_vertebrae_step} - loc_vertebrae_labels = {loc_vertebrae_labels} map_input = {map_input_list} map_output = {map_output_list} dilation_size = {dilation_size} - step_diff_label = {step_diff_label} - step_diff_disc = {step_diff_disc} default_superior_disc = {default_superior_disc} - default_superior_vertebrae = {default_superior_vertebrae} override = {override} max_workers = {max_workers} quiet = {quiet} @@ -254,20 +229,17 @@ def main(): loc_suffix=loc_suffix, disc_labels=disc_labels, init_disc=init_disc, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, max_workers=max_workers, quiet=quiet, @@ -285,20 +257,17 @@ def iterative_label_mp( loc_suffix='', disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, max_workers=mp.cpu_count(), quiet=False, @@ -326,21 +295,18 @@ def iterative_label_mp( partial( _iterative_label, disc_labels=disc_labels, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, init_disc=init_disc, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, override=override, ), seg_path_list, @@ -357,20 +323,17 @@ def _iterative_label( loc_path=None, disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, override=False, ): ''' @@ -394,20 +357,17 @@ def _iterative_label( loc, disc_labels=disc_labels, init_disc=init_disc, + output_c2c3=output_c2c3, output_disc_step=output_disc_step, loc_disc_labels=loc_disc_labels, vertebrae_labels=vertebrae_labels, vertebrae_extra_labels=vertebrae_extra_labels, - init_vertebrae=init_vertebrae, + output_c2=output_c2, output_vertebrae_step=output_vertebrae_step, - loc_vertebrae_labels=loc_vertebrae_labels, map_input_dict=map_input_dict, map_output_dict=map_output_dict, dilation_size=dilation_size, - step_diff_label=step_diff_label, - step_diff_disc=step_diff_disc, default_superior_disc=default_superior_disc, - default_superior_vertebrae=default_superior_vertebrae, ) except ValueError as e: output_seg_path.is_file() and output_seg_path.unlink() @@ -432,20 +392,17 @@ def iterative_label( loc=None, disc_labels=[], init_disc={}, + output_c2c3=0, output_disc_step=1, loc_disc_labels=[], vertebrae_labels=[], vertebrae_extra_labels=[], - init_vertebrae={}, + output_c2=0, output_vertebrae_step=1, - loc_vertebrae_labels=[], map_input_dict={}, map_output_dict={}, dilation_size=1, - step_diff_label=False, - step_diff_disc=False, default_superior_disc=0, - default_superior_vertebrae=0, ): ''' Label Vertebrae, IVDs, Spinal Cord and canal from init segmentation. @@ -467,6 +424,8 @@ def iterative_label( The disc labels init_disc : dict Init labels list for disc ordered by priority (input_label:output_label) + output_c2c3 : int + The output label for C2C3, used to calculate the first vertebrae label output_disc_step : int The step to take between disc labels in the output loc_disc_labels : list @@ -475,26 +434,18 @@ def iterative_label( The vertebrae labels vertebrae_extra_labels : list Extra vertebrae labels to add to add to adjacent vertebrae labels - init_vertebrae : dict - Init labels list for vertebrae ordered by priority (input_label:output_label) + output_c2 : int + The output label for C2, used to calculate the first vertebrae label output_vertebrae_step : int The step to take between vertebrae labels in the output - loc_vertebrae_labels : list - Localizer labels to use for detecting first vertebrae map_input_dict : dict A dict mapping labels from input into the output segmentation map_output_dict : dict A dict mapping labels from the output of the iterative labeling algorithm into different labels in the output segmentation dilation_size : int Number of voxels to dilate before finding connected voxels to label - step_diff_label : bool - Make step only for different labels - step_diff_disc : bool - Make step only for different discs default_superior_disc : int Default superior disc label if no init label found - default_superior_vertebrae : int - Default superior vertebrae label if no init label found Returns ------- @@ -505,193 +456,65 @@ def iterative_label( output_seg_data = np.zeros_like(seg_data) - loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) - - # If localizer is provided, transform it to the segmentation space - if loc_data is not None: - loc_data = tio.Resample( - tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) - )( - tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) - ).data.numpy()[0, ...].astype(np.uint8) - - binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) - - # Arrays to store the z indexes of the discs sorted superior to inferior - disc_sorted_z_indexes = [] - - # We run the same iterative algorithm for discs and vertebrae - for labels, extra_labels, step, init, default_sup, loc_labels, is_vert in ( - (disc_labels, [], output_disc_step, init_disc, default_superior_disc, loc_disc_labels, False), - (vertebrae_labels, vertebrae_extra_labels, output_vertebrae_step, init_vertebrae, default_superior_vertebrae, loc_vertebrae_labels, True)): - - # Skip if no labels are provided - if len(labels) == 0: - continue - - if is_vert: - _labels = [[_] for _ in labels] - else: - # For discs, combine all labels before label continue voxels since the discs not touching each other - _labels = [labels] - - # Init labeled segmentation - mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 - - # For each label, find connected voxels and label them into separate labels - for l in _labels: - mask = np.isin(seg_data, l) + # Get sorted connected components superio-inferior (SI) for the disc labels + disc_mask_labeled, disc_num_labels, disc_sorted_labels, disc_sorted_z_indexes = _get_si_sorted_components( + seg, + disc_labels, + dilation_size, + ) - # Dilate the mask to combine small disconnected regions - mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) + # Get sorted connected components superio-inferior (SI) for the vertebrae labels + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _get_si_sorted_components( + seg, + vertebrae_labels, + dilation_size, + combine_labels=True, + ) - # Label the connected voxels in the dilated mask into separate labels - tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) + # Combine sequential vertebrae labels based on some conditions + vert_mask_labeled, vert_num_labels, vert_sorted_labels, vert_sorted_z_indexes = _merge_vertebrae_labels( + seg, + vertebrae_labels, + vert_mask_labeled, + vert_num_labels, + vert_sorted_labels, + vert_sorted_z_indexes, + disc_sorted_z_indexes, + vertebrae_extra_labels, + ) - # Undo dilation - tmp_mask_labeled *= mask + # Get the first disc label + superior_disc_output_label = _get_superior_output_label( + seg, + loc, + disc_mask_labeled, + disc_sorted_labels, + init_disc, + output_disc_step, + loc_disc_labels, + default_superior_disc, + map_output_dict, + ) - # Add current labels to the labeled segmentation - if tmp_num_labels > 0: - mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels - num_labels += tmp_num_labels + # Sort the combined disc+vert labels by their z-index + sorted_labels = vert_sorted_labels + disc_sorted_labels + sorted_z_indexes = vert_sorted_z_indexes + disc_sorted_z_indexes + is_vert = [True] * len(vert_sorted_labels) + [False] * len(disc_sorted_labels) - # If no label found, raise error - if num_labels == 0: - raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels, is_vert = zip(*sorted(zip(sorted_z_indexes, sorted_labels, is_vert))[::-1]) - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) + # Get the superior output label for the vertebrae based on the superior disc label + # For C1 and C2 we have to adjust the first vertebrae label by the position of the first disc with substraction of (is_vert.index(False) - 1) * output_vertebrae_step + superior_vert_output_label = output_c2 + (superior_disc_output_label - output_c2c3) * (output_vertebrae_step / output_disc_step) - (is_vert.index(False) - 1) * output_vertebrae_step - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + # Label the vertebrae with the output labels superio-inferior + for i in range(vert_num_labels): + output_seg_data[vert_mask_labeled == vert_sorted_labels[i]] = superior_vert_output_label + output_vertebrae_step * i - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) - - # In this part we loop over the labels superior to inferior and combine sequential vertebrae labels based on some conditions - if is_vert: - # Combine sequential vertebrae labels if they have the same value in the original segmentation - # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value - if step_diff_label and len(labels) - len(init) > 1: - new_sorted_labels = [] - - # Store the previous label and the original label of the previous label - prev_l, prev_orig_label = 0, 0 - - # Loop over the sorted labels - for l in sorted_labels: - # Get the original label of the current label - curr_orig_label = seg_data[mask_labeled == l].flat[0] - - # Combine the current label with the previous label if they have the same original label - if curr_orig_label == prev_orig_label: - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_orig_label = l, curr_orig_label - - # Get the z index of the center of mass for each label - canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) - mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] - - # Sort the labels by their z-index (reversed to go from superior to inferior) - sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - # Combine sequential vertebrae labels if there is no disc between them - if step_diff_disc and len(disc_sorted_z_indexes) > 0: - new_sorted_labels = [] - - # Store the previous label and the z index of the previous label - prev_l, prev_z = 0, 0 - - for l, z in zip(sorted_labels, sorted_z_indexes): - # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process - if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): - # Combine the current label with the previous label - mask_labeled[mask_labeled == l] = prev_l - num_labels -= 1 - - else: - # Add the current label to the new sorted labels - new_sorted_labels.append(l) - prev_l, prev_z = l, z - - sorted_labels = new_sorted_labels - - # Reduce size of mask_labeled - if mask_labeled.max() < np.iinfo(np.uint8).max: - mask_labeled = mask_labeled.astype(np.uint8) - elif mask_labeled.max() < np.iinfo(np.uint16).max: - mask_labeled = mask_labeled.astype(np.uint16) - - else: - # Save the z indexes of the discs - disc_sorted_z_indexes = sorted_z_indexes - - # Find the most superior label in the segmentation - first_label = 0 - for k, v in init.items(): - if k in seg_data: - first_label = v - step * sorted_labels.index(mask_labeled[seg_data == k].flat[0]) - break - - # If no init label found, set it from the localizer - if first_label == 0 and loc_data is not None: - # Make mask for the intersection of the localizer labels and the labels in the segmentation - mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) - - # Get the first label from sorted_labels that is in the localizer specified labels - mask_labeled_masked = mask * mask_labeled - first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) - - if first_sorted_labels_in_loc > 0: - # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it - loc_data_masked = mask * loc_data - target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) - # If target in map_output_dict reverse it from the reversed map - # TODO Edge case if multiple keys have the same value, not used in the current implementation - target = {v: k for k, v in map_output_dict.items()}.get(target, target) - first_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) - - # If no init label found, set the default superior label - if first_label == 0 and default_sup > 0: - first_label = default_sup - - # If no init label found, print error - if first_label == 0: - raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") - - # Combine extra labels with adjacent vertebrae labels - if len(extra_labels) > 0: - mask_extra = np.isin(seg_data, extra_labels) - - # Loop over vertebral labels (from inferior because the transverse process make it steal from above) - for i in range(num_labels - 1, -1, -1): - # Mkae mask for the current vertebrae with filling the holes and dilating it - mask = fill(mask_labeled == sorted_labels[i]) - mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) - - # Add the intersection of the mask with the extra labels to the current verebrae - mask_labeled[mask_extra * mask] = sorted_labels[i] - - # Set the output value for the current label - for i in range(num_labels): - output_seg_data[mask_labeled == sorted_labels[i]] = first_label + step * i + # Label the discs with the output labels superio-inferior + for i in range(disc_num_labels): + output_seg_data[disc_mask_labeled == disc_sorted_labels[i]] = superior_disc_output_label + output_disc_step * i # Use the map to map labels from the iteative algorithm output, to the final output # This is useful to map the vertebrae label from the iteative algorithm output to the special sacrum label @@ -718,7 +541,234 @@ def iterative_label( return output_seg -def fill(mask): +def _get_si_sorted_components( + seg, + labels, + dilation_size=1, + combine_labels=False, + ): + ''' + Get sorted connected components superio-inferior (SI) for the given labels in the segmentation. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + binary_dilation_structure = ndi.iterate_structure(ndi.generate_binary_structure(3, 1), dilation_size) + + # Skip if no labels are provided + if len(labels) == 0: + return None, 0, [], [] + + if combine_labels: + _labels = [[_] for _ in labels] + else: + # For discs, combine all labels before label continue voxels since the discs not touching each other + _labels = [labels] + + # Init labeled segmentation + mask_labeled, num_labels = np.zeros_like(seg_data, dtype=np.uint32), 0 + + # For each label, find connected voxels and label them into separate labels + for l in _labels: + mask = np.isin(seg_data, l) + + # Dilate the mask to combine small disconnected regions + mask_dilated = ndi.binary_dilation(mask, binary_dilation_structure) + + # Label the connected voxels in the dilated mask into separate labels + tmp_mask_labeled, tmp_num_labels = ndi.label(mask_dilated.astype(np.uint32), np.ones((3, 3, 3))) + + # Undo dilation + tmp_mask_labeled *= mask + + # Add current labels to the labeled segmentation + if tmp_num_labels > 0: + mask_labeled[tmp_mask_labeled != 0] = tmp_mask_labeled[tmp_mask_labeled != 0] + num_labels + num_labels += tmp_num_labels + + # If no label found, raise error + if num_labels == 0: + raise ValueError(f"Some label must be in the segmentation (labels: {labels})") + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, range(1, num_labels + 1))] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes,range(1,num_labels+1)))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _merge_vertebrae_labels( + seg, + labels, + mask_labeled, + num_labels, + sorted_labels, + sorted_z_indexes, + disc_sorted_z_indexes, + extra_labels, + ): + ''' + Combine sequential vertebrae labels based on some conditions. + ''' + if num_labels == 0: + return mask_labeled, num_labels, sorted_labels, sorted_z_indexes + + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + # Combine sequential vertebrae labels if they have the same value in the original segmentation + # This is useful when part part of the vertebrae is not connected to the main part but have the same odd/even value + if len(labels) > 1: + new_sorted_labels = [] + + # Store the previous label and the original label of the previous label + prev_l, prev_orig_label = 0, 0 + + # Loop over the sorted labels + for l in sorted_labels: + # Get the original label of the current label + curr_orig_label = seg_data[mask_labeled == l].flat[0] + + # Combine the current label with the previous label if they have the same original label + if curr_orig_label == prev_orig_label: + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_orig_label = l, curr_orig_label + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, new_sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Combine sequential vertebrae labels if there is no disc between them + if len(disc_sorted_z_indexes) > 0: + new_sorted_labels = [] + + # Store the previous label and the z index of the previous label + prev_l, prev_z = 0, 0 + + for l, z in zip(sorted_labels, sorted_z_indexes): + # Do not combine first and last vertebrae since it can be C1 or only contain the spinous process + if l not in sorted_labels[:2] and l != sorted_labels[-1] and prev_l > 0 and not any(z < _ < prev_z for _ in disc_sorted_z_indexes): + # Combine the current label with the previous label + mask_labeled[mask_labeled == l] = prev_l + num_labels -= 1 + + else: + # Add the current label to the new sorted labels + new_sorted_labels.append(l) + prev_l, prev_z = l, z + + sorted_labels = new_sorted_labels + + # Reduce size of mask_labeled + if mask_labeled.max() < np.iinfo(np.uint8).max: + mask_labeled = mask_labeled.astype(np.uint8) + elif mask_labeled.max() < np.iinfo(np.uint16).max: + mask_labeled = mask_labeled.astype(np.uint16) + + # Combine extra labels with adjacent vertebrae labels + if len(extra_labels) > 0: + mask_extra = np.isin(seg_data, extra_labels) + + # Loop over vertebral labels (from inferior because the transverse process make it steal from above) + for i in range(num_labels - 1, -1, -1): + # Mkae mask for the current vertebrae with filling the holes and dilating it + mask = _fill(mask_labeled == sorted_labels[i]) + mask = ndi.binary_dilation(mask, ndi.iterate_structure(ndi.generate_binary_structure(3, 1), 1)) + + # Add the intersection of the mask with the extra labels to the current verebrae + mask_labeled[mask_extra * mask] = sorted_labels[i] + + # Get the z index of the center of mass for each label + canonical_mask_labeled = np.asanyarray(nib.as_closest_canonical(nib.Nifti1Image(mask_labeled, seg.affine, seg.header)).dataobj).round().astype(mask_labeled.dtype) + mask_labeled_z_indexes = [_[-1] for _ in ndi.center_of_mass(canonical_mask_labeled != 0, canonical_mask_labeled, sorted_labels)] + + # Sort the labels by their z-index (reversed to go from superior to inferior) + sorted_z_indexes, sorted_labels = zip(*sorted(zip(mask_labeled_z_indexes, new_sorted_labels))[::-1]) + + return mask_labeled, num_labels, list(sorted_labels), list(sorted_z_indexes) + +def _get_superior_output_label( + seg, + loc, + mask_labeled, + sorted_labels, + init, + step, + loc_labels, + default_superior, + map_output_dict, + ): + ''' + Get the first label for the iterative labeling algorithm. + ''' + seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + + loc_data = loc and np.asanyarray(loc.dataobj).round().astype(np.uint8) + + # If localizer is provided, transform it to the segmentation space + if loc_data is not None: + loc_data = tio.Resample( + tio.ScalarImage(tensor=seg_data[None, ...], affine=seg.affine) + )( + tio.LabelMap(tensor=loc_data[None, ...], affine=loc.affine) + ).data.numpy()[0, ...].astype(np.uint8) + + # Find the most superior label in the segmentation + superior_output_label = 0 + for k, v in init.items(): + if k in seg_data: + superior_output_label = v - step * sorted_labels.index(np.argmax(np.bincount(mask_labeled[seg_data == k].flat))) + break + + # If no init label found, set it from the localizer + if superior_output_label == 0 and loc_data is not None: + # Make mask for the intersection of the localizer labels and the labels in the segmentation + mask = np.isin(loc_data, loc_labels) * np.isin(mask_labeled, sorted_labels) + + # Get the first label from sorted_labels that is in the localizer specified labels + mask_labeled_masked = mask * mask_labeled + first_sorted_labels_in_loc = next(np.array(sorted_labels)[np.isin(sorted_labels, mask_labeled_masked)].flat, 0) + + if first_sorted_labels_in_loc > 0: + # Get the target label for first_sorted_labels_in_loc - the label from the localizer that has the most voxels in it + loc_data_masked = mask * loc_data + target = np.argmax(np.bincount(loc_data_masked[mask_labeled_masked == first_sorted_labels_in_loc].flat)) + # If target in map_output_dict reverse it from the reversed map + # TODO Edge case if multiple keys have the same value, not used in the current implementation + target = {v: k for k, v in map_output_dict.items()}.get(target, target) + superior_output_label = target - step * sorted_labels.index(first_sorted_labels_in_loc) + + # If no init label found, set the default superior label + if superior_output_label == 0 and default_superior > 0: + superior_output_label = default_superior + + # If no init label found, print error + if superior_output_label == 0: + raise ValueError(f"Some initiation label must be in the segmentation (init: {list(init.keys())})") + + return superior_output_label + +def _fill(mask): ''' Fill holes in a binary mask