diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index 80f904d..4a1f383 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -175,7 +175,7 @@ def get_corpus_coms( seg_nii.assert_affine(orientation=["P", "I", "R"]) # # Extract Corpus region and try to find all coms naively (some skips shouldnt matter) - corpus_nii = seg_nii.extract_label(Location.Vertebra_Corpus_border.value) + corpus_nii = seg_nii.extract_label([Location.Vertebra_Corpus_border, Location.Vertebra_Corpus]) corpus_nii.erode_msk_(mm=2, connectivity=2, verbose=False) if 1 in corpus_nii.unique() and corpus_size_cleaning > 0: corpus_nii.set_array_( @@ -192,7 +192,7 @@ def get_corpus_coms( ) if 1 not in corpus_nii.unique(): - logger.print("No 1 in corpus nifty, cannot make vertebra mask", Log_Type.FAIL) + logger.print(f"No corpus found after get_corpus_coms post process, cannot make vertebra mask. {corpus_nii.unique()}", Log_Type.FAIL) return None if not process_detect_and_solve_merged_corpi: @@ -256,19 +256,15 @@ def get_corpus_coms( stats_by_height.pop(vl) stats_by_height = dict(sorted(stats_by_height.items(), key=lambda x: x[1][0])) stats_by_height_keys = list(stats_by_height.keys()) - print(stats_by_height_keys) continue logger.print("Merged corpi, try to fix it", verbose=verbose) neighbor_verts = { stats_by_height_keys[idx + i]: stats_by_height[stats_by_height_keys[idx + i]] for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] - if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99 + if (idx + i) < len(stats_by_height_keys) and (idx + i) >= 0 and stats_by_height_keys[idx + i] < 99 } - # stats_by_height_keys[idx + i] - # for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] - # if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99 - # ] # (+-3) + logger.print("neighbor_vert_labels", neighbor_verts, verbose=verbose) if len(neighbor_verts) == 0: logger.print("Got no neighbor vert labels to fix", Log_Type.FAIL) @@ -505,6 +501,7 @@ def collect_vertebra_predictions( 47: 7, 48: 8, 49: 9, + 50: 9, Location.Spinal_Cord.value: 0, Location.Spinal_Canal.value: 0, Location.Vertebra_Disc.value: 0, diff --git a/spineps/seg_run.py b/spineps/seg_run.py index 301c5c3..d842c8d 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -262,6 +262,7 @@ def process_img_nii( # noqa: C901 proc_inst_clean_small_cc_artifacts: bool = True, proc_inst_largest_k_cc: int = 0, proc_inst_detect_and_solve_merged_corpi: bool = True, + vertebra_instance_labeling_offset=2, # Both proc_fill_3d_holes: bool = True, proc_assign_missing_cc: bool = True, @@ -460,7 +461,7 @@ def process_img_nii( # noqa: C901 seg_nii=seg_nii_back, vert_nii=whole_vert_nii, debug_data=debug_data_run, - labeling_offset=1, + labeling_offset=vertebra_instance_labeling_offset - 1, proc_clean_inst_by_sem=proc_clean_inst_by_sem, proc_assign_missing_cc=proc_assign_missing_cc, proc_vertebra_inconsistency=proc_vertebra_inconsistency, diff --git a/spineps/utils/filepaths.py b/spineps/utils/filepaths.py index fedd428..189f29b 100755 --- a/spineps/utils/filepaths.py +++ b/spineps/utils/filepaths.py @@ -9,6 +9,7 @@ # "/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/" # ) # None # You can put an absolute path to the model weights here instead of using environment variable spineps_environment_path_backup = Path(__file__).parent.parent.joinpath("models") # EDIT this to use this instead of environment variable +spineps_environment_path_backup.mkdir(exist_ok=True) def get_mri_segmentor_models_dir() -> Path: diff --git a/spineps/utils/inference_api.py b/spineps/utils/inference_api.py index 95e0d80..8c18db2 100755 --- a/spineps/utils/inference_api.py +++ b/spineps/utils/inference_api.py @@ -49,7 +49,10 @@ def load_inf_model( elif ddevice == "cuda": # multithreading in torch doesn't help nnU-Net if run on GPU torch.set_num_threads(1) if init_threads else None - torch.set_num_interop_threads(1) if init_threads else None + try: + torch.set_num_interop_threads(1) if init_threads else None + except RuntimeError: + pass device = torch.device("cuda") else: device = torch.device("mps")