Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing resolution #2

Merged
merged 10 commits into from
Feb 8, 2024
67 changes: 29 additions & 38 deletions spineps/phase_instance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# from utils.predictor import nnUNetPredictor
import numpy as np
from TPTBox import NII, Location, Log_Type
from TPTBox.core.np_utils import np_calc_crop_around_centerpoint, np_dice
from TPTBox.core.np_utils import np_calc_crop_around_centerpoint, np_count_nonzero, np_dice, np_unique
from tqdm import tqdm

from spineps.seg_enums import ErrCode, OutputType
Expand All @@ -14,7 +14,7 @@ def predict_instance_mask(
seg_nii: NII,
model: Segmentation_Model,
debug_data: dict,
pad_size: int = 2,
pad_size: int = 0,
fill_holes: bool = True,
use_height_estimate: bool = False,
proc_corpus_clean: bool = True,
Expand Down Expand Up @@ -59,7 +59,8 @@ def predict_instance_mask(
zms = seg_nii_rdy.zoom
logger.print("zms", zms, verbose=verbose)
expected_zms = model.calc_recommended_resampling_zoom(seg_nii_rdy.zoom)
seg_nii_rdy.rescale_(expected_zms, verbose=logger) # in PIR
if not seg_nii_rdy.assert_affine(zoom=expected_zms, raise_error=False):
seg_nii_rdy.rescale_(expected_zms, verbose=logger) # in PIR
#
seg_nii_uncropped = seg_nii_rdy.copy()
logger.print(
Expand All @@ -69,9 +70,9 @@ def predict_instance_mask(
uncropped_vert_mask = np.zeros(seg_nii_uncropped.shape)
logger.print("Vertebra uncropped_vert_mask empty", uncropped_vert_mask.shape, verbose=verbose)
#
crop = seg_nii_rdy.compute_crop_slice(dist=5)
crop = seg_nii_rdy.compute_crop(dist=5)
# logger.print("Crop", crop, verbose=verbose)
seg_nii_rdy.apply_crop_slice_(crop)
seg_nii_rdy.apply_crop_(crop)
logger.print(f"Crop down from {uncropped_vert_mask.shape} to {seg_nii_rdy.shape}", verbose=verbose)
# arr[crop] = X, then set nifty to arr
logger.print("Vertebra seg_nii_rdy", seg_nii_rdy.zoom, seg_nii_rdy.orientation, seg_nii_rdy.shape, verbose=verbose)
Expand All @@ -83,7 +84,9 @@ def predict_instance_mask(
vert_size_threshold = max(int(vert_size_threshold / (expected_zms[0] * expected_zms[1] * expected_zms[2])), 40)

seg_labels = seg_nii.unique()
assert 49 in seg_labels, f"no corpus ({Location.Vertebra_Corpus_border.value}) labels in this segmentation, cannot proceed"
if 49 not in seg_labels:
logger.print(f"no corpus ({Location.Vertebra_Corpus_border.value}) labels in this segmentation, cannot proceed", Log_Type.FAIL)
return None, ErrCode.EMPTY

# get all the 3vert predictions
vert_predictions, hierarchical_existing_predictions, n_corpus_coms = collect_vertebra_predictions(
Expand Down Expand Up @@ -144,17 +147,17 @@ def predict_instance_mask(
debug_data["inst_uncropped_vert_arr_a"] = whole_vert_nii_uncropped.copy()

# Resample back to input space
whole_vert_nii_uncropped.rescale_(zms, verbose=verbose)
debug_data["inst_uncropped_vert_arr_b_rescale"] = whole_vert_nii_uncropped.copy()
whole_vert_nii_uncropped.reorient_(orientation, verbose=verbose)
debug_data["inst_uncropped_vert_arr_c_reorient"] = whole_vert_nii_uncropped.copy()
# whole_vert_nii_uncropped.rescale_(zms, verbose=verbose)
# debug_data["inst_uncropped_vert_arr_b_rescale"] = whole_vert_nii_uncropped.copy()
# whole_vert_nii_uncropped.reorient_(orientation, verbose=verbose)
# debug_data["inst_uncropped_vert_arr_c_reorient"] = whole_vert_nii_uncropped.copy()
if pad_size > 0:
# logger.print(whole_vert_nii_uncropped.shape)
arr = whole_vert_nii_uncropped.get_array()
arr = arr[pad_size:-pad_size, pad_size:-pad_size, pad_size:-pad_size]
whole_vert_nii_uncropped.set_array_(arr)
# logger.print(whole_vert_nii_uncropped.shape)
whole_vert_nii_uncropped.pad_to(shp, inplace=True)
# whole_vert_nii_uncropped.pad_to(shp, inplace=True)

return whole_vert_nii_uncropped, ErrCode.OK

Expand Down Expand Up @@ -197,6 +200,10 @@ def collect_vertebra_predictions(
corpus_coms.reverse() # from bottom to top
n_corpus_coms = len(corpus_coms)

if n_corpus_coms < 3:
logger.print(f"Too few vertebra semantically segmented ({n_corpus_coms})", Log_Type.FAIL)
return None, [], 0

shp = (
# n_corpus_coms,
# 3
Expand Down Expand Up @@ -248,26 +255,7 @@ def collect_vertebra_predictions(
break
seg_at_com = seg_arr_c[int(com[0])][int(com[1])][int(com[2])] != 0

if use_height_estimate:
com_above = 0 if com_idx < 2 else com_idx - 2
com_above = corpus_coms[com_above] # (com, bbox)
com_below = n_corpus_coms - 1 if com_idx > n_corpus_coms - 3 else com_idx + 2
com_below = corpus_coms[com_below]
com_y = com[1]
height = min(
int(max(abs(com_above[1] - com_y), abs(com_below[1] - com_y)) * 2.5), # type:ignore
cutout_size[1],
)
if len(corpus_coms) <= 3:
height = min(int(abs(com_above[1] - com_below[1]) * 2.5), cutout_size[1])
height = height + 1 if height % 2 != 0 else height
cutout_size2 = (
cutout_size[0],
height,
cutout_size[2],
)
else:
cutout_size2 = cutout_size
cutout_size2 = cutout_size

# Calc cutout
arr_cut, cutout_coords, paddings = np_calc_crop_around_centerpoint(com, seg_arr_c, cutout_size2)
Expand All @@ -278,15 +266,18 @@ def collect_vertebra_predictions(
cut_nii,
resample_to_recommended=False,
pad_size=0,
# resample_output_to_input_space=False,
resample_output_to_input_space=False,
verbose=False,
)
vert_cut_nii = results[OutputType.seg_modelres].reorient_()
vert_cut_nii = results[OutputType.seg].reorient_()
# print("vert_cut_nii", vert_cut_nii.shape)
# logger.print(f"Done {com_idx}")
debug_data[f"inst_cutout_vert_nii_{com_idx}_pred"] = vert_cut_nii.copy()
vert_cut_nii = post_process_single_3vert_prediction(
vert_cut_nii, None, fill_holes=fill_holes, largest_cc=proc_largest_cc # type:ignore
vert_cut_nii,
None,
fill_holes=fill_holes,
largest_cc=proc_largest_cc, # type:ignore
)
vert_labels = vert_cut_nii.unique() # 1,2,3
debug_data[f"inst_cutout_vert_nii_{com_idx}_proc"] = vert_cut_nii.copy()
Expand Down Expand Up @@ -527,13 +518,13 @@ def merge_coupled_predictions(
combine[combine < m] = 0
combine[combine != 0] = idx

count_new = np.count_nonzero(combine)
count_new = np_count_nonzero(combine)
if count_new == 0:
logger.print("ZERO instance mask failure on vertebra instance creation", Log_Type.FAIL)
return seg_nii, debug_data, ErrCode.EMPTY
fixed_n = combine.copy()
fixed_n[whole_vert_arr != 0] = 0
count_cut = np.count_nonzero(fixed_n)
count_cut = np_count_nonzero(fixed_n)
relative_overlap = (count_new - count_cut) / count_new
if relative_overlap > 0.6:
logger.print(k, f" was skipped because it overlaps {round(relative_overlap, 4)} with established verts", verbose=verbose)
Expand All @@ -543,15 +534,15 @@ def merge_coupled_predictions(

debug_data["inst_crop_vert_arr_a_raw"] = seg_nii.set_array(whole_vert_arr)

if len(np.unique(whole_vert_arr)) == 1:
if len(np_unique(whole_vert_arr)) == 1:
logger.print("Vert mask empty, will skip", Log_Type.FAIL)
return whole_vert_nii.set_array_(whole_vert_arr, verbose=False), debug_data, ErrCode.EMPTY

# Cleanup step
if proc_cleanvert:
whole_vert_arr = clean_cc_artifacts(
whole_vert_arr,
labels=np.unique(whole_vert_arr)[1:], # type:ignore
labels=np_unique(whole_vert_arr)[1:], # type:ignore
cc_size_threshold=vert_size_threshold,
only_delete=True,
logger=logger,
Expand Down
Loading
Loading