-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_stat_runs.py
executable file
·105 lines (91 loc) · 4.15 KB
/
evaluate_stat_runs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import torch
import time
import imageio
import numpy as np
from pathlib import Path
from torch.autograd import Variable
from utils.make_env import make_env
from algorithms.maddpg import MADDPG
import pickle as pkl
def run(config):
original_model_path = (Path('./models') / config.env_id / config.model_name /
('run%i' % config.run_num))
# if config.incremental is not None:
# model_path = model_path / 'incremental' / ('model_ep%i.pt' %
# config.incremental)
# else:
# model_path = model_path / 'model.pt'
if config.save_gifs:
gif_path = original_model_path.parent / 'gifs'
gif_path.mkdir(exist_ok=True)
# Model numbers in folder for stat runs
rrange = [1, 1001, 2001, 3001, 4001, 5001, 6001, 7001,
8001, 9001]
stat_run_all_models = []
for r in rrange:
print("Model :" + str(r))
model_path = original_model_path / 'incremental' / ('model_ep%i.pt' % r)
maddpg = MADDPG.init_from_save(model_path)
env = make_env(config.env_id, discrete_action=maddpg.discrete_action)
maddpg.prep_rollouts(device='cpu')
ifi = 1 / config.fps # inter-frame interval
stat_return_list = []
for ep_i in range(config.n_episodes):
print("Episode %i of %i" % (ep_i + 1, config.n_episodes))
obs = env.reset()
if config.save_gifs:
frames = []
frames.append(env.render('rgb_array')[0])
#env.render('human')
episode_reward = 0
for t_i in range(config.episode_length):
calc_start = time.time()
# rearrange observations to be per agent, and convert to torch Variable
torch_obs = [Variable(torch.Tensor(obs[i]).view(1, -1),
requires_grad=False)
for i in range(maddpg.nagents)]
# get actions as torch Variables
torch_actions = maddpg.step(torch_obs, explore=False)
# convert actions to numpy arrays
actions = [ac.data.numpy().flatten() for ac in torch_actions]
obs, rewards, dones, infos = env.step(actions)
# get the global reward
episode_reward += rewards[0][0]
if config.save_gifs:
frames.append(env.render('rgb_array')[0])
calc_end = time.time()
elapsed = calc_end - calc_start
if elapsed < ifi:
time.sleep(ifi - elapsed)
#env.render('human')
if config.save_gifs:
gif_num = 0
while (gif_path / ('%i_%i.gif' % (gif_num, ep_i))).exists():
gif_num += 1
imageio.mimsave(str(gif_path / ('%i_%i.gif' % (gif_num, ep_i))),
frames, duration=ifi)
# end of episodes (one-stat-run)
stat_return_list.append(episode_reward / config.episode_length)
# end of model
stat_run_all_models.append(stat_return_list)
env.close()
pickling_on = open(str(original_model_path)+"/stat_runs", "wb")
pkl.dump(stat_run_all_models, pickling_on)
pickling_on.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--env_id", default="simple_spread", help="Name of environment")
parser.add_argument("--model_name", default="Exp",
help="Name of model")
parser.add_argument("--run_num", default=68, type=int)
parser.add_argument("--save_gifs", action="store_true",
help="Saves gif of each episode into model directory")
parser.add_argument("--incremental", default=None, type=int,
help="Load incremental policy from given episode " +
"rather than final policy")
parser.add_argument("--n_episodes", default=15, type=int)
parser.add_argument("--episode_length", default=25, type=int)
parser.add_argument("--fps", default=30, type=int)
config = parser.parse_args()
run(config)