-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathIGAE.py
94 lines (77 loc) · 3.79 KB
/
IGAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import opt
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
class GNNLayer(Module):
def __init__(self, in_features, out_features):
super(GNNLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
if opt.args.name == "dblp" or opt.args.name == "hhar":
self.act = nn.Tanh()
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
elif opt.args.name == "usps" or opt.args.name == "acm" or opt.args.name == "cite":
self.act = nn.Tanh()
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
elif opt.args.name == "reut":
self.act = nn.LeakyReLU(0.2, inplace=True)
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
torch.nn.init.xavier_uniform_(self.weight)
def forward(self, features, adj, active=False):
if active:
if opt.args.name == "dblp" or opt.args.name == "hhar":
support = self.act(F.linear(features, self.weight)) # add bias
elif opt.args.name == "usps" or opt.args.name == "acm" or opt.args.name == "cite" or opt.args.name == "reut":
support = self.act(torch.mm(features, self.weight))
else:
if opt.args.name == "dblp" or opt.args.name == "hhar":
support = F.linear(features, self.weight) # add bias
elif opt.args.name == "usps" or opt.args.name == "acm" or opt.args.name == "cite" or opt.args.name == "reut":
support = torch.mm(features, self.weight)
output = torch.spmm(adj, support)
return output
class IGAE_encoder(nn.Module):
def __init__(self, gae_n_enc_1, gae_n_enc_2, gae_n_enc_3, n_input):
super(IGAE_encoder, self).__init__()
self.gnn_1 = GNNLayer(n_input, gae_n_enc_1)
self.gnn_2 = GNNLayer(gae_n_enc_1, gae_n_enc_2)
self.gnn_3 = GNNLayer(gae_n_enc_2, gae_n_enc_3)
self.s = nn.Sigmoid()
def forward(self, x, adj):
z = self.gnn_1(x, adj, active=False if opt.args.name == "hhar" else True)
z = self.gnn_2(z, adj, active=False if opt.args.name == "hhar" else True)
z_igae = self.gnn_3(z, adj, active=False)
z_igae_adj = self.s(torch.mm(z_igae, z_igae.t()))
return z_igae, z_igae_adj
class IGAE_decoder(nn.Module):
def __init__(self, gae_n_dec_1, gae_n_dec_2, gae_n_dec_3, n_input):
super(IGAE_decoder, self).__init__()
self.gnn_4 = GNNLayer(gae_n_dec_1, gae_n_dec_2)
self.gnn_5 = GNNLayer(gae_n_dec_2, gae_n_dec_3)
self.gnn_6 = GNNLayer(gae_n_dec_3, n_input)
self.s = nn.Sigmoid()
def forward(self, z_igae, adj):
z = self.gnn_4(z_igae, adj, active=False if opt.args.name == "hhar" else True)
z = self.gnn_5(z, adj, active=False if opt.args.name == "hhar" else True)
z_hat = self.gnn_6(z, adj, active=False if opt.args.name == "hhar" else True)
z_hat_adj = self.s(torch.mm(z_hat, z_hat.t()))
return z_hat, z_hat_adj
class IGAE(nn.Module):
def __init__(self, gae_n_enc_1, gae_n_enc_2, gae_n_enc_3, gae_n_dec_1, gae_n_dec_2, gae_n_dec_3, n_input):
super(IGAE, self).__init__()
self.encoder = IGAE_encoder(
gae_n_enc_1=gae_n_enc_1,
gae_n_enc_2=gae_n_enc_2,
gae_n_enc_3=gae_n_enc_3,
n_input=n_input)
self.decoder = IGAE_decoder(
gae_n_dec_1=gae_n_dec_1,
gae_n_dec_2=gae_n_dec_2,
gae_n_dec_3=gae_n_dec_3,
n_input=n_input)
def forward(self, x, adj):
z_igae, z_igae_adj = self.encoder(x, adj)
z_hat, z_hat_adj = self.decoder(z_igae, adj)
adj_hat = z_igae_adj + z_hat_adj
return z_igae, z_hat, adj_hat