forked from SYSU-SAIL/PASSRnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
92 lines (81 loc) · 4.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from PIL import Image
import os
from torch.utils.data.dataset import Dataset
from torchvision.transforms import ToTensor
import random
import torch
import numpy as np
from skimage import measure
from torch.nn import init
class TrainSetLoader(Dataset):
def __init__(self, dataset_dir, cfg):
super(TrainSetLoader, self).__init__()
self.dataset_dir = dataset_dir + '/patches_x' + str(cfg.scale_factor)
self.file_list = os.listdir(self.dataset_dir)
def __getitem__(self, index):
img_hr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr0.png')
img_hr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr1.png')
img_lr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr0.png')
img_lr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr1.png')
img_hr_left = np.array(img_hr_left, dtype=np.float32)
img_hr_right = np.array(img_hr_right, dtype=np.float32)
img_lr_left = np.array(img_lr_left, dtype=np.float32)
img_lr_right = np.array(img_lr_right, dtype=np.float32)
img_hr_left, img_hr_right, img_lr_left, img_lr_right = augumentation(img_hr_left, img_hr_right, img_lr_left, img_lr_right)
return toTensor(img_hr_left), toTensor(img_hr_right), toTensor(img_lr_left), toTensor(img_lr_right)
def __len__(self):
return len(self.file_list)
class TestSetLoader(Dataset):
def __init__(self, dataset_dir, scale_factor):
super(TestSetLoader, self).__init__()
self.dataset_dir = dataset_dir
self.scale_factor = scale_factor
self.file_list = os.listdir(dataset_dir + '/hr')
def __getitem__(self, index):
hr_image_left = Image.open(self.dataset_dir + '/hr/' + self.file_list[index] + '/hr0.png')
hr_image_right = Image.open(self.dataset_dir + '/hr/' + self.file_list[index] + '/hr1.png')
lr_image_left = Image.open(self.dataset_dir + '/lr_x' + str(self.scale_factor) + '/' + self.file_list[index] + '/lr0.png')
lr_image_right = Image.open(self.dataset_dir + '/lr_x' + str(self.scale_factor) + '/' + self.file_list[index] + '/lr1.png')
hr_image_left = ToTensor()(hr_image_left)
hr_image_right = ToTensor()(hr_image_right)
lr_image_left = ToTensor()(lr_image_left)
lr_image_right = ToTensor()(lr_image_right)
return hr_image_left, hr_image_right, lr_image_left, lr_image_right
def __len__(self):
return len(self.file_list)
def augumentation(hr_image_left, hr_image_right, lr_image_left, lr_image_right):
if random.random()<0.5: # flip horizonly
lr_image_left = lr_image_left[:, ::-1, :]
lr_image_right = lr_image_right[:, ::-1, :]
hr_image_left = hr_image_left[:, ::-1, :]
hr_image_right = hr_image_right[:, ::-1, :]
if random.random()<0.5: #flip vertically
lr_image_left = lr_image_left[::-1, :, :]
lr_image_right = lr_image_right[::-1, :, :]
hr_image_left = hr_image_left[::-1, :, :]
hr_image_right = hr_image_right[::-1, :, :]
""""no rotation
if random.random()<0.5:
lr_image_left = lr_image_left.transpose(1, 0, 2)
lr_image_right = lr_image_right.transpose(1, 0, 2)
hr_image_left = hr_image_left.transpose(1, 0, 2)
hr_image_right = hr_image_right.transpose(1, 0, 2)
"""
return np.ascontiguousarray(hr_image_left), np.ascontiguousarray(hr_image_right), \
np.ascontiguousarray(lr_image_left), np.ascontiguousarray(lr_image_right)
def toTensor(img):
img = torch.from_numpy(img.transpose((2, 0, 1)))
return img.float().div(255)
class L1Loss(object):
def __call__(self, input, target):
return torch.abs(input - target).mean()
def cal_psnr(img1, img2):
img1_np = np.array(img1)
img2_np = np.array(img2)
return measure.compare_psnr(img1_np, img2_np)
def save_ckpt(state, save_path='./log', filename='checkpoint.pth.tar'):
torch.save(state, os.path.join(save_path,filename))
def weights_init_xavier(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
init.xavier_normal(m.weight.data)