-
Notifications
You must be signed in to change notification settings - Fork 1
/
drag_pipeline.py
280 lines (246 loc) · 9.69 KB
/
drag_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import dataclass
from itertools import product
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from torchvision import transforms
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler
from diffusers.utils import randn_tensor, BaseOutput
from motion_sup import unet_feat_hook, MotionSup
def load_img(path, target_size=512):
"""Load an image, resize and output -1..1"""
image = PIL.Image.open(path).convert("RGB")
tform = transforms.Compose(
[
transforms.Resize(target_size),
transforms.CenterCrop(target_size),
transforms.ToTensor(),
]
)
image = tform(image)
return 2.0 * image - 1.0
def backward_ddim(x_t, alpha_t: "alpha_t", alpha_tm1: "alpha_{t-1}", eps_xt):
""" from noise to image"""
return (
alpha_tm1**0.5
* (
(alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
+ ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
)
+ x_t
)
def forward_ddim(x_t, alpha_t: "alpha_t", alpha_tp1: "alpha_{t+1}", eps_xt):
""" from image to noise, it's the same as backward_ddim"""
return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)
class DragDiffusionPipeline(DiffusionPipeline):
"""
Modify from:
https://github.com/cccntu/efficient-prompt-to-prompt/blob/main/ddim-inversion.ipynb
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker = None,
feature_extractor = None,
):
super().__init__()
self.unet_feat_cache: Tensor = None
unet.up_blocks[-1].register_forward_hook(hook=partial(unet_feat_hook, self))
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
)
self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
@torch.inference_mode()
def get_text_embedding(self, prompt):
text_input_ids = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
return text_embeddings
@torch.inference_mode()
def get_image_latents(self, image, sample=True, rng_generator=None):
encoding_dist = self.vae.encode(image).latent_dist
if sample:
encoding = encoding_dist.sample(generator=rng_generator)
else:
encoding = encoding_dist.mode()
latents = encoding * 0.18215
return latents
@torch.enable_grad()
def tune_latent(
self,
latents,
mask,
timestep,
prompt_embeds,
cross_attention_kwargs,
handle_points,
target_points,
steps=150,
return_inbetween=False,
):
self.unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)
motion_track = MotionSup(
handle_points,
target_points,
self.unet_feat_cache[0],
latents,
mask,
steps=steps
)
cache = []
cache_step = list(range(0, steps, max(1, steps // 16))) + [steps - 1]
bar = tqdm(range(motion_track.steps))
for j in bar:
self.unet(
motion_track.ref_latent.to(latents.dtype),
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)
feat_map = self.unet_feat_cache[0]
loss = motion_track.step(feat_map, prev_feat=None) # update motion_track.ref_latent
motion_track.search_handle(motion_track.ref_feat, feat_map)
bar.set_postfix({"l1_loss": loss})
if return_inbetween and j in cache_step:
cur_latent = motion_track.ref_latent.data.clone().detach()
cache.append(cur_latent.to(latents.dtype))
final_latent = motion_track.ref_latent.to(latents.dtype).data
if return_inbetween:
return final_latent, cache
else:
return final_latent
@property
def timesteps_tensor(self):
return self.scheduler.timesteps.to(self.device)
@torch.inference_mode()
def backward_diffusion(
self,
use_old_emb_i=25,
text_embeddings=None,
old_text_embeddings=None,
new_text_embeddings=None,
latents: Optional[torch.FloatTensor] = None,
num_inference_steps: int = 50,
guidance_scale: float = 1,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
reverse_process: bool = False,
early_stop_step: Optional[int] = None,
**kwargs,
):
"""
Generate image from text prompt and latents
"""
assert latents.size(0) == 1, "Current implmentation don't support batch size > 1!"
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
for i, t in enumerate(self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
prev_timestep = (
t
- self.scheduler.config.num_train_timesteps
// self.scheduler.num_inference_steps
)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# ddim
alpha_prod_t = self.scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.scheduler.final_alpha_cumprod
)
if reverse_process:
alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
latents = backward_ddim(
x_t=latents,
alpha_t=alpha_prod_t,
alpha_tm1=alpha_prod_t_prev,
eps_xt=noise_pred,
)
if early_stop_step and i >= early_stop_step:
return latents
return latents
@torch.inference_mode()
def decode_image(self, latents: torch.FloatTensor, **kwargs) -> List["PIL_IMAGE"]:
scaled_latents = 1 / 0.18215 * latents
image = [
self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
]
image = torch.cat(image, dim=0)
return image
@torch.inference_mode()
def torch_to_numpy(self, image) -> List["PIL_IMAGE"]:
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image