forked from naoto0804/pytorch-inpainting-with-partial-conv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net.py
214 lines (172 loc) · 7.84 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find(
'Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
nn.init.normal_(m.weight, 0.0, 0.02)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
return init_fun
class VGG16FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
vgg16 = models.vgg16(pretrained=True)
self.enc_1 = nn.Sequential(*vgg16.features[:5])
self.enc_2 = nn.Sequential(*vgg16.features[5:10])
self.enc_3 = nn.Sequential(*vgg16.features[10:17])
# fix the encoder
for i in range(3):
for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
param.requires_grad = False
def forward(self, image):
results = [image]
for i in range(3):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]
class PartialConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super().__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, False)
self.input_conv.apply(weights_init('kaiming'))
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
# mask is not updated
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, input, mask):
# http://masc.cs.gmu.edu/wiki/partialconv
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
# W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)
output = self.input_conv(input * mask)
if self.input_conv.bias is not None:
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask)
no_update_holes = output_mask == 0
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = (output - output_bias) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
class PCBActiv(nn.Module):
def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu',
conv_bias=False):
super().__init__()
if sample == 'down-5':
self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias)
elif sample == 'down-7':
self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias)
elif sample == 'down-3':
self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias)
else:
self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias)
if bn:
self.bn = nn.BatchNorm2d(out_ch)
if activ == 'relu':
self.activation = nn.ReLU()
elif activ == 'leaky':
self.activation = nn.LeakyReLU(negative_slope=0.2)
def forward(self, input, input_mask):
h, h_mask = self.conv(input, input_mask)
if hasattr(self, 'bn'):
h = self.bn(h)
if hasattr(self, 'activation'):
h = self.activation(h)
return h, h_mask
class PConvUNet(nn.Module):
def __init__(self, layer_size=7, input_channels=3, upsampling_mode='nearest'):
super().__init__()
self.freeze_enc_bn = False
self.upsampling_mode = upsampling_mode
self.layer_size = layer_size
self.enc_1 = PCBActiv(input_channels, 64, bn=False, sample='down-7')
self.enc_2 = PCBActiv(64, 128, sample='down-5')
self.enc_3 = PCBActiv(128, 256, sample='down-5')
self.enc_4 = PCBActiv(256, 512, sample='down-3')
for i in range(4, self.layer_size):
name = 'enc_{:d}'.format(i + 1)
setattr(self, name, PCBActiv(512, 512, sample='down-3'))
for i in range(4, self.layer_size):
name = 'dec_{:d}'.format(i + 1)
setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky'))
self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1 = PCBActiv(64 + input_channels, input_channels,
bn=False, activ=None, conv_bias=True)
def forward(self, input, input_mask):
h_dict = {} # for the output of enc_N
h_mask_dict = {} # for the output of enc_N
h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask
h_key_prev = 'h_0'
for i in range(1, self.layer_size + 1):
l_key = 'enc_{:d}'.format(i)
h_key = 'h_{:d}'.format(i)
h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
h_dict[h_key_prev], h_mask_dict[h_key_prev])
h_key_prev = h_key
h_key = 'h_{:d}'.format(self.layer_size)
h, h_mask = h_dict[h_key], h_mask_dict[h_key]
# concat upsampled output of h_enc_N-1 and dec_N+1, then do dec_N
# (exception)
# input dec_2 dec_1
# h_enc_7 h_enc_8 dec_8
for i in range(self.layer_size, 0, -1):
enc_h_key = 'h_{:d}'.format(i - 1)
dec_l_key = 'dec_{:d}'.format(i)
h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)
h_mask = F.interpolate(
h_mask, scale_factor=2, mode='nearest')
h = torch.cat([h, h_dict[enc_h_key]], dim=1)
h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
h, h_mask = getattr(self, dec_l_key)(h, h_mask)
return h, h_mask
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super().train(mode)
if self.freeze_enc_bn:
for name, module in self.named_modules():
if isinstance(module, nn.BatchNorm2d) and 'enc' in name:
module.eval()
if __name__ == '__main__':
size = (1, 3, 5, 5)
input = torch.ones(size)
input_mask = torch.ones(size)
input_mask[:, :, 2:, :][:, :, :, 2:] = 0
conv = PartialConv(3, 3, 3, 1, 1)
l1 = nn.L1Loss()
input.requires_grad = True
output, output_mask = conv(input, input_mask)
loss = l1(output, torch.randn(1, 3, 5, 5))
loss.backward()
assert (torch.sum(input.grad != input.grad).item() == 0)
assert (torch.sum(torch.isnan(conv.input_conv.weight.grad)).item() == 0)
assert (torch.sum(torch.isnan(conv.input_conv.bias.grad)).item() == 0)
# model = PConvUNet()
# output, output_mask = model(input, input_mask)