-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathHLVC_layer2_B-frame_decoder.py
108 lines (81 loc) · 4.02 KB
/
HLVC_layer2_B-frame_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import numpy as np
import tensorflow as tf
import tensorflow_compression as tfc
from scipy import misc
import CNN_img
import motion
import MC_network
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
config = tf.ConfigProto(allow_soft_placement=True)
sess = tf.Session(config=config)
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--ref_1", default='ref_1.png')
parser.add_argument("--ref_2", default='ref_2.png')
# parser.add_argument("--raw", default='raw.png')
parser.add_argument("--com", default='dec.png')
parser.add_argument("--bin", default='bits_B.bin')
parser.add_argument("--mode", default='PSNR', choices=['PSNR', 'MS-SSIM'])
parser.add_argument("--l", type=int, default=4096, choices=[32, 64, 128, 256, 1024, 2048, 4096, 8192])
parser.add_argument("--N", type=int, default=128, choices=[128])
parser.add_argument("--M", type=int, default=128, choices=[128])
args = parser.parse_args()
batch_size = 1
Channel = 3
Y0_com_img = misc.imread(args.ref_1)
# Y1_raw_img = misc.imread(args.raw)
Y2_com_img = misc.imread(args.ref_2)
Y0_com_img = np.expand_dims(Y0_com_img, 0)
# Y1_raw_img = np.expand_dims(Y1_raw_img, 0)
Y2_com_img = np.expand_dims(Y2_com_img, 0)
Height = np.size(Y0_com_img, 1)
Width = np.size(Y0_com_img, 2)
Y0_com = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel])
# Y1_raw = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel])
Y2_com = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel])
string_mv_tensor = tf.placeholder(tf.string, [])
string_res_tensor = tf.placeholder(tf.string, [])
# Motion Decoding
entropy_bottleneck_mv = tfc.EntropyBottleneck(dtype=tf.float32, name='entropy_bottleneck')
flow_latent_hat = entropy_bottleneck_mv.decompress(
tf.expand_dims(string_mv_tensor, 0), [Height//16, Width//16, args.M], channels=args.M)
# Residual Decoding
entropy_bottleneck_res = tfc.EntropyBottleneck(dtype=tf.float32, name='entropy_bottleneck_1_1')
res_latent_hat = entropy_bottleneck_res.decompress(
tf.expand_dims(string_res_tensor, 0), [Height//16, Width//16, args.M], channels=args.M)
flow_hat = CNN_img.MV_synthesis(flow_latent_hat, args.N, out_filters=4)
[flow_hat_0, flow_hat_2] = tf.split(flow_hat, [2, 2], axis=-1)
# Motion Compensation
Y1_warp_hat_0 = tf.contrib.image.dense_image_warp(Y0_com, flow_hat_0)
Y1_warp_hat_2 = tf.contrib.image.dense_image_warp(Y2_com, flow_hat_2)
Y1_warp_hat = (Y1_warp_hat_0 + Y1_warp_hat_2)/2.0
MC_input = tf.concat([flow_hat, Y0_com, Y2_com, Y1_warp_hat], axis=-1)
Y1_MC = MC_network.MC(MC_input)
Res_hat = CNN_img.Res_synthesis(res_latent_hat, num_filters=args.N)
# Reconstructed frame
Y1_com = tf.clip_by_value(Res_hat + Y1_MC, 0, 1)
# if args.metric == 'PSNR':
# train_mse = tf.reduce_mean(tf.squared_difference(Y1_com, Y1_raw))
# quality = 10.0*tf.log(1.0/train_mse)/tf.log(10.0)
# elif args.metric == 'MS-SSIM':
# quality = tf.math.reduce_mean(tf.image.ssim_multiscale(Y1_com, Y1_raw, max_val=1))
saver = tf.train.Saver(max_to_keep=None)
model_path = './HLVC_model/Layer2_B-frame/' \
'Layer2_B_' + args.mode + '_' + str(args.l) + '_model/model.ckpt'
saver.restore(sess, save_path=model_path)
with open(args.bin, "rb") as ff:
quality_com = np.frombuffer(ff.read(4), dtype=np.float32)
mv_len = np.frombuffer(ff.read(2), dtype=np.uint16)
string_mv = ff.read(np.int(mv_len))
string_res = ff.read()
compressed_frame = sess.run(Y1_com,
feed_dict={Y0_com: Y0_com_img / 255.0,
# Y1_raw: Y1_raw_img / 255.0,
Y2_com: Y2_com_img / 255.0,
string_mv_tensor: string_mv,
string_res_tensor: string_res})
bpp = os.path.getsize(args.bin) * 8 / Height / Width
misc.imsave(args.com, np.uint8(np.round(compressed_frame[0] * 255.0)))
# print('Decoded', args.mode + ' (before WRQE) = ' + str(quality_com[0]), 'bpp = ' + str(bpp))