Skip to content

Commit

Permalink
small bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ga84mun committed Aug 13, 2024
1 parent 5608507 commit 1ffe7e4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
13 changes: 5 additions & 8 deletions spineps/phase_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_corpus_coms(
seg_nii.assert_affine(orientation=["P", "I", "R"])
#
# Extract Corpus region and try to find all coms naively (some skips shouldnt matter)
corpus_nii = seg_nii.extract_label(Location.Vertebra_Corpus_border.value)
corpus_nii = seg_nii.extract_label([Location.Vertebra_Corpus_border, Location.Vertebra_Corpus])
corpus_nii.erode_msk_(mm=2, connectivity=2, verbose=False)
if 1 in corpus_nii.unique() and corpus_size_cleaning > 0:
corpus_nii.set_array_(
Expand All @@ -192,7 +192,7 @@ def get_corpus_coms(
)

if 1 not in corpus_nii.unique():
logger.print("No 1 in corpus nifty, cannot make vertebra mask", Log_Type.FAIL)
logger.print(f"No corpus found after get_corpus_coms post process, cannot make vertebra mask. {corpus_nii.unique()}", Log_Type.FAIL)
return None

if not process_detect_and_solve_merged_corpi:
Expand Down Expand Up @@ -256,19 +256,15 @@ def get_corpus_coms(
stats_by_height.pop(vl)
stats_by_height = dict(sorted(stats_by_height.items(), key=lambda x: x[1][0]))
stats_by_height_keys = list(stats_by_height.keys())
print(stats_by_height_keys)
continue

logger.print("Merged corpi, try to fix it", verbose=verbose)
neighbor_verts = {
stats_by_height_keys[idx + i]: stats_by_height[stats_by_height_keys[idx + i]]
for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99
if (idx + i) < len(stats_by_height_keys) and (idx + i) >= 0 and stats_by_height_keys[idx + i] < 99
}
# stats_by_height_keys[idx + i]
# for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
# if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99
# ] # (+-3)

logger.print("neighbor_vert_labels", neighbor_verts, verbose=verbose)
if len(neighbor_verts) == 0:
logger.print("Got no neighbor vert labels to fix", Log_Type.FAIL)
Expand Down Expand Up @@ -505,6 +501,7 @@ def collect_vertebra_predictions(
47: 7,
48: 8,
49: 9,
50: 9,
Location.Spinal_Cord.value: 0,
Location.Spinal_Canal.value: 0,
Location.Vertebra_Disc.value: 0,
Expand Down
3 changes: 2 additions & 1 deletion spineps/seg_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def process_img_nii( # noqa: C901
proc_inst_clean_small_cc_artifacts: bool = True,
proc_inst_largest_k_cc: int = 0,
proc_inst_detect_and_solve_merged_corpi: bool = True,
vertebra_instance_labeling_offset=2,
# Both
proc_fill_3d_holes: bool = True,
proc_assign_missing_cc: bool = True,
Expand Down Expand Up @@ -460,7 +461,7 @@ def process_img_nii( # noqa: C901
seg_nii=seg_nii_back,
vert_nii=whole_vert_nii,
debug_data=debug_data_run,
labeling_offset=1,
labeling_offset=vertebra_instance_labeling_offset - 1,
proc_clean_inst_by_sem=proc_clean_inst_by_sem,
proc_assign_missing_cc=proc_assign_missing_cc,
proc_vertebra_inconsistency=proc_vertebra_inconsistency,
Expand Down
1 change: 1 addition & 0 deletions spineps/utils/filepaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# "/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/"
# ) # None # You can put an absolute path to the model weights here instead of using environment variable
spineps_environment_path_backup = Path(__file__).parent.parent.joinpath("models") # EDIT this to use this instead of environment variable
spineps_environment_path_backup.mkdir(exist_ok=True)


def get_mri_segmentor_models_dir() -> Path:
Expand Down
5 changes: 4 additions & 1 deletion spineps/utils/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def load_inf_model(
elif ddevice == "cuda":
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1) if init_threads else None
torch.set_num_interop_threads(1) if init_threads else None
try:
torch.set_num_interop_threads(1) if init_threads else None
except RuntimeError:
pass
device = torch.device("cuda")
else:
device = torch.device("mps")
Expand Down

0 comments on commit 1ffe7e4

Please sign in to comment.