From 93299573104e997e196fe231ffba9cf2e4696558 Mon Sep 17 00:00:00 2001 From: iback Date: Tue, 7 May 2024 09:31:18 +0000 Subject: [PATCH 1/5] renamed processing parameters consistently, added mincutmaxflow to combat merged vertebrae --- pyproject.toml | 2 +- spineps/entrypoint.py | 8 +- spineps/phase_instance.py | 184 ++++++++++++++++++++++++++++----- spineps/phase_post.py | 75 +++++++++++--- spineps/phase_pre.py | 8 +- spineps/phase_semantic.py | 78 ++++++++++---- spineps/seg_pipeline.py | 7 +- spineps/seg_run.py | 155 +++++++++++++++++++-------- spineps/utils/mincutmaxflow.py | 182 ++++++++++++++++++++++++++++++++ unit_tests/test_semantic.py | 4 +- 10 files changed, 586 insertions(+), 117 deletions(-) create mode 100644 spineps/utils/mincutmaxflow.py diff --git a/pyproject.toml b/pyproject.toml index 650f459..7899ee4 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ torchmetrics = "^1.1.2" tqdm = "^4.66.1" einops= "^0.6.1" nnunetv2 = "2.2" -tptbox = "^0.0.9" +tptbox = "^0.1.0" antspyx = "*" rich = "^13.6.0" diff --git a/spineps/entrypoint.py b/spineps/entrypoint.py index 0d47f92..d4aca0a 100755 --- a/spineps/entrypoint.py +++ b/spineps/entrypoint.py @@ -209,8 +209,8 @@ def run_sample(opt: Namespace): "override_postpair": opt.override_postpair, "override_ctd": opt.override_ctd, # - "do_crop_semantic": not opt.nocrop, - "proc_n4correction": not opt.non4, + "proc_sem_crop_input": not opt.nocrop, + "proc_sem_n4_bias_correction": not opt.non4, "ignore_compatibility_issues": opt.ignore_inference_compatibility, "verbose": opt.verbose, } @@ -283,8 +283,8 @@ def run_dataset(opt: Namespace): "ignore_inference_compatibility": opt.ignore_inference_compatibility, "ignore_bids_filter": opt.ignore_bids_filter, # - "do_crop_semantic": not opt.nocrop, - "proc_n4correction": not opt.non4, + "proc_sem_crop_input": not opt.nocrop, + "proc_sem_n4_bias_correction": not opt.non4, "snapshot_copy_folder": opt.save_snaps_folder, "verbose": opt.verbose, } diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index 15589bb..c136bcc 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -1,12 +1,21 @@ # from utils.predictor import nnUNetPredictor import numpy as np from TPTBox import NII, Location, Log_Type -from TPTBox.core.np_utils import np_calc_crop_around_centerpoint, np_count_nonzero, np_dice, np_unique +from TPTBox.core.np_utils import ( + np_calc_crop_around_centerpoint, + np_center_of_mass, + np_count_nonzero, + np_dice, + np_extract_label, + np_unique, + np_volume, +) from tqdm import tqdm from spineps.seg_enums import ErrCode, OutputType from spineps.seg_model import Segmentation_Model from spineps.seg_pipeline import logger +from spineps.utils.mincutmaxflow import np_mincutmaxflow from spineps.utils.proc_functions import clean_cc_artifacts @@ -15,10 +24,11 @@ def predict_instance_mask( model: Segmentation_Model, debug_data: dict, pad_size: int = 0, - fill_holes: bool = True, + proc_inst_fill_3d_holes: bool = True, + proc_detect_and_solve_merged_corpi: bool = True, proc_corpus_clean: bool = True, - proc_cleanvert: bool = True, - proc_largest_cc: int = 0, + proc_inst_clean_small_cc_artifacts: bool = True, + proc_inst_largest_k_cc: int = 0, verbose: bool = False, ) -> tuple[NII | None, ErrCode]: """Based on subregion segmentation, feeds individual arcus coms to a network to get the vertebra body segmentations @@ -94,8 +104,9 @@ def predict_instance_mask( corpus_size_cleaning=corpus_size_cleaning if proc_corpus_clean else 0, cutout_size=cutout_size, debug_data=debug_data, - proc_largest_cc=proc_largest_cc, - fill_holes=False, + process_detect_and_solve_merged_corpi=proc_detect_and_solve_merged_corpi, + proc_inst_largest_k_cc=proc_inst_largest_k_cc, + proc_inst_fill_holes=False, verbose=verbose, ) if vert_predictions is None: @@ -109,7 +120,7 @@ def predict_instance_mask( hierarchical_existing_predictions, vert_size_threshold, debug_data=debug_data, - proc_cleanvert=proc_cleanvert, + proc_inst_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, ) del vert_predictions, hierarchical_existing_predictions if errcode != ErrCode.OK: @@ -122,7 +133,7 @@ def predict_instance_mask( uniq_labels = whole_vert_nii.unique() - if fill_holes: + if proc_inst_fill_3d_holes: whole_vert_nii.fill_holes_(verbose=logger) debug_data["inst_cropped_vert_arr_c_proc"] = whole_vert_nii.copy() n_vert_bodies = len(uniq_labels) @@ -155,18 +166,16 @@ def predict_instance_mask( return whole_vert_nii_uncropped, ErrCode.OK -def collect_vertebra_predictions( +def get_corpus_coms( seg_nii: NII, - model: Segmentation_Model, corpus_size_cleaning: int, - cutout_size, - debug_data: dict, - proc_largest_cc: int = 0, - fill_holes: bool = False, + process_detect_and_solve_merged_corpi: bool = True, verbose: bool = False, -) -> tuple[np.ndarray | None, list[str], int]: +) -> list: + seg_nii.assert_affine(orientation=["P", "I", "R"]) + # # Extract Corpus region and try to find all coms naively (some skips shouldnt matter) - corpus_nii = seg_nii.extract_label(49) + corpus_nii = seg_nii.extract_label(Location.Vertebra_Corpus_border.value) corpus_nii.erode_msk_(mm=2, connectivity=2, verbose=False) if 1 in corpus_nii.unique() and corpus_size_cleaning > 0: corpus_nii.set_array_( @@ -184,12 +193,133 @@ def collect_vertebra_predictions( if 1 not in corpus_nii.unique(): logger.print("No 1 in corpus nifty, cannot make vertebra mask", Log_Type.FAIL) - return None, [], 0 + return None + + if not process_detect_and_solve_merged_corpi: + corpus_coms = corpus_nii.get_segmentation_connected_components_center_of_mass(label=1, sort_by_axis=1) + corpus_coms.reverse() # from bottom to top + return corpus_coms + + ############ + # Detect merged vertebra and use mincutmaxflow algo on it + ############ + + # Get corpus CCs + corpus_cc, corpus_cc_n = corpus_nii.get_segmentation_connected_components(labels=1) + corpus_cc = corpus_cc[1] + corpus_cc_n = corpus_cc_n[1] + logger.print(f"Found {corpus_cc_n} Corpus ccs", verbose=verbose) + + # Check against ivd order + seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) + subreg_cc, subreg_cc_n = seg_sem.get_segmentation_connected_components(labels=Location.Vertebra_Disc.value) + subreg_cc = subreg_cc[Location.Vertebra_Disc.value] + subreg_cc[subreg_cc > 0] += 100 + subreg_cc_n = subreg_cc_n[Location.Vertebra_Disc.value] + logger.print(f"Found {subreg_cc_n} IVD ccs", verbose=verbose) + coms = np_center_of_mass(subreg_cc) + volumes = np_volume(subreg_cc) + stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()} + + vert_coms = np_center_of_mass(corpus_cc) + vert_volumes = np_volume(corpus_cc) + + for i, g in vert_coms.items(): + stats[i] = (g[1], False, vert_volumes[i]) + + stats_by_height = dict(sorted(stats.items(), key=lambda x: x[1][0])) + stats_by_height_keys = list(stats_by_height.keys()) + + for vl in stats_by_height_keys: + idx = stats_by_height_keys.index(vl) + statsvl = stats_by_height[vl] + + is_ivd = statsvl[1] + selfheight = statsvl[0] + + nidx = idx + 1 + if nidx < 0 or nidx >= len(stats_by_height_keys): + continue - corpus_coms = corpus_nii.get_segmentation_connected_components_center_of_mass( - label=1, sort_by_axis=1 - ) # TODO replace with approx_com by bbox + nkey = stats_by_height_keys[nidx] + if nkey in stats_by_height and stats_by_height[nkey][1] == is_ivd: + neighbor = stats_by_height[nkey] + neighborheight = neighbor[0] + logger.print( + f"Wrong ivd-vert alternation found in label {vl}, is_ivd = {is_ivd}, neighbor {nkey}", + Log_Type.STRANGE, + verbose=verbose, + ) + # check if same heigh, then just merge ivd label + if abs(neighborheight - selfheight) < 10: + logger.print("Same height, just merge") + stats_by_height.pop(vl) + stats_by_height = dict(sorted(stats_by_height.items(), key=lambda x: x[1][0])) + stats_by_height_keys = list(stats_by_height.keys()) + print(stats_by_height_keys) + continue + + logger.print("Merged corpi, try to fix it", verbose=verbose) + neighbor_verts = { + stats_by_height_keys[idx + i]: stats_by_height[stats_by_height_keys[idx + i]] + for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] + if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99 + } + # stats_by_height_keys[idx + i] + # for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] + # if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99 + # ] # (+-3) + logger.print("neighbor_vert_labels", neighbor_verts, verbose=verbose) + if len(neighbor_verts) == 0: + logger.print("Got no neighbor vert labels to fix", Log_Type.FAIL) + continue + neighbor_volumes = [k[2] for nv, k in neighbor_verts.items()] + logger.print("neighbor_volumes", neighbor_volumes, verbose=verbose) + n_neighbors_without_target = len(neighbor_volumes) - 1 + if n_neighbors_without_target > 2: + argmax = np.argmax(neighbor_volumes) + n_avg_volume = sum([neighbor_volumes[i] for i in range(len(neighbor_volumes)) if i != argmax]) / n_neighbors_without_target + + diff_volume = neighbor_volumes[argmax] / n_avg_volume + if diff_volume > 1.5: + logger.print( + f"Volume difference detected in label {vl}, diff = {diff_volume}, volume = {neighbor_volumes[argmax]}, neighbor_avg = {n_avg_volume}", + Log_Type.STRANGE, + ) + + target_vert_id = list(neighbor_verts.keys())[argmax] + segvert = np_extract_label(corpus_cc, target_vert_id, inplace=False) + try: + split_vert = np_mincutmaxflow(segvert, None) + corpus_cc[split_vert == 2] = corpus_cc.max() + 1 + except Exception as e: + logger.print(f"MinCutMaxFlow failed with exception {e}") + + corpus_coms = list(np_center_of_mass(corpus_cc).values()) + corpus_coms.sort(key=lambda a: a[1]) corpus_coms.reverse() # from bottom to top + return corpus_coms + + +def collect_vertebra_predictions( + seg_nii: NII, + model: Segmentation_Model, + corpus_size_cleaning: int, + cutout_size, + debug_data: dict, + proc_inst_largest_k_cc: int = 0, + process_detect_and_solve_merged_corpi: bool = True, + proc_inst_fill_holes: bool = False, + verbose: bool = False, +) -> tuple[np.ndarray | None, list[str], int]: + corpus_coms = get_corpus_coms( + seg_nii, + corpus_size_cleaning=corpus_size_cleaning, + process_detect_and_solve_merged_corpi=process_detect_and_solve_merged_corpi, + verbose=verbose, + ) + if corpus_coms is None: + return None, [], 0 n_corpus_coms = len(corpus_coms) if n_corpus_coms < 3: @@ -204,7 +334,7 @@ def collect_vertebra_predictions( seg_nii.shape[2], ) hierarchical_existing_predictions = [] - hierarchical_predictions = np.zeros((n_corpus_coms, 3, *shp), dtype=corpus_nii.dtype) + hierarchical_predictions = np.zeros((n_corpus_coms, 3, *shp), dtype=seg_nii.dtype) # print("hierarchical_predictions", hierarchical_predictions.shape) vert_predict_template = np.zeros(shp, dtype=np.uint16) # print("vert_predict_template", vert_predict_template.shape) @@ -268,8 +398,8 @@ def collect_vertebra_predictions( vert_cut_nii = post_process_single_3vert_prediction( vert_cut_nii, None, - fill_holes=fill_holes, - largest_cc=proc_largest_cc, # type:ignore + fill_holes=proc_inst_fill_holes, + largest_cc=proc_inst_largest_k_cc, # type:ignore ) vert_labels = vert_cut_nii.unique() # 1,2,3 debug_data[f"inst_cutout_vert_nii_{com_idx}_proc"] = vert_cut_nii.copy() @@ -328,7 +458,7 @@ def from_vert3_predictions_make_vert_mask( vert_size_threshold: int, debug_data: dict, # - proc_cleanvert: bool = True, + proc_inst_clean_small_cc_artifacts: bool = True, verbose: bool = False, ) -> tuple[NII, dict, ErrCode]: # instance approach: each 1/2/3 pred finds it most agreeing partner in surrounding predictions (com idx -2 to +2 all three pred) @@ -348,7 +478,7 @@ def from_vert3_predictions_make_vert_mask( coupled_predictions=coupled_predictions, hierarchical_predictions=hierarchical_predictions, debug_data=debug_data, - proc_cleanvert=proc_cleanvert, + proc_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, vert_size_threshold=vert_size_threshold, verbose=verbose, ) @@ -488,7 +618,7 @@ def merge_coupled_predictions( coupled_predictions, hierarchical_predictions: np.ndarray, debug_data: dict, - proc_cleanvert: bool = True, + proc_clean_small_cc_artifacts: bool = True, vert_size_threshold: int = 0, verbose: bool = False, ) -> tuple[NII, dict, ErrCode]: @@ -531,7 +661,7 @@ def merge_coupled_predictions( return whole_vert_nii.set_array_(whole_vert_arr, verbose=False), debug_data, ErrCode.EMPTY # Cleanup step - if proc_cleanvert: + if proc_clean_small_cc_artifacts: whole_vert_arr = clean_cc_artifacts( whole_vert_arr, labels=np_unique(whole_vert_arr)[1:], # type:ignore diff --git a/spineps/phase_post.py b/spineps/phase_post.py index 3b0e3a9..8c41a9e 100644 --- a/spineps/phase_post.py +++ b/spineps/phase_post.py @@ -8,6 +8,7 @@ np_bbox_binary, np_center_of_mass, np_connected_components, + np_contacts, np_count_nonzero, np_dilate_msk, np_extract_label, @@ -25,8 +26,10 @@ def phase_postprocess_combined( debug_data: dict | None, labeling_offset: int = 0, proc_assign_missing_cc: bool = True, + proc_clean_inst_by_sem: bool = True, n_vert_bodies: int | None = None, - process_vertebra_inconsistency: bool = True, + process_merge_vertebra: bool = True, + proc_vertebra_inconsistency: bool = True, verbose: bool = False, ) -> tuple[NII, NII]: logger.print("Post process", Log_Type.STAGE) @@ -40,7 +43,9 @@ def phase_postprocess_combined( if debug_data is None: debug_data = {} # - vert_nii.apply_mask(seg_nii, inplace=True) + + if proc_clean_inst_by_sem: + vert_nii.apply_mask(seg_nii, inplace=True) crop_slices = seg_nii.compute_crop(dist=2) vert_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype) seg_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype) @@ -59,17 +64,21 @@ def phase_postprocess_combined( verbose=verbose, ) - if process_vertebra_inconsistency: + if process_merge_vertebra and Location.Vertebra_Disc.value in seg_nii_cleaned.unique(): + detect_and_solve_merged_vertebra(seg_nii_cleaned, whole_vert_nii_cleaned) + + if proc_vertebra_inconsistency: # Assigns superior/inferior based on instance label overlap assign_vertebra_inconsistency(seg_nii_cleaned, whole_vert_nii_cleaned) # Label vertebra top -> down - whole_vert_nii_cleaned, vert_labels = label_instance_top_to_bottom(whole_vert_nii_cleaned) - if labeling_offset != 0: - whole_vert_nii_cleaned.map_labels_({i: i + 1 for i in vert_labels if i != 0}, verbose=verbose) + whole_vert_nii_cleaned, vert_labels = label_instance_top_to_bottom(whole_vert_nii_cleaned, labeling_offset=labeling_offset) + # if labeling_offset != 0: + # whole_vert_nii_cleaned.map_labels_({i: i + 1 for i in vert_labels if i != 0}, verbose=verbose) logger.print(f"Labeled {len(vert_labels)} vertebra instances from top to bottom") - vert_arr_cleaned = add_ivd_ep_vert_label(whole_vert_nii_cleaned, seg_nii_cleaned) + # + # vert_arr_cleaned[seg_nii_cleaned.get_seg_array() == v_name2idx["S1"]] = v_name2idx["S1"] ############### # Uncrop @@ -305,20 +314,14 @@ def find_nearest_lower(seq, x): return max(values_lower) -def label_instance_top_to_bottom(vert_nii: NII): +def label_instance_top_to_bottom(vert_nii: NII, labeling_offset: int = 0): ori = vert_nii.orientation vert_nii.reorient_() vert_arr = vert_nii.get_seg_array() com_i = np_center_of_mass(vert_arr) - # Old, more precise version (but takes longer) - # comb = {} - # for i in present_labels: - # arr_i = vert_arr.copy() - # arr_i[arr_i != i] = 0 - # comb[i] = center_of_mass(arr_i) comb_l = list(zip(com_i.keys(), com_i.values(), strict=True)) comb_l.sort(key=lambda a: a[1][1]) # PIR - com_map = {comb_l[idx][0]: idx + 1 for idx in range(len(comb_l))} + com_map = {comb_l[idx][0]: idx + 1 + labeling_offset for idx in range(len(comb_l))} vert_nii.map_labels_(com_map, verbose=False) vert_nii.reorient_(ori) @@ -374,3 +377,45 @@ def assign_vertebra_inconsistency(seg_nii: NII, vert_nii: NII): ) vert_nii.set_array_(vert_arr) + + +def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII): + seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) + # get all ivd CCs from seg_sem + + stats = {} + # Map IVDS + subreg_cc, subreg_cc_n = seg_sem.get_segmentation_connected_components(labels=Location.Vertebra_Disc.value) + subreg_cc = subreg_cc[Location.Vertebra_Disc.value] + 100 + subreg_cc_n = subreg_cc_n[Location.Vertebra_Disc.value] + + coms = np_center_of_mass(subreg_cc) + volumes = np_volume(subreg_cc) + stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()} + + vert_coms = vert_nii.center_of_masses() + vert_volumes = vert_nii.volumes() + + for i, g in vert_coms.items(): + stats[i] = (g[1], False, vert_volumes[i]) + + stats_by_height = dict(sorted(stats.items(), key=lambda x: x[1][0])) + stats_by_height_keys = list(stats_by_height.keys()) + + # detect C2 split into two components + first_key, second_key = stats_by_height_keys[0], stats_by_height_keys[1] + first_stats, second_stats = stats_by_height[first_key], stats_by_height[second_key] + if first_stats[1] is False and second_stats[1] is False: # noqa: SIM102 + # both vertebra + if first_stats[2] < 0.5 * second_stats[2]: + # first is significantly smaller than second and they are close in height + # how many pixels are touching + vert_firsttwo_arr = vert_nii.extract_label(first_key).get_seg_array() + vert_firsttwo_arr2 = vert_nii.extract_label(second_key).get_seg_array() + vert_firsttwo_arr += vert_firsttwo_arr2 + 1 + contacts = np_contacts(vert_firsttwo_arr, connectivity=3) + if contacts[(1, 2)] > 20: + logger.print("Found first two instance weird, will merge", Log_Type.STRANGE) + vert_nii.map_labels_({first_key: second_key}, verbose=False) + + return seg_nii, vert_nii diff --git a/spineps/phase_pre.py b/spineps/phase_pre.py index fc66f13..1198aae 100644 --- a/spineps/phase_pre.py +++ b/spineps/phase_pre.py @@ -14,8 +14,8 @@ def preprocess_input( mri_nii: NII, debug_data: dict, # noqa: ARG001 pad_size: int = 4, - do_n4: bool = True, - do_crop: bool = True, + proc_do_n4_bias_correction: bool = True, + proc_crop_input: bool = True, verbose: bool = False, ) -> tuple[NII | None, ErrCode]: logger.print("Prepare input image", Log_Type.STAGE) @@ -25,7 +25,7 @@ def preprocess_input( try: # Enforce to range [0, 1500] mri_nii.normalize_to_range_(min_value=0, max_value=9000, verbose=logger) - crop = mri_nii.compute_crop(dist=0) if do_crop else (slice(None, None), slice(None, None), slice(None, None)) + crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None)) except ValueError: logger.print("Image Nifty is empty, skip this", Log_Type.FAIL) return None, ErrCode.EMPTY @@ -34,7 +34,7 @@ def preprocess_input( logger.print(f"Crop down from {mri_nii.shape} to {cropped_nii.shape}", verbose=verbose) # N4 Bias field correction - if do_n4: + if proc_do_n4_bias_correction: n4_start = perf_counter() cropped_nii, _ = n4_bias(cropped_nii) # PIR logger.print(f"N4 Bias field correction done in {perf_counter() - n4_start} sec", verbose=True) diff --git a/spineps/phase_semantic.py b/spineps/phase_semantic.py index c0a8ecc..3a87e08 100755 --- a/spineps/phase_semantic.py +++ b/spineps/phase_semantic.py @@ -13,9 +13,11 @@ def predict_semantic_mask( mri_nii: NII, model: Segmentation_Model, - debug_data: dict, # noqa: ARG001 - fill_holes: bool = True, - clean_artifacts: bool = True, + debug_data: dict, + proc_fill_3d_holes: bool = True, + proc_clean_beyond_largest_bounding_box: bool = True, + proc_remove_inferior_beyond_canal: bool = True, + proc_clean_small_cc_artifacts: bool = True, verbose: bool = False, ) -> tuple[NII | None, NII | None, np.ndarray | None, ErrCode]: """Predicts the semantic mask, takes care of rescaling, and back @@ -45,10 +47,18 @@ def predict_semantic_mask( unc_nii = results.get(OutputType.unc, None) softmax_logits = results[OutputType.softmax_logits] + logger.print("Post-process semantic mask...") + + debug_data["sem_raw"] = seg_nii.copy() + if len(seg_nii.unique()) == 0: logger.print("Subregion mask is empty, skip this", Log_Type.FAIL) return seg_nii, unc_nii, softmax_logits, ErrCode.EMPTY - if clean_artifacts: + + if proc_remove_inferior_beyond_canal: + seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy()) + + if proc_clean_small_cc_artifacts: seg_nii.set_array_( clean_cc_artifacts( seg_nii, @@ -73,19 +83,41 @@ def predict_semantic_mask( ), verbose=verbose, ) - if fill_holes: - seg_nii = seg_nii.fill_holes_(fill_holes_labels, verbose=logger) - seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy()) + # Do two iterations of both processing if enabled to make sure + if proc_remove_inferior_beyond_canal: + seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy()) + + if proc_clean_beyond_largest_bounding_box: + seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy()) + + if proc_remove_inferior_beyond_canal and proc_clean_beyond_largest_bounding_box: + seg_nii = remove_nonsacrum_beyond_canal_height(seg_nii=seg_nii.copy()) + seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy()) + + if proc_fill_3d_holes: + seg_nii = seg_nii.fill_holes_(fill_holes_labels, verbose=logger) return seg_nii, unc_nii, softmax_logits, ErrCode.OK +def remove_nonsacrum_beyond_canal_height(seg_nii: NII): + seg_nii.assert_affine(orientation=("P", "I", "R")) + canal_nii = seg_nii.extract_label([Location.Spinal_Canal.value, Location.Spinal_Cord.value]) + crop_i = canal_nii.compute_crop(dist=16)[1] + seg_arr = seg_nii.get_seg_array() + sacrum_arr = seg_nii.extract_label(26).get_seg_array() + seg_arr[:, 0 : crop_i.start, :] = 0 + seg_arr[:, crop_i.stop :, :] = 0 + seg_arr[sacrum_arr == 1] = 26 + return seg_nii.set_array_(seg_arr) + + def semantic_bounding_box_clean(seg_nii: NII): ori = seg_nii.orientation seg_binary = seg_nii.reorient_().extract_label(list(seg_nii.unique())) # whole thing binary seg_bin_largest_k_cc_nii = seg_binary.get_largest_k_segmentation_connected_components( - k=20, labels=1, connectivity=3, return_original_labels=False + k=None, labels=1, connectivity=3, return_original_labels=False ) max_k = int(seg_bin_largest_k_cc_nii.max()) if max_k > 3: @@ -93,7 +125,9 @@ def semantic_bounding_box_clean(seg_nii: NII): # PIR largest_nii = seg_bin_largest_k_cc_nii.extract_label(1) # width fixed, and heigh include all connected components within bounding box, then repeat - p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=5) + p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=4) + bboxes = [(p_slice, i_slice, r_slice)] + # PIR -> fixed, extendable, extendable incorporated = [1] changed = True @@ -101,22 +135,28 @@ def semantic_bounding_box_clean(seg_nii: NII): changed = False for k in [l for l in range(2, max_k + 1) if l not in incorporated]: k_nii = seg_bin_largest_k_cc_nii.extract_label(k) - p, i, r = k_nii.compute_crop(dist=3) - i_slice_compare = slice( - max(i_slice.start - 10, 0), i_slice.stop + 10 - ) # more margin in inferior direction (allows for gaps in spine) - if overlap_slice(p_slice, p) and overlap_slice(i_slice_compare, i) and overlap_slice(r_slice, r): - # extend bbox - i_slice = slice(min(i_slice.start, i.start), max(i_slice.stop, i.stop)) - r_slice = slice(min(r_slice.start, r.start), max(r_slice.stop, r.stop)) - incorporated.append(k) - changed = True + p, i, r = k_nii.compute_crop(dist=4) + + for bbox in bboxes: + i_slice_compare = slice( + max(bbox[1].start - 4, 0), bbox[1].stop + 4 + ) # more margin in inferior direction (allows for gaps of size 15 in spine) + if overlap_slice(bbox[0], p) and overlap_slice(i_slice_compare, i) and overlap_slice(bbox[2], r): + # extend bbox + bboxes.append((p, i, r)) + incorporated.append(k) + changed = True + break seg_bin_arr = seg_binary.get_seg_array() crop = (p_slice, i_slice, r_slice) seg_bin_clean_arr = np.zeros(seg_bin_arr.shape) seg_bin_clean_arr[crop] = 1 + # everything below biggest k get always removed + largest_k_arr = seg_bin_largest_k_cc_nii.get_seg_array() + seg_bin_clean_arr[largest_k_arr == 0] = 0 + seg_arr = seg_nii.get_seg_array() # logger.print(seg_nii.volumes()) seg_arr[seg_bin_clean_arr != 1] = 0 diff --git a/spineps/seg_pipeline.py b/spineps/seg_pipeline.py index 17f074d..4b46853 100755 --- a/spineps/seg_pipeline.py +++ b/spineps/seg_pipeline.py @@ -1,5 +1,6 @@ # from utils.predictor import nnUNetPredictor import subprocess +from typing import Any from scipy.ndimage import center_of_mass from TPTBox import NII, Location, No_Logger, Zooms, v_name2idx @@ -8,8 +9,7 @@ from spineps.seg_model import Segmentation_Model -logger = No_Logger() -logger.override_prefix = "SPINEPS" +logger = No_Logger(prefix="SPINEPS") fill_holes_labels = [ Location.Vertebra_Corpus_border.value, @@ -37,6 +37,7 @@ def predict_centroids_from_both( vert_nii_cleaned: NII, seg_nii: NII, models: list[Segmentation_Model], + parameter: dict[str, Any], input_zms_pir: Zooms | None = None, ): """Calculates the centroids of each vertebra corpus by using both semantic and instance mask @@ -70,6 +71,8 @@ def predict_centroids_from_both( ctd.info["models"] = models_repr ctd.info["revision"] = pipeline_revision() ctd.info["timestamp"] = format_time_short(get_time()) + for pname, pvalue in parameter.items(): + ctd.info[pname] = str(pvalue) return ctd diff --git a/spineps/seg_run.py b/spineps/seg_run.py index 808ea9c..1e5bafd 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -48,15 +48,24 @@ def process_dataset( override_ctd: bool = False, snapshot_copy_folder: Path | None | bool = None, # - do_crop_semantic: bool = True, - # - proc_n4correction: bool = True, - proc_fillholes: bool = True, - proc_clean: bool = True, - proc_corpus_clean: bool = True, - proc_cleanvert: bool = True, + pad_size: int = 4, + # Processings + # Semantic + proc_sem_crop_input: bool = True, + proc_sem_n4_bias_correction: bool = True, + proc_sem_remove_inferior_beyond_canal: bool = False, + proc_sem_clean_beyond_largest_bounding_box: bool = True, + proc_sem_clean_small_cc_artifacts: bool = True, + # Instance + proc_inst_corpus_clean: bool = True, + proc_inst_clean_small_cc_artifacts: bool = True, + proc_inst_largest_k_cc: int = 0, + proc_inst_detect_and_solve_merged_corpi: bool = True, + # Both + proc_fill_3d_holes: bool = True, proc_assign_missing_cc: bool = True, - proc_largest_cc: int = 0, + proc_clean_inst_by_sem: bool = True, + proc_vertebra_inconsistency: bool = True, # ignore_model_compatibility: bool = False, ignore_inference_compatibility: bool = False, @@ -187,17 +196,25 @@ def process_dataset( override_postpair=override_postpair, override_ctd=override_ctd, # - do_crop_semantic=do_crop_semantic, - proc_n4correction=proc_n4correction, - proc_fillholes=proc_fillholes, + proc_pad_size=pad_size, + # + proc_sem_crop_input=proc_sem_crop_input, + proc_sem_n4_bias_correction=proc_sem_n4_bias_correction, + proc_fill_3d_holes=proc_fill_3d_holes, + proc_sem_remove_inferior_beyond_canal=proc_sem_remove_inferior_beyond_canal, + proc_sem_clean_beyond_largest_bounding_box=proc_sem_clean_beyond_largest_bounding_box, # - proc_clean=proc_clean, - proc_corpus_clean=proc_corpus_clean, - proc_cleanvert=proc_cleanvert, + proc_sem_clean_small_cc_artifacts=proc_sem_clean_small_cc_artifacts, + proc_inst_detect_and_solve_merged_corpi=proc_inst_detect_and_solve_merged_corpi, + proc_inst_corpus_clean=proc_inst_corpus_clean, + proc_inst_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, proc_assign_missing_cc=proc_assign_missing_cc, - proc_largest_cc=proc_largest_cc, + proc_inst_largest_k_cc=proc_inst_largest_k_cc, + proc_clean_inst_by_sem=proc_clean_inst_by_sem, + proc_vertebra_inconsistency=proc_vertebra_inconsistency, # snapshot_copy_folder=snapshot_copy_folder, + ignore_bids_filter=ignore_bids_filter, ignore_compatibility_issues=ignore_inference_compatibility, log_inference_time=log_inference_time, verbose=verbose, @@ -248,19 +265,28 @@ def process_img_nii( # noqa: C901 override_postpair: bool = False, override_ctd: bool = False, # - do_crop_semantic: bool = True, - proc_n4correction: bool = True, - proc_fillholes: bool = True, - # - proc_clean: bool = True, - proc_corpus_clean: bool = True, - proc_cleanvert: bool = True, + proc_pad_size: int = 4, + # Processings + # Semantic + proc_sem_crop_input: bool = True, + proc_sem_n4_bias_correction: bool = True, + proc_sem_remove_inferior_beyond_canal: bool = False, + proc_sem_clean_beyond_largest_bounding_box: bool = True, + proc_sem_clean_small_cc_artifacts: bool = True, + # Instance + proc_inst_corpus_clean: bool = True, + proc_inst_clean_small_cc_artifacts: bool = True, + proc_inst_largest_k_cc: int = 0, + proc_inst_detect_and_solve_merged_corpi: bool = True, + # Both + proc_fill_3d_holes: bool = True, proc_assign_missing_cc: bool = True, - proc_largest_cc: int = 0, - process_vertebra_inconsistency: bool = True, + proc_clean_inst_by_sem: bool = True, + proc_vertebra_inconsistency: bool = True, # lambda_semantic: Callable[[NII], NII] | None = None, snapshot_copy_folder: Path | None = None, + ignore_bids_filter: bool = False, ignore_compatibility_issues: bool = False, log_inference_time: bool = True, verbose: bool = False, @@ -305,9 +331,12 @@ def process_img_nii( # noqa: C901 Returns: ErrCode: Error code depicting whether the operation was successful or not """ + arguments = locals() input_format = img_ref.format - output_paths = output_paths_from_input(img_ref, derivative_name, snapshot_copy_folder, input_format=input_format) + output_paths = output_paths_from_input( + img_ref, derivative_name, snapshot_copy_folder, input_format=input_format, non_strict_mode=ignore_bids_filter + ) out_spine = output_paths["out_spine"] out_spine_raw = output_paths["out_spine_raw"] out_vert = output_paths["out_vert"] @@ -358,7 +387,7 @@ def process_img_nii( # noqa: C901 input_nii = img_ref.open_nii() input_package = InputPackage( input_nii, - pad_size=4, + pad_size=proc_pad_size, ) logger.print("Input image", input_nii.zoom, input_nii.orientation, input_nii.shape) @@ -371,21 +400,24 @@ def process_img_nii( # noqa: C901 input_nii, pad_size=input_package.pad_size, debug_data=debug_data_run, - do_crop=do_crop_semantic, - do_n4=proc_n4correction, + proc_crop_input=proc_sem_crop_input, + proc_do_n4_bias_correction=proc_sem_n4_bias_correction, verbose=verbose, ) if errcode != ErrCode.OK: logger.print("Got Error from preprocessing", Log_Type.FAIL) return output_paths, errcode # make subreg mask + assert input_preprocessed is not None seg_nii_modelres, unc_nii, softmax_logits, errcode = predict_semantic_mask( input_preprocessed, model_semantic, debug_data=debug_data_run, verbose=verbose, - fill_holes=proc_fillholes, - clean_artifacts=proc_clean, + proc_fill_3d_holes=proc_fill_3d_holes, + proc_clean_small_cc_artifacts=proc_sem_clean_small_cc_artifacts, + proc_clean_beyond_largest_bounding_box=proc_sem_clean_beyond_largest_bounding_box, + proc_remove_inferior_beyond_canal=proc_sem_remove_inferior_beyond_canal, ) if errcode != ErrCode.OK: return output_paths, errcode @@ -423,10 +455,11 @@ def process_img_nii( # noqa: C901 model_instance, debug_data=debug_data_run, verbose=verbose, - fill_holes=proc_fillholes, - proc_corpus_clean=proc_corpus_clean, - proc_cleanvert=proc_cleanvert, - proc_largest_cc=proc_largest_cc, + proc_inst_fill_3d_holes=proc_fill_3d_holes, + proc_detect_and_solve_merged_corpi=proc_inst_detect_and_solve_merged_corpi, + proc_corpus_clean=proc_inst_corpus_clean, + proc_inst_clean_small_cc_artifacts=proc_inst_clean_small_cc_artifacts, + proc_inst_largest_k_cc=proc_inst_largest_k_cc, ) if errcode != ErrCode.OK: logger.print(f"Vert Mask creation failed with errcode {errcode}", Log_Type.FAIL) @@ -457,8 +490,9 @@ def process_img_nii( # noqa: C901 vert_nii=whole_vert_nii, debug_data=debug_data_run, labeling_offset=1, + proc_clean_inst_by_sem=proc_clean_inst_by_sem, proc_assign_missing_cc=proc_assign_missing_cc, - process_vertebra_inconsistency=process_vertebra_inconsistency, + proc_vertebra_inconsistency=proc_vertebra_inconsistency, verbose=verbose, ) @@ -479,6 +513,7 @@ def process_img_nii( # noqa: C901 vert_nii_clean, seg_nii_clean, models=[model_semantic, model_instance], + parameter={l: v for l, v in arguments.items() if "proc_" in l}, input_zms_pir=input_package.zms_pir, ) ctd.rescale(input_package._zms, verbose=logger).reorient(input_package._orientation).save(out_ctd, verbose=logger) @@ -516,28 +551,60 @@ def process_img_nii( # noqa: C901 return output_paths, ErrCode.OK -def output_paths_from_input(img_ref: BIDS_FILE, derivative_name: str, snapshot_copy_folder: Path | None, input_format: str): - out_spine = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) - out_vert = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "vert", "mod": img_ref.format}) - out_snap = img_ref.get_changed_path(format="snp", file_type="png", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) - out_ctd = img_ref.get_changed_path(format="ctd", file_type="json", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) +def output_paths_from_input( + img_ref: BIDS_FILE, + derivative_name: str, + snapshot_copy_folder: Path | None, + input_format: str, + non_strict_mode: bool = False, +): + out_spine = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) + out_vert = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "vert", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) + out_snap = img_ref.get_changed_path( + bids_format="snp", + file_type="png", + parent=derivative_name, + info={"seg": "spine", "mod": img_ref.format}, + non_strict_mode=non_strict_mode, + ) + out_ctd = img_ref.get_changed_path( + bids_format="ctd", + file_type="json", + parent=derivative_name, + info={"seg": "spine", "mod": img_ref.format}, + non_strict_mode=non_strict_mode, + ) out_snap2 = snapshot_copy_folder.joinpath(out_snap.name) if snapshot_copy_folder is not None else out_snap # out_debug = out_vert.parent.joinpath(f"debug_{input_format}") # out_raw = out_vert.parent.joinpath(f"output_raw_{input_format}") # - out_spine_raw = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "spine-raw", "mod": img_ref.format}) + out_spine_raw = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "spine-raw", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) out_spine_raw = out_raw.joinpath(out_spine_raw.name) # - out_vert_raw = img_ref.get_changed_path(format="msk", parent=derivative_name, info={"seg": "vert-raw", "mod": img_ref.format}) + out_vert_raw = img_ref.get_changed_path( + bids_format="msk", parent=derivative_name, info={"seg": "vert-raw", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) out_vert_raw = out_raw.joinpath(out_vert_raw.name) # - out_unc = img_ref.get_changed_path(format="uncertainty", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}) + out_unc = img_ref.get_changed_path( + bids_format="uncertainty", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format}, non_strict_mode=non_strict_mode + ) out_unc = out_raw.joinpath(out_unc.name) # out_logits = img_ref.get_changed_path( - file_type="npz", format="logit", parent=derivative_name, info={"seg": "spine", "mod": img_ref.format} + file_type="npz", + bids_format="logit", + parent=derivative_name, + info={"seg": "spine", "mod": img_ref.format}, + non_strict_mode=non_strict_mode, ) out_logits = out_raw.joinpath(out_logits.name) return { diff --git a/spineps/utils/mincutmaxflow.py b/spineps/utils/mincutmaxflow.py new file mode 100644 index 0000000..4f561c3 --- /dev/null +++ b/spineps/utils/mincutmaxflow.py @@ -0,0 +1,182 @@ +from typing import Optional, Union + +import cc3d +import networkx as nx +import numpy as np +from scipy.ndimage import binary_dilation, binary_erosion, generate_binary_structure +from TPTBox import NII + + +def mincutmaxflow( + vertebra_nii: NII, + separator_ivd: NII, +) -> NII: + connectivity = 6 + vol = vertebra_nii.get_seg_array() + return vertebra_nii.set_array( + split_cc( + vol=vol, + sep=separator_ivd.get_seg_array() if separator_ivd is not None else None, + connectivity=connectivity, + structure=generate_binary_structure(vol.ndim, connectivity), + min_vol=10, + ) + ) + + +def np_mincutmaxflow( + vertebra_arr: np.ndarray, + separator_ivd_arr: np.ndarray | None, +) -> np.ndarray: + connectivity = 6 + return split_cc( + vol=vertebra_arr, + sep=separator_ivd_arr if separator_ivd_arr is not None else None, + connectivity=connectivity, + structure=generate_binary_structure(vertebra_arr.ndim, connectivity), + min_vol=10, + ) + + +def split_cc( # noqa: C901 + vol: np.ndarray, + sep: np.ndarray | None = None, + connectivity: int = 6, + structure: None | np.ndarray | list[np.ndarray] = None, + min_vol: int | None = None, + max_cut: int | None = None, + max_ignore: int | None = 6, + voxel_dim: np.ndarray | None = None, + add_2d_edges: bool = False, +) -> np.ndarray: + """ + + @param vol: volume which only contains values of one connected component + @param sep: if given, vol is will not be eroded, sep will be dilatgit ed. + @param connectivity: 6 (voxel faces), 18 (+edges), or 26 (+corners) + @param structure: 3d numpy array of type true, with which the volume should be eroded + @param min_vol: minimal size of the both eroded + @param max_cut: + @param voxel_dim: weights along the dimensions (x y and z) for cost function (if not set, all are 1.0) + @param add_2d_edges: if True, not only add edges to left/up/depth,.. also to left+up,up+depth,left+depth,.. + @return: + """ + _, m = cc3d.connected_components(vol, connectivity=connectivity, return_N=True) + if m != 1: + raise Exception(f"volume is separable into {m} parts with {connectivity=} - it should be 1.") # noqa: TRY002 + vol_erode = vol + iterations = 0 + # if sep is not None: + # sep = np.invert(sep) + if isinstance(structure, np.ndarray): + structure = [structure] + while True: + # vol_erode_old = vol_erode + structure_acctual = structure[iterations % len(structure)] if structure is not None else None + if sep is not None: + sep = binary_dilation(sep, structure=structure_acctual) + # vol_erode = vol & sep + vol_erode = np.where(sep, 0, vol) + else: + vol_erode = binary_erosion(vol_erode, structure=structure_acctual) + cc_erode, m = cc3d.connected_components(vol_erode, connectivity=connectivity, return_N=True) + iterations += 1 + if m > 1 and max_ignore is not None: + res = np.unique(cc_erode, return_counts=True) + max_errors = sum([s for x, s in zip(*res, strict=True) if x > 0 and s <= max_ignore]) + idxs = [x for x, s in zip(*res, strict=True) if x > 0 and s > max_ignore] + m = len(idxs) + if m > 2: + break # should result in error "erosion with struc.." + if m == 2: + cc_erode_ = np.zeros(cc_erode.shape, dtype=cc_erode.dtype) + cc_erode_[cc_erode == idxs[0]] = 1 + cc_erode_[cc_erode == idxs[1]] = 2 + cc_erode = cc_erode_ + break + # otherwise, contiue! + if m == 0: + raise Exception( # noqa: TRY002 + f"cannot split volume into two parts after {iterations} iterations, all values are 0 after erosion." + ) + if m > 2: + raise Exception( # noqa: TRY002 + f"erosion with struture {structure} leads to {m} separate connected components after {iterations} isterations, expect 2." + ) + + S = cc_erode == 1 # noqa: N806 + T = cc_erode == 2 # noqa: N806 + G_ = vol ^ vol_erode # noqa: N806 + + if min_vol is not None and S.sum() < min_vol: + raise Exception( # noqa: TRY002 + f"after erosion for split, volume of one structure is {S.sum()} which is smaller than the accepted size {min_vol}." + ) + if min_vol is not None and T.sum() < min_vol: + raise Exception( # noqa: TRY002 + f"after erosion for split, volume of one structure is {T.sum()} which is smaller than the accepted size {min_vol}." + ) + if voxel_dim is None: + voxel_dim = np.ones([3]) + capacity_end = 1000 + else: + # this is max(x*y,y*z,x*z)*1000 + capacity_end = np.prod(voxel_dim) / np.min(voxel_dim) * 1000 + + S_dil = binary_dilation(S, structure=structure_acctual) # noqa: N806 + T_dil = binary_dilation(T, structure=structure_acctual) # noqa: N806 + to_S = np.argwhere(S_dil & G_) # noqa: N806 + to_T = np.argwhere(T_dil & G_) # noqa: N806 + if len(to_T) == 0 or len(to_S) == 0: + raise Exception("no connection between separated objects and remaining vertices found") # noqa: TRY002 + + G = nx.Graph() # noqa: N806 + import itertools + + for x, y, z in itertools.product([0, 1], repeat=3): + vec = np.array([x, y, z]) + xe, ye, ze = np.array(G_.shape) - vec + + def add_edges(points, diff1, diff2, cap): + """Add edge from point + diff1 to point + diff2 for each point in points. Must be tuples because + numpy arrays are note hashable, and nodes in the graph have to be hashable.""" + G.add_edges_from([*zip(map(tuple, points + diff1), map(tuple, points + diff2), strict=False)], capacity=cap) + + if x + y + z == 1: + # calculate the product of the both dimension, which are + # for x=1, y=z=0 it is y*z + capacity = np.prod(voxel_dim[vec == 0]) + add_edges(np.argwhere(G_[x:, y:, z:] & G_[:xe, :ye, :ze]), [0, 0, 0], [x, y, z], capacity) + if x + y + z == 2 and add_2d_edges: + # calculate a plane diagonal in two dimension and direct into the other dimension + # for x=y=1, z=0 it is sqrt(x^2+y^2)*z + capacity = voxel_dim[vec == 0] * np.linalg.norm(voxel_dim[vec == 1]) + add_edges(np.argwhere(G_[x:, y:, z:] & G_[:xe, :ye, :ze]), [0, 0, 0], [x, y, z], capacity) + if x == 1: + add_edges(np.argwhere(G_[:xe, y:, z:] & G_[x:, :ye, :ze]), [x, 0, 0], [0, y, z], capacity) + else: + add_edges(np.argwhere(G_[x:, :ye, z:] & G_[:xe, y:, :ze]), [0, y, 0], [x, 0, z], capacity) + + G.add_edges_from([((x, y, z), "t") for x, y, z in to_T], capacity=capacity_end) + G.add_edges_from([("s", (x, y, z)) for x, y, z in to_S], capacity=capacity_end) + + if not nx.has_path(G, "s", "t"): + raise Exception("no path exists from s to t") # noqa: TRY002 + + cut_value, (s_idx, t_idx) = nx.minimum_cut(G, "s", "t") + if max_cut is not None and cut_value > max_cut: + raise Exception(f"cut size is {cut_value} whereas maximal cut of {max_cut} is allowed") # noqa: TRY002 + # print(f"{cut_value=}") + + s_idx.remove("s") + t_idx.remove("t") + if len(s_idx) == 0 or len(t_idx) == 0: + raise Exception("vertices of one side are empty - this should not happen") # noqa: TRY002 + cc_erode[tuple(np.asarray(list(s_idx)).reshape([-1, 3]).transpose())] = 1 + cc_erode[tuple(np.asarray(list(t_idx)).reshape([-1, 3]).transpose())] = 2 + + lost = np.abs((cc_erode > 0).sum() - vol.sum()) + if lost > max_errors: + raise Exception(f"lost {lost} points while separating but only {max_errors} losts allowed") # noqa: TRY002 + return cc_erode + # print(cc_erode[(a + b) // 2]) diff --git a/unit_tests/test_semantic.py b/unit_tests/test_semantic.py index f64bb42..c4a20ce 100644 --- a/unit_tests/test_semantic.py +++ b/unit_tests/test_semantic.py @@ -79,7 +79,9 @@ def test_segment_scan(self): model = Segmentation_Model_Dummy() model.run = MagicMock(return_value={OutputType.seg: subreg, OutputType.softmax_logits: None}) debug_data = {} - seg_nii, unc_nii, softmax_logits, errcode = predict_semantic_mask(mri, model, debug_data=debug_data, verbose=True) + seg_nii, unc_nii, softmax_logits, errcode = predict_semantic_mask( + mri, model, debug_data=debug_data, verbose=True, proc_clean_small_cc_artifacts=False + ) predicted_volumes = seg_nii.volumes() ref_volumes = subreg.volumes() print(predicted_volumes) From a773398ea166c39cdfa3bc83aaf585cc58b4ef76 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 5 Jun 2024 06:59:06 +0000 Subject: [PATCH 2/5] added some arguments --- spineps/utils/mincutmaxflow.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/spineps/utils/mincutmaxflow.py b/spineps/utils/mincutmaxflow.py index 4f561c3..51f11b7 100644 --- a/spineps/utils/mincutmaxflow.py +++ b/spineps/utils/mincutmaxflow.py @@ -10,8 +10,10 @@ def mincutmaxflow( vertebra_nii: NII, separator_ivd: NII, + connectivity: int = 1, ) -> NII: - connectivity = 6 + assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}" + connectivity = min(connectivity * 2, 8) if vertebra_nii.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 vol = vertebra_nii.get_seg_array() return vertebra_nii.set_array( split_cc( @@ -20,6 +22,7 @@ def mincutmaxflow( connectivity=connectivity, structure=generate_binary_structure(vol.ndim, connectivity), min_vol=10, + voxel_dim=np.asarray(vertebra_nii.zoom), ) ) @@ -27,14 +30,18 @@ def mincutmaxflow( def np_mincutmaxflow( vertebra_arr: np.ndarray, separator_ivd_arr: np.ndarray | None, + connectivity: int = 1, + zoom: np.ndarray | None = None, ) -> np.ndarray: - connectivity = 6 + assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}" + connectivity = min(connectivity * 2, 8) if vertebra_arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 return split_cc( vol=vertebra_arr, sep=separator_ivd_arr if separator_ivd_arr is not None else None, connectivity=connectivity, structure=generate_binary_structure(vertebra_arr.ndim, connectivity), min_vol=10, + voxel_dim=np.asarray(zoom) if zoom is not None else None, ) @@ -47,12 +54,12 @@ def split_cc( # noqa: C901 max_cut: int | None = None, max_ignore: int | None = 6, voxel_dim: np.ndarray | None = None, - add_2d_edges: bool = False, + add_2d_edges: bool = True, ) -> np.ndarray: """ @param vol: volume which only contains values of one connected component - @param sep: if given, vol is will not be eroded, sep will be dilatgit ed. + @param sep: if given, vol is will not be eroded, sep will be dilated. @param connectivity: 6 (voxel faces), 18 (+edges), or 26 (+corners) @param structure: 3d numpy array of type true, with which the volume should be eroded @param min_vol: minimal size of the both eroded @@ -138,6 +145,7 @@ def split_cc( # noqa: C901 xe, ye, ze = np.array(G_.shape) - vec def add_edges(points, diff1, diff2, cap): + # print(f"Add edge {points}, {diff1}, {diff2}, {cap}") """Add edge from point + diff1 to point + diff2 for each point in points. Must be tuples because numpy arrays are note hashable, and nodes in the graph have to be hashable.""" G.add_edges_from([*zip(map(tuple, points + diff1), map(tuple, points + diff2), strict=False)], capacity=cap) @@ -178,5 +186,5 @@ def add_edges(points, diff1, diff2, cap): lost = np.abs((cc_erode > 0).sum() - vol.sum()) if lost > max_errors: raise Exception(f"lost {lost} points while separating but only {max_errors} losts allowed") # noqa: TRY002 - return cc_erode + return cc_erode, (S, T, G_, S_dil, T_dil, G) # print(cc_erode[(a + b) // 2]) From fa53744e0d8367f3b7ee9895cfd7ab20d7ba1815 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 5 Jun 2024 07:00:58 +0000 Subject: [PATCH 3/5] added fix_wrong_posterior_instance_label --- spineps/utils/proc_functions.py | 58 +++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/spineps/utils/proc_functions.py b/spineps/utils/proc_functions.py index a12d404..11815dc 100755 --- a/spineps/utils/proc_functions.py +++ b/spineps/utils/proc_functions.py @@ -2,8 +2,8 @@ import numpy as np from ants.utils.convert_nibabel import from_nibabel from scipy.ndimage import center_of_mass -from TPTBox import NII, Logger_Interface -from TPTBox.core.np_utils import np_bbox_binary, np_count_nonzero, np_dilate_msk, np_unique, np_volume +from TPTBox import NII, Location, Logger_Interface +from TPTBox.core.np_utils import np_bbox_binary, np_count_nonzero, np_dilate_msk, np_unique, np_unique_withoutzero, np_volume from tqdm import tqdm @@ -181,3 +181,57 @@ def connected_components_3d(mask_image: np.ndarray, connectivity: int = 3, verbo if (n) != 1: # zero is a label print(f"subreg {subreg} does not have one CC (not counting zeros), got {n}") if verbose else None return subreg_cc, subreg_cc_stats + + +def fix_wrong_posterior_instance_label(seg_sem: NII, seg_inst: NII, logger) -> NII: + seg_sem = seg_sem.copy() + seg_inst = seg_inst.copy() + orientation = seg_sem.orientation + seg_sem.assert_affine(other=seg_inst) + seg_sem.reorient_() + seg_inst.reorient_() + + seg_inst_arr_proc = seg_inst.get_seg_array() + + instance_labels = [i for i in seg_inst.unique() if 1 <= i <= 25] + + for vert in instance_labels: + inst_vert = seg_inst.extract_label(vert) + # sem_vert = seg_sem.apply_mask(inst_vert) + + # Check if multiple CC exist + inst_vert_cc = inst_vert.get_largest_k_segmentation_connected_components(3, return_original_labels=False) + inst_vert_cc_n = int(inst_vert_cc.max()) + # + if inst_vert_cc_n == 1: + continue + # + # inst_vert_cc is labeled 1 to 3 + for i in range(2, inst_vert_cc_n + 1): + inst_vert_cc_i = inst_vert_cc.extract_label(i) + + crop = inst_vert_cc_i.compute_crop(dist=1) + inst_vert_cc_i_c = inst_vert_cc_i.apply_crop(crop) + + cc_sem_vert = seg_sem.apply_crop(crop).apply_mask(inst_vert_cc_i_c) + # cc_vert is semantic mask of only that cc of instance + + cc_sem_vert_labels = cc_sem_vert.unique() + # is that cc only arcus and spinosus? + if len(cc_sem_vert_labels) <= 2 and np.all( + [i in [Location.Arcus_Vertebrae.value, Location.Spinosus_Process.value] for i in cc_sem_vert_labels] + ): + # neighbor that have non arcus/spinosus label? + neighbor_instance_labels = seg_inst.apply_crop(crop).get_seg_array() + neighbor_instance_labels[inst_vert_cc_i_c.get_seg_array() == 1] = 0 + neighbor_instance_labels = np_unique_withoutzero(neighbor_instance_labels) + # which instance labels does it touch + logger.print(f"vert {vert}, cc_k {i} has instance neighbors {neighbor_instance_labels}") + # is it touching only one other instance label? + if len(neighbor_instance_labels) == 1 and neighbor_instance_labels[0] != vert: + to_label = neighbor_instance_labels[0] + logger.print(f"vert {vert}, cc_k {i} relabel to instance {to_label}") + seg_inst_arr_proc[inst_vert_cc_i.get_seg_array() == 1] = to_label + + seg_inst_proc = seg_inst.set_array(seg_inst_arr_proc).reorient_(orientation) + return seg_inst_proc From 7a6a7e1f99db44b104c50259d7a664748fa80c7f Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 5 Jun 2024 07:01:50 +0000 Subject: [PATCH 4/5] added fix_wrong_posterior_instance_label to phase_post --- spineps/phase_post.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spineps/phase_post.py b/spineps/phase_post.py index 8c41a9e..94f3a43 100644 --- a/spineps/phase_post.py +++ b/spineps/phase_post.py @@ -18,6 +18,7 @@ ) from spineps.seg_pipeline import logger, vertebra_subreg_labels +from spineps.utils.proc_functions import fix_wrong_posterior_instance_label def phase_postprocess_combined( @@ -30,6 +31,7 @@ def phase_postprocess_combined( n_vert_bodies: int | None = None, process_merge_vertebra: bool = True, proc_vertebra_inconsistency: bool = True, + proc_assign_posterior_instance_label: bool = True, verbose: bool = False, ) -> tuple[NII, NII]: logger.print("Post process", Log_Type.STAGE) @@ -67,6 +69,9 @@ def phase_postprocess_combined( if process_merge_vertebra and Location.Vertebra_Disc.value in seg_nii_cleaned.unique(): detect_and_solve_merged_vertebra(seg_nii_cleaned, whole_vert_nii_cleaned) + if proc_assign_posterior_instance_label: + whole_vert_nii_cleaned = fix_wrong_posterior_instance_label(seg_nii_cleaned, seg_inst=whole_vert_nii_cleaned, logger=logger) + if proc_vertebra_inconsistency: # Assigns superior/inferior based on instance label overlap assign_vertebra_inconsistency(seg_nii_cleaned, whole_vert_nii_cleaned) From e4926082d9449e8512a04be057ef12c948fa5b70 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 5 Jun 2024 07:08:25 +0000 Subject: [PATCH 5/5] replaced mincutmaxflow with my own plane splitting algorithm, seems to work decent enough for corpus COM detection --- spineps/phase_instance.py | 177 ++++++++++++++++++++++++++++-- spineps/utils/mincutmaxflow.py | 190 --------------------------------- 2 files changed, 171 insertions(+), 196 deletions(-) delete mode 100644 spineps/utils/mincutmaxflow.py diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index c136bcc..3c41e30 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -4,9 +4,13 @@ from TPTBox.core.np_utils import ( np_calc_crop_around_centerpoint, np_center_of_mass, + np_connected_components, np_count_nonzero, np_dice, + np_dilate_msk, + np_erode_msk, np_extract_label, + np_get_largest_k_connected_components, np_unique, np_volume, ) @@ -15,7 +19,6 @@ from spineps.seg_enums import ErrCode, OutputType from spineps.seg_model import Segmentation_Model from spineps.seg_pipeline import logger -from spineps.utils.mincutmaxflow import np_mincutmaxflow from spineps.utils.proc_functions import clean_cc_artifacts @@ -201,14 +204,14 @@ def get_corpus_coms( return corpus_coms ############ - # Detect merged vertebra and use mincutmaxflow algo on it + # Detect merged vertebra and use plane split algo on it ############ # Get corpus CCs corpus_cc, corpus_cc_n = corpus_nii.get_segmentation_connected_components(labels=1) corpus_cc = corpus_cc[1] corpus_cc_n = corpus_cc_n[1] - logger.print(f"Found {corpus_cc_n} Corpus ccs", verbose=verbose) + logger.print(f"Found {corpus_cc_n} Corpus ccs (naively)", verbose=verbose) # Check against ivd order seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) @@ -216,7 +219,7 @@ def get_corpus_coms( subreg_cc = subreg_cc[Location.Vertebra_Disc.value] subreg_cc[subreg_cc > 0] += 100 subreg_cc_n = subreg_cc_n[Location.Vertebra_Disc.value] - logger.print(f"Found {subreg_cc_n} IVD ccs", verbose=verbose) + logger.print(f"Found {subreg_cc_n} IVD ccs (naively)", verbose=verbose) coms = np_center_of_mass(subreg_cc) volumes = np_volume(subreg_cc) stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()} @@ -290,17 +293,179 @@ def get_corpus_coms( target_vert_id = list(neighbor_verts.keys())[argmax] segvert = np_extract_label(corpus_cc, target_vert_id, inplace=False) try: - split_vert = np_mincutmaxflow(segvert, None) + logger.print("get_separating_components to split vertebra", verbose=verbose) + (spart, tpart, spart_dil, tpart_dil, stpart) = get_separating_components( + segvert, + connectivity=3, + ) + + logger.print("Splitting by plane") + plane_split_nii = get_plane_split(segvert, corpus_nii, spart, tpart, spart_dil, tpart_dil) + split_vert = split_by_plane(segvert, plane_split_nii) corpus_cc[split_vert == 2] = corpus_cc.max() + 1 + + seg_nii.set_array(corpus_cc).save( + "/DATA/NAS/ongoing_projects/hendrik/mri_usage/nako_fixmerge_test/derivatives_seg/mincutmaxflow_array.nii.gz" + ) except Exception as e: - logger.print(f"MinCutMaxFlow failed with exception {e}") + logger.print(f"Separating Corpi failed with exception {e}") corpus_coms = list(np_center_of_mass(corpus_cc).values()) corpus_coms.sort(key=lambda a: a[1]) corpus_coms.reverse() # from bottom to top + logger.print(f"Found {len(corpus_coms)} final Corpus ccs", verbose=verbose) return corpus_coms +def get_separating_components( + segvert: np.ndarray, + max_iter: int = 10, + connectivity: int = 3, +): + check_connectivtiy = 3 + vol = segvert.copy() + vol_old = vol.copy() + iterations = 0 + while True: + vol_erode = np_erode_msk(vol, mm=1, connectivity=connectivity) + subreg_cc, subreg_cc_n = np_connected_components(vol_erode, connectivity=check_connectivtiy) + if 1 in subreg_cc_n and subreg_cc_n[1] > 1: + vol = subreg_cc[1] + break + elif 1 not in subreg_cc_n: + vol_dilated = np_dilate_msk(vol, mm=1, connectivity=connectivity, mask=vol.copy()) + # use iteration before to get other CC + vol[vol_old != 0] = 2 # all possible voxels are 2 + vol[vol_dilated == 1] = 1 + + if 2 not in np_volume(vol): + raise Exception( # noqa: TRY002 + f"cannot split volume into two parts after {iterations} iterations, all values are 0 after erosion." + ) + volume = np_volume(vol) + dil_iter = 0 + while volume[1] / (volume[1] + volume[2]) < 0.5: + vol_dilated = np_dilate_msk(vol_dilated, mm=1, connectivity=connectivity, mask=vol.copy()) + # inst_nii.set_array(vol_dilated).save(files_out + f"subreg_cc_vol_dilated{dil_iter}.nii.gz") + vol[vol_dilated == 1] = 1 + # inst_nii.set_array(vol).save(files_out + f"subreg_cc_dilation{dil_iter}.nii.gz") + volume = np_volume(vol) + if 1 not in volume or 2 not in volume: + raise Exception("Could not divide into two instance") # noqa: TRY002 + dil_iter += 1 + + vol_1 = np_get_largest_k_connected_components(vol == 1, k=1, connectivity=check_connectivtiy).astype(np.uint8) + vol_2 = np_get_largest_k_connected_components(vol == 2, k=1, connectivity=check_connectivtiy).astype(np.uint8) + vol_2 *= 2 + vol[vol == 1] = vol_1[vol == 1] + vol[vol == 2] = vol_2[vol == 2] + break + vol_old = vol + vol = vol_erode + iterations += 1 + if iterations > max_iter: + raise Exception(f"Could not divide into two instance after max iterations {max_iter}") # noqa: TRY002 + if len(np_volume(vol)) != 2: + logger.print("Get largest two components") + subreg_cc_2k = np_get_largest_k_connected_components(vol, k=2, connectivity=check_connectivtiy, return_original_labels=False) + spart = subreg_cc_2k == 1 + tpart = subreg_cc_2k == 2 + else: + spart = vol == 1 + tpart = vol == 2 + + if spart.sum() == 0 or tpart.sum() == 0: + raise Exception("S or T are empty") # noqa: TRY002 + + spart_dil = np_dilate_msk(spart, mm=1, connectivity=connectivity) + tpart_dil = np_dilate_msk(tpart, mm=1, connectivity=connectivity) + stpart = (spart_dil + (tpart_dil * 2)).astype(np.uint8) + while 3 not in np_volume(stpart): + spart_dil = np_dilate_msk(spart_dil, mm=1, connectivity=connectivity) + stpart = (spart_dil + (tpart_dil * 2)).astype(np.uint8) + if 3 in np_volume(stpart): + break + tpart_dil = np_dilate_msk(tpart_dil, mm=1, connectivity=connectivity) + stpart = (spart_dil + (tpart_dil * 2)).astype(np.uint8) + # + stpart = spart_dil + (tpart_dil * 2) + return spart, tpart, spart_dil, tpart_dil, stpart + + +def get_plane_split( + segvert: np.ndarray, + compare_nii: NII, + spart: np.ndarray, + tpart: np.ndarray, + spart_dil: np.ndarray, + tpart_dil: np.ndarray, +): + s_dilint = spart_dil.astype(np.uint8) + t_dilint = tpart_dil.astype(np.uint8) + # + collision_arr = s_dilint + t_dilint + if 2 not in collision_arr: + logger.print("S and T dilated do not touch each other error") + return compare_nii.set_array(np.zeros_like(segvert)) + collision_arr[collision_arr != 2] = 0 + collision_arr[collision_arr == 2] = 1 + collision_point = np_center_of_mass(collision_arr)[1] + # TODO instead of COM, calc COM of S and T, make vector and use merge point along that vector as collision point? should be more accurate + # + normal_vector = np_center_of_mass(spart.astype(np.uint8))[1] - np_center_of_mass(tpart.astype(np.uint8))[1] + normalized_normal = normal_vector / np.linalg.norm(normal_vector) + # + axis = np.argmax(np.abs(normalized_normal)) + dims = [0, 1, 2] + dims.remove(axis) + dim1, dim2 = dims + # + shift_total = -collision_point.dot(normal_vector) + xx, yy = np.meshgrid(range(collision_arr.shape[dim1]), range(collision_arr.shape[dim2])) + zz = (-normal_vector[dim1] * xx - normal_vector[dim2] * yy - shift_total) * 1.0 / normal_vector[axis] + z_max = collision_arr.shape[axis] - 1 + zz[zz < 0] = 0 + zz[zz > z_max] = z_max + # make cords to array again + plane_coords = np.zeros([xx.shape[0], xx.shape[1], 3], dtype=int) + plane_coords[:, :, axis] = zz + plane_coords[:, :, dim1] = xx + plane_coords[:, :, dim2] = yy + + plane = segvert * 0 + plane[plane_coords[:, :, 0], plane_coords[:, :, 1], plane_coords[:, :, 2]] = 1 + plane_nii = compare_nii.set_array(plane) + + plane_filled_nii = plane_nii.copy() + orientation = plane_filled_nii.orientation + plane_filled_nii.reorient_(("S", "A", "R")) + plane_filled_arr = plane_filled_nii.get_array() + x_slice = np.ones_like(plane_filled_arr[0]) * np.max(plane_filled_arr) + 1 + # plane_filled[:, 0, :] = 2 + # plane_filled[:, -1, :] = 1 + for i in range(plane_filled_arr.shape[0]): + curr_slice = plane_filled_arr[i] + cond = np.where(curr_slice != 0) + x_slice[cond] = np.minimum(curr_slice[cond], x_slice[cond]) + plane_filled_arr[i] = x_slice + + plane_filled_nii.set_array_(plane_filled_arr).reorient_(orientation) + return plane_filled_nii + + +def split_by_plane( + segvert: np.ndarray, + plane_filled_nii: NII, +): + plane_filled_arr = plane_filled_nii.get_array() + # 1 above, 2 below + plane_filled_arr[segvert == 0] = 0 + segvert = plane_filled_arr + # segvert[plane_filled_arr == 1] = 1 + # segvert[plane_filled_arr == 2] = 2 + return segvert + + def collect_vertebra_predictions( seg_nii: NII, model: Segmentation_Model, diff --git a/spineps/utils/mincutmaxflow.py b/spineps/utils/mincutmaxflow.py deleted file mode 100644 index 51f11b7..0000000 --- a/spineps/utils/mincutmaxflow.py +++ /dev/null @@ -1,190 +0,0 @@ -from typing import Optional, Union - -import cc3d -import networkx as nx -import numpy as np -from scipy.ndimage import binary_dilation, binary_erosion, generate_binary_structure -from TPTBox import NII - - -def mincutmaxflow( - vertebra_nii: NII, - separator_ivd: NII, - connectivity: int = 1, -) -> NII: - assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}" - connectivity = min(connectivity * 2, 8) if vertebra_nii.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 - vol = vertebra_nii.get_seg_array() - return vertebra_nii.set_array( - split_cc( - vol=vol, - sep=separator_ivd.get_seg_array() if separator_ivd is not None else None, - connectivity=connectivity, - structure=generate_binary_structure(vol.ndim, connectivity), - min_vol=10, - voxel_dim=np.asarray(vertebra_nii.zoom), - ) - ) - - -def np_mincutmaxflow( - vertebra_arr: np.ndarray, - separator_ivd_arr: np.ndarray | None, - connectivity: int = 1, - zoom: np.ndarray | None = None, -) -> np.ndarray: - assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}" - connectivity = min(connectivity * 2, 8) if vertebra_arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 - return split_cc( - vol=vertebra_arr, - sep=separator_ivd_arr if separator_ivd_arr is not None else None, - connectivity=connectivity, - structure=generate_binary_structure(vertebra_arr.ndim, connectivity), - min_vol=10, - voxel_dim=np.asarray(zoom) if zoom is not None else None, - ) - - -def split_cc( # noqa: C901 - vol: np.ndarray, - sep: np.ndarray | None = None, - connectivity: int = 6, - structure: None | np.ndarray | list[np.ndarray] = None, - min_vol: int | None = None, - max_cut: int | None = None, - max_ignore: int | None = 6, - voxel_dim: np.ndarray | None = None, - add_2d_edges: bool = True, -) -> np.ndarray: - """ - - @param vol: volume which only contains values of one connected component - @param sep: if given, vol is will not be eroded, sep will be dilated. - @param connectivity: 6 (voxel faces), 18 (+edges), or 26 (+corners) - @param structure: 3d numpy array of type true, with which the volume should be eroded - @param min_vol: minimal size of the both eroded - @param max_cut: - @param voxel_dim: weights along the dimensions (x y and z) for cost function (if not set, all are 1.0) - @param add_2d_edges: if True, not only add edges to left/up/depth,.. also to left+up,up+depth,left+depth,.. - @return: - """ - _, m = cc3d.connected_components(vol, connectivity=connectivity, return_N=True) - if m != 1: - raise Exception(f"volume is separable into {m} parts with {connectivity=} - it should be 1.") # noqa: TRY002 - vol_erode = vol - iterations = 0 - # if sep is not None: - # sep = np.invert(sep) - if isinstance(structure, np.ndarray): - structure = [structure] - while True: - # vol_erode_old = vol_erode - structure_acctual = structure[iterations % len(structure)] if structure is not None else None - if sep is not None: - sep = binary_dilation(sep, structure=structure_acctual) - # vol_erode = vol & sep - vol_erode = np.where(sep, 0, vol) - else: - vol_erode = binary_erosion(vol_erode, structure=structure_acctual) - cc_erode, m = cc3d.connected_components(vol_erode, connectivity=connectivity, return_N=True) - iterations += 1 - if m > 1 and max_ignore is not None: - res = np.unique(cc_erode, return_counts=True) - max_errors = sum([s for x, s in zip(*res, strict=True) if x > 0 and s <= max_ignore]) - idxs = [x for x, s in zip(*res, strict=True) if x > 0 and s > max_ignore] - m = len(idxs) - if m > 2: - break # should result in error "erosion with struc.." - if m == 2: - cc_erode_ = np.zeros(cc_erode.shape, dtype=cc_erode.dtype) - cc_erode_[cc_erode == idxs[0]] = 1 - cc_erode_[cc_erode == idxs[1]] = 2 - cc_erode = cc_erode_ - break - # otherwise, contiue! - if m == 0: - raise Exception( # noqa: TRY002 - f"cannot split volume into two parts after {iterations} iterations, all values are 0 after erosion." - ) - if m > 2: - raise Exception( # noqa: TRY002 - f"erosion with struture {structure} leads to {m} separate connected components after {iterations} isterations, expect 2." - ) - - S = cc_erode == 1 # noqa: N806 - T = cc_erode == 2 # noqa: N806 - G_ = vol ^ vol_erode # noqa: N806 - - if min_vol is not None and S.sum() < min_vol: - raise Exception( # noqa: TRY002 - f"after erosion for split, volume of one structure is {S.sum()} which is smaller than the accepted size {min_vol}." - ) - if min_vol is not None and T.sum() < min_vol: - raise Exception( # noqa: TRY002 - f"after erosion for split, volume of one structure is {T.sum()} which is smaller than the accepted size {min_vol}." - ) - if voxel_dim is None: - voxel_dim = np.ones([3]) - capacity_end = 1000 - else: - # this is max(x*y,y*z,x*z)*1000 - capacity_end = np.prod(voxel_dim) / np.min(voxel_dim) * 1000 - - S_dil = binary_dilation(S, structure=structure_acctual) # noqa: N806 - T_dil = binary_dilation(T, structure=structure_acctual) # noqa: N806 - to_S = np.argwhere(S_dil & G_) # noqa: N806 - to_T = np.argwhere(T_dil & G_) # noqa: N806 - if len(to_T) == 0 or len(to_S) == 0: - raise Exception("no connection between separated objects and remaining vertices found") # noqa: TRY002 - - G = nx.Graph() # noqa: N806 - import itertools - - for x, y, z in itertools.product([0, 1], repeat=3): - vec = np.array([x, y, z]) - xe, ye, ze = np.array(G_.shape) - vec - - def add_edges(points, diff1, diff2, cap): - # print(f"Add edge {points}, {diff1}, {diff2}, {cap}") - """Add edge from point + diff1 to point + diff2 for each point in points. Must be tuples because - numpy arrays are note hashable, and nodes in the graph have to be hashable.""" - G.add_edges_from([*zip(map(tuple, points + diff1), map(tuple, points + diff2), strict=False)], capacity=cap) - - if x + y + z == 1: - # calculate the product of the both dimension, which are - # for x=1, y=z=0 it is y*z - capacity = np.prod(voxel_dim[vec == 0]) - add_edges(np.argwhere(G_[x:, y:, z:] & G_[:xe, :ye, :ze]), [0, 0, 0], [x, y, z], capacity) - if x + y + z == 2 and add_2d_edges: - # calculate a plane diagonal in two dimension and direct into the other dimension - # for x=y=1, z=0 it is sqrt(x^2+y^2)*z - capacity = voxel_dim[vec == 0] * np.linalg.norm(voxel_dim[vec == 1]) - add_edges(np.argwhere(G_[x:, y:, z:] & G_[:xe, :ye, :ze]), [0, 0, 0], [x, y, z], capacity) - if x == 1: - add_edges(np.argwhere(G_[:xe, y:, z:] & G_[x:, :ye, :ze]), [x, 0, 0], [0, y, z], capacity) - else: - add_edges(np.argwhere(G_[x:, :ye, z:] & G_[:xe, y:, :ze]), [0, y, 0], [x, 0, z], capacity) - - G.add_edges_from([((x, y, z), "t") for x, y, z in to_T], capacity=capacity_end) - G.add_edges_from([("s", (x, y, z)) for x, y, z in to_S], capacity=capacity_end) - - if not nx.has_path(G, "s", "t"): - raise Exception("no path exists from s to t") # noqa: TRY002 - - cut_value, (s_idx, t_idx) = nx.minimum_cut(G, "s", "t") - if max_cut is not None and cut_value > max_cut: - raise Exception(f"cut size is {cut_value} whereas maximal cut of {max_cut} is allowed") # noqa: TRY002 - # print(f"{cut_value=}") - - s_idx.remove("s") - t_idx.remove("t") - if len(s_idx) == 0 or len(t_idx) == 0: - raise Exception("vertices of one side are empty - this should not happen") # noqa: TRY002 - cc_erode[tuple(np.asarray(list(s_idx)).reshape([-1, 3]).transpose())] = 1 - cc_erode[tuple(np.asarray(list(t_idx)).reshape([-1, 3]).transpose())] = 2 - - lost = np.abs((cc_erode > 0).sum() - vol.sum()) - if lost > max_errors: - raise Exception(f"lost {lost} points while separating but only {max_errors} losts allowed") # noqa: TRY002 - return cc_erode, (S, T, G_, S_dil, T_dil, G) - # print(cc_erode[(a + b) // 2])