-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_monai_bundle.py
81 lines (68 loc) · 2.2 KB
/
run_monai_bundle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
from monai.bundle.scripts import run
def get_parser():
parser = argparse.ArgumentParser(
description="Run a MONAI bundle script with specified configurations."
)
parser.add_argument(
"--mode",
type=str,
required=True,
choices=[
"train",
"inference_predict",
"inference_evaluation",
"temperature_scaling",
"temperature_scaling_eval",
],
help="Operation mode.",
)
parser.add_argument(
"--sys",
type=str,
required=True,
choices=["low", "med", "high"],
help="System specification.",
)
parser.add_argument(
"--data", type=str, required=True, help="Dataset name, e.g., brats_2021."
)
parser.add_argument(
"--model", type=str, help="Model configuration name, e.g., baseline_ce."
)
return parser
def get_config_files(args):
# Common configs:
config_files = [
"bundle/configs/common.yaml",
f"bundle/configs/data/{args.data}.yaml",
f"bundle/configs/sys/{args.sys}_spec.yaml",
]
mode_specific_configs = {
"train": ["train", "validation", f"train/{args.model}"],
"inference_predict": ["inference_predict"],
"inference_evaluation": ["inference_eval"],
"temperature_scaling": ["temp_scale"],
"temperature_scaling_eval": ["inference_eval", "temp_scale_eval"],
}
config_files.extend(
[f"bundle/configs/{file}.yaml" for file in mode_specific_configs[args.mode]]
)
return config_files
def main():
parser = get_parser()
args = parser.parse_args()
if args.mode == "train" and not args.model:
raise ValueError("Model configuration name is required for training.")
config_files = get_config_files(args)
model_name = f"{args.model}_{args.data}_{args.sys}"
model_name += "_temp_scaled" if args.mode == "temperature_scaling_eval" else ""
run(
bundle_root="./bundle",
meta_file="./bundle/configs/metadata.json",
config_file=config_files,
logging_file="./bundle/configs/logging.conf",
model_name=model_name,
)
if __name__ == "__main__":
main()