forked from chaovven/PyRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·100 lines (79 loc) · 3.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import datetime
import os
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
import sys
import torch as th
from run import run
import yaml
import gym
from gym.spaces import Box, Discrete
from envs import REGISTRY as ENV_REGISTRY
ex = Experiment("pyrl")
@ex.main
def my_main(_run, _config, _log):
# Setting the random seed throughout the modules
config = deepcopy(_config)
np.random.seed(config["seed"])
th.manual_seed(config["seed"])
if config['env'] in ENV_REGISTRY.keys(): # customize your own environment here, stored in ./envs
env = ENV_REGISTRY[config['env']]()
else:
env = gym.make(config['env'])
env.seed(config["seed"])
# add info about env
config['state_dim'] = env.observation_space.shape[0]
config['ep_limit'] = env._max_episode_steps
if isinstance(env.action_space, Box):
config['discrete'] = False
config['action_dim'] = env.action_space.shape[0]
config['max_action'] = env.action_space.high
config['min_action'] = env.action_space.low
elif isinstance(env.action_space, Discrete):
config['discrete'] = True
config['action_dim'] = env.action_space.n
# fields that appear in the event filename
critic_lr = '' if config['learner'] in ['dqn'] else '__clr={}'.format(config["critic_lr"])
unique_token = '{}-{}_{}_lr={}{}_{}'.format(config["env"],
config["learner"],
config["name"],
config["lr"],
critic_lr,
config["time"])
config['tb_path'] = os.path.join(dirname(abspath(__file__)), "results", config['env'], unique_token)
run(_run, config, _log, env)
def _get_config(params, arg_name, subfolder):
config_name = None
for _i, _v in enumerate(params):
if _v.split("=")[0] == arg_name:
config_name = _v.split("=")[1]
del params[_i]
break
if config_name is not None:
with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)),
"r") as f:
try:
config_dict = yaml.load(f)
except yaml.YAMLError as exc:
assert False, "{}.yaml error: {}".format(config_name, exc)
return config_dict
if __name__ == '__main__':
params = deepcopy(sys.argv)
# load default.yaml
with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
try:
config_dict = yaml.load(f)
except yaml.YAMLError as exc:
assert False, "default.yaml error: {}".format(exc)
# Load algorithm configs
alg_config = _get_config(params, "--alg", 'algs')
config_dict.update(alg_config)
config_dict['time'] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ex.observers.append(
FileStorageObserver(os.path.join(dirname(abspath(__file__)), "results/sacred", config_dict['time'])))
# add all the config to sacred
ex.add_config(config_dict)
ex.run_commandline(params)