Skip to content

Commit

Permalink
added padding to semantic phase with edge mode, makes boundary segmen…
Browse files Browse the repository at this point in the history
…tations waaay better
  • Loading branch information
Hendrik-code committed Jan 31, 2024
1 parent 197cab9 commit 61a5291
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
1 change: 1 addition & 0 deletions spineps/phase_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def collect_vertebra_predictions(
results = model.segment_scan(
cut_nii,
resample_to_recommended=False,
pad_size=0,
# resample_output_to_input_space=False,
verbose=False,
)
Expand Down
1 change: 1 addition & 0 deletions spineps/phase_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def label_instance_top_to_bottom(vert_nii: NII):
present_labels = list(vert_nii.unique())
vert_arr = vert_nii.get_seg_array()
com_i = np_approx_center_of_mass(vert_arr, present_labels)
# TODO
comb = {}
for i in present_labels:
arr_i = vert_arr.copy()
Expand Down
1 change: 1 addition & 0 deletions spineps/phase_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def predict_semantic_mask(

results = model.segment_scan(
mri_nii_rdy,
pad_size=2,
resample_to_recommended=True,
verbose=verbose,
) # type:ignore
Expand Down
37 changes: 22 additions & 15 deletions spineps/seg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def calc_recommended_resampling_zoom(self, input_zoom: Zooms) -> Zooms:
def segment_scan(
self,
input: Image_Reference | dict[InputType, Image_Reference],
pad_size: int = 4,
pad_size: int = 2,
step_size: float | None = 0.5,
resample_to_recommended: bool = True,
verbose: bool = False,
Expand Down Expand Up @@ -124,27 +124,34 @@ def segment_scan(
# Check if all required inputs are there
if not set(list(inputdict.keys())).issuperset(self.inference_config.expected_inputs):
self.print(f"expected {self.inference_config.expected_inputs}, but only got {list(inputdict.keys())}")

# Convert input to nifty
first_nii = to_nii(
inputdict[self.inference_config.expected_inputs[0]], seg=self.inference_config.expected_inputs[0] == InputType.seg
)
orig_shape = first_nii.shape
orientation = first_nii.orientation
zms = first_nii.zoom
#
orig_shape = None
orientation = None
zms = None
#
input_niftys_in_order = []
zms_pir: Zooms = None
for idx, id in enumerate(self.inference_config.expected_inputs):
nii = to_nii(inputdict[id], seg=id == InputType.seg)

if pad_size > 0:
arr = nii.get_array()
arr = np.pad(arr, 2, mode="edge")
nii.set_array_(arr)
input_niftys_in_order.append(nii)

if orig_shape is None:
orig_shape = nii.shape
orientation = nii.orientation
zms = nii.zoom

assert (
nii.shape == orig_shape and nii.orientation == orientation and nii.zoom == zms
), "All inputs need to be of same shape, orientation and zoom, got at least two different."
nii.reorient_(self.inference_config.model_expected_orientation, verbose=self.logger)
zms_pir = nii.zoom
if resample_to_recommended:
nii.rescale_(self.calc_recommended_resampling_zoom(zms_pir), verbose=self.logger)
input_niftys_in_order.append(nii)

if not resample_to_recommended:
self.print("resample_to_recommended set to False, segmentation might not work. Proceed at own risk", Log_Type.WARNING)
Expand All @@ -158,7 +165,6 @@ def segment_scan(
self.print("Run Segmentation")
result = self.run(
input=input_niftys_in_order,
pad_size=pad_size,
verbose=verbose,
)
assert OutputType.seg in result and isinstance(result[OutputType.seg], NII), "No seg output in segmentation result"
Expand All @@ -173,6 +179,11 @@ def segment_scan(
v.pad_to(orig_shape, inplace=True)
if k == OutputType.seg:
v.map_labels_(self.inference_config.segmentation_labels, verbose=self.logger)
if pad_size > 0:
arr = v.get_array()
arr = arr[pad_size:-pad_size, pad_size:-pad_size, pad_size:-pad_size]
v.set_array_(arr)

self.print(f"out_seg {k}", v.zoom, v.orientation, v.shape, verbose=verbose)
self.print("Segmenting done!")
return result
Expand All @@ -197,7 +208,6 @@ def acquisition(self) -> Acquisition:
def run(
self,
input: list[NII],
pad_size: int = 2,
verbose: bool = False,
) -> dict[OutputType, NII | None]:
pass
Expand Down Expand Up @@ -273,14 +283,12 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self:
def run(
self,
input: list[NII],
pad_size: int = 2,
verbose: bool = False,
) -> dict[OutputType, NII | None]:
self.print("Segmenting...")
seg_nii, unc_nii, softmax_logits = run_inference(
input,
self.predictor,
# pad_size=pad_size,
)
self.print("Segmentation done!")
self.print("out_inf", seg_nii.zoom, seg_nii.orientation, seg_nii.shape, verbose=verbose)
Expand Down Expand Up @@ -314,7 +322,6 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self:
def run(
self,
input: list[NII],
pad_size: int = 2, # TODO not implemented yet
verbose: bool = False,
) -> dict[OutputType, NII | None]:
assert len(input) == 1, "Unet3D does not support more than one input"
Expand Down
1 change: 0 additions & 1 deletion spineps/utils/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def load_inf_model(
def run_inference(
input: str | NII | list[NII],
predictor: nnUNetPredictor,
# pad_size: int = 2,
reorient_PIR: bool = False,
) -> tuple[NII, NII | None, np.ndarray]:
"""Runs nnUnet model inference on one input.
Expand Down

0 comments on commit 61a5291

Please sign in to comment.