-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsuper_resolve.py
84 lines (68 loc) · 3.54 KB
/
super_resolve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import argparse
import torch
import torch.cuda
from net import StereoSRChrominance, StereoSRLuminance
def main():
parser = argparse.ArgumentParser(description='Make Super Resolution Image')
parser.add_argument('--left_image', type=str, help='left_image')
parser.add_argument('--right_image', type=str, help='right_image')
parser.add_argument('--scale_factor', type=int, default=2, help='scale factor')
opt = parser.parse_args()
L = 33
S = opt.scale_factor
SH = 64
device = torch.device('cuda' if torch.cuda.is_available() else ('cpu'))
img_L = np.array(Image.open(opt.left_image).convert('YCbCr'),dtype='uint8')
img_R = np.array(Image.open(opt.right_image).convert('YCbCr'),dtype='uint8')
img_size = int(img_L.shape[0]/(L*S))*int((img_L.shape[1]-SH*S)/(L*S))
input = np.zeros([img_size,67,L*S,L*S],dtype='float')
cnt = 0
for y in range(int(img_L.shape[0]/(L*S))):
for x in range(int((img_L.shape[1]-SH*S)/(L*S))):
top = y*L*S
left = x*L*S+SH*S
img_L_HR = img_L[top:top + L * S, left:left + L * S, :]
img_L_LR = (Image.fromarray(img_L_HR)).resize([L, L], resample=Image.BICUBIC)
img_R_HR = img_R[top:top + L * S, left - (SH - 1) * S:left + L * S, :]
img_R_LR = (Image.fromarray(img_R_HR)).resize([L + SH - 1, L], resample=Image.BICUBIC)
img_R_LR = np.array(img_R_LR.resize([img_R_LR.size[0] * S, img_R_LR.size[1] * S]))
img_Input = np.zeros([L * S, L * S, SH + 3], dtype='uint8')
img_Input[:, :, SH:SH + 3] = np.array(img_L_LR.resize([L * S, L * S], resample=Image.BICUBIC))
for i in range(SH):
# img_R_LR = img_R_LR.resize([L*S, L*S], resample=Image.BICUBIC)
if i == 0:
img = img_R_LR[:, -33 * S:, :]
else:
img = img_R_LR[:, -33 * S - i * S:-i * S, :]
img = np.array((Image.fromarray(img)).convert('YCbCr'))
img_Input[:, :, i] = img[:, :, 0]
input[cnt,:,:,:] = img_Input.transpose([2,0,1])
cnt += 1
input = torch.Tensor(input.astype('float') / 255).to(device)
#Lum = StereoSRLuminance().to(device)
Lum = torch.load('result/model_lum19.pth').to(device)
#Chr = StereoSRChrominance().to(device)
Chr = torch.load('result/model_chr19.pth').to(device)
lum = Lum(input[:,:SH+1,:,:])
output = Chr(input[:, 65:, :, :], lum)
input = np.clip(input[:,64:,:,:].to('cpu').detach().numpy() * 255, 0, 255)
output = np.clip(output.to('cpu').detach().numpy()*255, 0, 255)
super_img = np.zeros([3,int(img_L.shape[0]/(L*S))*L*S,int((img_L.shape[1]-SH*S)/(L*S))*L*S], dtype='uint8')
input_img = np.zeros([3, int(img_L.shape[0] / (L * S)) * L * S, int((img_L.shape[1] - SH * S) / (L * S)) * L * S],dtype='uint8')
cnt = 0
for y in range(int(img_L.shape[0]/(L*S))):
for x in range(int((img_L.shape[1]-SH*S)/(L*S))):
super_img[:,y*L*S:(y+1)*L*S,x*L*S:(x+1)*L*S] = output[cnt,:,:,:]
input_img[:, y * L * S:(y + 1) * L * S, x * L * S:(x + 1) * L * S] = input[cnt, :, :, :]
cnt += 1
input_img = Image.fromarray(input_img.transpose([1, 2, 0]), 'YCbCr')
super_img = Image.fromarray(super_img.transpose([1,2,0]), 'YCbCr')
plt.imshow(input_img)
plt.show()
plt.imshow(super_img)
plt.show()
if __name__ == '__main__':
main()