-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathDFCN.py
71 lines (58 loc) · 2.49 KB
/
DFCN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from AE import AE
from IGAE import IGAE
class DFCN(nn.Module):
def __init__(self, ae_n_enc_1, ae_n_enc_2, ae_n_enc_3,
ae_n_dec_1, ae_n_dec_2, ae_n_dec_3,
gae_n_enc_1, gae_n_enc_2, gae_n_enc_3,
gae_n_dec_1, gae_n_dec_2, gae_n_dec_3,
n_input, n_z, n_clusters, v=1.0, n_node=None, device=None):
super(DFCN, self).__init__()
self.ae = AE(
ae_n_enc_1=ae_n_enc_1,
ae_n_enc_2=ae_n_enc_2,
ae_n_enc_3=ae_n_enc_3,
ae_n_dec_1=ae_n_dec_1,
ae_n_dec_2=ae_n_dec_2,
ae_n_dec_3=ae_n_dec_3,
n_input=n_input,
n_z=n_z)
self.gae = IGAE(
gae_n_enc_1=gae_n_enc_1,
gae_n_enc_2=gae_n_enc_2,
gae_n_enc_3=gae_n_enc_3,
gae_n_dec_1=gae_n_dec_1,
gae_n_dec_2=gae_n_dec_2,
gae_n_dec_3=gae_n_dec_3,
n_input=n_input)
self.a = nn.Parameter(nn.init.constant_(torch.zeros(n_node, n_z), 0.5), requires_grad=True).to(device)
self.b = 1 - self.a
self.cluster_layer = nn.Parameter(torch.Tensor(n_clusters, n_z), requires_grad=True)
torch.nn.init.xavier_normal_(self.cluster_layer.data)
self.v = v
self.gamma = Parameter(torch.zeros(1))
def forward(self, x, adj):
z_ae = self.ae.encoder(x)
z_igae, z_igae_adj = self.gae.encoder(x, adj)
z_i = self.a * z_ae + self.b * z_igae
z_l = torch.spmm(adj, z_i)
s = torch.mm(z_l, z_l.t())
s = F.softmax(s, dim=1)
z_g = torch.mm(s, z_l)
z_tilde = self.gamma * z_g + z_l
x_hat = self.ae.decoder(z_tilde)
z_hat, z_hat_adj = self.gae.decoder(z_tilde, adj)
adj_hat = z_igae_adj + z_hat_adj
q = 1.0 / (1.0 + torch.sum(torch.pow((z_tilde).unsqueeze(1) - self.cluster_layer, 2), 2) / self.v)
q = q.pow((self.v + 1.0) / 2.0)
q = (q.t() / torch.sum(q, 1)).t()
q1 = 1.0 / (1.0 + torch.sum(torch.pow(z_ae.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v)
q1 = q1.pow((self.v + 1.0) / 2.0)
q1 = (q1.t() / torch.sum(q1, 1)).t()
q2 = 1.0 / (1.0 + torch.sum(torch.pow(z_igae.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v)
q2 = q2.pow((self.v + 1.0) / 2.0)
q2 = (q2.t() / torch.sum(q2, 1)).t()
return x_hat, z_hat, adj_hat, z_ae, z_igae, q, q1, q2, z_tilde