-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathRedGNNT.py
109 lines (80 loc) · 3.73 KB
/
RedGNNT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
from ..utils.utils import scatter
from . import BaseModel, register_model
import dgl
# from dgl import
from scipy.sparse import csr_matrix
import numpy as np
@register_model('RedGNNT')
class RedGNNT(BaseModel):
@classmethod
def build_model_from_args(cls, args, loader):
return cls(args, loader)
def __init__(self, args, loader):
super(RedGNNT, self).__init__()
self.device = args.device
self.hidden_dim = args.hidden_dim
self.attn_dim = args.attn_dim
self.n_layer = args.n_layer
self.loader = loader
self.n_rel = self.loader.n_rel
acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x}
act = acts[args.act]
self.act = act
self.gnn_layers = []
for i in range(self.n_layer):
self.gnn_layers.append(RedGNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act))
self.gnn_layers = nn.ModuleList(self.gnn_layers)
self.dropout = nn.Dropout(args.dropout)
self.W_final = nn.Linear(self.hidden_dim, 1, bias=False) # get score
self.gate = nn.GRU(self.hidden_dim, self.hidden_dim)
def forward(self, subs, rels, mode='train'): # source node, rels
n = len(subs)
q_sub = torch.LongTensor(subs).to(self.device)
q_rel = torch.LongTensor(rels).to(self.device)
h0 = torch.zeros((1, n, self.hidden_dim)).to(self.device) # 1 * n * d
nodes = torch.cat([torch.arange(n).unsqueeze(1).to(self.device), q_sub.unsqueeze(1)], 1)
hidden = torch.zeros(n, self.hidden_dim).to(self.device)
for i in range(self.n_layer):
nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode)
edges = edges.to(self.device)
old_nodes_new_idx = old_nodes_new_idx.to(self.device)
hidden = self.gnn_layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx)
h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).to(self.device).index_copy_(1, old_nodes_new_idx, h0)
hidden = self.dropout(hidden)
hidden, h0 = self.gate(hidden.unsqueeze(0), h0)
hidden = hidden.squeeze(0)
scores = self.W_final(hidden).squeeze(-1)
scores_all = torch.zeros((n, self.loader.n_ent)).to(self.device)
scores_all[[nodes[:, 0], nodes[:,1]]] = scores
return scores_all
class RedGNNLayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x):
super(RedGNNLayer, self).__init__()
self.n_rel = n_rel
self.in_dim = in_dim
self.out_dim = out_dim
self.attn_dim = attn_dim
self.act = act
self.rela_embed = nn.Embedding(2*n_rel+1, in_dim)
self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wqr_attn = nn.Linear(in_dim, attn_dim)
self.w_alpha = nn.Linear(attn_dim, 1)
self.W_h = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx):
# edges: [batch_idx, head, rela, tail, old_idx, new_idx]
sub = edges[:,4]
rel = edges[:,2]
obj = edges[:,5]
hs = hidden[sub]
hr = self.rela_embed(rel)
r_idx = edges[:,0]
h_qr = self.rela_embed(q_rel)[r_idx]
message = hs + hr
alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))))
message = alpha * message
message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
hidden_new = self.act(self.W_h(message_agg))
return hidden_new