diff --git a/totalspineseg/utils/transform_seg2image.py b/totalspineseg/utils/transform_seg2image.py index 7914cd3..dbf782d 100644 --- a/totalspineseg/utils/transform_seg2image.py +++ b/totalspineseg/utils/transform_seg2image.py @@ -236,6 +236,10 @@ def transform_seg2image( nibabel.Nifti1Image Output segmentation. ''' + # Check if the input image is 4D and take the first image from the last axis for resampling + if len(np.asanyarray(image.dataobj).shape) == 4: + image = image.slicer[..., 0] + image_data = np.asanyarray(image.dataobj).astype(np.float64) image_affine = image.affine.copy() seg_data = np.asanyarray(seg.dataobj)