-
Notifications
You must be signed in to change notification settings - Fork 147
/
homo_GNN.py
76 lines (70 loc) · 2.88 KB
/
homo_GNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import dgl
from .HeteroMLP import HGNNPostMP, HGNNPreMP
from . import BaseModel, register_model
from ..layers import SkipConnection
stage_dict = {
'stack': SkipConnection.GNNStackStage,
'skipsum': SkipConnection.GNNSkipStage,
'skipconcat': SkipConnection.GNNSkipStage,
}
@register_model('homo_GNN')
class homo_GNN(BaseModel):
r"""
General homogeneous GNN model for HGNN
HeteroMLP + HomoGNN + HeteroMLP
"""
@classmethod
def build_model_from_args(cls, args, hg):
out_node_type = args.out_node_type
return cls(args, hg, out_node_type)
def __init__(self, args, hg, out_node_type, **kwargs):
super(homo_GNN, self).__init__()
self.out_node_type = out_node_type
if args.layers_pre_mp - 1 > 0:
self.pre_mp = HGNNPreMP(args, hg.ntypes, args.layers_pre_mp, args.hidden_dim, args.hidden_dim)
if args.layers_gnn > 0:
GNNStage = stage_dict[args.stage_type]
self.gnn = GNNStage(gnn_type=args.gnn_type,
stage_type=args.stage_type,
dim_in=args.hidden_dim,
dim_out=args.hidden_dim,
num_layers=args.layers_gnn,
skip_every=1,
dropout=args.dropout,
act=args.activation,
has_bn=args.has_bn,
num_heads=args.num_heads,
has_l2norm=args.has_l2norm,
num_etypes=len(hg.etypes),
num_ntypes=len(hg.ntypes))
gnn_out_dim = self.gnn.dim_out
self.post_mp = HGNNPostMP(args, self.out_node_type, args.layers_post_mp, gnn_out_dim, args.out_dim)
def forward(self, hg, h_dict):
with hg.local_scope():
if hasattr(self, 'pre_mp'):
h_dict = self.pre_mp(h_dict)
if len(hg.ntypes) == 1:
hg.ndata['h'] = h_dict[hg.ntypes[0]]
else:
hg.ndata['h'] = h_dict
homo_g = dgl.to_homogeneous(hg, ndata=['h'])
homo_g = dgl.remove_self_loop(homo_g)
homo_g = dgl.add_self_loop(homo_g)
h = homo_g.ndata.pop('h')
if hasattr(self, 'gnn'):
h = self.gnn(homo_g, h)
if len(hg.ntypes) == 1:
out_h = {hg.ntypes[0]: h}
else:
out_h = self.h2dict(h, hg.ndata['h'], self.out_node_type)
if hasattr(self, 'post_mp'):
out_h = self.post_mp(out_h)
return out_h
def h2dict(self, h, hdict, node_list):
pre = 0
out_h = {}
for i, value in hdict.items():
if i in node_list:
out_h[i] = h[pre:value.shape[0]+pre]
pre += value.shape[0]
return out_h