Skip to content

Commit

Permalink
Merge pull request #11 from Hendrik-code/fix_merged_vertebra
Browse files Browse the repository at this point in the history
Revamped processing techniques
  • Loading branch information
Hendrik-code authored Jun 5, 2024
2 parents 5f317c6 + aefe7a7 commit cee3a6f
Show file tree
Hide file tree
Showing 10 changed files with 630 additions and 119 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ torchmetrics = "^1.1.2"
tqdm = "^4.66.1"
einops= "^0.6.1"
nnunetv2 = "2.2"
tptbox = "^0.0.9"
tptbox = "^0.1.0"
antspyx = "*"
rich = "^13.6.0"

Expand Down
8 changes: 4 additions & 4 deletions spineps/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def run_sample(opt: Namespace):
"override_postpair": opt.override_postpair,
"override_ctd": opt.override_ctd,
#
"do_crop_semantic": not opt.nocrop,
"proc_n4correction": not opt.non4,
"proc_sem_crop_input": not opt.nocrop,
"proc_sem_n4_bias_correction": not opt.non4,
"ignore_compatibility_issues": opt.ignore_inference_compatibility,
"verbose": opt.verbose,
}
Expand Down Expand Up @@ -284,8 +284,8 @@ def run_dataset(opt: Namespace):
"ignore_inference_compatibility": opt.ignore_inference_compatibility,
"ignore_bids_filter": opt.ignore_bids_filter,
#
"do_crop_semantic": not opt.nocrop,
"proc_n4correction": not opt.non4,
"proc_sem_crop_input": not opt.nocrop,
"proc_sem_n4_bias_correction": not opt.non4,
"snapshot_copy_folder": opt.save_snaps_folder,
"verbose": opt.verbose,
}
Expand Down
349 changes: 322 additions & 27 deletions spineps/phase_instance.py

Large diffs are not rendered by default.

80 changes: 65 additions & 15 deletions spineps/phase_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
np_bbox_binary,
np_center_of_mass,
np_connected_components,
np_contacts,
np_count_nonzero,
np_dilate_msk,
np_extract_label,
Expand All @@ -17,6 +18,7 @@
)

from spineps.seg_pipeline import logger, vertebra_subreg_labels
from spineps.utils.proc_functions import fix_wrong_posterior_instance_label


def phase_postprocess_combined(
Expand All @@ -25,8 +27,11 @@ def phase_postprocess_combined(
debug_data: dict | None,
labeling_offset: int = 0,
proc_assign_missing_cc: bool = True,
proc_clean_inst_by_sem: bool = True,
n_vert_bodies: int | None = None,
process_vertebra_inconsistency: bool = True,
process_merge_vertebra: bool = True,
proc_vertebra_inconsistency: bool = True,
proc_assign_posterior_instance_label: bool = True,
verbose: bool = False,
) -> tuple[NII, NII]:
logger.print("Post process", Log_Type.STAGE)
Expand All @@ -40,7 +45,9 @@ def phase_postprocess_combined(
if debug_data is None:
debug_data = {}
#
vert_nii.apply_mask(seg_nii, inplace=True)

if proc_clean_inst_by_sem:
vert_nii.apply_mask(seg_nii, inplace=True)
crop_slices = seg_nii.compute_crop(dist=2)
vert_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype)
seg_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype)
Expand All @@ -59,17 +66,24 @@ def phase_postprocess_combined(
verbose=verbose,
)

if process_vertebra_inconsistency:
if process_merge_vertebra and Location.Vertebra_Disc.value in seg_nii_cleaned.unique():
detect_and_solve_merged_vertebra(seg_nii_cleaned, whole_vert_nii_cleaned)

if proc_assign_posterior_instance_label:
whole_vert_nii_cleaned = fix_wrong_posterior_instance_label(seg_nii_cleaned, seg_inst=whole_vert_nii_cleaned, logger=logger)

if proc_vertebra_inconsistency:
# Assigns superior/inferior based on instance label overlap
assign_vertebra_inconsistency(seg_nii_cleaned, whole_vert_nii_cleaned)

# Label vertebra top -> down
whole_vert_nii_cleaned, vert_labels = label_instance_top_to_bottom(whole_vert_nii_cleaned)
if labeling_offset != 0:
whole_vert_nii_cleaned.map_labels_({i: i + 1 for i in vert_labels if i != 0}, verbose=verbose)
whole_vert_nii_cleaned, vert_labels = label_instance_top_to_bottom(whole_vert_nii_cleaned, labeling_offset=labeling_offset)
# if labeling_offset != 0:
# whole_vert_nii_cleaned.map_labels_({i: i + 1 for i in vert_labels if i != 0}, verbose=verbose)
logger.print(f"Labeled {len(vert_labels)} vertebra instances from top to bottom")

vert_arr_cleaned = add_ivd_ep_vert_label(whole_vert_nii_cleaned, seg_nii_cleaned)
#
#
vert_arr_cleaned[seg_nii_cleaned.get_seg_array() == v_name2idx["S1"]] = v_name2idx["S1"]
###############
# Uncrop
Expand Down Expand Up @@ -305,20 +319,14 @@ def find_nearest_lower(seq, x):
return max(values_lower)


def label_instance_top_to_bottom(vert_nii: NII):
def label_instance_top_to_bottom(vert_nii: NII, labeling_offset: int = 0):
ori = vert_nii.orientation
vert_nii.reorient_()
vert_arr = vert_nii.get_seg_array()
com_i = np_center_of_mass(vert_arr)
# Old, more precise version (but takes longer)
# comb = {}
# for i in present_labels:
# arr_i = vert_arr.copy()
# arr_i[arr_i != i] = 0
# comb[i] = center_of_mass(arr_i)
comb_l = list(zip(com_i.keys(), com_i.values(), strict=True))
comb_l.sort(key=lambda a: a[1][1]) # PIR
com_map = {comb_l[idx][0]: idx + 1 for idx in range(len(comb_l))}
com_map = {comb_l[idx][0]: idx + 1 + labeling_offset for idx in range(len(comb_l))}

vert_nii.map_labels_(com_map, verbose=False)
vert_nii.reorient_(ori)
Expand Down Expand Up @@ -374,3 +382,45 @@ def assign_vertebra_inconsistency(seg_nii: NII, vert_nii: NII):
)

vert_nii.set_array_(vert_arr)


def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII):
seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False)
# get all ivd CCs from seg_sem

stats = {}
# Map IVDS
subreg_cc, subreg_cc_n = seg_sem.get_segmentation_connected_components(labels=Location.Vertebra_Disc.value)
subreg_cc = subreg_cc[Location.Vertebra_Disc.value] + 100
subreg_cc_n = subreg_cc_n[Location.Vertebra_Disc.value]

coms = np_center_of_mass(subreg_cc)
volumes = np_volume(subreg_cc)
stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()}

vert_coms = vert_nii.center_of_masses()
vert_volumes = vert_nii.volumes()

for i, g in vert_coms.items():
stats[i] = (g[1], False, vert_volumes[i])

stats_by_height = dict(sorted(stats.items(), key=lambda x: x[1][0]))
stats_by_height_keys = list(stats_by_height.keys())

# detect C2 split into two components
first_key, second_key = stats_by_height_keys[0], stats_by_height_keys[1]
first_stats, second_stats = stats_by_height[first_key], stats_by_height[second_key]
if first_stats[1] is False and second_stats[1] is False: # noqa: SIM102
# both vertebra
if first_stats[2] < 0.5 * second_stats[2]:
# first is significantly smaller than second and they are close in height
# how many pixels are touching
vert_firsttwo_arr = vert_nii.extract_label(first_key).get_seg_array()
vert_firsttwo_arr2 = vert_nii.extract_label(second_key).get_seg_array()
vert_firsttwo_arr += vert_firsttwo_arr2 + 1
contacts = np_contacts(vert_firsttwo_arr, connectivity=3)
if contacts[(1, 2)] > 20:
logger.print("Found first two instance weird, will merge", Log_Type.STRANGE)
vert_nii.map_labels_({first_key: second_key}, verbose=False)

return seg_nii, vert_nii
8 changes: 4 additions & 4 deletions spineps/phase_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def preprocess_input(
mri_nii: NII,
debug_data: dict, # noqa: ARG001
pad_size: int = 4,
do_n4: bool = True,
do_crop: bool = True,
proc_do_n4_bias_correction: bool = True,
proc_crop_input: bool = True,
verbose: bool = False,
) -> tuple[NII | None, ErrCode]:
logger.print("Prepare input image", Log_Type.STAGE)
Expand All @@ -25,7 +25,7 @@ def preprocess_input(
try:
# Enforce to range [0, 1500]
mri_nii.normalize_to_range_(min_value=0, max_value=9000, verbose=logger)
crop = mri_nii.compute_crop(dist=0) if do_crop else (slice(None, None), slice(None, None), slice(None, None))
crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None))
except ValueError:
logger.print("Image Nifty is empty, skip this", Log_Type.FAIL)
return None, ErrCode.EMPTY
Expand All @@ -34,7 +34,7 @@ def preprocess_input(
logger.print(f"Crop down from {mri_nii.shape} to {cropped_nii.shape}", verbose=verbose)

# N4 Bias field correction
if do_n4:
if proc_do_n4_bias_correction:
n4_start = perf_counter()
cropped_nii, _ = n4_bias(cropped_nii) # PIR
logger.print(f"N4 Bias field correction done in {perf_counter() - n4_start} sec", verbose=True)
Expand Down
78 changes: 59 additions & 19 deletions spineps/phase_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
def predict_semantic_mask(
mri_nii: NII,
model: Segmentation_Model,
debug_data: dict, # noqa: ARG001
fill_holes: bool = True,
clean_artifacts: bool = True,
debug_data: dict,
proc_fill_3d_holes: bool = True,
proc_clean_beyond_largest_bounding_box: bool = True,
proc_remove_inferior_beyond_canal: bool = True,
proc_clean_small_cc_artifacts: bool = True,
verbose: bool = False,
) -> tuple[NII | None, NII | None, np.ndarray | None, ErrCode]:
"""Predicts the semantic mask, takes care of rescaling, and back
Expand Down Expand Up @@ -45,10 +47,18 @@ def predict_semantic_mask(
unc_nii = results.get(OutputType.unc, None)
softmax_logits = results[OutputType.softmax_logits]

logger.print("Post-process semantic mask...")

debug_data["sem_raw"] = seg_nii.copy()

if len(seg_nii.unique()) == 0:
logger.print("Subregion mask is empty, skip this", Log_Type.FAIL)
return seg_nii, unc_nii, softmax_logits, ErrCode.EMPTY
if clean_artifacts:

if proc_remove_inferior_beyond_canal:
seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy())

if proc_clean_small_cc_artifacts:
seg_nii.set_array_(
clean_cc_artifacts(
seg_nii,
Expand All @@ -73,50 +83,80 @@ def predict_semantic_mask(
),
verbose=verbose,
)
if fill_holes:
seg_nii = seg_nii.fill_holes_(fill_holes_labels, verbose=logger)

seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy())
# Do two iterations of both processing if enabled to make sure
if proc_remove_inferior_beyond_canal:
seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy())

if proc_clean_beyond_largest_bounding_box:
seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy())

if proc_remove_inferior_beyond_canal and proc_clean_beyond_largest_bounding_box:
seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy())
seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy())

if proc_fill_3d_holes:
seg_nii = seg_nii.fill_holes_(fill_holes_labels, verbose=logger)

return seg_nii, unc_nii, softmax_logits, ErrCode.OK


def remove_nonsacrum_beyond_canal_height(seg_nii: NII):
seg_nii.assert_affine(orientation=("P", "I", "R"))
canal_nii = seg_nii.extract_label([Location.Spinal_Canal.value, Location.Spinal_Cord.value])
crop_i = canal_nii.compute_crop(dist=16)[1]
seg_arr = seg_nii.get_seg_array()
sacrum_arr = seg_nii.extract_label(26).get_seg_array()
seg_arr[:, 0 : crop_i.start, :] = 0
seg_arr[:, crop_i.stop :, :] = 0
seg_arr[sacrum_arr == 1] = 26
return seg_nii.set_array_(seg_arr)


def semantic_bounding_box_clean(seg_nii: NII):
ori = seg_nii.orientation
seg_binary = seg_nii.reorient_().extract_label(list(seg_nii.unique())) # whole thing binary
seg_bin_largest_k_cc_nii = seg_binary.get_largest_k_segmentation_connected_components(
k=20, labels=1, connectivity=3, return_original_labels=False
k=None, labels=1, connectivity=3, return_original_labels=False
)
max_k = int(seg_bin_largest_k_cc_nii.max())
if max_k > 3:
logger.print(f"Found {max_k} unique connected components in semantic mask", Log_Type.STRANGE)
# PIR
largest_nii = seg_bin_largest_k_cc_nii.extract_label(1)
# width fixed, and heigh include all connected components within bounding box, then repeat
p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=5)
p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=4)
bboxes = [(p_slice, i_slice, r_slice)]

# PIR -> fixed, extendable, extendable
incorporated = [1]
changed = True
while changed:
changed = False
for k in [l for l in range(2, max_k + 1) if l not in incorporated]:
k_nii = seg_bin_largest_k_cc_nii.extract_label(k)
p, i, r = k_nii.compute_crop(dist=3)
i_slice_compare = slice(
max(i_slice.start - 10, 0), i_slice.stop + 10
) # more margin in inferior direction (allows for gaps in spine)
if overlap_slice(p_slice, p) and overlap_slice(i_slice_compare, i) and overlap_slice(r_slice, r):
# extend bbox
i_slice = slice(min(i_slice.start, i.start), max(i_slice.stop, i.stop))
r_slice = slice(min(r_slice.start, r.start), max(r_slice.stop, r.stop))
incorporated.append(k)
changed = True
p, i, r = k_nii.compute_crop(dist=4)

for bbox in bboxes:
i_slice_compare = slice(
max(bbox[1].start - 4, 0), bbox[1].stop + 4
) # more margin in inferior direction (allows for gaps of size 15 in spine)
if overlap_slice(bbox[0], p) and overlap_slice(i_slice_compare, i) and overlap_slice(bbox[2], r):
# extend bbox
bboxes.append((p, i, r))
incorporated.append(k)
changed = True
break

seg_bin_arr = seg_binary.get_seg_array()
crop = (p_slice, i_slice, r_slice)
seg_bin_clean_arr = np.zeros(seg_bin_arr.shape)
seg_bin_clean_arr[crop] = 1

# everything below biggest k get always removed
largest_k_arr = seg_bin_largest_k_cc_nii.get_seg_array()
seg_bin_clean_arr[largest_k_arr == 0] = 0

seg_arr = seg_nii.get_seg_array()
# logger.print(seg_nii.volumes())
seg_arr[seg_bin_clean_arr != 1] = 0
Expand Down
7 changes: 5 additions & 2 deletions spineps/seg_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# from utils.predictor import nnUNetPredictor
import subprocess
from typing import Any

from scipy.ndimage import center_of_mass
from TPTBox import NII, Location, No_Logger, Zooms, v_name2idx
Expand All @@ -8,8 +9,7 @@

from spineps.seg_model import Segmentation_Model

logger = No_Logger()
logger.override_prefix = "SPINEPS"
logger = No_Logger(prefix="SPINEPS")

fill_holes_labels = [
Location.Vertebra_Corpus_border.value,
Expand Down Expand Up @@ -37,6 +37,7 @@ def predict_centroids_from_both(
vert_nii_cleaned: NII,
seg_nii: NII,
models: list[Segmentation_Model],
parameter: dict[str, Any],
input_zms_pir: Zooms | None = None,
):
"""Calculates the centroids of each vertebra corpus by using both semantic and instance mask
Expand Down Expand Up @@ -70,6 +71,8 @@ def predict_centroids_from_both(
ctd.info["models"] = models_repr
ctd.info["revision"] = pipeline_revision()
ctd.info["timestamp"] = format_time_short(get_time())
for pname, pvalue in parameter.items():
ctd.info[pname] = str(pvalue)
return ctd


Expand Down
Loading

0 comments on commit cee3a6f

Please sign in to comment.