-
Notifications
You must be signed in to change notification settings - Fork 67
/
eval.py
157 lines (118 loc) · 4.4 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import torch
import utils
import os
import pickle
from torch.utils import data
import numpy as np
from collections import defaultdict
import modules
torch.backends.cudnn.deterministic = True
parser = argparse.ArgumentParser()
parser.add_argument('--save-folder', type=str,
default='checkpoints',
help='Path to checkpoints.')
parser.add_argument('--num-steps', type=int, default=1,
help='Number of prediction steps to evaluate.')
parser.add_argument('--dataset', type=str,
default='data/shapes_eval.h5',
help='Dataset string.')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='Disable CUDA training.')
args_eval = parser.parse_args()
meta_file = os.path.join(args_eval.save_folder, 'metadata.pkl')
model_file = os.path.join(args_eval.save_folder, 'model.pt')
args = pickle.load(open(meta_file, 'rb'))['args']
args.cuda = not args_eval.no_cuda and torch.cuda.is_available()
args.batch_size = 100
args.dataset = args_eval.dataset
args.seed = 0
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
device = torch.device('cuda' if args.cuda else 'cpu')
dataset = utils.PathDataset(
hdf5_file=args.dataset, path_length=args_eval.num_steps)
eval_loader = data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# Get data sample
obs = eval_loader.__iter__().next()[0]
input_shape = obs[0][0].size()
model = modules.ContrastiveSWM(
embedding_dim=args.embedding_dim,
hidden_dim=args.hidden_dim,
action_dim=args.action_dim,
input_dims=input_shape,
num_objects=args.num_objects,
sigma=args.sigma,
hinge=args.hinge,
ignore_action=args.ignore_action,
copy_action=args.copy_action,
encoder=args.encoder).to(device)
model.load_state_dict(torch.load(model_file))
model.eval()
# topk = [1, 5, 10]
topk = [1]
hits_at = defaultdict(int)
num_samples = 0
rr_sum = 0
pred_states = []
next_states = []
with torch.no_grad():
for batch_idx, data_batch in enumerate(eval_loader):
data_batch = [[t.to(
device) for t in tensor] for tensor in data_batch]
observations, actions = data_batch
if observations[0].size(0) != args.batch_size:
continue
obs = observations[0]
next_obs = observations[-1]
state = model.obj_encoder(model.obj_extractor(obs))
next_state = model.obj_encoder(model.obj_extractor(next_obs))
pred_state = state
for i in range(args_eval.num_steps):
pred_trans = model.transition_model(pred_state, actions[i])
pred_state = pred_state + pred_trans
pred_states.append(pred_state.cpu())
next_states.append(next_state.cpu())
pred_state_cat = torch.cat(pred_states, dim=0)
next_state_cat = torch.cat(next_states, dim=0)
full_size = pred_state_cat.size(0)
# Flatten object/feature dimensions
next_state_flat = next_state_cat.view(full_size, -1)
pred_state_flat = pred_state_cat.view(full_size, -1)
dist_matrix = utils.pairwise_distance_matrix(
next_state_flat, pred_state_flat)
dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
dist_matrix_augmented = torch.cat(
[dist_matrix_diag, dist_matrix], dim=1)
# Workaround to get a stable sort in numpy.
dist_np = dist_matrix_augmented.numpy()
indices = []
for row in dist_np:
keys = (np.arange(len(row)), row)
indices.append(np.lexsort(keys))
indices = np.stack(indices, axis=0)
indices = torch.from_numpy(indices).long()
print('Processed {} batches of size {}'.format(
batch_idx + 1, args.batch_size))
labels = torch.zeros(
indices.size(0), device=indices.device,
dtype=torch.int64).unsqueeze(-1)
num_samples += full_size
print('Size of current topk evaluation batch: {}'.format(
full_size))
for k in topk:
match = indices[:, :k] == labels
num_matches = match.sum()
hits_at[k] += num_matches.item()
match = indices == labels
_, ranks = match.max(1)
reciprocal_ranks = torch.reciprocal(ranks.double() + 1)
rr_sum += reciprocal_ranks.sum()
pred_states = []
next_states = []
for k in topk:
print('Hits @ {}: {}'.format(k, hits_at[k] / float(num_samples)))
print('MRR: {}'.format(rr_sum / float(num_samples)))