-
Notifications
You must be signed in to change notification settings - Fork 155
/
predict.py
247 lines (224 loc) · 8.25 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import tempfile
from argparse import Namespace
import numpy as np
import time
from cog import BasePredictor, Input, Path
import dlib
import imageio
import torch
from PIL import Image
from torchvision import transforms
from models.e4e import e4e
from models.psp import pSp
from scripts.align_faces_parallel import align_face
from scripts import encoder_bootstrapping_inference
from utils.common import tensor2im
from utils.inference_utils import run_on_batch
DOMAINS = ["faces", "toonify"]
class Predictor(BasePredictor):
def setup(self):
print("Starting setup!")
self.model_paths = {
"faces": "pretrained_models/restyle_psp_ffhq_encode.pt",
"toonify": "pretrained_models/restyle_psp_toonify.pt",
}
print("Loading checkpoints...")
self.checkpoints = {
"faces": torch.load(self.model_paths["faces"], map_location="cpu"),
"toonify": torch.load(self.model_paths["toonify"], map_location="cpu"),
}
print("Done!")
self.shape_predictor = dlib.shape_predictor(
"/content/shape_predictor_68_face_landmarks.dat"
)
self.default_transforms = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
self.cars_transforms = transforms.Compose(
[
transforms.Resize((192, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
print("Setup complete!")
def predict(
self,
input: Path = Input(description="Path to input image"),
encoding_type: str = Input(
choices=DOMAINS,
description="Which domain you wish to run on",
),
num_iterations: int = Input(
default=5,
ge=1,
le=10,
description="Number of ReStyle iterations to run. "
"For `faces` we recommend 5 iterations and for `toonify` we recommend 1 to 2 iterations.",
),
display_intermediate_results: bool = Input(
choices=False,
description="Whether to display all intermediate outputs. If unchecked, will display only the final result.",
),
) -> Path:
if encoding_type == "toonify":
return self.run_toonify_bootstrapping(
input, num_iterations, display_intermediate_results
)
else:
return self.run_default_encoding(
input, encoding_type, num_iterations, display_intermediate_results
)
def run_default_encoding(
self, input, encoding_type, num_iterations, display_intermediate_results
):
# load model
print(f"Loading {encoding_type} model...")
ckpt = self.checkpoints[encoding_type]
opts = ckpt["opts"]
opts["checkpoint_path"] = self.model_paths[encoding_type]
opts = Namespace(**opts)
net = e4e(opts) if encoding_type == "horses" else pSp(opts)
net.eval()
net.cuda()
print("Done!")
# define some arguments
opts.n_iters_per_batch = num_iterations
opts.resize_outputs = False
# define transforms
image_transforms = (
self.cars_transforms if encoding_type == "cars" else self.default_transforms
)
# if working on faces load and align the image
if encoding_type == "faces":
print("Aligning image...")
input_image = self.run_alignment(str(input))
print("Done!")
# otherwise simply load the image
else:
input_image = Image.open(str(input)).convert("RGB")
# preprocess image
transformed_image = image_transforms(input_image)
# run inference
print("Running inference...")
with torch.no_grad():
start = time.time()
avg_image = self.get_avg_image(net, encoding_type)
result_batch, result_latents = run_on_batch(
transformed_image.unsqueeze(0).cuda(), net, opts, avg_image
)
total_time = time.time() - start
print(f"Finished inference in {total_time} seconds!")
# post-processing
print("Preparing result...")
resize_amount = (
(512, 384)
if encoding_type == "cars_encode"
else (opts.output_size, opts.output_size)
)
res = self.get_final_output(
result_batch, resize_amount, display_intermediate_results, opts
)
# display output
out_path = Path(tempfile.mkdtemp()) / "output.png"
imageio.imwrite(str(out_path), res)
return out_path
def run_toonify_bootstrapping(
self, input, num_iterations, display_intermediate_results
):
# load ffhq model
print("Loading faces model...")
ckpt = self.checkpoints["faces"]
opts = ckpt["opts"]
opts["checkpoint_path"] = self.model_paths["faces"]
opts = Namespace(**opts)
net_ffhq = pSp(opts)
net_ffhq.eval()
net_ffhq.cuda()
print("Done!")
# load toonify model
print("Loading toonify model...")
ckpt = self.checkpoints["toonify"]
opts = ckpt["opts"]
opts["checkpoint_path"] = self.model_paths["toonify"]
opts = Namespace(**opts)
net_toonify = pSp(opts)
net_toonify.eval()
net_toonify.cuda()
print("Done!")
# define some arguments
opts.n_iters_per_batch = num_iterations
opts.resize_outputs = False
# load, align, and preprocess image
print("Aligning image...")
input_image = self.run_alignment(str(input))
print("Done!")
transformed_image = self.default_transforms(input_image)
# run inference
print("Running inference...")
with torch.no_grad():
start = time.time()
avg_image = self.get_avg_image(net_ffhq, encoding_type="faces")
result_batch = encoder_bootstrapping_inference.run_on_batch(
transformed_image.unsqueeze(0).cuda(),
net_ffhq,
net_toonify,
opts,
avg_image,
)
total_time = time.time() - start
print(f"Finished inference in {total_time} seconds!")
# post-processing
print("Preparing result...")
resize_amount = (1024, 1024)
res = self.get_final_output(
result_batch, resize_amount, display_intermediate_results, opts
)
# display output
out_path = Path(tempfile.mkdtemp()) / "output.png"
imageio.imwrite(str(out_path), res)
return out_path
def run_alignment(self, image_path):
try:
aligned_image = align_face(
filepath=image_path, predictor=self.shape_predictor
)
except Exception:
raise ValueError(
f"Oh no! Could not align face! \nPlease try another image!"
)
return aligned_image
@staticmethod
def get_avg_image(net, encoding_type):
avg_image = net(
net.latent_avg.unsqueeze(0),
input_code=True,
randomize_noise=False,
return_latents=False,
average_code=True,
)[0]
avg_image = avg_image.to("cuda").float().detach()
if encoding_type == "cars":
avg_image = avg_image[:, 32:224, :]
return avg_image
@staticmethod
def get_final_output(
result_batch, resize_amount, display_intermediate_results, opts
):
result_tensors = result_batch[0] # there's one image in our batch
if display_intermediate_results:
result_images = [
tensor2im(result_tensors[iter_idx])
for iter_idx in range(opts.n_iters_per_batch)
]
else:
result_images = [tensor2im(result_tensors[-1])]
res = np.array(result_images[0].resize(resize_amount))
for idx, result in enumerate(result_images[1:]):
res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)
res = Image.fromarray(res)
return res