-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
127 lines (104 loc) · 4.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import warnings
import torch
# point robot
from config.pointrobot import \
args_pointrobot_sparse_hyperx, args_pointrobot_sparse_varibad
# multi-step gridworld
from config.rooms import \
args_room_hyperx, args_room_varibad, args_room_varibad_x_state
# sparse ant-goal
from config.sparse_ant_goal import \
args_sparse_ant_goal_rl2, \
args_sparse_ant_goal_humplik, args_sparse_ant_goal_hyperx, args_sparse_ant_goal_varibad
# sparse cheetah-dir environments
from config.sparse_cheetah_dir import \
cds_belief_oracle, cds_varibad, cds_hyperx, cds_rl2, cds_humplik
# mountain treasure
from config.treasure_hunt import \
args_treasure_varibad, args_treasure_hyperx, \
args_treasure_varibad_x_state, args_treasure_rl2, args_treasure_humplik
from learner import Learner
from metalearner import MetaLearner
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env-type', default='room_hyperx')
args, rest_args = parser.parse_known_args()
env = args.env_type
# --- Mountain Treasure ---
if env == 'treasure_hunt_varibad':
args = args_treasure_varibad.get_args(rest_args)
elif env == 'treasure_hunt_hyperx':
args = args_treasure_hyperx.get_args(rest_args)
elif env == 'treasure_hunt_varibad_x_state':
args = args_treasure_varibad_x_state.get_args(rest_args)
elif env == 'treasure_hunt_rl2':
args = args_treasure_rl2.get_args(rest_args)
elif env == 'treasure_hunt_humplik':
args = args_treasure_humplik.get_args(rest_args)
# -- Multi-Stage GridWorld --
elif env == 'room_varibad':
args = args_room_varibad.get_args(rest_args)
elif env == 'room_varibad_x_state':
args = args_room_varibad_x_state.get_args(rest_args)
elif env == 'room_hyperx':
args = args_room_hyperx.get_args(rest_args)
# --- Sparse MUJOCO Half Cheetah ---
elif env == 'cds_belief_oracle':
args = cds_belief_oracle.get_args(rest_args)
elif env == 'cds_varibad':
args = cds_varibad.get_args(rest_args)
elif env == 'cds_hyperx':
args = cds_hyperx.get_args(rest_args)
elif env == 'cds_rl2':
args = cds_rl2.get_args(rest_args)
elif env == 'cds_humplik':
args = cds_humplik.get_args(rest_args)
# --- Sparse MUJOCO Ant Goal ---
elif env == 'sparse_ant_goal_rl2':
args = args_sparse_ant_goal_rl2.get_args(rest_args)
elif env == 'sparse_ant_goal_varibad':
args = args_sparse_ant_goal_varibad.get_args(rest_args)
elif env == 'sparse_ant_goal_humplik':
args = args_sparse_ant_goal_humplik.get_args(rest_args)
elif env == 'sparse_ant_goal_hyperx':
args = args_sparse_ant_goal_hyperx.get_args(rest_args)
# --- Sparse Point Robot ---
elif env == 'pointrobot_sparse_varibad':
args = args_pointrobot_sparse_varibad.get_args(rest_args)
elif env == 'pointrobot_sparse_hyperx':
args = args_pointrobot_sparse_hyperx.get_args(rest_args)
else:
raise NotImplementedError
# warning for deterministic execution
if args.deterministic_execution:
print('Envoking deterministic code execution.')
if torch.backends.cudnn.enabled:
warnings.warn('Running with deterministic CUDNN.')
if args.num_processes > 1:
raise RuntimeError('If you want fully deterministic code, run it with num_processes=1.'
'Warning: This will slow things down.')
# check if we're adding an exploration bonus
args.add_exploration_bonus = args.exploration_bonus_hyperstate or \
args.exploration_bonus_state or \
args.exploration_bonus_belief or \
args.exploration_bonus_vae_error
# clean up arguments
if hasattr(args, 'disable_decoder') and args.disable_decoder:
args.decode_reward = False
args.decode_state = False
args.decode_task = False
# loop through all passed seeds
seed_list = [args.seed] if isinstance(args.seed, int) else args.seed
for seed in seed_list:
args.seed = seed
args.action_space = None
# start training
if args.disable_metalearner:
learner = Learner(args)
else:
learner = MetaLearner(args)
learner.train()
if __name__ == '__main__':
main()