forked from RaphaelMeudec/deblur-gan
-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
121 lines (96 loc) · 3.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Dense, Flatten, Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from layer_utils import ReflectionPadding2D, res_block
from subpixel import SubpixelConv2D
from keras.layers.convolutional import UpSampling2D
# the paper defined hyper-parameter:chr
channel_rate = 64
# Note the image_shape must be multiple of patch_shape
image_shape = (256, 256, 3)
patch_shape = (channel_rate, channel_rate, 3)
ngf = 64
ndf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
input_shape_discriminator = (256, 256, output_nc)
n_blocks_gen = 9
def generator_model():
"""Build generator architecture."""
# Current version : ResNet block
inputs = Input(shape=image_shape)
x = ReflectionPadding2D((3, 3))(inputs)
x = Conv2D(filters=ngf, kernel_size=(7, 7), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
x = Conv2D(filters=ngf*mult*2, kernel_size=(3, 3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
mult = 2**n_downsampling
for i in range(n_blocks_gen):
x = res_block(x, ngf*mult, use_dropout=True)
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
# x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3, 3), strides=2, padding='same')(x)
x = UpSampling2D()(x)
x = Conv2D(filters=int(ngf * mult / 2),kernel_size=(3,3),padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# x = ReflectionPadding2D((3, 3))(x)
x = Conv2D(filters=output_nc, kernel_size=(9, 9), padding='same')(x)
x = Activation('tanh')(x)
outputs = Add()([x, inputs])
# outputs = Lambda(lambda z: K.clip(z, -1, 1))(x)
outputs = Lambda(lambda z: z/2)(outputs)
model = Model(inputs=inputs, outputs=outputs, name='Generator')
return model
def discriminator_model():
"""Build discriminator architecture."""
n_layers, use_sigmoid = 3, False
inputs = Input(shape=input_shape_discriminator)
x = Conv2D(filters=ndf, kernel_size=(4, 4), strides=2, padding='same')(inputs)
x = LeakyReLU(0.2)(x)
nf_mult, nf_mult_prev = 1, 1
for n in range(n_layers):
nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4, 4), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4, 4), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same')(x)
if use_sigmoid:
x = Activation('sigmoid')(x)
x = Flatten()(x)
x = Dense(1024, activation='tanh')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=x, name='Discriminator')
return model
def generator_containing_discriminator(generator, discriminator):
inputs = Input(shape=image_shape)
generated_image = generator(inputs)
outputs = discriminator(generated_image)
model = Model(inputs=inputs, outputs=outputs)
return model
def generator_containing_discriminator_multiple_outputs(generator, discriminator):
inputs = Input(shape=image_shape)
generated_image = generator(inputs)
outputs = discriminator(generated_image)
model = Model(inputs=inputs, outputs=[generated_image, outputs])
return model
if __name__ == '__main__':
g = generator_model()
g.summary()
d = discriminator_model()
d.summary()
m = generator_containing_discriminator(generator_model(), discriminator_model())
m.summary()