-
Notifications
You must be signed in to change notification settings - Fork 19
/
test.py
92 lines (68 loc) · 2.69 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.utils.data as Data
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
from lib.network import CSNet
from torch import nn
import time
import os
import argparse
from tqdm import tqdm
from data_utils import TestDatasetFromFolder, psnr
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.transforms import ToPILImage
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--block_size', default=32, type=int, help='CS block size')
parser.add_argument('--save_img', default=1, type=int, help='')
parser.add_argument('--sub_rate', default=0.1, type=float, help='sampling sub rate')
parser.add_argument('--NetWeights', type=str, default='epochs_subrate_0.1_blocksize_32/net_epoch_200_0.001724.pth', help="path of CSNet weights for testing")
opt = parser.parse_args()
BLOCK_SIZE = opt.block_size
val_set = TestDatasetFromFolder('/media/gdh-95/data/Set14', blocksize=BLOCK_SIZE)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
net = CSNet(BLOCK_SIZE, opt.sub_rate)
mse_loss = nn.MSELoss()
if opt.NetWeights != '':
net.load_state_dict(torch.load(opt.NetWeights))
if torch.cuda.is_available():
net.cuda()
mse_loss.cuda()
for epoch in range(1, 1+1):
train_bar = tqdm(val_loader)
running_results = {'batch_sizes': 0, 'g_loss': 0, }
save_dir = 'results' + '_subrate_' + str(opt.sub_rate) + '_blocksize_' + str(
BLOCK_SIZE)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
net.eval()
psnrs = 0.0
img_id = 0
for data, target in train_bar:
batch_size = data.size(0)
if batch_size <= 0:
continue
running_results['batch_sizes'] += batch_size
img_id += 1
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = net(z)
fake_img[fake_img>1] = 1
fake_img[fake_img<0] = 0
psnr_t = psnr(fake_img.data.cpu(), real_img.data.cpu())
psnrs += psnr_t
g_loss = mse_loss(fake_img, real_img)
running_results['g_loss'] += g_loss.item() * batch_size
train_bar.set_description(desc='[%d] Loss_G: %.4f' % (
epoch, running_results['g_loss'] / running_results['batch_sizes']))
if opt.save_img > 0:
res = fake_img.data.cpu()
res = torch.squeeze(res, 0)
res = ToPILImage()(res)
res.save(save_dir + '/res_'+str(img_id)+'_'+str(psnr_t)+'.png')
print("averate psnrs is: ", psnrs/img_id)