Skip to content

Commit

Permalink
minor refactoring for PR-82
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Soni committed Aug 21, 2024
1 parent 312ae55 commit 5997cda
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 112 deletions.
2 changes: 1 addition & 1 deletion models/bamf_nnunet_mr_breast/config/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ execute:
- NiftiConverter
- module: NNUnetRunnerV2
nnunet_dataset: Dataset009_Breast
roi: BREAST,FGT
roi: BREAST+FGT
- module: NNUnetRunnerV2
nnunet_dataset: Dataset011_Breast
roi: BREAST+BREAST_CARCINOMA
Expand Down
9 changes: 4 additions & 5 deletions models/bamf_nnunet_mr_breast/meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"description": "The MR scan of a patient.",
"format": "DICOM",
"modality": "MR",
"bodypartexamined": "BREAST,FGT,BREAST+TUMOR",
"bodypartexamined": "BREAST",
"slicethickness": "2.5mm",
"non-contrast": true,
"contrast": false
Expand All @@ -20,11 +20,10 @@
{
"label": "Segmentation",
"type": "Segmentation",
"description": "Segmentation for breast, fgt, and tumor",
"description": "Segmentation for breast, fgt, and breast carcinoma",
"classes": [
"BREAST",
"FGT",
"BREAST+TUMOR"
"BREAST+FGT",
"BREAST+BREAST_CARCINOMA"
]
}
],
Expand Down
6 changes: 2 additions & 4 deletions models/bamf_nnunet_mr_breast/utils/BreastPostProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,8 @@ def task(self, instance: Instance, in_breast_and_fgt_data: InstanceData, in_brea
tumor_seg = np.copy(tumor_seg),
mr_path = in_mr_data.abspath
)
process_dir = self.config.data.requestTempDir(label="nnunet-breast-processor")
process_file = os.path.join(process_dir, f'bamf_nnunet_mr_breast.nii.gz')
sitk.WriteImage(
output_seg,
process_file,
out_data.abspath,
)
shutil.copyfile(process_file, out_data.abspath)

102 changes: 0 additions & 102 deletions models/bamf_nnunet_mr_breast/utils/NNUnetRunnerV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,107 +23,15 @@
@IO.Config('nnunet_dataset', str, None, the='nnunet dataset name')
@IO.Config('nnunet_config', str, None, the='nnunet model name (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres)')
@IO.Config('folds', int, None, the='number of folds to run nnunet on')
@IO.Config('use_tta', bool, True, the='flag to enable test time augmentation')
@IO.Config('export_prob_maps', bool, False, the='flag to export probability maps')
@IO.Config('prob_map_segments', list, [], the='segment labels for probability maps')
@IO.Config('roi', str, None, the='roi or comma separated list of roi the nnunet segments')
class NNUnetRunnerV2(Module):

nnunet_dataset: str
nnunet_config: str
input_data_type: DataType
folds: int # TODO: support optional config attributes
use_tta: bool
export_prob_maps: bool
prob_map_segments: list
roi: str

def export_prob_mask(self, nnunet_out_dir: str, ref_file: InstanceData, output_dtype: str = 'float32', structure_list: Optional[List[str]] = None):
"""
Convert softmax probability maps to NRRD. For simplicity, the probability maps
are converted by default to UInt8
Arguments:
model_output_folder : required - path to the folder where the inferred segmentation masks should be stored.
ref_file : required - InstanceData object of the generated segmentation mask used as reference file.
output_dtype : optional - output data type. Data type float16 is not supported by the NRRD standard,
so the choice should be between uint8, uint16 or float32.
structure_list : optional - list of the structures whose probability maps are stored in the
first channel of the `.npz` file (output from the nnU-Net pipeline
when `export_prob_maps` is set to True).
Outputs:
This function [...]
"""

# initialize structure list
if structure_list is None:
if self.roi is not None:
structure_list = self.roi.split(',')
else:
structure_list = []

# sanity check user inputs
assert(output_dtype in ["uint8", "uint16", "float32"])

# input file containing the raw information
pred_softmax_fn = 'VOLUME_001.npz'
pred_softmax_path = os.path.join(nnunet_out_dir, pred_softmax_fn)

# parse NRRD file - we will make use of if to populate the header of the
# NRRD mask we are going to get from the inferred segmentation mask
sitk_ct = sitk.ReadImage(ref_file.abspath)

# generate bundle for prob masks
# TODO: we really have to create folders (or add this as an option that defaults to true) automatically
prob_masks_bundle = ref_file.getDataBundle('prob_masks')
if not os.path.isdir(prob_masks_bundle.abspath):
os.mkdir(prob_masks_bundle.abspath)

# load softmax probability maps
pred_softmax_all = np.load(pred_softmax_path)["softmax"]

# iterate all channels
for channel in range(0, len(pred_softmax_all)):

structure = structure_list[channel] if channel < len(structure_list) else f"structure_{channel}"
pred_softmax_segmask = pred_softmax_all[channel].astype(dtype = np.float32)

if output_dtype == "float32":
# no rescale needed - the values will be between 0 and 1
# set SITK image dtype to Float32
sitk_dtype = sitk.sitkFloat32

elif output_dtype == "uint8":
# rescale between 0 and 255, quantize
pred_softmax_segmask = (255*pred_softmax_segmask).astype(np.int32)
# set SITK image dtype to UInt8
sitk_dtype = sitk.sitkUInt8

elif output_dtype == "uint16":
# rescale between 0 and 65536
pred_softmax_segmask = (65536*pred_softmax_segmask).astype(np.int32)
# set SITK image dtype to UInt16
sitk_dtype = sitk.sitkUInt16
else:
raise ValueError("Invalid output data type. Please choose between uint8, uint16 or float32.")

pred_softmax_segmask_sitk = sitk.GetImageFromArray(pred_softmax_segmask)
pred_softmax_segmask_sitk.CopyInformation(sitk_ct)
pred_softmax_segmask_sitk = sitk.Cast(pred_softmax_segmask_sitk, sitk_dtype)

# generate data
prob_mask = InstanceData(f'{structure}.nrrd', DataType(FileType.NRRD, {'mod': 'prob_mask', 'structure': structure}), bundle=prob_masks_bundle)

# export file
writer = sitk.ImageFileWriter()
writer.UseCompressionOn()
writer.SetFileName(prob_mask.abspath)
writer.Execute(pred_softmax_segmask_sitk)

# check if the file was written
if os.path.isfile(prob_mask.abspath):
self.v(f" > prob mask for {structure} saved to {prob_mask.abspath}")
prob_mask.confirm()

@IO.Instance()
@IO.Input("in_data", the="input data to run nnunet on")
@IO.Output("out_data", 'VOLUME_001.nii.gz', 'nifti:mod=seg:model=nnunet', data='in_data', the="output data from nnunet")
Expand Down Expand Up @@ -177,13 +85,7 @@ def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData
# add optional arguments
if self.folds is not None:
bash_command += ["--folds", str(self.folds)]

if not self.use_tta:
bash_command += ["--disable_tta"]

if self.export_prob_maps:
bash_command += ["--save_npz"]

# run command
self.subprocess(bash_command, text=True)

Expand All @@ -202,9 +104,5 @@ def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData
# copy output data to instance
shutil.copyfile(out_path, out_data.abspath)

# export probabiliy maps if requested as dynamic data
if self.export_prob_maps:
self.export_prob_mask(str(out_dir), out_data, 'float32', self.prob_map_segments)

# update meta dynamically
out_data.type.meta += meta

0 comments on commit 5997cda

Please sign in to comment.