diff --git a/parse_config.py b/parse_config.py index 309f153e..7cbc5f57 100644 --- a/parse_config.py +++ b/parse_config.py @@ -69,8 +69,9 @@ def from_args(cls, args, options=''): cfg_fname = Path(args.config) config = read_json(cfg_fname) - if args.config and resume: + if args.config and resume and args.fine_tune: # update new config for fine-tuning + # if not fine tune, the saved config will remain as it is config.update(read_json(args.config)) # parse custom cli options into dictionary diff --git a/test.py b/test.py index fc084fac..c639e0ea 100644 --- a/test.py +++ b/test.py @@ -72,6 +72,8 @@ def main(config): args = argparse.ArgumentParser(description='PyTorch Template') args.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') + args.add_argument('-ft', '--fine-tune', default=False, type=bool, + help='fine tune pretrained model or not. If True, the saved config will be overridden by outer config') args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') args.add_argument('-d', '--device', default=None, type=str, diff --git a/train.py b/train.py index a43f6c47..26012e88 100644 --- a/train.py +++ b/train.py @@ -58,6 +58,8 @@ def main(config): args = argparse.ArgumentParser(description='PyTorch Template') args.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') + args.add_argument('-ft', '--fine-tune', default=False, type=bool, + help='fine tune pretrained model or not. If True, the saved config will be overridden by outer config') args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') args.add_argument('-d', '--device', default=None, type=str,