-
Notifications
You must be signed in to change notification settings - Fork 35
/
domain_specific_deblur.py
89 lines (71 loc) · 3.1 KB
/
domain_specific_deblur.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
from math import ceil, log10
from pathlib import Path
import torchvision
import yaml
from PIL import Image
from torch.nn import DataParallel
from torch.utils.data import DataLoader, Dataset
class Images(Dataset):
def __init__(self, root_dir, duplicates):
self.root_path = Path(root_dir)
self.image_list = list(self.root_path.glob("*.png"))
self.duplicates = (
duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images
)
def __len__(self):
return self.duplicates * len(self.image_list)
def __getitem__(self, idx):
img_path = self.image_list[idx // self.duplicates]
image = torchvision.transforms.ToTensor()(Image.open(img_path))
if self.duplicates == 1:
return image, img_path.stem
else:
return image, img_path.stem + f"_{(idx % self.duplicates)+1}"
parser = argparse.ArgumentParser(description="PULSE")
# I/O arguments
parser.add_argument("--input_dir", type=str, default="imgs/blur_faces", help="input data directory")
parser.add_argument(
"--output_dir", type=str, default="experiments/domain_specific_deblur/results", help="output data directory"
)
parser.add_argument(
"--cache_dir",
type=str,
default="experiments/domain_specific_deblur/cache",
help="cache directory for model weights",
)
parser.add_argument(
"--yml_path", type=str, default="options/domain_specific_deblur/stylegan2.yml", help="configuration file"
)
kwargs = vars(parser.parse_args())
with open(kwargs["yml_path"], "rb") as f:
opt = yaml.safe_load(f)
dataset = Images(kwargs["input_dir"], duplicates=opt["duplicates"])
out_path = Path(kwargs["output_dir"])
out_path.mkdir(parents=True, exist_ok=True)
dataloader = DataLoader(dataset, batch_size=opt["batch_size"])
if opt["stylegan_ver"] == 1:
from models.dsd.dsd_stylegan import DSDStyleGAN
model = DSDStyleGAN(opt=opt, cache_dir=kwargs["cache_dir"])
else:
from models.dsd.dsd_stylegan2 import DSDStyleGAN2
model = DSDStyleGAN2(opt=opt, cache_dir=kwargs["cache_dir"])
model = DataParallel(model)
toPIL = torchvision.transforms.ToPILImage()
for ref_im, ref_im_name in dataloader:
if opt["save_intermediate"]:
padding = ceil(log10(100))
for i in range(opt["batch_size"]):
int_path_HR = Path(out_path / ref_im_name[i] / "HR")
int_path_LR = Path(out_path / ref_im_name[i] / "LR")
int_path_HR.mkdir(parents=True, exist_ok=True)
int_path_LR.mkdir(parents=True, exist_ok=True)
for j, (HR, LR) in enumerate(model(ref_im)):
for i in range(opt["batch_size"]):
toPIL(HR[i].cpu().detach().clamp(0, 1)).save(int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png")
toPIL(LR[i].cpu().detach().clamp(0, 1)).save(int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png")
else:
# out_im = model(ref_im,**kwargs)
for j, (HR, LR) in enumerate(model(ref_im)):
for i in range(opt["batch_size"]):
toPIL(HR[i].cpu().detach().clamp(0, 1)).save(out_path / f"{ref_im_name[i]}.png")