Skip to content

Commit

Permalink
Merge pull request #76 from neuropoly/nm/extract_posterior_tip
Browse files Browse the repository at this point in the history
Extract posterior tip of the discs during level extraction
  • Loading branch information
NathanMolinier authored Nov 20, 2024
2 parents f0bd108 + 106b34d commit 88377f2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
2 changes: 2 additions & 0 deletions totalspineseg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ def main():
output_path / 'step1_levels',
canal_labels=[1, 2],
disc_labels=list(range(63, 68)) + list(range(71, 83)) + list(range(91, 96)) + [100],
c1_label=11,
c2_label=50,
overwrite=True,
max_workers=max_workers,
quiet=quiet,
Expand Down
89 changes: 56 additions & 33 deletions totalspineseg/utils/extract_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def main():
)
parser.add_argument(
'--c1-label', type=int, default=0,
help='The label for C1 vertebra in the segmentation, if provided it will be used to determine if C1 is in the segmentation.'
help='The label for C1 vertebra in the segmentation, if provided it will be used to extract the level 1.'
)
parser.add_argument(
'--c2-label', type=int, default=0,
help='The label for C2 vertebra in the segmentation (this label may also be used with other vertebrae), if provided it will be used to extract the level 1.'
)
parser.add_argument(
'--overwrite', '-r', action="store_true", default=False,
Expand All @@ -82,6 +86,7 @@ def main():
canal_labels = args.canal_labels
disc_labels = [l for raw in args.disc_labels for l in (raw if isinstance(raw, list) else [raw])]
c1_label = args.c1_label
c2_label = args.c2_label
overwrite = args.overwrite
max_workers = args.max_workers
quiet = args.quiet
Expand All @@ -98,6 +103,7 @@ def main():
canal_labels = {canal_labels}
disc_labels = {disc_labels}
c1_label = {c1_label}
c2_label = {c2_label}
overwrite = {overwrite}
max_workers = {max_workers}
quiet = {quiet}
Expand All @@ -112,6 +118,7 @@ def main():
canal_labels=canal_labels,
disc_labels=disc_labels,
c1_label=c1_label,
c2_label=c2_label,
overwrite=overwrite,
max_workers=max_workers,
quiet=quiet,
Expand All @@ -126,6 +133,7 @@ def extract_levels_mp(
canal_labels=[],
disc_labels=[],
c1_label=0,
c2_label=0,
overwrite=False,
max_workers=mp.cpu_count(),
quiet=False,
Expand All @@ -148,6 +156,7 @@ def extract_levels_mp(
canal_labels=canal_labels,
disc_labels=disc_labels,
c1_label=c1_label,
c2_label=c2_label,
overwrite=overwrite,
),
seg_path_list,
Expand All @@ -163,6 +172,7 @@ def _extract_levels(
canal_labels=[],
disc_labels=[],
c1_label=0,
c2_label=0,
overwrite=False,
):
'''
Expand All @@ -184,6 +194,7 @@ def _extract_levels(
canal_labels=canal_labels,
disc_labels=disc_labels,
c1_label=c1_label,
c2_label=c2_label,
)
except ValueError as e:
output_seg_path.is_file() and output_seg_path.unlink()
Expand All @@ -208,12 +219,14 @@ def extract_levels(
canal_labels=[],
disc_labels=[],
c1_label=0,
c2_label=0,
):
'''
Extract vertebrae levels from Spinal Canal and Discs.
The function extracts the vertebrae levels from the input segmentation by finding the closest voxel in the canal centerline to the middle of each disc.
The superior voxels in the canal centerline are set to 1 and the middle voxels between C2-C3 and the superior voxels are set to 2.
The function extracts the vertebrae levels from the input segmentation by finding the closest voxel in the canal anteriorline to the middle of each disc.
The superior voxels of the top vertebrae is set to 1 and the middle voxels between C2-C3 and the superior voxels are set to 2.
The generated labeling convention follows the one from SCT (https://spinalcordtoolbox.com/stable/user_section/tutorials/vertebral-labeling/labeling-conventions.html)
Parameters
----------
Expand Down Expand Up @@ -245,17 +258,16 @@ def extract_levels(
if not np.any(mask_canal):
raise ValueError(f"No canal labels found in the segmentation.")

# Create a mask the canal centerline by finding the middle voxels in x and y axes for each z index
# Create a canal anteriorline shifted toward the posterior tip by finding the middle voxels in x and the maximum voxels in y for each z index
mask_min_x_indices = np.min(indices[0], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1))
mask_max_x_indices = np.max(indices[0], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1))
mask_mid_x = indices[0] == ((mask_min_x_indices + mask_max_x_indices) // 2)
mask_min_y_indices = np.min(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).max, keepdims=True, axis=(0, 1))
mask_max_y_indices = np.max(indices[1], where=mask_canal, initial=np.iinfo(indices.dtype).min, keepdims=True, axis=(0, 1))
mask_mid_y = indices[1] == ((mask_min_y_indices + mask_max_y_indices) // 2)
mask_canal_centerline = mask_canal * mask_mid_x * mask_mid_y
mask_max_y = indices[1] == mask_max_y_indices
mask_canal_anteriorline = mask_canal * mask_mid_x * mask_max_y

# Get the indices of the canal centerline
canal_centerline_indices = np.array(np.nonzero(mask_canal_centerline)).T
# Get the indices of the canal anteriorline
canal_anteriorline_indices = np.array(np.nonzero(mask_canal_anteriorline)).T

# Get the labels of the discs in the segmentation
disc_labels_in_seg = np.array(disc_labels)[np.isin(disc_labels, seg_data)]
Expand Down Expand Up @@ -283,45 +295,56 @@ def extract_levels(
# Make the discs_indices 2D array
discs_indices = np.array(discs_indices).T

# Calculate the distance of each disc voxel to each canal centerline voxel
discs_distances_from_all_centerline = np.linalg.norm(discs_indices[:, None, :] - canal_centerline_indices[None, ...], axis=2)
# Calculate the distance of each disc voxel to each canal anteriorline voxel
discs_distances_from_all_anteriorline = np.linalg.norm(discs_indices[:, None, :] - canal_anteriorline_indices[None, ...], axis=2)

# Find the minimum distance for each disc voxel and the corresponding canal centerline index
discs_distance_from_centerline = np.min(discs_distances_from_all_centerline, axis=1)
discs_centerline_indices = canal_centerline_indices[np.argmin(discs_distances_from_all_centerline, axis=1)]
# Find the minimum distance for each disc voxel and the corresponding canal anteriorline index
discs_distance_from_anteriorline = np.min(discs_distances_from_all_anteriorline, axis=1)
discs_anteriorline_indices = canal_anteriorline_indices[np.argmin(discs_distances_from_all_anteriorline, axis=1)]

# Find the closest canal centerline voxel to each disc label
disc_labels_centerline_indices = [discs_centerline_indices[discs_indices_labels == label][np.argmin(discs_distance_from_centerline[discs_indices_labels == label])] for label in disc_labels_in_seg]
# Find the closest voxel to the canal anteriorline for each disc label (posterior tip)
disc_labels_anteriorline_indices = [discs_anteriorline_indices[discs_indices_labels == label][np.argmin(discs_distance_from_anteriorline[discs_indices_labels == label])] for label in disc_labels_in_seg]

# Set the output labels to the closest canal centerline voxels
for idx, label in zip(disc_labels_centerline_indices, disc_labels_in_seg):
# Set the output labels to the closest voxel to the canal anteriorline for each disc
for idx, label in zip(disc_labels_anteriorline_indices, disc_labels_in_seg):
output_seg_data[tuple(idx)] = map_labels[label]

# If C2-C3 is in the segmentation, set 1 and 2 to the superior voxels in the canal centerline and the middle voxels between C2-C3 and the superior voxels
if 3 in output_seg_data:
# If C2-C3 and C1 are in the segmentation, set 1 and 2
if 3 in output_seg_data and c2_label != 0 and c1_label != 0 and all(np.isin([c1_label, c2_label], seg_data)):
# Place 1 at the top of C2 if C1 is visible in the image
# Find the location of the C2-C3 disc
c2c3_index = np.unravel_index(np.argmax(output_seg_data == 3), seg_data.shape)

# Find the location of the superior voxels in the canal centerline
canal_superior_index = np.unravel_index(np.argmax(mask_canal_centerline * indices[2]), seg_data.shape)
# Find the maximum coordinate of the vertebra C1
c1_coords = np.where(seg_data == c1_label)
c1_z_max_index = np.max(c1_coords[2])

# Extract coordinate of the vertebrae
# The coordinate of 1 needs to be in the same slice as 3 but below the max index of C1
vert_coords = np.where(seg_data[c2c3_index[0],:,:c1_z_max_index] == c2_label)

# Check if not empty
if len(vert_coords[1]) > 0:
# Find top pixel of the vertebrae
argmax_z = np.argmax(vert_coords[1])
top_vert_voxel = tuple([c2c3_index[0]]+[vert_coords[i][argmax_z] for i in range(2)])

if (c1_label > 0 and c1_label in seg_data) or (c1_label == 0 and canal_superior_index[2] - c2c3_index[2] >= 8 and output_seg_data.shape[2] - canal_superior_index[2] >= 2):
# If C1 is in the segmentation or C2-C3 at least 8 voxels below the top of the canal and the top of the canal is at least 2 voxels from the top of the image
# Set 1 to the superior voxels
output_seg_data[canal_superior_index] = 1
# Set 1 to the superior voxels and project onto the anterior line
top_vert_distances_from_all_anteriorline = np.linalg.norm(top_vert_voxel - canal_anteriorline_indices[None, ...], axis=2)
top_vert_index_anteriorline = canal_anteriorline_indices[np.argmin(top_vert_distances_from_all_anteriorline, axis=1)]
output_seg_data[tuple(top_vert_index_anteriorline[0])] = 1

# Set 2 to the middle voxels between C2-C3 and the superior voxels
c1c2_z_index = (canal_superior_index[2] + c2c3_index[2]) // 2
c1c2_index = np.unravel_index(np.argmax(mask_canal_centerline * (indices[2] == c1c2_z_index)), seg_data.shape)
output_seg_data[c1c2_index] = 2
c1c2_index = tuple([(top_vert_voxel[i] + c2c3_index[i]) // 2 for i in range(3)])

elif canal_superior_index[2] - c2c3_index[2] >= 4:
# If C2-C3 at least 4 voxels below the top of the canal
output_seg_data[canal_superior_index] = 2
# Project 2 on the anterior line
c1c2_distances_from_all_anteriorline = np.linalg.norm(c1c2_index - canal_anteriorline_indices[None, ...], axis=2)
c1c2_index_anteriorline = canal_anteriorline_indices[np.argmin(c1c2_distances_from_all_anteriorline, axis=1)]
output_seg_data[tuple(c1c2_index_anteriorline[0])] = 2

output_seg = nib.Nifti1Image(output_seg_data, seg.affine, seg.header)

return output_seg

if __name__ == '__main__':
main()
main()

0 comments on commit 88377f2

Please sign in to comment.