diff --git a/pyproject.toml b/pyproject.toml index 650f459..7899ee4 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ torchmetrics = "^1.1.2" tqdm = "^4.66.1" einops= "^0.6.1" nnunetv2 = "2.2" -tptbox = "^0.0.9" +tptbox = "^0.1.0" antspyx = "*" rich = "^13.6.0" diff --git a/spineps/entrypoint.py b/spineps/entrypoint.py index d5da75f..21cec79 100755 --- a/spineps/entrypoint.py +++ b/spineps/entrypoint.py @@ -210,8 +210,8 @@ def run_sample(opt: Namespace): "override_postpair": opt.override_postpair, "override_ctd": opt.override_ctd, # - "do_crop_semantic": not opt.nocrop, - "proc_n4correction": not opt.non4, + "proc_sem_crop_input": not opt.nocrop, + "proc_sem_n4_bias_correction": not opt.non4, "ignore_compatibility_issues": opt.ignore_inference_compatibility, "verbose": opt.verbose, } @@ -284,8 +284,8 @@ def run_dataset(opt: Namespace): "ignore_inference_compatibility": opt.ignore_inference_compatibility, "ignore_bids_filter": opt.ignore_bids_filter, # - "do_crop_semantic": not opt.nocrop, - "proc_n4correction": not opt.non4, + "proc_sem_crop_input": not opt.nocrop, + "proc_sem_n4_bias_correction": not opt.non4, "snapshot_copy_folder": opt.save_snaps_folder, "verbose": opt.verbose, } diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index 15589bb..3c41e30 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -1,7 +1,19 @@ # from utils.predictor import nnUNetPredictor import numpy as np from TPTBox import NII, Location, Log_Type -from TPTBox.core.np_utils import np_calc_crop_around_centerpoint, np_count_nonzero, np_dice, np_unique +from TPTBox.core.np_utils import ( + np_calc_crop_around_centerpoint, + np_center_of_mass, + np_connected_components, + np_count_nonzero, + np_dice, + np_dilate_msk, + np_erode_msk, + np_extract_label, + np_get_largest_k_connected_components, + np_unique, + np_volume, +) from tqdm import tqdm from spineps.seg_enums import ErrCode, OutputType @@ -15,10 +27,11 @@ def predict_instance_mask( model: Segmentation_Model, debug_data: dict, pad_size: int = 0, - fill_holes: bool = True, + proc_inst_fill_3d_holes: bool = True, + proc_detect_and_solve_merged_corpi: bool = True, proc_corpus_clean: bool = True, - proc_cleanvert: bool = True, - proc_largest_cc: int = 0, + proc_inst_clean_small_cc_artifacts: bool = True, + proc_inst_largest_k_cc: int = 0, verbose: bool = False, ) -> tuple[NII | None, ErrCode]: """Based on subregion segmentation, feeds individual arcus coms to a network to get the vertebra body segmentations @@ -94,8 +107,9 @@ def predict_instance_mask( corpus_size_cleaning=corpus_size_cleaning if proc_corpus_clean else 0, cutout_size=cutout_size, debug_data=debug_data, - proc_largest_cc=proc_largest_cc, - fill_holes=False, + process_detect_and_solve_merged_corpi=proc_detect_and_solve_merged_corpi, + proc_inst_largest_k_cc=proc_inst_largest_k_cc, + proc_inst_fill_holes=False, verbose=verbose, ) if vert_predictions is None: @@ -109,7 +123,7 @@ def predict_instance_mask( hierarchical_existing_predictions, vert_size_threshold, debug_data=debug_data, - proc_cleanvert=proc_cleanvert, + proc_inst_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, ) del vert_predictions, hierarchical_existing_predictions if errcode != ErrCode.OK: @@ -122,7 +136,7 @@ def predict_instance_mask( uniq_labels = whole_vert_nii.unique() - if fill_holes: + if proc_inst_fill_3d_holes: whole_vert_nii.fill_holes_(verbose=logger) debug_data["inst_cropped_vert_arr_c_proc"] = whole_vert_nii.copy() n_vert_bodies = len(uniq_labels) @@ -155,18 +169,16 @@ def predict_instance_mask( return whole_vert_nii_uncropped, ErrCode.OK -def collect_vertebra_predictions( +def get_corpus_coms( seg_nii: NII, - model: Segmentation_Model, corpus_size_cleaning: int, - cutout_size, - debug_data: dict, - proc_largest_cc: int = 0, - fill_holes: bool = False, + process_detect_and_solve_merged_corpi: bool = True, verbose: bool = False, -) -> tuple[np.ndarray | None, list[str], int]: +) -> list: + seg_nii.assert_affine(orientation=["P", "I", "R"]) + # # Extract Corpus region and try to find all coms naively (some skips shouldnt matter) - corpus_nii = seg_nii.extract_label(49) + corpus_nii = seg_nii.extract_label(Location.Vertebra_Corpus_border.value) corpus_nii.erode_msk_(mm=2, connectivity=2, verbose=False) if 1 in corpus_nii.unique() and corpus_size_cleaning > 0: corpus_nii.set_array_( @@ -184,12 +196,295 @@ def collect_vertebra_predictions( if 1 not in corpus_nii.unique(): logger.print("No 1 in corpus nifty, cannot make vertebra mask", Log_Type.FAIL) - return None, [], 0 + return None + + if not process_detect_and_solve_merged_corpi: + corpus_coms = corpus_nii.get_segmentation_connected_components_center_of_mass(label=1, sort_by_axis=1) + corpus_coms.reverse() # from bottom to top + return corpus_coms + + ############ + # Detect merged vertebra and use plane split algo on it + ############ + + # Get corpus CCs + corpus_cc, corpus_cc_n = corpus_nii.get_segmentation_connected_components(labels=1) + corpus_cc = corpus_cc[1] + corpus_cc_n = corpus_cc_n[1] + logger.print(f"Found {corpus_cc_n} Corpus ccs (naively)", verbose=verbose) + + # Check against ivd order + seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) + subreg_cc, subreg_cc_n = seg_sem.get_segmentation_connected_components(labels=Location.Vertebra_Disc.value) + subreg_cc = subreg_cc[Location.Vertebra_Disc.value] + subreg_cc[subreg_cc > 0] += 100 + subreg_cc_n = subreg_cc_n[Location.Vertebra_Disc.value] + logger.print(f"Found {subreg_cc_n} IVD ccs (naively)", verbose=verbose) + coms = np_center_of_mass(subreg_cc) + volumes = np_volume(subreg_cc) + stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()} + + vert_coms = np_center_of_mass(corpus_cc) + vert_volumes = np_volume(corpus_cc) + + for i, g in vert_coms.items(): + stats[i] = (g[1], False, vert_volumes[i]) + + stats_by_height = dict(sorted(stats.items(), key=lambda x: x[1][0])) + stats_by_height_keys = list(stats_by_height.keys()) + + for vl in stats_by_height_keys: + idx = stats_by_height_keys.index(vl) + statsvl = stats_by_height[vl] + + is_ivd = statsvl[1] + selfheight = statsvl[0] + + nidx = idx + 1 + if nidx < 0 or nidx >= len(stats_by_height_keys): + continue + + nkey = stats_by_height_keys[nidx] + if nkey in stats_by_height and stats_by_height[nkey][1] == is_ivd: + neighbor = stats_by_height[nkey] + neighborheight = neighbor[0] + logger.print( + f"Wrong ivd-vert alternation found in label {vl}, is_ivd = {is_ivd}, neighbor {nkey}", + Log_Type.STRANGE, + verbose=verbose, + ) + # check if same heigh, then just merge ivd label + if abs(neighborheight - selfheight) < 10: + logger.print("Same height, just merge") + stats_by_height.pop(vl) + stats_by_height = dict(sorted(stats_by_height.items(), key=lambda x: x[1][0])) + stats_by_height_keys = list(stats_by_height.keys()) + print(stats_by_height_keys) + continue - corpus_coms = corpus_nii.get_segmentation_connected_components_center_of_mass( - label=1, sort_by_axis=1 - ) # TODO replace with approx_com by bbox + logger.print("Merged corpi, try to fix it", verbose=verbose) + neighbor_verts = { + stats_by_height_keys[idx + i]: stats_by_height[stats_by_height_keys[idx + i]] + for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] + if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99 + } + # stats_by_height_keys[idx + i] + # for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] + # if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99 + # ] # (+-3) + logger.print("neighbor_vert_labels", neighbor_verts, verbose=verbose) + if len(neighbor_verts) == 0: + logger.print("Got no neighbor vert labels to fix", Log_Type.FAIL) + continue + neighbor_volumes = [k[2] for nv, k in neighbor_verts.items()] + logger.print("neighbor_volumes", neighbor_volumes, verbose=verbose) + n_neighbors_without_target = len(neighbor_volumes) - 1 + if n_neighbors_without_target > 2: + argmax = np.argmax(neighbor_volumes) + n_avg_volume = sum([neighbor_volumes[i] for i in range(len(neighbor_volumes)) if i != argmax]) / n_neighbors_without_target + + diff_volume = neighbor_volumes[argmax] / n_avg_volume + if diff_volume > 1.5: + logger.print( + f"Volume difference detected in label {vl}, diff = {diff_volume}, volume = {neighbor_volumes[argmax]}, neighbor_avg = {n_avg_volume}", + Log_Type.STRANGE, + ) + + target_vert_id = list(neighbor_verts.keys())[argmax] + segvert = np_extract_label(corpus_cc, target_vert_id, inplace=False) + try: + logger.print("get_separating_components to split vertebra", verbose=verbose) + (spart, tpart, spart_dil, tpart_dil, stpart) = get_separating_components( + segvert, + connectivity=3, + ) + + logger.print("Splitting by plane") + plane_split_nii = get_plane_split(segvert, corpus_nii, spart, tpart, spart_dil, tpart_dil) + split_vert = split_by_plane(segvert, plane_split_nii) + corpus_cc[split_vert == 2] = corpus_cc.max() + 1 + + seg_nii.set_array(corpus_cc).save( + "/DATA/NAS/ongoing_projects/hendrik/mri_usage/nako_fixmerge_test/derivatives_seg/mincutmaxflow_array.nii.gz" + ) + except Exception as e: + logger.print(f"Separating Corpi failed with exception {e}") + + corpus_coms = list(np_center_of_mass(corpus_cc).values()) + corpus_coms.sort(key=lambda a: a[1]) corpus_coms.reverse() # from bottom to top + logger.print(f"Found {len(corpus_coms)} final Corpus ccs", verbose=verbose) + return corpus_coms + + +def get_separating_components( + segvert: np.ndarray, + max_iter: int = 10, + connectivity: int = 3, +): + check_connectivtiy = 3 + vol = segvert.copy() + vol_old = vol.copy() + iterations = 0 + while True: + vol_erode = np_erode_msk(vol, mm=1, connectivity=connectivity) + subreg_cc, subreg_cc_n = np_connected_components(vol_erode, connectivity=check_connectivtiy) + if 1 in subreg_cc_n and subreg_cc_n[1] > 1: + vol = subreg_cc[1] + break + elif 1 not in subreg_cc_n: + vol_dilated = np_dilate_msk(vol, mm=1, connectivity=connectivity, mask=vol.copy()) + # use iteration before to get other CC + vol[vol_old != 0] = 2 # all possible voxels are 2 + vol[vol_dilated == 1] = 1 + + if 2 not in np_volume(vol): + raise Exception( # noqa: TRY002 + f"cannot split volume into two parts after {iterations} iterations, all values are 0 after erosion." + ) + volume = np_volume(vol) + dil_iter = 0 + while volume[1] / (volume[1] + volume[2]) < 0.5: + vol_dilated = np_dilate_msk(vol_dilated, mm=1, connectivity=connectivity, mask=vol.copy()) + # inst_nii.set_array(vol_dilated).save(files_out + f"subreg_cc_vol_dilated{dil_iter}.nii.gz") + vol[vol_dilated == 1] = 1 + # inst_nii.set_array(vol).save(files_out + f"subreg_cc_dilation{dil_iter}.nii.gz") + volume = np_volume(vol) + if 1 not in volume or 2 not in volume: + raise Exception("Could not divide into two instance") # noqa: TRY002 + dil_iter += 1 + + vol_1 = np_get_largest_k_connected_components(vol == 1, k=1, connectivity=check_connectivtiy).astype(np.uint8) + vol_2 = np_get_largest_k_connected_components(vol == 2, k=1, connectivity=check_connectivtiy).astype(np.uint8) + vol_2 *= 2 + vol[vol == 1] = vol_1[vol == 1] + vol[vol == 2] = vol_2[vol == 2] + break + vol_old = vol + vol = vol_erode + iterations += 1 + if iterations > max_iter: + raise Exception(f"Could not divide into two instance after max iterations {max_iter}") # noqa: TRY002 + if len(np_volume(vol)) != 2: + logger.print("Get largest two components") + subreg_cc_2k = np_get_largest_k_connected_components(vol, k=2, connectivity=check_connectivtiy, return_original_labels=False) + spart = subreg_cc_2k == 1 + tpart = subreg_cc_2k == 2 + else: + spart = vol == 1 + tpart = vol == 2 + + if spart.sum() == 0 or tpart.sum() == 0: + raise Exception("S or T are empty") # noqa: TRY002 + + spart_dil = np_dilate_msk(spart, mm=1, connectivity=connectivity) + tpart_dil = np_dilate_msk(tpart, mm=1, connectivity=connectivity) + stpart = (spart_dil + (tpart_dil * 2)).astype(np.uint8) + while 3 not in np_volume(stpart): + spart_dil = np_dilate_msk(spart_dil, mm=1, connectivity=connectivity) + stpart = (spart_dil + (tpart_dil * 2)).astype(np.uint8) + if 3 in np_volume(stpart): + break + tpart_dil = np_dilate_msk(tpart_dil, mm=1, connectivity=connectivity) + stpart = (spart_dil + (tpart_dil * 2)).astype(np.uint8) + # + stpart = spart_dil + (tpart_dil * 2) + return spart, tpart, spart_dil, tpart_dil, stpart + + +def get_plane_split( + segvert: np.ndarray, + compare_nii: NII, + spart: np.ndarray, + tpart: np.ndarray, + spart_dil: np.ndarray, + tpart_dil: np.ndarray, +): + s_dilint = spart_dil.astype(np.uint8) + t_dilint = tpart_dil.astype(np.uint8) + # + collision_arr = s_dilint + t_dilint + if 2 not in collision_arr: + logger.print("S and T dilated do not touch each other error") + return compare_nii.set_array(np.zeros_like(segvert)) + collision_arr[collision_arr != 2] = 0 + collision_arr[collision_arr == 2] = 1 + collision_point = np_center_of_mass(collision_arr)[1] + # TODO instead of COM, calc COM of S and T, make vector and use merge point along that vector as collision point? should be more accurate + # + normal_vector = np_center_of_mass(spart.astype(np.uint8))[1] - np_center_of_mass(tpart.astype(np.uint8))[1] + normalized_normal = normal_vector / np.linalg.norm(normal_vector) + # + axis = np.argmax(np.abs(normalized_normal)) + dims = [0, 1, 2] + dims.remove(axis) + dim1, dim2 = dims + # + shift_total = -collision_point.dot(normal_vector) + xx, yy = np.meshgrid(range(collision_arr.shape[dim1]), range(collision_arr.shape[dim2])) + zz = (-normal_vector[dim1] * xx - normal_vector[dim2] * yy - shift_total) * 1.0 / normal_vector[axis] + z_max = collision_arr.shape[axis] - 1 + zz[zz < 0] = 0 + zz[zz > z_max] = z_max + # make cords to array again + plane_coords = np.zeros([xx.shape[0], xx.shape[1], 3], dtype=int) + plane_coords[:, :, axis] = zz + plane_coords[:, :, dim1] = xx + plane_coords[:, :, dim2] = yy + + plane = segvert * 0 + plane[plane_coords[:, :, 0], plane_coords[:, :, 1], plane_coords[:, :, 2]] = 1 + plane_nii = compare_nii.set_array(plane) + + plane_filled_nii = plane_nii.copy() + orientation = plane_filled_nii.orientation + plane_filled_nii.reorient_(("S", "A", "R")) + plane_filled_arr = plane_filled_nii.get_array() + x_slice = np.ones_like(plane_filled_arr[0]) * np.max(plane_filled_arr) + 1 + # plane_filled[:, 0, :] = 2 + # plane_filled[:, -1, :] = 1 + for i in range(plane_filled_arr.shape[0]): + curr_slice = plane_filled_arr[i] + cond = np.where(curr_slice != 0) + x_slice[cond] = np.minimum(curr_slice[cond], x_slice[cond]) + plane_filled_arr[i] = x_slice + + plane_filled_nii.set_array_(plane_filled_arr).reorient_(orientation) + return plane_filled_nii + + +def split_by_plane( + segvert: np.ndarray, + plane_filled_nii: NII, +): + plane_filled_arr = plane_filled_nii.get_array() + # 1 above, 2 below + plane_filled_arr[segvert == 0] = 0 + segvert = plane_filled_arr + # segvert[plane_filled_arr == 1] = 1 + # segvert[plane_filled_arr == 2] = 2 + return segvert + + +def collect_vertebra_predictions( + seg_nii: NII, + model: Segmentation_Model, + corpus_size_cleaning: int, + cutout_size, + debug_data: dict, + proc_inst_largest_k_cc: int = 0, + process_detect_and_solve_merged_corpi: bool = True, + proc_inst_fill_holes: bool = False, + verbose: bool = False, +) -> tuple[np.ndarray | None, list[str], int]: + corpus_coms = get_corpus_coms( + seg_nii, + corpus_size_cleaning=corpus_size_cleaning, + process_detect_and_solve_merged_corpi=process_detect_and_solve_merged_corpi, + verbose=verbose, + ) + if corpus_coms is None: + return None, [], 0 n_corpus_coms = len(corpus_coms) if n_corpus_coms < 3: @@ -204,7 +499,7 @@ def collect_vertebra_predictions( seg_nii.shape[2], ) hierarchical_existing_predictions = [] - hierarchical_predictions = np.zeros((n_corpus_coms, 3, *shp), dtype=corpus_nii.dtype) + hierarchical_predictions = np.zeros((n_corpus_coms, 3, *shp), dtype=seg_nii.dtype) # print("hierarchical_predictions", hierarchical_predictions.shape) vert_predict_template = np.zeros(shp, dtype=np.uint16) # print("vert_predict_template", vert_predict_template.shape) @@ -268,8 +563,8 @@ def collect_vertebra_predictions( vert_cut_nii = post_process_single_3vert_prediction( vert_cut_nii, None, - fill_holes=fill_holes, - largest_cc=proc_largest_cc, # type:ignore + fill_holes=proc_inst_fill_holes, + largest_cc=proc_inst_largest_k_cc, # type:ignore ) vert_labels = vert_cut_nii.unique() # 1,2,3 debug_data[f"inst_cutout_vert_nii_{com_idx}_proc"] = vert_cut_nii.copy() @@ -328,7 +623,7 @@ def from_vert3_predictions_make_vert_mask( vert_size_threshold: int, debug_data: dict, # - proc_cleanvert: bool = True, + proc_inst_clean_small_cc_artifacts: bool = True, verbose: bool = False, ) -> tuple[NII, dict, ErrCode]: # instance approach: each 1/2/3 pred finds it most agreeing partner in surrounding predictions (com idx -2 to +2 all three pred) @@ -348,7 +643,7 @@ def from_vert3_predictions_make_vert_mask( coupled_predictions=coupled_predictions, hierarchical_predictions=hierarchical_predictions, debug_data=debug_data, - proc_cleanvert=proc_cleanvert, + proc_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, vert_size_threshold=vert_size_threshold, verbose=verbose, ) @@ -488,7 +783,7 @@ def merge_coupled_predictions( coupled_predictions, hierarchical_predictions: np.ndarray, debug_data: dict, - proc_cleanvert: bool = True, + proc_clean_small_cc_artifacts: bool = True, vert_size_threshold: int = 0, verbose: bool = False, ) -> tuple[NII, dict, ErrCode]: @@ -531,7 +826,7 @@ def merge_coupled_predictions( return whole_vert_nii.set_array_(whole_vert_arr, verbose=False), debug_data, ErrCode.EMPTY # Cleanup step - if proc_cleanvert: + if proc_clean_small_cc_artifacts: whole_vert_arr = clean_cc_artifacts( whole_vert_arr, labels=np_unique(whole_vert_arr)[1:], # type:ignore diff --git a/spineps/phase_post.py b/spineps/phase_post.py index 3b0e3a9..94f3a43 100644 --- a/spineps/phase_post.py +++ b/spineps/phase_post.py @@ -8,6 +8,7 @@ np_bbox_binary, np_center_of_mass, np_connected_components, + np_contacts, np_count_nonzero, np_dilate_msk, np_extract_label, @@ -17,6 +18,7 @@ ) from spineps.seg_pipeline import logger, vertebra_subreg_labels +from spineps.utils.proc_functions import fix_wrong_posterior_instance_label def phase_postprocess_combined( @@ -25,8 +27,11 @@ def phase_postprocess_combined( debug_data: dict | None, labeling_offset: int = 0, proc_assign_missing_cc: bool = True, + proc_clean_inst_by_sem: bool = True, n_vert_bodies: int | None = None, - process_vertebra_inconsistency: bool = True, + process_merge_vertebra: bool = True, + proc_vertebra_inconsistency: bool = True, + proc_assign_posterior_instance_label: bool = True, verbose: bool = False, ) -> tuple[NII, NII]: logger.print("Post process", Log_Type.STAGE) @@ -40,7 +45,9 @@ def phase_postprocess_combined( if debug_data is None: debug_data = {} # - vert_nii.apply_mask(seg_nii, inplace=True) + + if proc_clean_inst_by_sem: + vert_nii.apply_mask(seg_nii, inplace=True) crop_slices = seg_nii.compute_crop(dist=2) vert_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype) seg_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype) @@ -59,17 +66,24 @@ def phase_postprocess_combined( verbose=verbose, ) - if process_vertebra_inconsistency: + if process_merge_vertebra and Location.Vertebra_Disc.value in seg_nii_cleaned.unique(): + detect_and_solve_merged_vertebra(seg_nii_cleaned, whole_vert_nii_cleaned) + + if proc_assign_posterior_instance_label: + whole_vert_nii_cleaned = fix_wrong_posterior_instance_label(seg_nii_cleaned, seg_inst=whole_vert_nii_cleaned, logger=logger) + + if proc_vertebra_inconsistency: # Assigns superior/inferior based on instance label overlap assign_vertebra_inconsistency(seg_nii_cleaned, whole_vert_nii_cleaned) # Label vertebra top -> down - whole_vert_nii_cleaned, vert_labels = label_instance_top_to_bottom(whole_vert_nii_cleaned) - if labeling_offset != 0: - whole_vert_nii_cleaned.map_labels_({i: i + 1 for i in vert_labels if i != 0}, verbose=verbose) + whole_vert_nii_cleaned, vert_labels = label_instance_top_to_bottom(whole_vert_nii_cleaned, labeling_offset=labeling_offset) + # if labeling_offset != 0: + # whole_vert_nii_cleaned.map_labels_({i: i + 1 for i in vert_labels if i != 0}, verbose=verbose) logger.print(f"Labeled {len(vert_labels)} vertebra instances from top to bottom") - vert_arr_cleaned = add_ivd_ep_vert_label(whole_vert_nii_cleaned, seg_nii_cleaned) + # + # vert_arr_cleaned[seg_nii_cleaned.get_seg_array() == v_name2idx["S1"]] = v_name2idx["S1"] ############### # Uncrop @@ -305,20 +319,14 @@ def find_nearest_lower(seq, x): return max(values_lower) -def label_instance_top_to_bottom(vert_nii: NII): +def label_instance_top_to_bottom(vert_nii: NII, labeling_offset: int = 0): ori = vert_nii.orientation vert_nii.reorient_() vert_arr = vert_nii.get_seg_array() com_i = np_center_of_mass(vert_arr) - # Old, more precise version (but takes longer) - # comb = {} - # for i in present_labels: - # arr_i = vert_arr.copy() - # arr_i[arr_i != i] = 0 - # comb[i] = center_of_mass(arr_i) comb_l = list(zip(com_i.keys(), com_i.values(), strict=True)) comb_l.sort(key=lambda a: a[1][1]) # PIR - com_map = {comb_l[idx][0]: idx + 1 for idx in range(len(comb_l))} + com_map = {comb_l[idx][0]: idx + 1 + labeling_offset for idx in range(len(comb_l))} vert_nii.map_labels_(com_map, verbose=False) vert_nii.reorient_(ori) @@ -374,3 +382,45 @@ def assign_vertebra_inconsistency(seg_nii: NII, vert_nii: NII): ) vert_nii.set_array_(vert_arr) + + +def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII): + seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) + # get all ivd CCs from seg_sem + + stats = {} + # Map IVDS + subreg_cc, subreg_cc_n = seg_sem.get_segmentation_connected_components(labels=Location.Vertebra_Disc.value) + subreg_cc = subreg_cc[Location.Vertebra_Disc.value] + 100 + subreg_cc_n = subreg_cc_n[Location.Vertebra_Disc.value] + + coms = np_center_of_mass(subreg_cc) + volumes = np_volume(subreg_cc) + stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()} + + vert_coms = vert_nii.center_of_masses() + vert_volumes = vert_nii.volumes() + + for i, g in vert_coms.items(): + stats[i] = (g[1], False, vert_volumes[i]) + + stats_by_height = dict(sorted(stats.items(), key=lambda x: x[1][0])) + stats_by_height_keys = list(stats_by_height.keys()) + + # detect C2 split into two components + first_key, second_key = stats_by_height_keys[0], stats_by_height_keys[1] + first_stats, second_stats = stats_by_height[first_key], stats_by_height[second_key] + if first_stats[1] is False and second_stats[1] is False: # noqa: SIM102 + # both vertebra + if first_stats[2] < 0.5 * second_stats[2]: + # first is significantly smaller than second and they are close in height + # how many pixels are touching + vert_firsttwo_arr = vert_nii.extract_label(first_key).get_seg_array() + vert_firsttwo_arr2 = vert_nii.extract_label(second_key).get_seg_array() + vert_firsttwo_arr += vert_firsttwo_arr2 + 1 + contacts = np_contacts(vert_firsttwo_arr, connectivity=3) + if contacts[(1, 2)] > 20: + logger.print("Found first two instance weird, will merge", Log_Type.STRANGE) + vert_nii.map_labels_({first_key: second_key}, verbose=False) + + return seg_nii, vert_nii diff --git a/spineps/phase_pre.py b/spineps/phase_pre.py index fc66f13..1198aae 100644 --- a/spineps/phase_pre.py +++ b/spineps/phase_pre.py @@ -14,8 +14,8 @@ def preprocess_input( mri_nii: NII, debug_data: dict, # noqa: ARG001 pad_size: int = 4, - do_n4: bool = True, - do_crop: bool = True, + proc_do_n4_bias_correction: bool = True, + proc_crop_input: bool = True, verbose: bool = False, ) -> tuple[NII | None, ErrCode]: logger.print("Prepare input image", Log_Type.STAGE) @@ -25,7 +25,7 @@ def preprocess_input( try: # Enforce to range [0, 1500] mri_nii.normalize_to_range_(min_value=0, max_value=9000, verbose=logger) - crop = mri_nii.compute_crop(dist=0) if do_crop else (slice(None, None), slice(None, None), slice(None, None)) + crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None)) except ValueError: logger.print("Image Nifty is empty, skip this", Log_Type.FAIL) return None, ErrCode.EMPTY @@ -34,7 +34,7 @@ def preprocess_input( logger.print(f"Crop down from {mri_nii.shape} to {cropped_nii.shape}", verbose=verbose) # N4 Bias field correction - if do_n4: + if proc_do_n4_bias_correction: n4_start = perf_counter() cropped_nii, _ = n4_bias(cropped_nii) # PIR logger.print(f"N4 Bias field correction done in {perf_counter() - n4_start} sec", verbose=True) diff --git a/spineps/phase_semantic.py b/spineps/phase_semantic.py index c0a8ecc..3a87e08 100755 --- a/spineps/phase_semantic.py +++ b/spineps/phase_semantic.py @@ -13,9 +13,11 @@ def predict_semantic_mask( mri_nii: NII, model: Segmentation_Model, - debug_data: dict, # noqa: ARG001 - fill_holes: bool = True, - clean_artifacts: bool = True, + debug_data: dict, + proc_fill_3d_holes: bool = True, + proc_clean_beyond_largest_bounding_box: bool = True, + proc_remove_inferior_beyond_canal: bool = True, + proc_clean_small_cc_artifacts: bool = True, verbose: bool = False, ) -> tuple[NII | None, NII | None, np.ndarray | None, ErrCode]: """Predicts the semantic mask, takes care of rescaling, and back @@ -45,10 +47,18 @@ def predict_semantic_mask( unc_nii = results.get(OutputType.unc, None) softmax_logits = results[OutputType.softmax_logits] + logger.print("Post-process semantic mask...") + + debug_data["sem_raw"] = seg_nii.copy() + if len(seg_nii.unique()) == 0: logger.print("Subregion mask is empty, skip this", Log_Type.FAIL) return seg_nii, unc_nii, softmax_logits, ErrCode.EMPTY - if clean_artifacts: + + if proc_remove_inferior_beyond_canal: + seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy()) + + if proc_clean_small_cc_artifacts: seg_nii.set_array_( clean_cc_artifacts( seg_nii, @@ -73,19 +83,41 @@ def predict_semantic_mask( ), verbose=verbose, ) - if fill_holes: - seg_nii = seg_nii.fill_holes_(fill_holes_labels, verbose=logger) - seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy()) + # Do two iterations of both processing if enabled to make sure + if proc_remove_inferior_beyond_canal: + seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy()) + + if proc_clean_beyond_largest_bounding_box: + seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy()) + + if proc_remove_inferior_beyond_canal and proc_clean_beyond_largest_bounding_box: + seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy()) + seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy()) + + if proc_fill_3d_holes: + seg_nii = seg_nii.fill_holes_(fill_holes_labels, verbose=logger) return seg_nii, unc_nii, softmax_logits, ErrCode.OK +def remove_nonsacrum_beyond_canal_height(seg_nii: NII): + seg_nii.assert_affine(orientation=("P", "I", "R")) + canal_nii = seg_nii.extract_label([Location.Spinal_Canal.value, Location.Spinal_Cord.value]) + crop_i = canal_nii.compute_crop(dist=16)[1] + seg_arr = seg_nii.get_seg_array() + sacrum_arr = seg_nii.extract_label(26).get_seg_array() + seg_arr[:, 0 : crop_i.start, :] = 0 + seg_arr[:, crop_i.stop :, :] = 0 + seg_arr[sacrum_arr == 1] = 26 + return seg_nii.set_array_(seg_arr) + + def semantic_bounding_box_clean(seg_nii: NII): ori = seg_nii.orientation seg_binary = seg_nii.reorient_().extract_label(list(seg_nii.unique())) # whole thing binary seg_bin_largest_k_cc_nii = seg_binary.get_largest_k_segmentation_connected_components( - k=20, labels=1, connectivity=3, return_original_labels=False + k=None, labels=1, connectivity=3, return_original_labels=False ) max_k = int(seg_bin_largest_k_cc_nii.max()) if max_k > 3: @@ -93,7 +125,9 @@ def semantic_bounding_box_clean(seg_nii: NII): # PIR largest_nii = seg_bin_largest_k_cc_nii.extract_label(1) # width fixed, and heigh include all connected components within bounding box, then repeat - p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=5) + p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=4) + bboxes = [(p_slice, i_slice, r_slice)] + # PIR -> fixed, extendable, extendable incorporated = [1] changed = True @@ -101,22 +135,28 @@ def semantic_bounding_box_clean(seg_nii: NII): changed = False for k in [l for l in range(2, max_k + 1) if l not in incorporated]: k_nii = seg_bin_largest_k_cc_nii.extract_label(k) - p, i, r = k_nii.compute_crop(dist=3) - i_slice_compare = slice( - max(i_slice.start - 10, 0), i_slice.stop + 10 - ) # more margin in inferior direction (allows for gaps in spine) - if overlap_slice(p_slice, p) and overlap_slice(i_slice_compare, i) and overlap_slice(r_slice, r): - # extend bbox - i_slice = slice(min(i_slice.start, i.start), max(i_slice.stop, i.stop)) - r_slice = slice(min(r_slice.start, r.start), max(r_slice.stop, r.stop)) - incorporated.append(k) - changed = True + p, i, r = k_nii.compute_crop(dist=4) + + for bbox in bboxes: + i_slice_compare = slice( + max(bbox[1].start - 4, 0), bbox[1].stop + 4 + ) # more margin in inferior direction (allows for gaps of size 15 in spine) + if overlap_slice(bbox[0], p) and overlap_slice(i_slice_compare, i) and overlap_slice(bbox[2], r): + # extend bbox + bboxes.append((p, i, r)) + incorporated.append(k) + changed = True + break seg_bin_arr = seg_binary.get_seg_array() crop = (p_slice, i_slice, r_slice) seg_bin_clean_arr = np.zeros(seg_bin_arr.shape) seg_bin_clean_arr[crop] = 1 + # everything below biggest k get always removed + largest_k_arr = seg_bin_largest_k_cc_nii.get_seg_array() + seg_bin_clean_arr[largest_k_arr == 0] = 0 + seg_arr = seg_nii.get_seg_array() # logger.print(seg_nii.volumes()) seg_arr[seg_bin_clean_arr != 1] = 0 diff --git a/spineps/seg_pipeline.py b/spineps/seg_pipeline.py index 17f074d..4b46853 100755 --- a/spineps/seg_pipeline.py +++ b/spineps/seg_pipeline.py @@ -1,5 +1,6 @@ # from utils.predictor import nnUNetPredictor import subprocess +from typing import Any from scipy.ndimage import center_of_mass from TPTBox import NII, Location, No_Logger, Zooms, v_name2idx @@ -8,8 +9,7 @@ from spineps.seg_model import Segmentation_Model -logger = No_Logger() -logger.override_prefix = "SPINEPS" +logger = No_Logger(prefix="SPINEPS") fill_holes_labels = [ Location.Vertebra_Corpus_border.value, @@ -37,6 +37,7 @@ def predict_centroids_from_both( vert_nii_cleaned: NII, seg_nii: NII, models: list[Segmentation_Model], + parameter: dict[str, Any], input_zms_pir: Zooms | None = None, ): """Calculates the centroids of each vertebra corpus by using both semantic and instance mask @@ -70,6 +71,8 @@ def predict_centroids_from_both( ctd.info["models"] = models_repr ctd.info["revision"] = pipeline_revision() ctd.info["timestamp"] = format_time_short(get_time()) + for pname, pvalue in parameter.items(): + ctd.info[pname] = str(pvalue) return ctd diff --git a/spineps/seg_run.py b/spineps/seg_run.py index 808ea9c..1e5bafd 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -48,15 +48,24 @@ def process_dataset( override_ctd: bool = False, snapshot_copy_folder: Path | None | bool = None, # - do_crop_semantic: bool = True, - # - proc_n4correction: bool = True, - proc_fillholes: bool = True, - proc_clean: bool = True, - proc_corpus_clean: bool = True, - proc_cleanvert: bool = True, + pad_size: int = 4, + # Processings + # Semantic + proc_sem_crop_input: bool = True, + proc_sem_n4_bias_correction: bool = True, + proc_sem_remove_inferior_beyond_canal: bool = False, + proc_sem_clean_beyond_largest_bounding_box: bool = True, + proc_sem_clean_small_cc_artifacts: bool = True, + # Instance + proc_inst_corpus_clean: bool = True, + proc_inst_clean_small_cc_artifacts: bool = True, + proc_inst_largest_k_cc: int = 0, + proc_inst_detect_and_solve_merged_corpi: bool = True, + # Both + proc_fill_3d_holes: bool = True, proc_assign_missing_cc: bool = True, - proc_largest_cc: int = 0, + proc_clean_inst_by_sem: bool = True, + proc_vertebra_inconsistency: bool = True, # ignore_model_compatibility: bool = False, ignore_inference_compatibility: bool = False, @@ -187,17 +196,25 @@ def process_dataset( override_postpair=override_postpair, override_ctd=override_ctd, # - do_crop_semantic=do_crop_semantic, - proc_n4correction=proc_n4correction, - proc_fillholes=proc_fillholes, + proc_pad_size=pad_size, + # + proc_sem_crop_input=proc_sem_crop_input, + proc_sem_n4_bias_correction=proc_sem_n4_bias_correction, + proc_fill_3d_holes=proc_fill_3d_holes, + proc_sem_remove_inferior_beyond_canal=proc_sem_remove_inferior_beyond_canal, + proc_sem_clean_beyond_largest_bounding_box=proc_sem_clean_beyond_largest_bounding_box, # - proc_clean=proc_clean, - proc_corpus_clean=proc_corpus_clean, - proc_cleanvert=proc_cleanvert, + proc_sem_clean_small_cc_artifacts=proc_sem_clean_small_cc_artifacts, + proc_inst_detect_and_solve_merged_corpi=proc_inst_detect_and_solve_merged_corpi, + proc_inst_corpus_clean=proc_inst_corpus_clean, + proc_inst_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, proc_assign_missing_cc=proc_assign_missing_cc, - proc_largest_cc=proc_largest_cc, + proc_inst_largest_k_cc=proc_inst_largest_k_cc, + proc_clean_inst_by_sem=proc_clean_inst_by_sem, + proc_vertebra_inconsistency=proc_vertebra_inconsistency, # snapshot_copy_folder=snapshot_copy_folder, + ignore_bids_filter=ignore_bids_filter, ignore_compatibility_issues=ignore_inference_compatibility, log_inference_time=log_inference_time, verbose=verbose, @@ -248,19 +265,28 @@ def process_img_nii( # noqa: C901 override_postpair: bool = False, override_ctd: bool = False, # - do_crop_semantic: bool = True, - proc_n4correction: bool = True, - proc_fillholes: bool = True, - # - proc_clean: bool = True, - proc_corpus_clean: bool = True, - proc_cleanvert: bool = True, + proc_pad_size: int = 4, + # Processings + # Semantic + proc_sem_crop_input: bool = True, + proc_sem_n4_bias_correction: bool = True, + proc_sem_remove_inferior_beyond_canal: bool = False, + proc_sem_clean_beyond_largest_bounding_box: bool = True, + proc_sem_clean_small_cc_artifacts: bool = True, + # Instance + proc_inst_corpus_clean: bool = True, + proc_inst_clean_small_cc_artifacts: bool = True, + proc_inst_largest_k_cc: int = 0, + proc_inst_detect_and_solve_merged_corpi: bool = True, + # Both + proc_fill_3d_holes: bool = True, proc_assign_missing_cc: bool = True, - proc_largest_cc: int = 0, - process_vertebra_inconsistency: bool = True, + proc_clean_inst_by_sem: bool = True, + proc_vertebra_inconsistency: bool = True, # lambda_semantic: Callable[[NII], NII] | None = None, snapshot_copy_folder: Path | None = None, + ignore_bids_filter: bool = False, ignore_compatibility_issues: bool = False, log_inference_time: bool = True, verbose: bool = False, @@ -305,9 +331,12 @@ def process_img_nii( # noqa: C901 Returns: ErrCode: Error code depicting whether the operation was successful or not """ + arguments = locals() input_format = img_ref.format - output_paths = output_paths_from_input(img_ref, derivative_name, snapshot_copy_folder, input_format=input_format) + output_paths = output_paths_from_input( + img_ref, derivative_name, snapshot_copy_folder, input_format=input_format, non_strict_mode=ignore_bids_filter + ) out_spine = output_paths["out_spine"] out_spine_raw = output_paths["out_spine_raw"] out_vert = output_paths["out_vert"] @@ -358,7 +387,7 @@ def process_img_nii( # noqa: C901 input_nii = img_ref.open_nii() input_package = InputPackage( input_nii, - pad_size=4, + pad_size=proc_pad_size, ) logger.print("Input image", input_nii.zoom, input_nii.orientation, input_nii.shape) @@ -371,21 +400,24 @@ def process_img_nii( # noqa: C901 input_nii, pad_size=input_package.pad_size, debug_data=debug_data_run, - do_crop=do_crop_semantic, - do_n4=proc_n4correction, + proc_crop_input=proc_sem_crop_input, + proc_do_n4_bias_correction=proc_sem_n4_bias_correction, verbose=verbose, ) if errcode != ErrCode.OK: logger.print("Got Error from preprocessing", Log_Type.FAIL) return output_paths, errcode # make subreg mask + assert input_preprocessed is not None seg_nii_modelres, unc_nii, softmax_logits, errcode = predict_semantic_mask( input_preprocessed, model_semantic, debug_data=debug_data_run, verbose=verbose, - fill_holes=proc_fillholes, - clean_artifacts=proc_clean, + proc_fill_3d_holes=proc_fill_3d_holes, + proc_clean_small_cc_artifacts=proc_sem_clean_small_cc_artifacts, + proc_clean_beyond_largest_bounding_box=proc_sem_clean_beyond_largest_bounding_box, + proc_remove_inferior_beyond_canal=proc_sem_remove_inferior_beyond_canal, ) if errcode != ErrCode.OK: return output_paths, errcode @@ -423,10 +455,11 @@ def process_img_nii( # noqa: C901 model_instance, debug_data=debug_data_run, verbose=verbose, - fill_holes=proc_fillholes, - proc_corpus_clean=proc_corpus_clean, - proc_cleanvert=proc_cleanvert, - proc_largest_cc=proc_largest_cc, + proc_inst_fill_3d_holes=proc_fill_3d_holes, + proc_detect_and_solve_merged_corpi=proc_inst_detect_and_solve_merged_corpi, + proc_corpus_clean=proc_inst_corpus_clean, + proc_inst_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, + proc_inst_largest_k_cc=proc_inst_largest_k_cc, ) if errcode != ErrCode.OK: logger.print(f"Vert Mask creation failed with errcode {errcode}", Log_Type.FAIL) @@ -457,8 +490,9 @@ def process_img_nii( # noqa: C901 vert_nii=whole_vert_nii, debug_data=debug_data_run, labeling_offset=1, + proc_clean_inst_by_sem=proc_clean_inst_by_sem, proc_assign_missing_cc=proc_assign_missing_cc, - process_vertebra_inconsistency=process_vertebra_inconsistency, + proc_vertebra_inconsistency=proc_vertebra_inconsistency, verbose=verbose, ) @@ -479,6 +513,7 @@ def process_img_nii( # noqa: C901 vert_nii_clean, seg_nii_clean, models=[model_semantic, model_instance], + parameter={l: v for l, v in arguments.items() if "proc_" in l}, input_zms_pir=input_package.zms_pir, ) ctd.rescale(input_package._zms, verbose=logger).reorient(input_package._orientation).save(out_ctd, verbose=logger) @@ -516,28 +551,60 @@ def process_img_nii( # noqa: C901 return output_paths, ErrCode.OK -def output_paths_from_input(img_ref: BIDS_FILE, derivative_name: str, snapshot_copy_folder: Path | None, input_format: str): - out_spine = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) - out_vert = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "vert", "mod": img_ref.format}) - out_snap = img_ref.get_changed_path(format="snp", file_type="png", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) - out_ctd = img_ref.get_changed_path(format="ctd", file_type="json", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) +def output_paths_from_input( + img_ref: BIDS_FILE, + derivative_name: str, + snapshot_copy_folder: Path | None, + input_format: str, + non_strict_mode: bool = False, +): + out_spine = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) + out_vert = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "vert", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) + out_snap = img_ref.get_changed_path( + bids_format="snp", + file_type="png", + parent=derivative_name, + info={"seg": "spine", "mod": img_ref.format}, + non_strict_mode=non_strict_mode, + ) + out_ctd = img_ref.get_changed_path( + bids_format="ctd", + file_type="json", + parent=derivative_name, + info={"seg": "spine", "mod": img_ref.format}, + non_strict_mode=non_strict_mode, + ) out_snap2 = snapshot_copy_folder.joinpath(out_snap.name) if snapshot_copy_folder is not None else out_snap # out_debug = out_vert.parent.joinpath(f"debug_{input_format}") # out_raw = out_vert.parent.joinpath(f"output_raw_{input_format}") # - out_spine_raw = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "spine-raw", "mod": img_ref.format}) + out_spine_raw = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "spine-raw", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) out_spine_raw = out_raw.joinpath(out_spine_raw.name) # - out_vert_raw = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "vert-raw", "mod": img_ref.format}) + out_vert_raw = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "vert-raw", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) out_vert_raw = out_raw.joinpath(out_vert_raw.name) # - out_unc = img_ref.get_changed_path(format="uncertainty", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) + out_unc = img_ref.get_changed_path( + bids_format="uncertainty", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) out_unc = out_raw.joinpath(out_unc.name) # out_logits = img_ref.get_changed_path( - file_type="npz", format="logit", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format} + file_type="npz", + bids_format="logit", + parent=derivative_name, + info={"seg": "spine", "mod": img_ref.format}, + non_strict_mode=non_strict_mode, ) out_logits = out_raw.joinpath(out_logits.name) return { diff --git a/spineps/utils/proc_functions.py b/spineps/utils/proc_functions.py index a12d404..11815dc 100755 --- a/spineps/utils/proc_functions.py +++ b/spineps/utils/proc_functions.py @@ -2,8 +2,8 @@ import numpy as np from ants.utils.convert_nibabel import from_nibabel from scipy.ndimage import center_of_mass -from TPTBox import NII, Logger_Interface -from TPTBox.core.np_utils import np_bbox_binary, np_count_nonzero, np_dilate_msk, np_unique, np_volume +from TPTBox import NII, Location, Logger_Interface +from TPTBox.core.np_utils import np_bbox_binary, np_count_nonzero, np_dilate_msk, np_unique, np_unique_withoutzero, np_volume from tqdm import tqdm @@ -181,3 +181,57 @@ def connected_components_3d(mask_image: np.ndarray, connectivity: int = 3, verbo if (n) != 1: # zero is a label print(f"subreg {subreg} does not have one CC (not counting zeros), got {n}") if verbose else None return subreg_cc, subreg_cc_stats + + +def fix_wrong_posterior_instance_label(seg_sem: NII, seg_inst: NII, logger) -> NII: + seg_sem = seg_sem.copy() + seg_inst = seg_inst.copy() + orientation = seg_sem.orientation + seg_sem.assert_affine(other=seg_inst) + seg_sem.reorient_() + seg_inst.reorient_() + + seg_inst_arr_proc = seg_inst.get_seg_array() + + instance_labels = [i for i in seg_inst.unique() if 1 <= i <= 25] + + for vert in instance_labels: + inst_vert = seg_inst.extract_label(vert) + # sem_vert = seg_sem.apply_mask(inst_vert) + + # Check if multiple CC exist + inst_vert_cc = inst_vert.get_largest_k_segmentation_connected_components(3, return_original_labels=False) + inst_vert_cc_n = int(inst_vert_cc.max()) + # + if inst_vert_cc_n == 1: + continue + # + # inst_vert_cc is labeled 1 to 3 + for i in range(2, inst_vert_cc_n + 1): + inst_vert_cc_i = inst_vert_cc.extract_label(i) + + crop = inst_vert_cc_i.compute_crop(dist=1) + inst_vert_cc_i_c = inst_vert_cc_i.apply_crop(crop) + + cc_sem_vert = seg_sem.apply_crop(crop).apply_mask(inst_vert_cc_i_c) + # cc_vert is semantic mask of only that cc of instance + + cc_sem_vert_labels = cc_sem_vert.unique() + # is that cc only arcus and spinosus? + if len(cc_sem_vert_labels) <= 2 and np.all( + [i in [Location.Arcus_Vertebrae.value, Location.Spinosus_Process.value] for i in cc_sem_vert_labels] + ): + # neighbor that have non arcus/spinosus label? + neighbor_instance_labels = seg_inst.apply_crop(crop).get_seg_array() + neighbor_instance_labels[inst_vert_cc_i_c.get_seg_array() == 1] = 0 + neighbor_instance_labels = np_unique_withoutzero(neighbor_instance_labels) + # which instance labels does it touch + logger.print(f"vert {vert}, cc_k {i} has instance neighbors {neighbor_instance_labels}") + # is it touching only one other instance label? + if len(neighbor_instance_labels) == 1 and neighbor_instance_labels[0] != vert: + to_label = neighbor_instance_labels[0] + logger.print(f"vert {vert}, cc_k {i} relabel to instance {to_label}") + seg_inst_arr_proc[inst_vert_cc_i.get_seg_array() == 1] = to_label + + seg_inst_proc = seg_inst.set_array(seg_inst_arr_proc).reorient_(orientation) + return seg_inst_proc diff --git a/unit_tests/test_semantic.py b/unit_tests/test_semantic.py index f64bb42..c4a20ce 100644 --- a/unit_tests/test_semantic.py +++ b/unit_tests/test_semantic.py @@ -79,7 +79,9 @@ def test_segment_scan(self): model = Segmentation_Model_Dummy() model.run = MagicMock(return_value={OutputType.seg: subreg, OutputType.softmax_logits: None}) debug_data = {} - seg_nii, unc_nii, softmax_logits, errcode = predict_semantic_mask(mri, model, debug_data=debug_data, verbose=True) + seg_nii, unc_nii, softmax_logits, errcode = predict_semantic_mask( + mri, model, debug_data=debug_data, verbose=True, proc_clean_small_cc_artifacts=False + ) predicted_volumes = seg_nii.volumes() ref_volumes = subreg.volumes() print(predicted_volumes)