-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdater.py
149 lines (117 loc) · 5.91 KB
/
updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
import six
import os
from chainer import cuda, optimizers, serializers, Variable
from chainer import training
from PIL import Image
import chainer.computational_graph as c
def calc_loss_perceptual(hout_dict,hcomp_dict,hgt_dict):
layers = list(hout_dict.keys())
layer_name = layers[0]
loss = F.mean_absolute_error(hout_dict[layer_name],hgt_dict[layer_name])
loss += F.mean_absolute_error(hcomp_dict[layer_name],hgt_dict[layer_name])
for layer_name in layers[1:]:
loss += F.mean_absolute_error(hout_dict[layer_name],hgt_dict[layer_name])
loss += F.mean_absolute_error(hcomp_dict[layer_name],hgt_dict[layer_name])
return loss
def vgg_extract(vgg_model, Img, layers=['pool1','pool2','pool3'],in_size=224):
B,C,H,W = Img.shape #BGR [0,1] range
Img = (Img + 1)*127.5
Img_chanells = [F.expand_dims(Img[:,i,:,:],axis=1) for i in range(3)]
Img_chanells[0] -= 103.939 #subtracted by [103.939, 116.779, 123.68]
Img_chanells[1] -= 116.779 #subtracted by [103.939, 116.779, 123.68]
Img_chanells[2] -= 123.68 #subtracted by [103.939, 116.779, 123.68]
Img = F.concat(Img_chanells,axis=1)
limx = H - in_size
limy = W - in_size
xs = np.random.randint(0,limx,B)
ys = np.random.randint(0,limy,B)
lis = [F.expand_dims(Img[i,:,x:x+in_size,y:y+in_size],axis=0) for i,(x,y) in enumerate(zip(xs,ys))]
Img_cropped = F.concat(lis,axis=0)
return vgg_model(Img_cropped,layers=layers)
def calc_loss_style(hout_dict,hcomp_dict,hgt_dict):
layers = hgt_dict.keys()
for i,layer_name in enumerate(layers):
B,C,H,W = hout_dict[layer_name].shape
hout = F.reshape(hout_dict[layer_name],(B,C,H*W))
hcomp = F.reshape(hcomp_dict[layer_name],(B,C,H*W))
hgt = F.reshape(hgt_dict[layer_name],(B,C,H*W))
hout_gram = F.batch_matmul(hout,hout,transb=True)
hcomp_gram = F.batch_matmul(hcomp,hcomp,transb=True)
hgt_gram = F.batch_matmul(hgt,hgt,transb=True)
if i==0:
L_style_out = F.mean_absolute_error(hout_gram,hgt_gram)/(C*H*W)
L_style_comp = F.mean_absolute_error(hcomp_gram,hgt_gram)/(C*H*W)
else:
L_style_out += F.mean_absolute_error(hout_gram,hgt_gram)/(C*H*W)
L_style_comp += F.mean_absolute_error(hcomp_gram,hgt_gram)/(C*H*W)
return L_style_out + L_style_comp
def calc_loss_tv(Icomp, mask, xp=np):
canvas = mask.data
canvas[:,:,:,:-1] += mask.data[:,:,:,1:] #mask left overlap
canvas[:,:,:,1:] += mask.data[:,:,:,:-1] #mask right overlap
canvas[:,:,:-1,:] += mask.data[:,:,1:,:] #mask up overlap
canvas[:,:,1:,:] += mask.data[:,:,:-1,:] #mask bottom overlap
P = Variable((xp.sign(canvas-0.5)+1.0)*0.5) #P region (hole mask: 1 pixel dilated region from hole)
return F.mean_absolute_error(P[:,:,:,1:]*Icomp[:,:,:,1:],P[:,:,:,:-1]*Icomp[:,:,:,:-1])+ F.mean_absolute_error(P[:,:,1:,:]*Icomp[:,:,1:,:],P[:,:,:-1,:]*Icomp[:,:,:-1,:])
def imgcrop_batch(img,pos_list,size=128):
B,ch,H,W = img.shape
lis = [F.expand_dims(img[i,:,x:x+size,y:y+size],axis=0) for i,(x,y) in enumerate(pos_list)]
return F.concat(lis,axis=0)
class Updater(chainer.training.StandardUpdater):
def __init__(self, *args, **kwargs):
self.vgg, self.model = kwargs.pop('models')
params = kwargs.pop('params')
self._lambda1 = params['lambda1']
self._lambda2 = params['lambda2']
self._lambda3 = params['lambda3']
self._lambda4 = params['lambda4']
self._image_size = params['image_size']
self._eval_foler = params['eval_folder']
self._dataset = params['dataset']
self._iter = 0
xp = self.model.xp
super(Updater, self).__init__(*args, **kwargs)
def update_core(self):
xp = self.model.xp
self._iter += 1
batch = self.get_iterator('main').next() #img_processed (B,4,H,W), origin (B,3,H,W), mask (B,1,H,W)
batchsize = len(batch)
w_in = self._image_size
zero_f = Variable(xp.zeros((batchsize, 3, w_in, w_in)).astype("f"))
x_train = np.zeros((batchsize, 3, w_in, w_in)).astype("f")
mask_train = np.zeros((batchsize, 3, w_in, w_in)).astype("f")
for i in range(batchsize):
x_train[i, :] = batch[i][0] #original image
mask_train[i, :] = batch[i][1] #0-1 mask of c
x_train = xp.array(x_train)
mask_train = xp.array(mask_train)
mask_b = xp.array(mask_train.astype("bool"))
I_gt = Variable(x_train)
M = Variable(mask_train)
M_b = Variable(mask_b)
I_out = self.model(I_gt,M)
I_comp = F.where(M_b,I_gt,I_out) #if an element of Mc_b is True, return the corresponded element of I_gt, otherwise return that of I_out)
fs_I_gt = vgg_extract(self.vgg,I_gt) #feature dict
fs_I_out = vgg_extract(self.vgg,I_out) #feature dict
fs_I_comp = vgg_extract(self.vgg,I_comp) #feature dict
opt_model = self.get_optimizer('model')
L_valid = F.mean_absolute_error(M*I_out,M*I_gt)
L_hole = F.mean_absolute_error((1-M)*I_out,(1-M)*I_gt)
L_perceptual = calc_loss_perceptual(fs_I_gt,fs_I_out,fs_I_comp)
L_style = calc_loss_style(fs_I_out,fs_I_comp,fs_I_gt) #Loss style out and comp
L_tv = calc_loss_tv(I_comp, M, xp=xp)
L_total = L_valid + self._lambda1 * L_hole + self._lambda2 * L_perceptual + \
self._lambda3 * L_style + self._lambda4 * L_tv
self.vgg.cleargrads()
self.model.cleargrads()
L_total.backward()
opt_model.update()
chainer.report({'L_valid': L_valid})
chainer.report({'L_hole': L_hole})
chainer.report({'L_perceptual': L_perceptual})
chainer.report({'L_style': L_style})
chainer.report({'L_tv': L_tv})