-
Notifications
You must be signed in to change notification settings - Fork 557
/
predict.py
125 lines (106 loc) · 3.74 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import sys
sys.path.insert(0, "stylegan-encoder")
import tempfile
import warnings
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
import torchvision.transforms as transforms
import dlib
from cog import BasePredictor, Path, Input
from demo import load_checkpoints
from demo import make_animation
from ffhq_dataset.face_alignment import image_align
from ffhq_dataset.landmarks_detector import LandmarksDetector
warnings.filterwarnings("ignore")
PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")
class Predictor(BasePredictor):
def setup(self):
self.device = torch.device("cuda:0")
datasets = ["vox", "taichi", "ted", "mgif"]
(
self.inpainting,
self.kp_detector,
self.dense_motion_network,
self.avd_network,
) = ({}, {}, {}, {})
for d in datasets:
(
self.inpainting[d],
self.kp_detector[d],
self.dense_motion_network[d],
self.avd_network[d],
) = load_checkpoints(
config_path=f"config/{d}-384.yaml"
if d == "ted"
else f"config/{d}-256.yaml",
checkpoint_path=f"checkpoints/{d}.pth.tar",
device=self.device,
)
def predict(
self,
source_image: Path = Input(
description="Input source image.",
),
driving_video: Path = Input(
description="Choose a micromotion.",
),
dataset_name: str = Input(
choices=["vox", "taichi", "ted", "mgif"],
default="vox",
description="Choose a dataset.",
),
) -> Path:
predict_mode = "relative" # ['standard', 'relative', 'avd']
# find_best_frame = False
pixel = 384 if dataset_name == "ted" else 256
if dataset_name == "vox":
# first run face alignment
align_image(str(source_image), 'aligned.png')
source_image = imageio.imread('aligned.png')
else:
source_image = imageio.imread(str(source_image))
reader = imageio.get_reader(str(driving_video))
fps = reader.get_meta_data()["fps"]
source_image = resize(source_image, (pixel, pixel))[..., :3]
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [
resize(frame, (pixel, pixel))[..., :3] for frame in driving_video
]
inpainting, kp_detector, dense_motion_network, avd_network = (
self.inpainting[dataset_name],
self.kp_detector[dataset_name],
self.dense_motion_network[dataset_name],
self.avd_network[dataset_name],
)
predictions = make_animation(
source_image,
driving_video,
inpainting,
kp_detector,
dense_motion_network,
avd_network,
device="cuda:0",
mode=predict_mode,
)
# save resulting video
out_path = Path(tempfile.mkdtemp()) / "output.mp4"
imageio.mimsave(
str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps
)
return out_path
def align_image(raw_img_path, aligned_face_path):
for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
image_align(raw_img_path, aligned_face_path, face_landmarks)