forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeneral_HGNN.py
95 lines (85 loc) · 3.71 KB
/
general_HGNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import dgl
from ..layers import SkipConnection
from openhgnn.models import BaseModel, register_model
from ..models.HeteroMLP import HGNNPostMP, HGNNPreMP
stage_dict = {
'stack': SkipConnection.HGNNStackStage,
'skipsum': SkipConnection.HGNNSkipStage,
'skipconcat': SkipConnection.HGNNSkipStage,
}
def HG_transformation(hg, metapaths_dict):
graph_data = {}
for key, mp in metapaths_dict.items():
mp_g = dgl.metapath_reachable_graph(hg, mp)
n_edge = mp_g.canonical_etypes[0]
graph_data[(n_edge[0], key, n_edge[2])] = mp_g.edges()
return dgl.heterograph(graph_data)
@register_model('general_HGNN')
class general_HGNN(BaseModel):
"""
General heterogeneous GNN model
"""
@classmethod
def build_model_from_args(cls, args, hg):
out_node_type = args.out_node_type
# args.subgraph_extraction = 'metapath'
if args.subgraph_extraction == 'relation':
new_hg = hg
print('relation extraction!')
elif args.subgraph_extraction == 'metapath':
if hasattr(args, 'meta_paths_dict'):
new_hg = HG_transformation(hg, args.meta_paths_dict)
print('metapath extraction!')
else:
raise ValueError('No meta-path is specified!')
elif args.subgraph_extraction == 'mixed':
relation_dict = args.meta_paths_dict
for etype in hg.canonical_etypes:
relation_dict[etype[1]] = [etype]
new_hg = HG_transformation(hg, relation_dict)
print('mixed extraction!')
pass
else:
raise ValueError('subgraph_extraction only supports relation, metapath and mixed')
return cls(args, new_hg, out_node_type)
def __init__(self, args, hg, out_node_type, **kwargs):
"""
"""
super(general_HGNN, self).__init__()
self.hg = hg
self.out_node_type = out_node_type
# the first linear is operated in outside of model (in trainerflow)
if args.layers_pre_mp - 1 > 0:
self.pre_mp = HGNNPreMP(args, self.hg.ntypes, args.layers_pre_mp, args.hidden_dim, args.hidden_dim)
if args.layers_gnn > 0:
HGNNStage = stage_dict[args.stage_type]
self.hgnn = HGNNStage(gnn_type=args.gnn_type,
rel_names=self.hg.etypes,
stage_type=args.stage_type,
dim_in=args.hidden_dim,
dim_out=args.hidden_dim,
num_layers=args.layers_gnn,
skip_every=1,
dropout=args.dropout,
act=args.activation,
has_bn=args.has_bn,
has_l2norm=args.has_l2norm,
num_heads=args.num_heads,
macro_func=args.macro_func)
gnn_out_dim = self.hgnn.dim_out
self.post_mp = HGNNPostMP(args, self.out_node_type, args.layers_post_mp, gnn_out_dim, args.out_dim)
def forward(self, hg, h_dict):
with hg.local_scope():
hg = self.hg
h_dict = {key: value for key, value in h_dict.items() if key in hg.ntypes}
if hasattr(self, 'pre_mp'):
h_dict = self.pre_mp(h_dict)
if hasattr(self, 'hgnn'):
h_dict = self.hgnn(hg, h_dict)
if hasattr(self, 'post_mp'):
out_h = {}
for key, value in h_dict.items():
if key in self.out_node_type:
out_h[key] = value
out_h = self.post_mp(out_h)
return out_h