-
Notifications
You must be signed in to change notification settings - Fork 1
/
inter_contra_model.py
80 lines (64 loc) · 2.32 KB
/
inter_contra_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch.nn as nn
import torch
import torch.nn.functional as F
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def reparameter(mu,sigma):
return (torch.randn_like(mu) *sigma) + mu
class Embedding_Net(nn.Module):
def __init__(self, opt):
super(Embedding_Net, self).__init__()
self.fc1 = nn.Linear(opt.resSize, opt.embedSize)
self.fc2 = nn.Linear(opt.embedSize, opt.outzSize)
self.lrelu = nn.LeakyReLU(0.2, True)
self.relu = nn.ReLU(True)
self.apply(weights_init)
def forward(self, features):
embedding= self.relu(self.fc1(features))
out_z = F.normalize(self.fc2(embedding), dim=1)
return embedding,out_z
class MLP_G(nn.Module):
def __init__(self, opt):
super(MLP_G, self).__init__()
self.fc1 = nn.Linear(opt.attSize + opt.nz, opt.ngh)
self.fc2 = nn.Linear(opt.ngh, opt.resSize)
self.lrelu = nn.LeakyReLU(0.2, True)
#self.prelu = nn.PReLU()
self.relu = nn.ReLU(True)
self.apply(weights_init)
def forward(self, noise, att):
h = torch.cat((noise, att), 1)
h = self.lrelu(self.fc1(h))
h = self.relu(self.fc2(h))
return h
class MLP_CRITIC(nn.Module):
def __init__(self, opt):
super(MLP_CRITIC, self).__init__()
self.fc1 = nn.Linear(opt.resSize + opt.attSize, opt.ndh)
#self.fc2 = nn.Linear(opt.ndh, opt.ndh)
self.fc2 = nn.Linear(opt.ndh, 1)
self.lrelu = nn.LeakyReLU(0.2, True)
self.apply(weights_init)
def forward(self, x, att):
h = torch.cat((x, att), 1)
h = self.lrelu(self.fc1(h))
h = self.fc2(h)
return h
class Dis_Embed_Att(nn.Module):
def __init__(self, opt):
super(Dis_Embed_Att, self).__init__()
self.fc1 = nn.Linear(opt.embedSize+opt.attSize, opt.nhF)
#self.fc2 = nn.Linear(opt.ndh, opt.ndh)
self.fc2 = nn.Linear(opt.nhF, 1)
self.lrelu = nn.LeakyReLU(0.2, True)
self.apply(weights_init)
def forward(self, input):
h = self.lrelu(self.fc1(input))
h = self.fc2(h)
return h