-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparse_config.py
270 lines (221 loc) · 10.8 KB
/
parse_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import logging
import os
from datetime import datetime
from functools import reduce, partial
from operator import getitem
from pathlib import Path
from logger import setup_logging
from utils import get_project_root, read_json, write_json
class ConfigParser:
def __init__(self, config, resume=None, modification=None, mode=None, use_tune=False, run_id=None,
create_save_log_dir = True):
"""
Class to parse configuration json file. Handles hyperparameters for training, initializations of modules,
checkpoint saving and logging module.
Args:
config (dict): Configuration dictionary.
resume (str, optional): Path to the checkpoint to resume training/testing from. Defaults to None.
modification (dict, optional): Modification dictionary to update the configuration. Defaults to None.
mode (str, optional): Mode of operation ('train' or 'test'). Defaults to None.
use_tune (bool, optional): Whether to use tuning. Defaults to False.
run_id (str, optional): Identifier for the run. Defaults to None - in this case, timestamp is used
create_save_log_dir (bool, optional): Whether to create directories for saving logs and outputs.
Defaults to True.
"""
# load config file and apply modification
self._config = _update_config(config, modification)
self._resume = resume
self._use_tune = use_tune
# set save_dir where trained model and log will be saved.
save_dir = Path(self.config['trainer']['save_dir'])
exper_name = self.config['name']
if run_id is None: # use timestamp as default run-id
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
details = "_{}_bs{}{}".format("ml" if self.config['arch']['args']['multi_label_training'] else "sl",
self.config['data_loader']['args']['batch_size'],
self.config['run_details'])
if create_save_log_dir:
# Training
if mode is None or mode=='train':
self._save_dir = Path(
os.path.join(get_project_root(), save_dir / 'models' / exper_name / str(run_id + details)))
self._log_dir = Path(
os.path.join(get_project_root(), save_dir / 'log' / exper_name / str(run_id + details)))
else:
self._save_dir=None
self._log_dir=None
# Testing
if mode is not None and mode=='test':
assert resume is not None, "checkpoint must be provided for testing"
assert 'valid' in self.config['data_loader']['test_dir'].lower() or \
'test' in self.config['data_loader']['test_dir'].lower(), "Path should link validation or test dir"
self._test_output_dir = Path(os.path.join(resume.parent, 'test_output')) if \
'test' in self.config['data_loader']['test_dir'].lower() \
else Path(os.path.join(resume.parent, 'valid_output'))
else:
self._test_output_dir = None # For training not needed
# make directory for saving checkpoints and log and test outputs (if needed).
exist_ok = run_id == ''
if self._save_dir is not None:
self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
if self._log_dir is not None:
self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
if self.test_output_dir is not None:
self.test_output_dir.mkdir(parents=True, exist_ok=True)
else:
self._save_dir = None
self._log_dir = None
self._test_output_dir = None
# save updated config file to the checkpoint dir
if self._save_dir is not None:
write_json(self.config, self.save_dir / 'config.json')
# if not self._use_tune:
# configure logging module if tuning is not active, else do it within the train method
setup_logging(self.log_dir)
self.log_levels = {
0: logging.WARNING,
1: logging.INFO,
2: logging.DEBUG
}
self._do_some_sanity_checks()
@classmethod
def from_args(cls, args, mode=None, options=''):
"""
Initialize this class from some cli arguments. Used in train, test.
"""
for opt in options:
args.add_argument(*opt.flags, default=None, type=opt.type)
if not isinstance(args, tuple):
args, unknown = args.parse_known_args()
if args.device is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
if args.resume is not None:
resume = Path(args.resume)
if args.tune:
cfg_path = resume.parent.parent / 'config.json'
else:
cfg_path = resume.parent / 'config.json'
else:
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
assert args.config is not None, msg_no_cfg
resume = None
cfg_path = Path(os.path.join(get_project_root(), args.config))
config = read_json(cfg_path)
if args.config and resume:
# update new config for fine-tuning
config.update(read_json(args.config))
if args.config and args.seed:
# Append the manual set seed to the config
config['SEED'] = args.seed
# parse custom cli options into dictionary
modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options}
return cls(config, resume, modification, mode, args.tune)
def init_obj(self, name, module, *args, **kwargs):
"""
Finds a function handle with the name given as 'type' in config, and returns the
instance initialized with corresponding arguments given.
`object = config.init_obj('name', module, a, b=1)`
is equivalent to
`object = module.name(a, b=1)`
"""
module_name = self[name]['type'] # e.g., MnistDataLoader
module_args = dict(self[name]['args']) # e.g., {'data_dir': 'data/', 'batch_size': 4, ..., 'seq_len': 72000}
# assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
# Gets a named attribute (here named "module_name") from an object (here from the given data-loader module)
return getattr(module, module_name)(*args, **module_args)
def init_ftn(self, name, module, *args, **kwargs):
"""
Finds a function handle with the name given as 'type' in config, and returns the
function with given arguments fixed with functools.partial.
`function = config.init_ftn('name', module, a, b=1)`
is equivalent to
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return partial(getattr(module, module_name), *args, **module_args)
def __getitem__(self, name):
"""Access items like ordinary dict."""
return self.config[name]
def get_logger(self, name, verbosity=2):
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
self.log_levels.keys())
assert verbosity in self.log_levels, msg_verbosity
logger = logging.getLogger(name)
logger.setLevel(self.log_levels[verbosity])
return logger
# setting read-only attributes
@property
def config(self):
return self._config
@property
def resume(self):
return self._resume
@resume.setter
def resume(self, value):
self._resume = value
@property
def save_dir(self):
return self._save_dir
@save_dir.setter
def save_dir(self, value):
self._save_dir = value
@property
def log_dir(self):
return self._log_dir
@log_dir.setter
def log_dir(self, value):
self._log_dir = value
@property
def test_output_dir(self):
return self._test_output_dir
@test_output_dir.setter
def test_output_dir(self, value):
self._test_output_dir = value
@property
def use_tune(self):
return self._use_tune
def _do_some_sanity_checks(self):
if self.config["loss"]["type"] == "BCE_with_logits" or self.config["loss"]["type"] == "balanced_BCE_with_logits"\
or self.config["loss"]["type"] == "focal_binary_cross_entropy_with_logits":
assert self.config["arch"]["args"]["multi_label_training"] \
and not self.config["arch"]["args"]["apply_final_activation"], \
"The used loss does not fit to the rest of the configuration"
elif self.config["loss"]["type"] == "BCE":
assert self.config["arch"]["args"]["multi_label_training"] \
and self.config["arch"]["args"]["apply_final_activation"],\
"The used loss does not fit to the rest of the configuration"
elif self.config["loss"]["type"] == "nll_loss":
assert not self.config["arch"]["args"]["multi_label_training"] \
and self.config["arch"]["args"]["apply_final_activation"],\
"The used loss does not fit to the rest of the configuration "
elif self.config["loss"]["type"] == "cross_entropy_loss" \
or self.config["loss"]["type"] == "balanced_cross_entropy":
assert not self.config["arch"]["args"]["multi_label_training"] \
and not self.config["arch"]["args"]["apply_final_activation"] ,\
"The used loss does not fit to the rest of the configuration "
# helper functions to update config dict with custom cli options
def _update_config(config, modification):
if modification is None:
return config
for k, v in modification.items():
if v is not None:
if isinstance(v, Path):
v = v.__str__()
_set_by_path(config, k, v)
return config
def _get_opt_name(flags):
for flg in flags:
if flg.startswith('--'):
return flg.replace('--', '')
return flags[0].replace('--', '')
def _set_by_path(tree, keys, value):
"""Set a value in a nested object in tree by sequence of keys."""
keys = keys.split(';')
_get_by_path(tree, keys[:-1])[keys[-1]] = value
def _get_by_path(tree, keys):
"""Access a nested object in tree by sequence of keys."""
return reduce(getitem, keys, tree)