forked from ad50810344/compression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
143 lines (135 loc) · 5.35 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import argparse
from model import *
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import json
import time
from datasets import Datasets, TestKodakDataset
from tensorboardX import SummaryWriter
from Meter import AverageMeter
torch.backends.cudnn.enabled = True
# gpu_num = 4
gpu_num = torch.cuda.device_count()
cur_lr = base_lr = 1e-4# * gpu_num
train_lambda = 8192
print_freq = 100
cal_step = 40
warmup_step = 0# // gpu_num
batch_size = 4
tot_epoch = 1000000
tot_step = 2500000
decay_interval = 2200000
lr_decay = 0.1
image_size = 256
logger = logging.getLogger("ImageCompression")
tb_logger = None
global_step = 0
save_model_freq = 50000
test_step = 10000
out_channel_N = 192
out_channel_M = 320
parser = argparse.ArgumentParser(description='Pytorch reimplement for variational image compression with a scale hyperprior')
parser.add_argument('-n', '--name', default='',
help='output training details')
parser.add_argument('-p', '--pretrain', default = '',
help='load pretrain model')
parser.add_argument('-t', '--test', default='',
help='test dataset')
parser.add_argument('--config', dest='config', required=False,
help = 'hyperparameter in json format')
parser.add_argument('--seed', default=234, type=int, help='seed for random functions, and network initialization')
parser.add_argument('--val', dest='val_path', required=True, help='the path of validation dataset')
def parse_config(config):
config = json.load(open(args.config))
global tot_epoch, tot_step, base_lr, cur_lr, lr_decay, decay_interval, train_lambda, batch_size, print_freq, \
out_channel_M, out_channel_N, save_model_freq, test_step
if 'tot_epoch' in config:
tot_epoch = config['tot_epoch']
if 'tot_step' in config:
tot_step = config['tot_step']
if 'train_lambda' in config:
train_lambda = config['train_lambda']
if train_lambda < 4096:
out_channel_N = 128
out_channel_M = 192
else:
out_channel_N = 192
out_channel_M = 320
if 'batch_size' in config:
batch_size = config['batch_size']
if "print_freq" in config:
print_freq = config['print_freq']
if "test_step" in config:
test_step = config['test_step']
if "save_model_freq" in config:
save_model_freq = config['save_model_freq']
if 'lr' in config:
if 'base' in config['lr']:
base_lr = config['lr']['base']
cur_lr = base_lr
if 'decay' in config['lr']:
lr_decay = config['lr']['decay']
if 'decay_interval' in config['lr']:
decay_interval = config['lr']['decay_interval']
if 'out_channel_N' in config:
out_channel_N = config['out_channel_N']
if 'out_channel_M' in config:
out_channel_M = config['out_channel_M']
def test(step):
with torch.no_grad():
net.eval()
sumBpp = 0
sumPsnr = 0
sumMsssim = 0
sumMsssimDB = 0
cnt = 0
for batch_idx, input in enumerate(test_loader):
clipped_recon_image, mse_loss, bpp_feature, bpp_z, bpp = net(input)
mse_loss, bpp_feature, bpp_z, bpp = \
torch.mean(mse_loss), torch.mean(bpp_feature), torch.mean(bpp_z), torch.mean(bpp)
psnr = 10 * (torch.log(1. / mse_loss) / np.log(10))
sumBpp += bpp
sumPsnr += psnr
msssim = ms_ssim(clipped_recon_image.cpu().detach(), input, data_range=1.0, size_average=True)
msssimDB = -10 * (torch.log(1-msssim) / np.log(10))
sumMsssimDB += msssimDB
sumMsssim += msssim
cnt += 1
logger.info("Num: {}, Bpp:{:.6f}, PSNR:{:.6f}, MS-SSIM:{:.6f}, MS-SSIM-DB:{:.6f}".format(cnt, bpp, psnr, msssim, msssimDB))
logger.info("Test on Kodak dataset: model-{}".format(step))
sumBpp /= cnt
sumPsnr /= cnt
sumMsssim /= cnt
sumMsssimDB /= cnt
logger.info("Dataset Average result---Dataset Num: {}, Bpp:{:.6f}, PSNR:{:.6f}, MS-SSIM:{:.6f}, MS-SSIM-DB:{:.6f}".format(cnt, sumBpp, sumPsnr, sumMsssim, sumMsssimDB))
if __name__ == "__main__":
args = parser.parse_args()
torch.manual_seed(seed=args.seed)
formatter = logging.Formatter('%(asctime)s - %(levelname)s] %(message)s')
formatter = logging.Formatter('[%(asctime)s][%(filename)s][L%(lineno)d][%(levelname)s] %(message)s')
stdhandler = logging.StreamHandler()
stdhandler.setLevel(logging.INFO)
stdhandler.setFormatter(formatter)
logger.addHandler(stdhandler)
tb_logger = None
logger.setLevel(logging.INFO)
logger.info("image compression test")
logger.info("config : ")
logger.info(open(args.config).read())
parse_config(args.config)
logger.info("out_channel_N:{}, out_channel_M:{}".format(out_channel_N, out_channel_M))
model = ImageCompressor(out_channel_N, out_channel_M)
if args.pretrain != '':
logger.info("loading model:{}".format(args.pretrain))
global_step = load_model(model, args.pretrain)
net = model.cuda()
net = torch.nn.DataParallel(net, list(range(gpu_num)))
global test_loader
if args.test == 'kodak':
test_dataset = TestKodakDataset(data_dir=args.val_path)
logger.info("No test dataset")
exit(-1)
test_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=1, pin_memory=True)
test(global_step)