Skip to content

Commit

Permalink
Fix filename for inference (#50)
Browse files Browse the repository at this point in the history
* Use `os.path.expanduser` to expand `~`

* Add the `_0000` suffix

Context: #49

* gzip the input file

Context: #49
  • Loading branch information
valosekj authored May 30, 2024
1 parent 97a6370 commit 8922091
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions packaging_ventral_rootlets/run_inference_single_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,31 @@ def main():
parser = get_parser()
args = parser.parse_args()

fname_file = args.i
fname_file_out = args.o
fname_file = os.path.expanduser(args.i)
fname_file_out = os.path.expanduser(args.o)
print(f'\nFound {fname_file} file.')

# If the fname_file is .nii, gzip it
# This is needed, because the filename suffix must match the `file_ending` in `dataset.json`. And as the
# `file_ending` for the ventral model is `.nii.gz`, we gzip the input file if it is not already gzipped.
# Context: https://github.com/ivadomed/model-spinal-rootlets/issues/49
if not fname_file.endswith('.nii.gz'):
print('Compressing the input image...')
os.system('gzip -f {}'.format(fname_file))
fname_file = fname_file + '.gz'
print(f'Compressed {fname_file}')

# Add .gz suffix to the output file if not already present. This is needed because we gzip the input file.
if not fname_file_out.endswith('.gz'):
fname_file_out = fname_file_out + '.gz'

# Create temporary directory in the temp to store the reoriented images
tmpdir = tmp_create()
# Copy the file to the temporary directory using shutil.copyfile
fname_file_tmp = os.path.join(tmpdir, os.path.basename(fname_file))
# NOTE: Add the `_0000` suffix, because nnUNet removes the last five characters:
# https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/inference/predict_from_raw_data.py#L171C19-L172C51
# Context: https://github.com/ivadomed/model-spinal-rootlets/issues/49
fname_file_tmp = os.path.join(tmpdir, os.path.basename(add_suffix(fname_file, '_0000')))
shutil.copyfile(fname_file, fname_file_tmp)
print(f'Copied {fname_file} to {fname_file_tmp}')

Expand All @@ -138,8 +155,7 @@ def main():
# reorient the image to LPI using SCT
os.system('sct_image -i {} -setorient LPI -o {}'.format(fname_file_tmp, fname_file_tmp))

# NOTE: for individual images, the _0000 suffix is not needed.
# BUT, the images should be in a list of lists
# Note: even a single file must be in a list of lists
fname_file_tmp_list = [[fname_file_tmp]]

# Use fold_all (all train/val subjects were used for training) or specific fold(s)
Expand Down Expand Up @@ -169,7 +185,7 @@ def main():

# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
join(args.path_model),
join(os.path.expanduser(args.path_model)),
use_folds=folds_avail,
checkpoint_name='checkpoint_final.pth' if not args.use_best_checkpoint else 'checkpoint_best.pth',
)
Expand Down

0 comments on commit 8922091

Please sign in to comment.