-
Notifications
You must be signed in to change notification settings - Fork 147
/
GATNE.py
157 lines (139 loc) · 5.73 KB
/
GATNE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import dgl
from . import register_model, BaseModel
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn import Parameter
import math
import dgl.function as fn
@register_model('GATNE-T')
class GATNE(BaseModel):
@classmethod
def build_model_from_args(cls, args, hg):
return cls(hg.num_nodes(), args.dim, args.edge_dim, hg.etypes, len(hg.etypes), args.att_dim)
def __init__(
self,
num_nodes,
embedding_size,
embedding_u_size,
edge_types,
edge_type_count,
att_dim,
):
super(GATNE, self).__init__()
self.num_nodes = num_nodes
self.embedding_size = embedding_size
self.embedding_u_size = embedding_u_size
self.edge_types = edge_types
self.edge_type_count = edge_type_count
self.att_dim = att_dim
self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.node_type_embeddings = Parameter(
torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
)
self.trans_weights = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
)
self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, att_dim)
)
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, att_dim, 1))
self.reset_parameters()
def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
# embs: [batch_size, embedding_size]
def forward(self, block):
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
batch_size = block.number_of_dst_nodes()
node_embed = self.node_embeddings
node_type_embed = []
with block.local_scope():
for i in range(self.edge_type_count):
edge_type = self.edge_types[i]
block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i]
block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i]
block.update_all(
fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type
)
node_type_embed.append(block.dstdata[edge_type])
node_type_embed = torch.stack(node_type_embed, 1)
tmp_node_type_embed = node_type_embed.unsqueeze(2).view(
-1, 1, self.embedding_u_size
)
trans_w = (
self.trans_weights.unsqueeze(0)
.repeat(batch_size, 1, 1, 1)
.view(-1, self.embedding_u_size, self.embedding_size)
)
trans_w_s1 = (
self.trans_weights_s1.unsqueeze(0)
.repeat(batch_size, 1, 1, 1)
.view(-1, self.embedding_u_size, self.att_dim)
)
trans_w_s2 = (
self.trans_weights_s2.unsqueeze(0)
.repeat(batch_size, 1, 1, 1)
.view(-1, self.att_dim, 1)
)
attention = (
F.softmax(
torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)),
trans_w_s2,
)
.squeeze(2)
.view(-1, self.edge_type_count),
dim=1,
)
.unsqueeze(1)
.repeat(1, self.edge_type_count, 1)
)
node_type_embed = torch.matmul(attention, node_type_embed).view(
-1, 1, self.embedding_u_size
)
node_embed = node_embed[output_nodes].unsqueeze(1).repeat(
1, self.edge_type_count, 1
) + torch.matmul(node_type_embed, trans_w).view(
-1, self.edge_type_count, self.embedding_size
)
last_node_embed = F.normalize(node_embed, dim=2)
return last_node_embed # [batch_size, edge_type_count, embedding_size]
class NSLoss(nn.Module):
def __init__(self, num_nodes, num_sampled, embedding_size):
super(NSLoss, self).__init__()
self.num_nodes = num_nodes
self.num_sampled = num_sampled
self.embedding_size = embedding_size
self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
# [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]
self.sample_weights = F.normalize(
torch.Tensor(
[
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
for k in range(num_nodes)
]
),
dim=0,
)
self.reset_parameters()
def reset_parameters(self):
self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
def forward(self, input, embs, label):
n = input.shape[0]
log_target = torch.log(
torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
)
negs = torch.multinomial(
self.sample_weights, self.num_sampled * n, replacement=True
).view(n, self.num_sampled)
noise = torch.neg(self.weights[negs])
sum_log_sampled = torch.sum(
torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
).squeeze()
loss = log_target + sum_log_sampled
return -loss.sum() / n