Skip to content

Commit

Permalink
fixed dtype bug, added assertAffine() from new TPTBox version
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrik-code committed Feb 5, 2024
1 parent ae3e26a commit 9d62d1d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
11 changes: 6 additions & 5 deletions spineps/phase_post.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# from utils.predictor import nnUNetPredictor
import numpy as np
from TPTBox import NII, Location, Log_Type, v_idx2name, v_name2idx
from TPTBox.core.np_utils import np_bbox_nd, np_connected_components, np_dilate_msk, np_map_labels, np_approx_center_of_mass
from scipy.ndimage import center_of_mass
from TPTBox import NII, Location, Log_Type, v_idx2name, v_name2idx
from TPTBox.core.np_utils import np_approx_center_of_mass, np_bbox_nd, np_connected_components, np_dilate_msk, np_map_labels

from spineps.seg_pipeline import logger, vertebra_subreg_labels

Expand All @@ -18,7 +18,7 @@ def phase_postprocess_combined(
) -> tuple[NII, NII]:
logger.print("Post process", Log_Type.STAGE)
with logger:
assert seg_nii.shape == vert_nii.shape, f"shape mismatch before cleaning, got {seg_nii.shape} and {vert_nii.shape}"
seg_nii.assert_affine(shape=vert_nii.shape)
# Post process semantic mask
###################
seg_nii = semantic_bounding_box_clean(seg_nii=seg_nii.copy())
Expand All @@ -31,8 +31,8 @@ def phase_postprocess_combined(
#
vert_nii.apply_mask(seg_nii, inplace=True)
crop_slices = seg_nii.compute_crop_slice(dist=3)
vert_uncropped_arr = np.zeros(vert_nii.shape)
seg_uncropped_arr = np.zeros(vert_nii.shape)
vert_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype)
seg_uncropped_arr = np.zeros(vert_nii.shape, dtype=seg_nii.dtype)

# Crop down
vert_nii.apply_crop_slice_(crop_slices)
Expand Down Expand Up @@ -63,6 +63,7 @@ def phase_postprocess_combined(
whole_vert_nii_cleaned.set_array_(vert_uncropped_arr, verbose=False)
#
seg_uncropped_arr[crop_slices] = seg_nii_cleaned.get_seg_array()

seg_nii_cleaned.set_array_(seg_uncropped_arr, verbose=False)
#
debug_data["vert_arr_crop_e_addivd"] = whole_vert_nii_cleaned.copy()
Expand Down
11 changes: 6 additions & 5 deletions spineps/seg_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def process_img_nii(
logger.print("Input image", zms, orientation, img_nii.shape, verbose=verbose)

# First stage
if not os.path.exists(out_spine) or override_semantic:
if not os.path.exists(out_spine_raw) or override_semantic:
# make subreg mask
seg_nii, seg_nii_modelres, unc_nii, softmax_logits, errcode = predict_semantic_mask(
img_nii,
Expand Down Expand Up @@ -377,6 +377,7 @@ def process_img_nii(
seg_nii_back = lambda_semantic(seg_nii_back)

seg_nii_back.nii = nib.nifti1.Nifti1Image(seg_nii_back.get_seg_array(), affine=affine, header=header)
seg_nii.assert_affine(other=img_nii)
seg_nii_back.save(out_spine_raw, verbose=logger)
if isinstance(seg_nii_modelres, NII) and save_modelres_mask:
seg_nii_modelres.save(str(out_spine_raw).replace("seg-spine", "seg-spineModelRes"), verbose=logger)
Expand All @@ -401,7 +402,7 @@ def process_img_nii(
return output_paths, ErrCode.SHAPE

# Second stage
if not os.path.exists(out_vert) or override_instance:
if not os.path.exists(out_vert_raw) or override_instance:
whole_vert_nii, errcode = predict_instance_mask(
seg_nii.copy(),
model_instance,
Expand All @@ -425,6 +426,7 @@ def process_img_nii(

# TODO make this better (instance mask gets to have same global coords)
whole_vert_nii.nii = nib.nifti1.Nifti1Image(whole_vert_nii.get_seg_array(), affine=affine, header=header)
whole_vert_nii.assert_affine(other=img_nii)
#
whole_vert_nii.save(out_vert_raw, verbose=logger)
done_something = True
Expand All @@ -446,9 +448,8 @@ def process_img_nii(
proc_assign_missing_cc=proc_assign_missing_cc,
verbose=verbose,
)
assert seg_nii_clean.shape == vert_nii_clean.shape, "shape mismatch after postprocess"
assert seg_nii_clean.zoom == zms, "zoom mismatch"
assert seg_nii_clean.orientation == orientation, "orientation mismatch"

seg_nii_clean.assert_affine(other=vert_nii_clean, zoom=zms, orientation=orientation)

seg_nii_clean.save(out_spine, verbose=logger)
vert_nii_clean.save(out_vert, verbose=logger)
Expand Down

0 comments on commit 9d62d1d

Please sign in to comment.