-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathrun_exp.py
49 lines (41 loc) · 1.16 KB
/
run_exp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import sys
import torch
import logging
import traceback
import numpy as np
from pprint import pprint
from runner import *
from utils.logger import setup_logging
from utils.arg_helper import parse_arguments, get_config
torch.set_printoptions(profile='full')
def main():
args = parse_arguments()
config = get_config(args.config_file)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
config.use_gpu = config.use_gpu and torch.cuda.is_available()
# log info
log_file = os.path.join(config.save_dir,
"log_exp_{}.txt".format(config.run_id))
logger = setup_logging(args.log_level, log_file)
logger.info("Writing log file to {}".format(log_file))
logger.info("Exp instance id = {}".format(config.run_id))
logger.info("Exp comment = {}".format(args.comment))
logger.info("Config =")
print(">" * 80)
pprint(config)
print("<" * 80)
# Run the experiment
try:
runner = eval(config.runner)(config)
if not args.test:
runner.train()
else:
runner.test()
except:
logger.error(traceback.format_exc())
sys.exit(0)
if __name__ == "__main__":
main()