forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHeteroMLP.py
70 lines (58 loc) · 2.03 KB
/
HeteroMLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from ..layers.HeteroLinear import HeteroMLPLayer
from ..layers.GeneralGNNLayer import MultiLinearLayer
def HGNNPreMP(args, node_types, num_pre_mp, in_dim, hidden_dim):
"""
HGNNPreMP, dimension is in_dim, hidden_dim, hidden_dim ...
Note:
Final layer has activation.
Parameters
----------
args
node_types
num_pre_mp
in_dim
hidden_dim
Returns
-------
"""
if num_pre_mp > 0:
linear_dict = {}
for ntype in node_types:
linear_dict[ntype] = [in_dim]
for _ in range(num_pre_mp):
linear_dict[ntype].append(hidden_dim)
return HeteroMLPLayer(linear_dict, act=args.activation, dropout=args.dropout,
has_l2norm=args.has_l2norm, has_bn=args.has_bn, final_act=True)
def HGNNPostMP(args, node_types, num_post_mp, hidden_dim, out_dim):
"""
HGNNPostMLP, hidden_dim, hidden_dim, ..., out_dim
Final layer has no activation.
Parameters
----------
args
node_types
num_post_mp
hidden_dim
out_dim
Returns
-------
"""
if num_post_mp > 0:
linear_dict = {}
for ntype in node_types:
linear_dict[ntype] = [hidden_dim]
for _ in range(num_post_mp-1):
linear_dict[ntype].append(hidden_dim)
linear_dict[ntype].append(out_dim)
return HeteroMLPLayer(linear_dict, act=args.activation, dropout=args.dropout,
has_l2norm=args.has_l2norm, has_bn=args.has_bn, final_act=False)
# def GNNPreMP(args, in_dim, hidden_dim):
# linear_list = [in_dim] + args.layers_pre_mp * [hidden_dim]
# return MultiLinearLayer(linear_list, dropout=args.dropout, act=args.activation, has_bn=args.has_bn,
# has_l2norm=args.has_l2norm)
#
#
# def GNNPostMP(args, hidden_dim, out_dim):
# linear_list = args.layers_pre_mp * [hidden_dim] + [out_dim]
# return MultiLinearLayer(linear_list, dropout=args.dropout, act=args.activation, has_bn=args.has_bn,
# has_l2norm=args.has_l2norm)