-
Notifications
You must be signed in to change notification settings - Fork 147
/
HDE.py
65 lines (55 loc) · 2.3 KB
/
HDE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import BaseModel, register_model
class GNN(nn.Module):
"""
Aggregate 2-hop neighbor.
"""
def __init__(self, input_dim, output_dim, num_neighbor, use_bias=True):
super(GNN, self).__init__()
self.input_dim = int(input_dim)
self.num_fea = int(input_dim)
self.output_dim = int(output_dim)
self.num_neighbor = num_neighbor
self.use_bias = use_bias
self.linear1 = nn.Linear(self.input_dim * 2, 64)
self.linear2 = nn.Linear(64+self.num_fea, 64)
self.linear3 = nn.Linear(64, self.output_dim)
def forward(self, fea):
node = fea[:, :self.num_fea]
neigh1 = fea[:, self.num_fea:self.num_fea * (self.num_neighbor + 1)]
neigh1 = torch.reshape(neigh1, [-1, self.num_neighbor, self.num_fea])
neigh2 = fea[:, self.num_fea * (self.num_neighbor + 1):]
neigh2 = torch.reshape(neigh2, [-1, self.num_neighbor, self.num_neighbor, self.num_fea])
neigh2_agg = torch.mean(neigh2, dim=2)
tmp = torch.cat([neigh1, neigh2_agg], dim=2)
tmp = F.relu(self.linear1(tmp))
emb = torch.cat([node, torch.mean(tmp, dim=1)], dim=1)
emb = F.relu(self.linear2(emb))
emb = F.relu(self.linear3(emb))
return emb
@register_model('HDE')
class HDE(BaseModel):
def __init__(self, input_dim, output_dim, num_neighbor, use_bias=True):
super(HDE, self).__init__()
self.input_dim = int(input_dim)
self.output_dim = int(output_dim)
self.num_neighbor = num_neighbor
self.use_bias = use_bias
self.aggregator = GNN(input_dim=input_dim, output_dim=output_dim, num_neighbor=num_neighbor)
self.linear1 = nn.Linear(2*self.output_dim, 32)
self.linear2 = nn.Linear(32, 2)
def forward(self, fea_a, fea_b):
emb_a = self.aggregator(fea_a)
emb_b = self.aggregator(fea_b)
emb = torch.cat([emb_a, emb_b], dim=1)
emb = F.relu(self.linear1(emb))
output = self.linear2(emb)
return output
@classmethod
def build_model_from_args(cls, args, hg):
return cls(input_dim=args.input_dim,
output_dim=args.output_dim,
num_neighbor=args.num_neighbor,
use_bias=args.use_bias)