Skip to content

Commit

Permalink
support flux controlnet and fix no cache context
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 13, 2025
1 parent 57b3d61 commit 1ba6b5f
Show file tree
Hide file tree
Showing 5 changed files with 1,131 additions and 13 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ git clone https://github.com/chengzeyi/Comfy-WaveSpeed.git

You can find demo workflows in the `workflows` folder.

[FLUX.1-dev with First Block Cache and Compilation](./workflows/flux.json)

[LTXV with First Block Cache and Compilation](./workflows/ltxv.json)

[HunyuanVideo with First Block Cache](./workflows/hunyuan_video.json)

[SDXL with First Block Cache](./workflows/sdxl.json)
| Workflow | Path |
| - | - |
| FLUX.1-dev with First Block Cache and Compilation | [workflows/flux.json](./workflows/flux.json)
| FLUX.1-dev ControlNet with First Block Cache and Compilation | [workflows/flux_controlnet.json](./workflows/flux_controlnet.json)
| LTXV with First Block Cache and Compilation | [workflows/ltxv.json](./workflows/ltxv.json)
| HunyuanVideo with First Block Cache | [workflows/hunyuan_video.json](./workflows/hunyuan_video.json)
| SDXL with First Block Cache | [workflows/sdxl.json](./workflows/sdxl.json)

**NOTE**: The `Compile Model+` node requires your computation to meet some software and hardware requirements, please refer to the [Enhanced `torch.compile`](#enhanced-torchcompile) section for more information.
If you have problems with the compilation node, you can remove it from the workflow and only use the `Apply First Block Cache` node.
Expand Down
22 changes: 19 additions & 3 deletions fbcache_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,17 @@ def validate_use_cache(use_cached):
model = model.clone()
diffusion_model = model.get_model_object(object_to_patch)

if diffusion_model.__class__.__name__ == "UNetModel":
if diffusion_model.__class__.__name__ in ("UNetModel", "FLUX"):

patch__forward = first_block_cache.create_patch_unet_model__forward(
if diffusion_model.__class__.__name__ == "UNetModel":
create_patch_function = first_block_cache.create_patch_unet_model__forward
elif diffusion_model.__class__.__name__ == "FLUX":
create_patch_function = first_block_cache.create_patch_flux_forward_orig
else:
raise ValueError(
f"Unsupported model {diffusion_model.__class__.__name__}")

patch_foward = create_patch_function(
diffusion_model,
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
Expand All @@ -144,7 +152,11 @@ def model_unet_function_wrapper(model_function, kwargs):
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

with patch__forward():
if first_block_cache.get_current_cache_context() is None:
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

with patch_foward():
return model_function(input, timestep, **c)
except model_management.InterruptProcessingException as exc:
prev_timestep = None
Expand Down Expand Up @@ -226,6 +238,10 @@ def model_unet_function_wrapper(model_function, kwargs):
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

if first_block_cache.get_current_cache_context() is None:
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

with unittest.mock.patch.object(
diffusion_model,
double_blocks_name,
Expand Down
218 changes: 216 additions & 2 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def call_remaining_transformer_blocks(self,
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual


# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24
def create_patch_unet_model__forward(model,
*,
residual_diff_threshold,
Expand All @@ -379,7 +380,7 @@ def create_patch_unet_model__forward(model,

def call_remaining_blocks(self, transformer_options, control,
transformer_patches, hs, h, *args, **kwargs):
original_h = h
original_hidden_states = h

for id, module in enumerate(self.input_blocks):
if id < 2:
Expand Down Expand Up @@ -421,7 +422,7 @@ def call_remaining_blocks(self, transformer_options, control,
output_shape = None
h = forward_timestep_embed(module, h, *args, output_shape,
**kwargs)
hidden_states_residual = h - original_h
hidden_states_residual = h - original_hidden_states
return h, hidden_states_residual

def unet_model__forward(self,
Expand Down Expand Up @@ -546,3 +547,216 @@ def patch__forward():
yield

return patch__forward


# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24
def create_patch_flux_forward_orig(model,
*,
residual_diff_threshold,
validate_can_use_cache_function=None):
from torch import Tensor
from comfy.ldm.flux.model import timestep_embedding

def call_remaining_blocks(self, blocks_replace, control, img, txt, vec,
pe, attn_mask):
original_hidden_states = img

for i, block in enumerate(self.double_block):
if i < 1:
continue
if ("double_block", i) in blocks_replace:

def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out

out = blocks_replace[("double_block",
i)]({
"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
}, {
"original_block": block_wrap
})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)

if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add

img = torch.cat((txt, img), 1)

for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:

def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out

out = blocks_replace[("single_block",
i)]({
"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
}, {
"original_block": block_wrap
})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)

if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1]:, ...] += add

img = img[:, txt.shape[1]:, ...]

img = img.contiguous()
hidden_states_residual = img - original_hidden_states
return img, hidden_states_residual

def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError(
"Input img and txt tensors must have 3 dimensions.")

# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(
timestep_embedding(guidance, 256).to(img.dtype))

vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)

ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i >= 1:
break
if ("double_block", i) in blocks_replace:

def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out

out = blocks_replace[("double_block",
i)]({
"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
}, {
"original_block": block_wrap
})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)

if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add

if i == 0:
first_hidden_states_residual = img
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=residual_diff_threshold,
)
if validate_can_use_cache_function is not None:
can_use_cache = validate_can_use_cache_function(can_use_cache)
if not can_use_cache:
set_buffer("first_hidden_states_residual",
first_hidden_states_residual)
del first_hidden_states_residual

torch._dynamo.graph_break()
if can_use_cache:
img = apply_prev_hidden_states_residual(img)
else:
img, hidden_states_residual = call_remaining_blocks(
self,
blocks_replace,
control,
img,
txt,
vec,
pe,
attn_mask,
)
set_buffer("hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break()

img = self.final_layer(img,
vec) # (N, T, patch_size ** 2 * out_channels)
return img

new_forward_orig = forward_orig.__get__(model)

@contextlib.contextmanager
def patch_forward_orig():
with unittest.mock.patch.object(model, "forward_orig", new_forward_orig):
yield

return patch_forward_orig
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast."
version = "1.1.1"
version = "1.1.2"
license = {file = "LICENSE"}

[project.urls]
Expand Down
Loading

0 comments on commit 1ba6b5f

Please sign in to comment.