forked from XueweiMeng/derain_filter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_dehaze.py
executable file
·37 lines (29 loc) · 1.26 KB
/
model_dehaze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import functools
import tensorflow as tf
conv2d = functools.partial(tf.layers.conv2d, padding='same',
kernel_initializer=tf.keras.initializers.he_normal())
class Model(object):
def __init__(self, channel=24, depth=7):
super(Model, self).__init__()
self.channel = channel
self.depth = depth
def forward(self, O):
with tf.variable_scope('can'):
x = conv2d(O, self.channel, 3, activation=tf.nn.leaky_relu,
name='enc')
dilations = [1, 1, 1, 1]
for i, dilation in enumerate(dilations):
x = conv2d(x, self.channel, 3, activation=tf.nn.leaky_relu,
dilation_rate=(dilation, dilation),
name='conv'+str(i))
x = conv2d(x, self.channel, 3, activation=tf.nn.leaky_relu,
name='dec1')
O_R = conv2d(x, 3, 1, activation=None, name='dec2')
return O_R
def get_metrics(self, B, P, R, O_R):
metrics = {
'loss': tf.losses.mean_squared_error(R, O_R),
'psnr': tf.reduce_mean(tf.image.psnr(B, P, max_val=1.0)),
'ssim': tf.reduce_mean(tf.image.ssim(B, P, max_val=1.0)),
}
return metrics