Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Illustrate Dragging Trajectory #75

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion drag_bench_evaluation/run_drag_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run_drag(source_image,

# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# update according to the given supervision
updated_init_code = drag_diffusion_update(model, init_code,
updated_init_code, opt_seq = drag_diffusion_update(model, init_code,
None, t, handle_points, target_points, mask, args)

# hijack the attention module
Expand Down
16 changes: 16 additions & 0 deletions drag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
prompt = gr.Textbox(label="Prompt")
lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
lora_status_bar = gr.Textbox(label="display LoRA training status")
results_path = gr.Textbox(value="./results", label="Folder for results")

# algorithm specific parameters
with gr.Tab("Drag Config"):
Expand All @@ -79,6 +80,11 @@
latent_lr = gr.Number(value=0.01, label="latent lr")
start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
save_optimization_seq_rgb = gr.Checkbox(
value=False,
label="Save Opt Seq",
info="Each step of optimization latent is saved in png file. Will take more time."
)

with gr.Tab("Base Model Config"):
with gr.Row():
Expand Down Expand Up @@ -136,6 +142,7 @@
with gr.Row():
pos_prompt_gen = gr.Textbox(label="Positive Prompt")
neg_prompt_gen = gr.Textbox(label="Negative Prompt")
results_path_gen = gr.Textbox(value="./results", label="Folder for results")

with gr.Tab("Generation Config"):
with gr.Row():
Expand Down Expand Up @@ -218,6 +225,11 @@
latent_lr_gen = gr.Number(value=0.01, label="latent lr")
start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False)
save_optimization_seq_rgb_gen = gr.Checkbox(
value=False,
label="Save Opt Seq",
info="Each step of optimization latent is saved in png file. Will take more time."
)

# event definition
# event for dragging user-input real image
Expand Down Expand Up @@ -265,6 +277,8 @@
lora_path,
start_step,
start_layer,
results_path,
save_optimization_seq_rgb,
],
[output_image]
)
Expand Down Expand Up @@ -343,6 +357,8 @@
b2_gen,
s1_gen,
s2_gen,
results_path_gen,
save_optimization_seq_rgb_gen,
],
[output_image_gen]
)
Expand Down
8 changes: 6 additions & 2 deletions utils/drag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def drag_diffusion_update(model,
# prepare optimizable init_code and optimizer
init_code.requires_grad_(True)
optimizer = torch.optim.Adam([init_code], lr=args.lr)
opt_seq = [init_code.detach().clone()]

# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
Expand Down Expand Up @@ -157,8 +158,9 @@ def drag_diffusion_update(model,
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
opt_seq.append(init_code.detach().clone())

return init_code
return init_code, opt_seq

def drag_diffusion_update_gen(model,
init_code,
Expand Down Expand Up @@ -218,6 +220,7 @@ def drag_diffusion_update_gen(model,

# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
opt_seq = [init_code.detach().clone()]
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
if args.guidance_scale > 1.:
Expand Down Expand Up @@ -281,6 +284,7 @@ def drag_diffusion_update_gen(model,
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
opt_seq.append(init_code.detach().clone())

return init_code
return init_code, opt_seq

53 changes: 49 additions & 4 deletions utils/ui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def run_drag(source_image,
lora_path,
start_step,
start_layer,
save_dir="./results"
save_dir="./results",
save_seq=True,
):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Expand Down Expand Up @@ -288,7 +289,7 @@ def run_drag(source_image,
text_embeddings = text_embeddings.float()
model.unet = model.unet.float()

updated_init_code = drag_diffusion_update(
updated_init_code, opt_seq = drag_diffusion_update(
model,
init_code,
text_embeddings,
Expand Down Expand Up @@ -345,6 +346,27 @@ def run_drag(source_image,
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))

if save_seq:
os.mkdir(os.path.join(save_dir, save_prefix))
# save list of latents in pt file
torch.save(opt_seq, os.path.join(save_dir, save_prefix, 'opt_seq.pt'))
for i in range(0, len(opt_seq), 2):
# denoise latents and save
latents = torch.cat([opt_seq[i].half(), opt_seq[i+1].half() if i+1 < len(opt_seq) else opt_seq[i].half()], dim=0)
gen_image_seq = model(
args.prompt,
encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0),
batch_size=2,
latents=latents,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)
gen_image_seq = F.interpolate(gen_image_seq, (full_h, full_w), mode='bilinear')
save_image(gen_image_seq[0].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i}.png'))
if i+1 < len(opt_seq):
save_image(gen_image_seq[1].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i+1}.png'))

out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
Expand Down Expand Up @@ -464,7 +486,9 @@ def run_drag_gen(
b2,
s1,
s2,
save_dir="./results"):
save_dir="./results",
save_seq=True,
):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
Expand Down Expand Up @@ -572,7 +596,7 @@ def run_drag_gen(
init_code = init_code.to(torch.float32)
text_embeddings = text_embeddings.to(torch.float32)
model.unet = model.unet.to(torch.float32)
updated_init_code = drag_diffusion_update_gen(model, init_code,
updated_init_code, opt_seq = drag_diffusion_update_gen(model, init_code,
text_embeddings, t, handle_points, target_points, mask, args)
updated_init_code = updated_init_code.to(torch.float16)
text_embeddings = text_embeddings.to(torch.float16)
Expand Down Expand Up @@ -619,6 +643,27 @@ def run_drag_gen(
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))

if save_seq:
os.mkdir(os.path.join(save_dir, save_prefix))
# save list of latents in pt file
torch.save(opt_seq, os.path.join(save_dir, save_prefix, 'opt_seq.pt'))
for i in range(0, len(opt_seq), 2):
# denoise latents and save
latents = torch.cat([opt_seq[i].half(), opt_seq[i+1].half() if i+1 < len(opt_seq) else opt_seq[i].half()], dim=0)
gen_image_seq = model(
args.prompt,
encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0),
batch_size=2,
latents=latents,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)
gen_image_seq = F.interpolate(gen_image_seq, (full_h, full_w), mode='bilinear')
save_image(gen_image_seq[0].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i}.png'))
if i+1 < len(opt_seq):
save_image(gen_image_seq[1].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i+1}.png'))

out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
Expand Down