-
Notifications
You must be signed in to change notification settings - Fork 147
/
NBF.py
862 lines (677 loc) · 36.4 KB
/
NBF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
import copy
from collections.abc import Sequence
import torch
from torch import nn, autograd
from functools import reduce
from torch.nn import functional as F
from torch import Tensor
from typing import Any, Optional
from . import BaseModel, register_model
from typing import Tuple
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand(other.size())
return src
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return scatter_sum(src, index, dim, out, dim_size)
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode='floor')
return out
def scatter_min(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
def scatter_max(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
@register_model('NBF')
class NBFNet(BaseModel):
@classmethod
def build_model_from_args(cls, args, hg):
return cls(input_dim=args.input_dim,
hidden_dims=args.hidden_dims,
num_relation = args.num_relation,
message_func = args.message_func,
aggregate_func = args.aggregate_func,
short_cut = args.short_cut,
layer_norm = args.layer_norm,
dependent = args.dependent,
)
def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="pna",
short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, num_mlp_layer=2,
dependent=True, remove_one_hop=False, num_beam=10, path_topk=10):
super(NBFNet, self).__init__()
if not isinstance(hidden_dims, Sequence):
hidden_dims = [hidden_dims]
self.dims = [input_dim] + list(hidden_dims)
self.num_relation = num_relation
self.short_cut = short_cut # whether to use residual connections between GNN layers
self.concat_hidden = concat_hidden # whether to compute final states as a function of all layer outputs or last
self.remove_one_hop = remove_one_hop # whether to dynamically remove one-hop edges from edge_index
self.num_beam = num_beam
self.path_topk = path_topk
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation,
self.dims[0], message_func, aggregate_func, layer_norm,
activation, dependent))
feature_dim = (sum(hidden_dims) if concat_hidden else hidden_dims[-1]) + input_dim
# additional relation embedding which serves as an initial 'query' for the NBFNet forward pass
# each layer has its own learnable relations matrix, so we send the total number of relations, too
self.query = nn.Embedding(num_relation, input_dim)
self.mlp = nn.Sequential()
mlp = []
for i in range(num_mlp_layer - 1):
mlp.append(nn.Linear(feature_dim, feature_dim))
mlp.append(nn.ReLU())
mlp.append(nn.Linear(feature_dim, 1))
self.mlp = nn.Sequential(*mlp)
def remove_easy_edges(self, data, h_index, t_index, r_index=None):
#
# we remove training edges (we need to predict them at training time) from the edge index
# think of it as a dynamic edge dropout
h_index_ext = torch.cat([h_index, t_index], dim=-1)
t_index_ext = torch.cat([t_index, h_index], dim=-1)
r_index_ext = torch.cat([r_index, r_index + self.num_relation // 2], dim=-1)
if self.remove_one_hop:
# we remove all existing immediate edges between heads and tails in the batch
edge_index = data.edge_index
easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1)
index = edge_match(edge_index, easy_edge)[0]
mask = ~index_to_mask(index, data.num_edges)
else:
# we remove existing immediate edges between heads and tails in the batch with the given relation
edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)])
# note that here we add relation types r_index_ext to the matching query
easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1)
index = edge_match(edge_index, easy_edge)[0]
mask = ~index_to_mask(index, data.num_edges)
data = copy.copy(data)
data.edge_index = data.edge_index[:, mask]
data.edge_type = data.edge_type[mask]
return data
def negative_sample_to_tail(self, h_index, t_index, r_index):
# convert p(h | t, r) to p(t' | h', r')
# h' = t, r' = r^{-1}, t' = h
is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True)
new_h_index = torch.where(is_t_neg, h_index, t_index)
new_t_index = torch.where(is_t_neg, t_index, h_index)
new_r_index = torch.where(is_t_neg, r_index, r_index + self.num_relation // 2)
return new_h_index, new_t_index, new_r_index
def bellmanford(self, data, h_index, r_index, separate_grad=False):
batch_size = len(r_index)
query = self.query(r_index)
index = h_index.unsqueeze(-1).expand_as(query)
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=h_index.device)
hiddens = []
edge_weights = []
layer_input = boundary
for layer in self.layers:
if separate_grad:
edge_weight = edge_weight.clone().requires_grad_()
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
else:
output = torch.cat([hiddens[-1], node_query], dim=-1)
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, data, batch):# data == train_data
h_index, t_index, r_index = batch.unbind(-1)
if self.training:
# Edge dropout in the training mode
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
# to make NBFNet iteration learn non-trivial paths
data = self.remove_easy_edges(data, h_index, t_index, r_index)
data.num_edges = data.edge_index.shape[1]
shape = h_index.shape
# turn all triples in a batch into a tail prediction mode
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index)
assert (h_index[:, [0]] == h_index).all()
assert (r_index[:, [0]] == r_index).all()
# message passing and updated node representations
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
feature = output["node_feature"]
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
# extract representations of tail entities from the updated node states
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
# probability logit for each tail node in the batch
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
score = self.mlp(feature).squeeze(-1)
return score.view(shape)
def visualize(self, data, batch):
assert batch.shape == (1, 3)
h_index, t_index, r_index = batch.unbind(-1)
output = self.bellmanford(data, h_index, r_index, separate_grad=True)
feature = output["node_feature"]
edge_weights = output["edge_weights"]
index = t_index.unsqueeze(0).unsqueeze(-1).expand(-1, -1, feature.shape[-1])
feature = feature.gather(1, index).squeeze(0)
score = self.mlp(feature).squeeze(-1)
edge_grads = autograd.grad(score, edge_weights)
distances, back_edges = self.beam_search_distance(data, edge_grads, h_index, t_index, self.num_beam)
paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk)
return paths, weights
@torch.no_grad()
def beam_search_distance(self, data, edge_grads, h_index, t_index, num_beam=10):
# beam search the top-k distance from h to t (and to every other node)
num_nodes = data.num_nodes
input = torch.full((num_nodes, num_beam), float("-inf"), device=h_index.device)
input[h_index, 0] = 0
edge_mask = data.edge_index[0, :] != t_index
distances = []
back_edges = []
for edge_grad in edge_grads:
# we don't allow any path goes out of t once it arrives at t
node_in, node_out = data.edge_index[:, edge_mask]
relation = data.edge_type[edge_mask]
edge_grad = edge_grad[edge_mask]
message = input[node_in] + edge_grad.unsqueeze(-1) # (num_edges, num_beam)
# (num_edges, num_beam, 3)
msg_source = torch.stack([node_in, node_out, relation], dim=-1).unsqueeze(1).expand(-1, num_beam, -1)
# (num_edges, num_beam)
is_duplicate = torch.isclose(message.unsqueeze(-1), message.unsqueeze(-2)) & \
(msg_source.unsqueeze(-2) == msg_source.unsqueeze(-3)).all(dim=-1)
# pick the first occurrence as the ranking in the previous node's beam
# this makes deduplication easier later
# and store it in msg_source
is_duplicate = is_duplicate.float() - \
torch.arange(num_beam, dtype=torch.float, device=message.device) / (num_beam + 1)
prev_rank = is_duplicate.argmax(dim=-1, keepdim=True)
msg_source = torch.cat([msg_source, prev_rank], dim=-1) # (num_edges, num_beam, 4)
node_out, order = node_out.sort()
node_out_set = torch.unique(node_out)
# sort messages w.r.t. node_out
message = message[order].flatten() # (num_edges * num_beam)
msg_source = msg_source[order].flatten(0, -2) # (num_edges * num_beam, 4)
size = node_out.bincount(minlength=num_nodes)
msg2out = size_to_index(size[node_out_set] * num_beam)
# deduplicate messages that are from the same source and the same beam
is_duplicate = (msg_source[1:] == msg_source[:-1]).all(dim=-1)
is_duplicate = torch.cat([torch.zeros(1, dtype=torch.bool, device=message.device), is_duplicate])
message = message[~is_duplicate]
msg_source = msg_source[~is_duplicate]
msg2out = msg2out[~is_duplicate]
size = msg2out.bincount(minlength=len(node_out_set))
if not torch.isinf(message).all():
# take the topk messages from the neighborhood
# distance: (len(node_out_set) * num_beam)
distance, rel_index = scatter_topk(message, size, k=num_beam)
abs_index = rel_index + (size.cumsum(0) - size).unsqueeze(-1)
# store msg_source for backtracking
back_edge = msg_source[abs_index] # (len(node_out_set) * num_beam, 4)
distance = distance.view(len(node_out_set), num_beam)
back_edge = back_edge.view(len(node_out_set), num_beam, 4)
# scatter distance / back_edge back to all nodes
distance = scatter_sum(distance, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam)
back_edge = scatter_sum(back_edge, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam, 4)
else:
distance = torch.full((num_nodes, num_beam), float("-inf"), device=message.device)
back_edge = torch.zeros(num_nodes, num_beam, 4, dtype=torch.long, device=message.device)
distances.append(distance)
back_edges.append(back_edge)
input = distance
return distances, back_edges
def topk_average_length(self, distances, back_edges, t_index, k=10):
# backtrack distances and back_edges to generate the paths
paths = []
average_lengths = []
for i in range(len(distances)):
distance, order = distances[i][t_index].flatten(0, -1).sort(descending=True)
back_edge = back_edges[i][t_index].flatten(0, -2)[order]
for d, (h, t, r, prev_rank) in zip(distance[:k].tolist(), back_edge[:k].tolist()):
if d == float("-inf"):
break
path = [(h, t, r)]
for j in range(i - 1, -1, -1):
h, t, r, prev_rank = back_edges[j][h, prev_rank].tolist()
path.append((h, t, r))
paths.append(path[::-1])
average_lengths.append(d / len(path))
if paths:
average_lengths, paths = zip(*sorted(zip(average_lengths, paths), reverse=True)[:k])
return paths, average_lengths
def index_to_mask(index, size):
index = index.view(-1)
size = int(index.max()) + 1 if size is None else size
mask = index.new_zeros(size, dtype=torch.bool)
mask[index] = True
return mask
def size_to_index(size):
range = torch.arange(len(size), device=size.device)
index2sample = range.repeat_interleave(size)
return index2sample
def multi_slice_mask(starts, ends, length):
values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
slices = torch.cat([starts, ends])
mask = scatter_sum(values, slices, dim=0, dim_size=length + 1)[:-1]
mask = mask.cumsum(0).bool()
return mask
def scatter_extend(data, size, input, input_size):
new_size = size + input_size
new_cum_size = new_size.cumsum(0)
new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
starts = new_cum_size - new_size
ends = starts + size
index = multi_slice_mask(starts, ends, new_cum_size[-1])
new_data[index] = data
new_data[~index] = input
return new_data, new_size
def scatter_topk(input, size, k, largest=True):
index2graph = size_to_index(size)
index2graph = index2graph.view([-1] + [1] * (input.ndim - 1))
mask = ~torch.isinf(input)
max = input[mask].max().item()
min = input[mask].min().item()
safe_input = input.clamp(2 * min - max, 2 * max - min)
offset = (max - min) * 4
if largest:
offset = -offset
input_ext = safe_input + offset * index2graph
index_ext = input_ext.argsort(dim=0, descending=largest)
num_actual = size.clamp(max=k)
num_padding = k - num_actual
starts = size.cumsum(0) - size
ends = starts + num_actual
mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten()
if (num_padding > 0).any():
# special case: size < k, pad with the last valid index
padding = ends - 1
padding2graph = size_to_index(num_padding)
mask = scatter_extend(mask, num_actual, padding[padding2graph], num_padding)[0]
index = index_ext[mask] # (N * k, ...)
value = input.gather(0, index)
if isinstance(k, torch.Tensor) and k.shape == size.shape:
value = value.view(-1, *input.shape[1:])
index = index.view(-1, *input.shape[1:])
index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1))
else:
value = value.view(-1, k, *input.shape[1:])
index = index.view(-1, k, *input.shape[1:])
index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
return value, index
def edge_match(edge_index, query_index):
# O((n + q)logn) time
# O(n) memory
# edge_index: big underlying graph
# query_index: edges to match
# preparing unique hashing of edges, base: (max_node, max_relation) + 1
base = edge_index.max(dim=1)[0] + 1
# we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT
# idea: max number of edges = num_nodes * num_relations
# e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0
# given a tuple (h, r), we will search for all other existing edges starting from head h
assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max
scale = base.cumprod(0)
scale = scale[-1] // scale
# hash both the original edge index and the query index to unique integers
edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0)
edge_hash, order = edge_hash.sort()
query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0)
# matched ranges: [start[i], end[i])
start = torch.bucketize(query_hash, edge_hash)
end = torch.bucketize(query_hash, edge_hash, right=True)
# num_match shows how many edges satisfy the (h, r) pattern for each query in the batch
num_match = end - start
# generate the corresponding ranges
offset = num_match.cumsum(0) - num_match
range = torch.arange(num_match.sum(), device=edge_index.device)
range = range + (start - offset).repeat_interleave(num_match)
return order[range], num_match
def negative_sampling(data, batch, num_negative, strict=True):# data==train_data
batch_size = len(batch)
pos_h_index, pos_t_index, pos_r_index = batch.t()
# strict negative sampling vs random negative sampling
if strict:
t_mask, h_mask = strict_negative_mask(data, batch)
t_mask = t_mask[:batch_size // 2]
neg_t_candidate = t_mask.nonzero()[:, 1]
num_t_candidate = t_mask.sum(dim=-1)
# draw samples for negative tails
rand = torch.rand(len(t_mask), num_negative, device=batch.device)
index = (rand * num_t_candidate.unsqueeze(-1)).long()
index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1)
neg_t_index = neg_t_candidate[index]
h_mask = h_mask[batch_size // 2:]
neg_h_candidate = h_mask.nonzero()[:, 1]
num_h_candidate = h_mask.sum(dim=-1)
# draw samples for negative heads
rand = torch.rand(len(h_mask), num_negative, device=batch.device)
index = (rand * num_h_candidate.unsqueeze(-1)).long()
index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1)
neg_h_index = neg_h_candidate[index]
else:
neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device)
neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:]
h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1)
t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1)
r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1)
t_index[:batch_size // 2, 1:] = neg_t_index
h_index[batch_size // 2:, 1:] = neg_h_index
return torch.stack([h_index, t_index, r_index], dim=-1)
def all_negative(data, batch):
pos_h_index, pos_t_index, pos_r_index = batch.t()
r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes)
# generate all negative tails for this batch
all_index = torch.arange(data.num_nodes, device=batch.device)
h_index, t_index = torch.meshgrid(pos_h_index, all_index)
t_batch = torch.stack([h_index, t_index, r_index], dim=-1)
# generate all negative heads for this batch
all_index = torch.arange(data.num_nodes, device=batch.device)
t_index, h_index = torch.meshgrid(pos_t_index, all_index)
h_batch = torch.stack([h_index, t_index, r_index], dim=-1)
return t_batch, h_batch
def strict_negative_mask(data, batch):
# this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives
# similarly, for a given (t, r) we will NOT sample existing true heads as random negatives
pos_h_index, pos_t_index, pos_r_index = batch.t()
# part I: sample hard negative tails
# edge index of all (head, relation) edges from the underlying graph
edge_index = torch.stack([data.edge_index[0], data.edge_type])
# edge index of current batch (head, relation) for which we will sample negatives
query_index = torch.stack([pos_h_index, pos_r_index])
# search for all true tails for the given (h, r) batch
edge_id, num_t_truth = edge_match(edge_index, query_index)
# build an index from the found edges
t_truth_index = data.edge_index[1, edge_id]
sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth)
t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
# assign 0s to the mask with the found true tails
t_mask[sample_id, t_truth_index] = 0
t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0)
# part II: sample hard negative heads
# edge_index[1] denotes tails, so the edge index becomes (t, r)
edge_index = torch.stack([data.edge_index[1], data.edge_type])
# edge index of current batch (tail, relation) for which we will sample heads
query_index = torch.stack([pos_t_index, pos_r_index])
# search for all true heads for the given (t, r) batch
edge_id, num_h_truth = edge_match(edge_index, query_index)
# build an index from the found edges
h_truth_index = data.edge_index[0, edge_id]
sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth)
h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
# assign 0s to the mask with the found true heads
h_mask[sample_id, h_truth_index] = 0
h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0)
return t_mask, h_mask
def compute_ranking(pred, target, mask=None):
pos_pred = pred.gather(-1, target.unsqueeze(-1))
if mask is not None:
# filtered ranking
ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
else:
# unfiltered ranking
ranking = torch.sum(pos_pred <= pred, dim=-1) + 1
return ranking
def degree(index: Tensor, num_nodes: Optional[int] = None,
dtype: Optional[torch.dtype] = None) -> Tensor:
N = maybe_num_nodes(index, num_nodes)
out = torch.zeros((N, ), dtype=dtype, device=index.device)
one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
return out.scatter_add_(0, index, one)#Torch
def maybe_num_nodes(edge_index, num_nodes=None):
if num_nodes is not None:
return num_nodes
elif isinstance(edge_index, Tensor):
if is_torch_sparse_tensor(edge_index):
return max(edge_index.size(0), edge_index.size(1))
return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
else:
return max(edge_index.size(0), edge_index.size(1))
def is_torch_sparse_tensor(src:Any) -> bool:
if isinstance(src, Tensor):
if src.layout == torch.sparse_coo:
return True
if src.layout == torch.sparse_csr:
return True
if src.layout == torch.sparse_csc:
return True
return False
class GeneralizedRelationalConv(torch.nn.Module):
eps = 1e-6
message2mul = {
"transe": "add",
"distmult": "mul",
}
def __init__(self, input_dim, output_dim, num_relation, query_input_dim, message_func="distmult",
aggregate_func="pna", layer_norm=False, activation="relu", dependent=True):
super(GeneralizedRelationalConv, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_relation = num_relation
self.query_input_dim = query_input_dim
self.message_func = message_func
self.aggregate_func = aggregate_func
self.dependent = dependent
self.node_dim = -2
if layer_norm:
self.layer_norm = nn.LayerNorm(output_dim)
else:
self.layer_norm = None
if isinstance(activation, str):
self.activation = getattr(F, activation)
else:
self.activation = activation
if self.aggregate_func == "pna":
self.linear = nn.Linear(input_dim * 13, output_dim)
else:
self.linear = nn.Linear(input_dim * 2, output_dim)
if dependent:
# obtain relation embeddings as a projection of the query relation
self.relation_linear = nn.Linear(query_input_dim, num_relation * input_dim)
else:
# relation embeddings as an independent embedding matrix per each layer
self.relation = nn.Embedding(num_relation, input_dim)
def forward(self, input, query, boundary,
edge_index, edge_type,
size, edge_weight=None):
batch_size = len(query)
if self.dependent:
relation = self.relation_linear(query).view(batch_size, self.num_relation, self.input_dim)
else:# falese
relation = self.relation.weight.expand(batch_size, -1, -1)
if edge_weight is None:
edge_weight = torch.ones(len(edge_type), device=input.device)
# input.shape == 64,2746,32 , input_j .shape == [64,10692,32]
# input_j = torch.gather(input=input, dim=1, index=edge_index[0] )
input_j = input.index_select(1, edge_index[0])
message_res = self.message(input_j=input_j,relation=relation,boundary=boundary,edge_type=edge_type)
aggr_res = self.aggregate(input=message_res,edge_weight=edge_weight, index=edge_index[1],dim_size=input.shape[1])
return self.update(update=aggr_res,input=input)
def message(self, input_j, relation, boundary, edge_type):
relation_j = relation.index_select(self.node_dim, edge_type)#Torch
#input_j .shape == [64,10692,32]
if self.message_func == "transe":
message = input_j + relation_j
elif self.message_func == "distmult":
message = input_j * relation_j
elif self.message_func == "rotate":
x_j_re, x_j_im = input_j.chunk(2, dim=-1)#Torch
r_j_re, r_j_im = relation_j.chunk(2, dim=-1)
message_re = x_j_re * r_j_re - x_j_im * r_j_im
message_im = x_j_re * r_j_im + x_j_im * r_j_re
message = torch.cat([message_re, message_im], dim=-1)
else:
raise ValueError("Unknown message function `%s`" % self.message_func)
# augment messages with the boundary condition
message = torch.cat([message, boundary], dim=self.node_dim) # (num_edges + num_nodes, batch_size, input_dim)
return message
def aggregate(self, input, edge_weight, index, dim_size):
index = torch.cat([index, torch.arange(dim_size, device=input.device)]) # (num_edges + num_nodes == )
edge_weight = torch.cat([edge_weight, torch.ones(dim_size, device=input.device)])
shape = [1] * input.ndim
shape[self.node_dim] = -1
edge_weight = edge_weight.view(shape)
if self.aggregate_func == "pna":
mean = scatter_mean(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size)
sq_mean = scatter_mean(input ** 2 * edge_weight, index, dim=self.node_dim, dim_size=dim_size)
max = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="max")
min= scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="min")
std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()#Torch
features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
features = features.flatten(-2)
degree_out = degree(index, dim_size).unsqueeze(0).unsqueeze(-1)
scale = degree_out.log()
scale = scale / scale.mean()
scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1)
output = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2)
return output
def update(self, update, input):
output = self.linear(torch.cat([input, update], dim=-1))
if self.layer_norm:
output = self.layer_norm(output)
if self.activation:
output = self.activation(output)
return output
def message_and_aggregate(self, edge_index, input, relation, boundary, edge_type, edge_weight, index, dim_size):
# fused computation of message and aggregate steps with the custom rspmm cuda kernel
# speed up computation by several times
# reduce memory complexity from O(|E|d) to O(|V|d), so we can apply it to larger graphs
batch_size, num_node = input.shape[:2]
input = input.transpose(0, 1).flatten(1)
relation = relation.transpose(0, 1).flatten(1)
boundary = boundary.transpose(0, 1).flatten(1)
degree_out = degree(index, dim_size).unsqueeze(-1) + 1
if self.message_func in self.message2mul:# self.message_func == "distmult"
mul = self.message2mul[self.message_func] # mul == "mul"
else:
raise ValueError("Unknown message function `%s`" % self.message_func)
if self.aggregate_func == "sum":
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
update = update + boundary
elif self.aggregate_func == "mean":
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
update = (update + boundary) / degree_out
elif self.aggregate_func == "max":
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
update = torch.max(update, boundary)
############ msg_func = mul aggr_func = pna,
elif self.aggregate_func == "pna":
# we use PNA with 4 aggregators (mean / max / min / std)
# and 3 scalars (identity / log degree / reciprocal of log degree)
sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
sq_sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation ** 2, input ** 2, sum="add",
mul=mul)
max = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
min = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="min", mul=mul)
mean = (sum + boundary) / degree_out
sq_mean = (sq_sum + boundary ** 2) / degree_out
max = torch.max(max, boundary)
min = torch.min(min, boundary) # (node, batch_size * input_dim)
std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
features = features.flatten(-2) # (node, batch_size * input_dim * 4)
scale = degree_out.log()
scale = scale / scale.mean()
scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1) # (node, 3)
update = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2) # (node, batch_size * input_dim * 4 * 3)
else:
raise ValueError("Unknown aggregation function `%s`" % self.aggregate_func)
update = update.view(num_node, batch_size, -1).transpose(0, 1)
return update
def propagate(self, edge_index, size=None, **kwargs):
if kwargs["edge_weight"].requires_grad or self.message_func == "rotate":
# the rspmm cuda kernel only works for TransE and DistMult message functions
# otherwise we invoke separate message & aggregate functions
return super(GeneralizedRelationalConv, self).propagate(edge_index, size, **kwargs)
for hook in self._propagate_forward_pre_hooks.values():
res = hook(self, (edge_index, size, kwargs))
if res is not None:
edge_index, size, kwargs = res
size = self._check_input(edge_index, size)
coll_dict = self._collect(self._fused_user_args, edge_index,
size, kwargs)
msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
for hook in self._message_and_aggregate_forward_pre_hooks.values():
res = hook(self, (edge_index, msg_aggr_kwargs))
if res is not None:
edge_index, msg_aggr_kwargs = res
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
for hook in self._message_and_aggregate_forward_hooks.values():
res = hook(self, (edge_index, msg_aggr_kwargs), out)
if res is not None:
out = res
update_kwargs = self.inspector.distribute("update", coll_dict)
out = self.update(out, **update_kwargs)
for hook in self._propagate_forward_hooks.values():
res = hook(self, (edge_index, size, kwargs), out)
if res is not None:
out = res
return out