Skip to content

Commit

Permalink
update to support diffusers==0.24.0
Browse files Browse the repository at this point in the history
Yujun-Shi committed Jan 29, 2024
1 parent e86e49a commit ebe659a
Showing 7 changed files with 325 additions and 93 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -36,9 +36,10 @@
</p>

## Disclaimer
This is a research project, NOT a commercial product.
This is a research project, NOT a commercial product. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.

## News and Update
* [Jan 29th] Update to support diffusers==0.24.0!
* [Oct 23rd] Code and data of DragBench are released! Please check README under "drag_bench_evaluation" for details.
* [Oct 16th] Integrate [FreeU](https://chenyangsi.top/FreeU/) when dragging generated image.
* [Oct 3rd] Speeding up LoRA training when editing real images. (**Now only around 20s on A100!**)
242 changes: 200 additions & 42 deletions drag_pipeline.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ name: dragdiff
channels:
- pytorch
- defaults
- nvidia
dependencies:
- python=3.8.5
- pip=22.3.1
@@ -31,11 +32,11 @@ dependencies:
- addict==2.4.0
- yapf==0.32.0
- prettytable==3.6.0
- safetensors==0.2.7
- safetensors==0.3.1
- basicsr==1.4.2
- accelerate==0.17.0
- decord==0.6.0
- diffusers==0.17.1
- diffusers==0.24.0
- moviepy==1.0.3
- opencv_python==4.7.0.68
- Pillow==9.4.0
13 changes: 8 additions & 5 deletions utils/attn_utils.py
Original file line number Diff line number Diff line change
@@ -136,7 +136,7 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
is_cross = encoder_hidden_states is not None
@@ -158,15 +158,17 @@ def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
# query = attn.to_q(hidden_states) + lora_scale * attn.to_q.lora_layer(hidden_states)
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states) + lora_scale * attn.to_k.lora_layer(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states) + lora_scale * attn.to_v.lora_layer(encoder_hidden_states)
key, value = attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)

query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))

@@ -176,7 +178,8 @@ def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora
attn.heads, scale=attn.scale)

# linear proj
hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
# hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.to_out[0].lora_layer(hidden_states)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

6 changes: 4 additions & 2 deletions utils/drag_utils.py
Original file line number Diff line number Diff line change
@@ -92,7 +92,8 @@ def drag_diffusion_update(model,

# the init output feature of unet
with torch.no_grad():
unet_output, F0 = model.forward_unet_features(init_code, t, encoder_hidden_states=text_embeddings,
unet_output, F0 = model.forward_unet_features(init_code, t,
encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_0,_ = model.step(unet_output, t, init_code)
# init_code_orig = copy.deepcopy(init_code)
@@ -110,7 +111,8 @@ def drag_diffusion_update(model,
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
unet_output, F1 = model.forward_unet_features(init_code, t, encoder_hidden_states=text_embeddings,
unet_output, F1 = model.forward_unet_features(init_code, t,
encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_updated,_ = model.step(unet_output, t, init_code)

129 changes: 93 additions & 36 deletions utils/lora_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -30,17 +30,16 @@
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0")
check_min_version("0.24.0")


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
@@ -146,6 +145,13 @@ def train_lora(image,
unet = UNet2DConditionModel.from_pretrained(
model_path, subfolder="unet", revision=None
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model_path,
vae=vae,
unet=unet,
text_encoder=text_encoder,
scheduler=noise_scheduler,
torch_dtype=torch.float16)

# set device and dtype
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -158,36 +164,71 @@ def train_lora(image,
vae.to(device, dtype=torch.float16)
text_encoder.to(device, dtype=torch.float16)

# initialize UNet LoRA
unet_lora_attn_procs = {}
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
else:
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)

# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=lora_rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
rank=lora_rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
rank=lora_rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=lora_rank,
)
)

# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())

unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)

# Optimizer creation
params_to_optimize = (unet_lora_layers.parameters())
params_to_optimize = (unet_lora_parameters)
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=lora_lr,
@@ -206,9 +247,11 @@ def train_lora(image,
)

# prepare accelerator
unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
optimizer = accelerator.prepare_optimizer(optimizer)
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
# unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
# optimizer = accelerator.prepare_optimizer(optimizer)
# lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)

unet,optimizer,lr_scheduler = accelerator.prepare(unet,optimizer,lr_scheduler)

# initialize text embeddings
with torch.no_grad():
@@ -221,11 +264,15 @@ def train_lora(image,
)
text_embedding = text_embedding.repeat(lora_batch_size, 1, 1)

# initialize latent distribution
image_transforms = transforms.Compose(
# initialize image transforms
image_transforms_pil = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(512),
]
)
image_transforms_tensor = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
@@ -234,8 +281,14 @@ def train_lora(image,
for step in progress.tqdm(range(lora_step), desc="training LoRA"):
unet.train()
image_batch = []
image_pil_batch = []
for _ in range(lora_batch_size):
image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16)
# first store pil image
image_transformed = image_transforms_pil(Image.fromarray(image))
image_pil_batch.append(image_transformed)

# then store tensor image
image_transformed = image_transforms_tensor(image_transformed).to(device, dtype=torch.float16)
image_transformed = image_transformed.unsqueeze(dim=0)
image_batch.append(image_transformed)

@@ -258,7 +311,9 @@ def train_lora(image,
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

# Predict the noise residual
model_pred = unet(noisy_model_input, timesteps, text_embedding).sample
model_pred = unet(noisy_model_input,
timesteps,
text_embedding).sample

# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
@@ -282,6 +337,7 @@ def train_lora(image,
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
unet_lora_layers = unet_lora_state_dict(unet)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_path_intermediate,
unet_lora_layers=unet_lora_layers,
@@ -294,6 +350,7 @@ def train_lora(image,
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
unet_lora_layers = unet_lora_state_dict(unet)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_path,
unet_lora_layers=unet_lora_layers,
20 changes: 15 additions & 5 deletions utils/ui_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
import torch.nn.functional as F

from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
from diffusers.models.embeddings import ImageProjection
from drag_pipeline import DragPipeline

from torchvision.utils import save_image
@@ -268,7 +269,7 @@ def run_drag(source_image,
# the latent code resolution is too small, only 64*64
invert_code = model.invert(source_image,
prompt,
text_embeddings=text_embeddings,
encoder_hidden_states=text_embeddings,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step)
@@ -282,12 +283,21 @@ def run_drag(source_image,
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]

# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# update according to the given supervision
# convert dtype to float for optimization
init_code = init_code.float()
text_embeddings = text_embeddings.float()
model.unet = model.unet.float()
updated_init_code = drag_diffusion_update(model, init_code,
text_embeddings, t, handle_points, target_points, mask, args)

updated_init_code = drag_diffusion_update(
model,
init_code,
text_embeddings,
t,
handle_points,
target_points,
mask,
args)

updated_init_code = updated_init_code.half()
text_embeddings = text_embeddings.half()
model.unet = model.unet.half()
@@ -309,7 +319,7 @@ def run_drag(source_image,
# inference the synthesized image
gen_image = model(
prompt=args.prompt,
text_embeddings=torch.cat([text_embeddings, text_embeddings], dim=0),
encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0),
batch_size=2,
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
guidance_scale=args.guidance_scale,

0 comments on commit ebe659a

Please sign in to comment.