forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RSHN.py
204 lines (170 loc) · 7.08 KB
/
RSHN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# The implementation of ICDM 2019 paper "Relation Structure-Aware Heterogeneous Graph Neural Network" RSHN.
# @Time : 2021/3/1
# @Author : Tianyu Zhao
# @Email : [email protected]
import dgl
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import function as fn
from dgl.utils import expand_as_pair
from dgl.nn.functional import edge_softmax
from . import BaseModel, register_model
from ..sampler.RSHN_sampler import coarsened_line_graph
@register_model('RSHN')
class RSHN(BaseModel):
r"""
Relation structure-aware heterogeneous graph neural network (RSHN) builds coarsened line graph to obtain edge features first,
then uses a novel Message Passing Neural Network (MPNN) to propagate node and edge features.
We implement a API build a coarsened line graph.
Attributes
-----------
edge_layers : AGNNConv
Applied in Edge Layer.
coarsened line graph : dgl.DGLGraph
Propagate edge features.
"""
@classmethod
def build_model_from_args(cls, args, hg):
rshn = cls(dim=args.hidden_dim,
out_dim=args.out_dim,
num_node_layer=args.num_node_layer,
num_edge_layer=args.num_edge_layer,
dropout=args.dropout
)
cl = coarsened_line_graph(rw_len=args.rw_len, batch_size=args.batch_size, n_dataset=args.dataset,
symmetric=True)
cl_graph = cl.get_cl_graph(hg).to(args.device)
cl_graph = cl.init_cl_graph(cl_graph)
rshn.cl_graph = cl_graph
linear_e1 = nn.Linear(in_features=cl_graph.num_nodes(), out_features=args.hidden_dim, bias=False)
nn.init.xavier_uniform_(linear_e1.weight)
rshn.linear_e1 = linear_e1
return rshn
def __init__(self, dim, out_dim, num_node_layer, num_edge_layer, dropout):
super(RSHN, self).__init__()
# map the edge feature
self.num_node_layer = num_node_layer
self.edge_layers = nn.ModuleList()
for i in range(num_edge_layer):
self.edge_layers.append(AGNNConv())
self.node_layers = nn.ModuleList()
for i in range(num_node_layer):
self.node_layers.append(GraphConv(in_feats=dim, out_feats=dim, dropout=dropout, activation=th.tanh))
self.linear = nn.Linear(in_features=dim, out_features=out_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.init_para()
def init_para(self):
return
def forward(self, hg, n_feats, *args, **kwargs):
r"""
First, apply edge_layer in cl_graph to get edge embedding.
Then, propagate node and edge features through GraphConv.
"""
# For full graph training, directly use the graph
# Forward of n layers of CompGraphConv
h = self.cl_graph.ndata['h']
h_e = self.cl_graph.edata['w']
for layer in self.edge_layers:
h = th.relu(layer(self.cl_graph, h, h_e))
h = self.dropout(h)
h = self.linear_e1(h)
edge_weight = {}
for i, e in enumerate(hg.canonical_etypes):
edge_weight[e] = h[i].expand(hg.num_edges(e), -1)
if hasattr(hg, 'ntypes'):
# edge_weight = F.embedding(hg.edata[dgl.ETYPE].long(), h)
# full graph training
for layer in self.node_layers:
n_feats = layer(hg, n_feats, edge_weight)
else:
# minibatch training
pass
for n in n_feats:
#n_feats[n] = self.dropout(self.linear(n_feats[n]))
n_feats[n] = self.linear(n_feats[n])
return n_feats
class AGNNConv(nn.Module):
def __init__(self,
eps=0.,
train_eps=False,
learn_beta=True):
super(AGNNConv, self).__init__()
self.initial_eps = eps
if learn_beta:
self.beta = nn.Parameter(th.Tensor(1))
else:
self.register_buffer('beta', th.Tensor(1))
self.learn_beta = learn_beta
if train_eps:
self.eps = th.nn.Parameter(th.ones([eps]))
else:
self.register_buffer('eps', th.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
self.eps.data.fill_(self.initial_eps)
if self.learn_beta:
self.beta.data.fill_(1)
def forward(self, graph, feat, edge_weight):
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1)
e = self.beta * edge_weight
#graph.edata['p'] = e
graph.edata['p'] = edge_softmax(graph, e, norm_by='src')
graph.update_all(fn.u_mul_e('norm_h', 'p', 'm'), fn.sum('m', 'h'))
rst = graph.dstdata.pop('h')
rst = (1 + self.eps) * feat + rst
return rst
class GraphConv(nn.Module):
def __init__(self,
in_feats,
out_feats, dropout,
activation=None,
):
super(GraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self.weight1 = nn.Parameter(th.Tensor(in_feats, out_feats))
#self.weight2 = nn.Parameter(th.Tensor(in_feats, out_feats))
self.reset_parameters()
self.dropout = nn.Dropout(dropout)
self.activation = activation
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1)
#nn.init.xavier_uniform_(self.weight2)
def forward(self, hg, feat, edge_weight=None):
with hg.local_scope():
outputs = {}
norm = {}
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
#assert edge_weight.shape[0] == graph.number_of_edges()
hg.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
for e in hg.canonical_etypes:
if e[0] == e[1]:
hg = dgl.remove_self_loop(hg, etype=e)
feat_src, feat_dst = expand_as_pair(feat, hg)
# aggregate first then mult W
hg.srcdata['h'] = feat_src
for e in hg.canonical_etypes:
stype, etype, dtype = e
sub_graph = hg[stype, etype, dtype]
sub_graph.update_all(aggregate_fn, fn.sum(msg='m', out='out'))
temp = hg.ndata['out'].pop(dtype)
degs = sub_graph.in_degrees().float().clamp(min=1)
if isinstance(temp, dict):
temp = temp[dtype]
if outputs.get(dtype) is None:
outputs[dtype] = temp
norm[dtype] = degs
else:
outputs[dtype].add_(temp)
norm[dtype].add_(degs)
def _apply(ntype, h, norm):
h = th.matmul(h+feat[ntype], self.weight1)
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype: _apply(ntype, h, norm) for ntype, h in outputs.items()}