Skip to content

Commit

Permalink
minor refactoring for PR-82
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Soni committed Aug 21, 2024
1 parent 5997cda commit 24d3b6b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
4 changes: 1 addition & 3 deletions models/bamf_nnunet_mr_breast/config/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ execute:
- NiftiConverter
- module: NNUnetRunnerV2
nnunet_dataset: Dataset009_Breast
roi: BREAST+FGT
roi: BREAST,FGT
- module: NNUnetRunnerV2
nnunet_dataset: Dataset011_Breast
roi: BREAST+BREAST_CARCINOMA
Expand All @@ -49,8 +49,6 @@ modules:

NNUnetRunnerV2:
in_data: nifti:mod=mr
nnunet_config: 3d_fullres
export_prob_maps: False

BreastPostProcessor:
in_breast_data: nifti:mod=seg:nnunet_task=Dataset009_Breast
Expand Down
9 changes: 3 additions & 6 deletions models/bamf_nnunet_mr_breast/utils/NNUnetRunnerV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@

@IO.ConfigInput('in_data', 'nifti:mod=mr', the="input data to run nnunet on")
@IO.Config('nnunet_dataset', str, None, the='nnunet dataset name')
@IO.Config('nnunet_config', str, None, the='nnunet model name (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres)')
@IO.Config('folds', int, None, the='number of folds to run nnunet on')
@IO.Config('roi', str, None, the='roi or comma separated list of roi the nnunet segments')
class NNUnetRunnerV2(Module):

nnunet_dataset: str
nnunet_config: str
input_data_type: DataType
folds: int # TODO: support optional config attributes
roi: str
Expand All @@ -40,7 +38,6 @@ def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData
# get the nnunet model to run
self.v("Running nnUNet_predict.")
self.v(f" > dataset: {self.nnunet_dataset}")
self.v(f" > config: {self.nnunet_config}")
self.v(f" > input data: {in_data.abspath}")
self.v(f" > output data: {out_data.abspath}")

Expand Down Expand Up @@ -80,7 +77,7 @@ def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData
bash_command += ["-i", str(inp_dir)]
bash_command += ["-o", str(out_dir)]
bash_command += ["-d", self.nnunet_dataset]
bash_command += ["-c", self.nnunet_config]
bash_command += ["-c", "3d_fullres"]

# add optional arguments
if self.folds is not None:
Expand All @@ -91,9 +88,9 @@ def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData

# output meta
meta = {
"model": "nnunet",
"model": "nnunet-v2",
"nnunet_dataset": self.nnunet_dataset,
"nnunet_config": self.nnunet_config,
"nnunet_config": "3d_fullres",
"roi": self.roi
}

Expand Down

0 comments on commit 24d3b6b

Please sign in to comment.