forked from jacobgil/pytorch-zssr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
71 lines (60 loc) · 2.38 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import PIL
import numpy as np
import sys
import random
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
import numbers
import cv2
from source_target_transforms import *
class DataSampler:
def __init__(self, img, sr_factor, crop_size):
self.img = img
self.sr_factor = sr_factor
self.pairs = self.create_hr_lr_pairs()
sizes = np.float32([x[0].size[0]*x[0].size[1] / float(img.size[0]*img.size[1]) \
for x in self.pairs])
self.pair_probabilities = sizes / np.sum(sizes)
self.transform = transforms.Compose([
RandomRotationFromSequence([0, 90, 180, 270]),
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomCrop(crop_size),
ToTensor()])
def create_hr_lr_pairs(self):
smaller_side = min(self.img.size[0 : 2])
larger_side = max(self.img.size[0 : 2])
factors = []
for i in range(smaller_side//5, smaller_side+1):
downsampled_smaller_side = i
zoom = float(downsampled_smaller_side)/smaller_side
downsampled_larger_side = round(larger_side*zoom)
if downsampled_smaller_side%self.sr_factor==0 and \
downsampled_larger_side%self.sr_factor==0:
factors.append(zoom)
pairs = []
for zoom in factors:
hr = self.img.resize((int(self.img.size[0]*zoom), \
int(self.img.size[1]*zoom)), \
resample=PIL.Image.BICUBIC)
lr = hr.resize((int(hr.size[0]/self.sr_factor), \
int(hr.size[1]/self.sr_factor)),
resample=PIL.Image.BICUBIC)
lr = lr.resize(hr.size, resample=PIL.Image.BICUBIC)
pairs.append((hr, lr))
return pairs
def generate_data(self):
while True:
hr, lr = random.choices(self.pairs, weights=self.pair_probabilities, k=1)[0]
hr_tensor, lr_tensor = self.transform((hr, lr))
hr_tensor = torch.unsqueeze(hr_tensor, 0)
lr_tensor = torch.unsqueeze(lr_tensor, 0)
yield hr_tensor, lr_tensor
if __name__ == '__main__':
img = PIL.Image.open(sys.argv[1])
sampler = DataSampler(img, 2)
for x in sampler.generate_data():
hr, lr = x
hr = hr.numpy().transpose((1, 2, 0))
lr = lr.numpy().transpose((1, 2, 0))