apurvakokate
/
GraphNN-For-Learning-Dynamics-and-Generating-Policies-with-Explanations-using-Decision-Trees
Public
forked from josyulakrishna/GraphNN-For-Learning-Dynamics-and-Generating-Policies-with-Explanations-using-Decision-Trees
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
101 lines (79 loc) · 3.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import numpy as np
def load_graph_features(G, action, state, delta_state, bs = 1, norm = True, noise = 0.003, std = None):
pos = state[:, 5:5+18].view(-1,6,3)
if noise > 0:
pos_noise = torch.randn(pos.size()).cuda() * noise * std[:, :3]
else:
pos_noise = 0
# print("pos", type(pos),"pos_noise", type(pos_noise) )
pos += pos_noise
if noise > 0:
delta_state[:, 5:5+18] -= pos_noise.view(-1, 18)
joints = pos[:,1:,-1] - pos[:,:-1,-1]
joints[joints > np.pi] -= np.pi * 2
joints[joints < -np.pi] += np.pi * 2
if norm:
center_pos = torch.mean(pos[:,:,:2], dim = 1, keepdim = True)
pos[:,:,:2] -= center_pos
vel = state[:, 5+18:5+36].view(-1,6,3)
if noise > 0:
vel_noise = torch.randn(vel.size()).cuda() * noise * std[:, 3:]
else:
vel_noise = 0
vel += vel_noise
if noise > 0:
delta_state[:, 5+18:5+36] -= vel_noise.view(-1, 18)
for node in G.nodes():
#print(node)
G.nodes[node]['feat'][:,:3] = pos[:,node]
G.nodes[node]['feat'][:, 3:] = vel[:, node]
for edge in G.edges():
if edge[0] < edge[1]:
G[edge[0]][edge[1]]['feat'][:,0] = -1
else:
G[edge[0]][edge[1]]['feat'][:, 0] = 1
m = min(edge)
G[edge[0]][edge[1]]['feat'][:, 1] = joints[:,m]
G[edge[0]][edge[1]]['feat'][:, 2] = action[:,m]
return G
def build_graph_loss(G, state):
loss = 0
n_nodes = len(G)
pos = state[:, 5:5 + 18].view(-1, 6, 3)
pos[:,:,2] -= (pos[:,:,2] > np.pi).float() * np.pi * 2
pos[:, :, 2] += (pos[:, :, 2] < -np.pi).float() * np.pi * 2
vel = state[:, 5 + 18:5 + 36].view(-1, 6, 3)
for node in G.nodes():
loss += torch.mean((G.nodes[node]['feat'][:,:3] - pos[:,node]) ** 2)
loss += torch.mean((G.nodes[node]['feat'][:, 3:] - vel[:, node]) ** 2)
loss /= n_nodes
return loss
def build_graph_loss2(G, H):
loss = 0
n_nodes = len(G)
for node in G.nodes():
loss += torch.mean((G.nodes[node]['feat'][:,:3] - H.nodes[node]['feat'][:,:3]) ** 2)
loss += torch.mean((G.nodes[node]['feat'][:, 3:] - H.nodes[node]['feat'][:,3:]) ** 2)
loss /= n_nodes
return loss
def init_graph_features(G, graph_feat_size, node_feat_size, edge_feat_size, bs=1, cuda=False):
if cuda:
G.graph['feat'] = torch.zeros(bs, graph_feat_size).cuda()
for node in G.nodes():
G.nodes[node]['feat'] = torch.zeros(bs, node_feat_size).cuda()
for edge in G.edges():
G[edge[0]][edge[1]]['feat'] = torch.zeros(bs, edge_feat_size).cuda()
else:
G.graph['feat'] = torch.zeros(bs, graph_feat_size)
for node in G.nodes():
G.nodes[node]['feat'] = torch.zeros(bs, node_feat_size)
for edge in G.edges():
G[edge[0]][edge[1]]['feat'] = torch.zeros(bs, edge_feat_size)
def detach(G):
G.graph['feat'] = G.graph['feat'].detach()
for node in G.nodes():
G.nodes[node]['feat'] = G.nodes[node]['feat'].detach()
for edge in G.edges():
G[edge[0]][edge[1]]['feat'] = G[edge[0]][edge[1]]['feat'].detach()
return G