Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Jun 28, 2024
2 parents d041789 + a0ea689 commit 989e506
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,12 +786,14 @@ class DynamiCrafterBatchInterpolation:
def INPUT_TYPES(s):
return {"required": {
"model": ("DCMODEL",),
"clip_vision": ("CLIP_VISION",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"images": ("IMAGE",),
"steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 20.0, "step": 0.01}),
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 20.0, "step": 0.01}),
"frames": ("INT", {"default": 16, "min": 1, "max": 100, "step": 1}),
"prompt": ("STRING", {"multiline": True, "default": "",}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"fs": ("INT", {"default": 10, "min": 2, "max": 100, "step": 1}),
"keep_model_loaded": ("BOOLEAN", {"default": True}),
Expand All @@ -813,7 +815,8 @@ def INPUT_TYPES(s):
FUNCTION = "process"
CATEGORY = "DynamiCrafterWrapper"

def process(self, model, images, prompt, cfg, steps, eta, seed, fs, keep_model_loaded, frames, vae_dtype, cut_near_keyframes):
def process(self, model, images, clip_vision, positive, negative, cfg, steps, eta, seed, fs, keep_model_loaded,
frames, vae_dtype, cut_near_keyframes):
assert images.shape[0] > 1, "DynamiCrafterBatchInterpolation needs at least 2 images"
device = mm.get_torch_device()
mm.unload_all_models()
Expand Down Expand Up @@ -847,7 +850,6 @@ def process(self, model, images, prompt, cfg, steps, eta, seed, fs, keep_model_l
if orig_H % 64 != 0 or orig_W % 64 != 0:
images = F.interpolate(images, size=(H, W), mode="bicubic")

split_prompt = split_and_trim(prompt)
out = []
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device)
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
Expand All @@ -869,18 +871,11 @@ def process(self, model, images, prompt, cfg, steps, eta, seed, fs, keep_model_l

self.model.first_stage_model.to('cpu')

self.model.cond_stage_model.to(device)
self.model.embedder.to(device)
self.model.image_proj_model.to(device)

try:
text_emb = self.model.get_learned_conditioning([split_prompt[i]])
print("Prompt: ", split_prompt[i])
except:
text_emb = self.model.get_learned_conditioning([split_prompt[0]])
print("Prompt: ", split_prompt[0])
text_emb = positive[0][0].to(device)
cond_images = clip_vision.encode_image(image.permute(0, 2, 3, 1))['last_hidden_state'].to(device)

cond_images = self.model.embedder(image)
img_emb = self.model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)

Expand All @@ -895,13 +890,14 @@ def process(self, model, images, prompt, cfg, steps, eta, seed, fs, keep_model_l
guidance_rescale = 0.7

## construct unconditional guidance
if cfg != 1.0:
uc_emb = self.model.get_learned_conditioning([""])
if cfg != 1.0:
uc_emb = negative[0][0].to(device)
## process image embedding token
if hasattr(self.model, 'embedder'):
uc_img = torch.zeros(noise_shape[0],3,224,224).to(self.model.device)
uc_img = torch.rand(noise_shape[0], 3, 224, 224).to(self.model.device)
## img: b c h w >> b l c
uc_img = self.model.embedder(uc_img)
uc_img = clip_vision.encode_image(uc_img.permute(0, 2, 3, 1))['last_hidden_state'].to(
self.model.device)
uc_img = self.model.image_proj_model(uc_img)
uc_emb = torch.cat([uc_emb, uc_img], dim=1)
if isinstance(cond, dict):
Expand All @@ -911,8 +907,6 @@ def process(self, model, images, prompt, cfg, steps, eta, seed, fs, keep_model_l
uc = uc_emb
else:
uc = None

self.model.cond_stage_model.to('cpu')
self.model.embedder.to('cpu')
self.model.image_proj_model.to('cpu')

Expand Down

0 comments on commit 989e506

Please sign in to comment.