-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSRGAN.py
121 lines (96 loc) · 3.65 KB
/
SRGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
# Conv --> BN --> LeakyReLU/PReLU
def __init__(self,
in_channels,
out_channels,
use_bn=True,
use_act=True,
discrimnator=False,
**kwargs):
super().__init__()
self.use_act = use_act
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discrimnator
else nn.PReLU(num_parameters=out_channels)
)
def forward(self, x):
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
class UpSampleBlock(nn.Module):
def __init__(self, in_c, scaler_factor):
super().__init__()
self.conv = nn.Conv2d(in_c, in_c * scaler_factor ** 2, kernel_size=3, stride=1, padding=1)
self.ps = nn.PixelShuffle(scaler_factor)
self.act = nn.PReLU(num_parameters=in_c)
def forward(self, x):
return self.act(self.ps(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.block1 = ConvBlock(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
self.block2 = ConvBlock(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=False
)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
return out + x
class Generator(nn.Module):
def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
super().__init__()
self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
self.residual = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
self.upsample = nn.Sequential(UpSampleBlock(num_channels, 2), UpSampleBlock(num_channels, 2))
self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)
def forward(self, x):
initial = self.initial(x)
x = self.residual(initial)
x = self.convblock(x) + initial
x = self.upsample(x)
return torch.tanh(self.final(x))
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
super().__init__()
block = []
for idx, feature in enumerate(features):
block.append(ConvBlock(in_channels, feature, kernel_size=3, stride=1 + idx % 2, padding=1, discrimnator=True, use_act=True, use_bn=False if idx == 0 else True))
in_channels = feature
self.block = nn.Sequential(*block)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((6, 6)),
nn.Flatten(),
nn.Linear(512 * 6 * 6, 1024),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(1024, 1)
)
def forward(self, x):
x = self.block(x)
return self.classifier(x)
def test():
low_resolution = 24
with torch.cuda.amp.autocast():
x = torch.randn((5, 3, low_resolution, low_resolution))
gen = Generator()
gen_out = gen(x)
disc = Discrimnator()
disc_out = disc(gen_out)
print(disc_out.shape)
print(gen_out.shape)
if __name__ == '__main__':
test()