diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index 316c3d9..7b9ebbc 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -15,7 +15,7 @@ def predict_instance_mask( seg_nii: NII, model: Segmentation_Model, debug_data: dict, - resample_output_to_input_space: bool = True, + pad_size: int = 2, fill_holes: bool = True, use_height_estimate: bool = False, proc_corpus_clean: bool = True, @@ -47,6 +47,16 @@ def predict_instance_mask( seg_nii_rdy = seg_nii.reorient(verbose=logger) debug_data["inst_uncropped_Subreg_nii_a_PIR"] = seg_nii_rdy.copy() + + # Padding? + if pad_size > 0: + # logger.print(seg_nii_rdy.shape) + arr = seg_nii_rdy.get_array() + arr = np.pad(arr, pad_size, mode="edge") + seg_nii_rdy.set_array_(arr) + # logger.print(seg_nii_rdy.shape) + # + zms = seg_nii_rdy.zoom logger.print("zms", zms, verbose=verbose) expected_zms = model.calc_recommended_resampling_zoom(seg_nii_rdy.zoom) @@ -61,7 +71,7 @@ def predict_instance_mask( logger.print("Vertebra uncropped_vert_mask empty", uncropped_vert_mask.shape, verbose=verbose) # crop = seg_nii_rdy.compute_crop_slice(dist=5) - logger.print("Crop", crop, verbose=verbose) + # logger.print("Crop", crop, verbose=verbose) seg_nii_rdy.apply_crop_slice_(crop) logger.print(f"Crop down from {uncropped_vert_mask.shape} to {seg_nii_rdy.shape}", verbose=verbose) # arr[crop] = X, then set nifty to arr @@ -134,12 +144,18 @@ def predict_instance_mask( whole_vert_nii_uncropped = seg_nii_uncropped.set_array(uncropped_vert_mask) debug_data["inst_uncropped_vert_arr_a"] = whole_vert_nii_uncropped.copy() - if resample_output_to_input_space: - whole_vert_nii_uncropped.rescale_(zms, verbose=verbose) - debug_data["inst_uncropped_vert_arr_b_rescale"] = whole_vert_nii_uncropped.copy() - whole_vert_nii_uncropped.reorient_(orientation, verbose=verbose) - debug_data["inst_uncropped_vert_arr_c_reorient"] = whole_vert_nii_uncropped.copy() - whole_vert_nii_uncropped.pad_to(shp, inplace=True) + # Resample back to input space + whole_vert_nii_uncropped.rescale_(zms, verbose=verbose) + debug_data["inst_uncropped_vert_arr_b_rescale"] = whole_vert_nii_uncropped.copy() + whole_vert_nii_uncropped.reorient_(orientation, verbose=verbose) + debug_data["inst_uncropped_vert_arr_c_reorient"] = whole_vert_nii_uncropped.copy() + if pad_size > 0: + # logger.print(whole_vert_nii_uncropped.shape) + arr = whole_vert_nii_uncropped.get_array() + arr = arr[pad_size:-pad_size, pad_size:-pad_size, pad_size:-pad_size] + whole_vert_nii_uncropped.set_array_(arr) + # logger.print(whole_vert_nii_uncropped.shape) + whole_vert_nii_uncropped.pad_to(shp, inplace=True) return whole_vert_nii_uncropped, ErrCode.OK @@ -176,7 +192,9 @@ def collect_vertebra_predictions( logger.print("No 1 in corpus nifty, cannot make vertebra mask", Log_Type.FAIL) return None, [], 0 - corpus_coms = corpus_nii.get_segmentation_connected_components_center_of_mass(label=1, sort_by_axis=1) + corpus_coms = corpus_nii.get_segmentation_connected_components_center_of_mass( + label=1, sort_by_axis=1 + ) # TODO replace with approx_com by bbox corpus_coms.reverse() # from bottom to top n_corpus_coms = len(corpus_coms) diff --git a/spineps/phase_post.py b/spineps/phase_post.py index 509a092..ff0a27c 100644 --- a/spineps/phase_post.py +++ b/spineps/phase_post.py @@ -87,8 +87,7 @@ def mask_cleaning_other( verbose: bool = False, ) -> tuple[NII, NII]: # make copy where both masks clean each other - vert_nii_cleaned = whole_vert_nii.copy() - vert_arr_cleaned = vert_nii_cleaned.get_seg_array() + vert_arr_cleaned = whole_vert_nii.get_seg_array() subreg_vert_nii = seg_nii.extract_label(vertebra_subreg_labels) subreg_vert_arr = subreg_vert_nii.get_seg_array() # if dilation_fill: @@ -124,7 +123,7 @@ def mask_cleaning_other( elif n_vert_pixels_rel_diff > 0.5: logger.print(f"A volume of {n_vert_pixels_rel_diff} * avg_vertebra_volume in subreg not matched by vertebra mask", Log_Type.WARNING) - return vert_nii_cleaned.set_array_(vert_arr_cleaned), seg_nii.set_array(subreg_arr) + return whole_vert_nii.set_array(vert_arr_cleaned), seg_nii.set_array(subreg_arr) def assign_missing_cc( @@ -144,12 +143,14 @@ def assign_missing_cc( logger.print("No CC had to be assigned", Log_Type.OK, verbose=verbose) return target_arr, reference_arr, deletion_map # subreg_arr_vert_rest is not hit pixels bei vertebra prediction - subreg_cc, _ = np_connected_components(subreg_arr_vert_rest, connectivity=1) + subreg_cc, _ = np_connected_components(subreg_arr_vert_rest, connectivity=2) + loop_counts = 0 # for label, for each cc for label, subreg_cc_map in subreg_cc.items(): if label == 0: continue cc_labels = np.unique(subreg_cc_map)[1:] + loop_counts += len(cc_labels) # print(cc_labels) for cc_l in cc_labels: cc_map = subreg_cc_map.copy() @@ -181,6 +182,7 @@ def assign_missing_cc( deletion_map[cc_bbox][cc_map_c == 1] = 1 # print("vert_arr\n", vert_arr) # print() + logger.print(f"Assign missing cc: Processed {loop_counts} missed ccs") return target_arr, reference_arr, deletion_map @@ -202,7 +204,7 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII): vert_l[subreg_arr != 49] = 0 # com of corpus region vert_l[vert_l != 0] = 1 if np.count_nonzero(vert_l) > 0: - coms_vert_dict[l] = center_of_mass(vert_l)[1] + coms_vert_dict[l] = np_approx_center_of_mass(vert_l, label_ref=1)[1][1] # center_of_mass(vert_l)[1] else: coms_vert_dict[l] = 0 @@ -223,7 +225,7 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII): continue c_l = subreg_cc.copy() c_l[c_l != c] = 0 - com_y = center_of_mass(c_l)[1] + com_y = np_approx_center_of_mass(c_l, label_ref=c)[c][1] # center_of_mass(c_l)[1] if com_y < min(coms_vert_y): label = min(coms_vert_labels) - 1 @@ -262,7 +264,7 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII): continue c_l = ep_cc.copy() c_l[c_l != c] = 0 - com_y = center_of_mass(c_l)[1] + com_y = np_approx_center_of_mass(c_l, label_ref=c)[c][1] # center_of_mass(c_l)[1] nearest_lower = find_nearest_lower(coms_vert_y, com_y) label = [i for i in coms_vert_dict if coms_vert_dict[i] == nearest_lower][0] mapping_ep_cc_to_vert_label[c] = label @@ -340,12 +342,12 @@ def label_instance_top_to_bottom(vert_nii: NII): present_labels = list(vert_nii.unique()) vert_arr = vert_nii.get_seg_array() com_i = np_approx_center_of_mass(vert_arr, present_labels) - # TODO - comb = {} - for i in present_labels: - arr_i = vert_arr.copy() - arr_i[arr_i != i] = 0 - comb[i] = center_of_mass(arr_i) + # Old, more precise version (but takes longer) + # comb = {} + # for i in present_labels: + # arr_i = vert_arr.copy() + # arr_i[arr_i != i] = 0 + # comb[i] = center_of_mass(arr_i) comb_l = list(zip(com_i.keys(), com_i.values())) comb_l.sort(key=lambda a: a[1][1]) # PIR com_map = {comb_l[idx][0]: idx + 1 for idx in range(len(comb_l))} diff --git a/spineps/seg_model.py b/spineps/seg_model.py index 604f3a2..a8e3dec 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -132,22 +132,24 @@ def segment_scan( input_niftys_in_order = [] zms_pir: Zooms = None for idx, id in enumerate(self.inference_config.expected_inputs): + # Make nifty nii = to_nii(inputdict[id], seg=id == InputType.seg) - + # Padding if pad_size > 0: arr = nii.get_array() - arr = np.pad(arr, 2, mode="edge") + arr = np.pad(arr, pad_size, mode="edge") nii.set_array_(arr) input_niftys_in_order.append(nii) - + # Save first values for comparison if orig_shape is None: orig_shape = nii.shape orientation = nii.orientation zms = nii.zoom - + # Consistency check assert ( nii.shape == orig_shape and nii.orientation == orientation and nii.zoom == zms ), "All inputs need to be of same shape, orientation and zoom, got at least two different." + # Reorient and rescale nii.reorient_(self.inference_config.model_expected_orientation, verbose=self.logger) zms_pir = nii.zoom if resample_to_recommended: diff --git a/spineps/seg_run.py b/spineps/seg_run.py index e54a013..891f681 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -202,14 +202,18 @@ def process_dataset( logger.print() logger.print(f"Processed {processed_seen_counter} scans with {modalities}", Log_Type.BOLD) - logger.print( - f"Scans that were skipped because all derivatives were present: {processed_alldone_counter}" - ) if processed_alldone_counter > 0 else None + ( + logger.print(f"Scans that were skipped because all derivatives were present: {processed_alldone_counter}") + if processed_alldone_counter > 0 + else None + ) not_processed_ok = processed_seen_counter - processed_alldone_counter - processed_counter if not_processed_ok > 0: logger.print(f"Scans that were not properly processed: {not_processed_ok}") - logger.print("Consult the log file for more info!") if save_log_data else logger.print( - "Set save_log_data=True to get a detailed log. Here are the scans in question:" + ( + logger.print("Consult the log file for more info!") + if save_log_data + else logger.print("Set save_log_data=True to get a detailed log. Here are the scans in question:") ) logger.print(not_properly_processed) @@ -403,7 +407,6 @@ def process_img_nii( model_instance, debug_data=debug_data_run, use_height_estimate=False, - resample_output_to_input_space=True, verbose=verbose, fill_holes=proc_fillholes, proc_corpus_clean=proc_corpus_clean,