Skip to content

Commit

Permalink
fix compability with old version of comfy
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 14, 2025
1 parent 3bf95c0 commit a7caa12
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
27 changes: 17 additions & 10 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe,
attn_mask, ca_idx, timesteps):
original_hidden_states = img

extra_block_forward_kwargs = {}
if attn_mask is not None:
extra_block_forward_kwargs["attn_mask"] = attn_mask

for i, block in enumerate(self.double_blocks):
if i < 1:
continue
Expand All @@ -580,7 +584,7 @@ def block_wrap(args):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
**extra_block_forward_kwargs)
return out

out = blocks_replace[("double_block",
Expand All @@ -589,7 +593,7 @@ def block_wrap(args):
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
**extra_block_forward_kwargs
}, {
"original_block": block_wrap
})
Expand All @@ -600,7 +604,7 @@ def block_wrap(args):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
**extra_block_forward_kwargs)

if control is not None: # Controlnet
control_i = control.get("input")
Expand Down Expand Up @@ -630,21 +634,21 @@ def block_wrap(args):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
**extra_block_forward_kwargs)
return out

out = blocks_replace[("single_block",
i)]({
"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
**extra_block_forward_kwargs
}, {
"original_block": block_wrap
})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, **extra_block_forward_kwargs)

if control is not None: # Controlnet
control_o = control.get("output")
Expand Down Expand Up @@ -709,8 +713,11 @@ def forward_orig(
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

blocks_replace = patches_replace.get("dit", {})
ca_idx = 0
extra_block_forward_kwargs = {}
if attn_mask is not None:
extra_block_forward_kwargs["attn_mask"] = attn_mask
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i >= 1:
break
Expand All @@ -723,7 +730,7 @@ def block_wrap(args):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
**extra_block_forward_kwargs)
return out

out = blocks_replace[("double_block",
Expand All @@ -732,7 +739,7 @@ def block_wrap(args):
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
**extra_block_forward_kwargs
}, {
"original_block": block_wrap
})
Expand All @@ -743,7 +750,7 @@ def block_wrap(args):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
**extra_block_forward_kwargs)

if control is not None: # Controlnet
control_i = control.get("input")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast."
version = "1.1.5"
version = "1.1.6"
license = {file = "LICENSE"}

[project.urls]
Expand Down

0 comments on commit a7caa12

Please sign in to comment.