-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
50 lines (42 loc) · 1.36 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import os
import ray
from time import time
from runner_for_test import TestRunner
from config import config
from env import Env
cfg = config()
model_path = cfg.model_path
device = cfg.device
decode_type = 'greedy'
test_size = 500
def test():
average_max_length = 0
average_mean_length = 0
average_time = 0
sum_time = 0
runner = TestRunner(metaAgentID=0, cfg=cfg, decode_type=decode_type)
checkpoint = torch.load(model_path + '/model_states.pth')
runner.model.load_state_dict(checkpoint['model'])
for i in range(test_size):
print(i)
env = Env(cfg, seed=i)
t1 = time()
with torch.no_grad():
max_length = runner.sample(env)
t2 = time()
max_length = max_length.item()
# mean_length = mean_length.item()
t = t2 - t1
average_max_length = (max_length + average_max_length * i) / (i + 1)
#average_mean_length = (mean_length + average_mean_length * i) / (i + 1)
average_time = (t + average_time * i) / (i + 1)
sum_time += t
print('average_max_length', average_max_length)
print('average_time', average_time)
print('average_max_length', average_max_length)
#print('average_mean_length', average_mean_length)
print('average_time', average_time)
print('sum_time', sum_time)
if __name__ == '__main__':
test()