Skip to content

Commit

Permalink
support SD3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 13, 2025
1 parent 85f6fdf commit e73a08f
Show file tree
Hide file tree
Showing 4 changed files with 685 additions and 17 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ You can find demo workflows in the `workflows` folder.
| FLUX.1-dev ControlNet with First Block Cache and Compilation | [workflows/flux_controlnet.json](./workflows/flux_controlnet.json)
| LTXV with First Block Cache and Compilation | [workflows/ltxv.json](./workflows/ltxv.json)
| HunyuanVideo with First Block Cache | [workflows/hunyuan_video.json](./workflows/hunyuan_video.json)
| SD3.5 with First Block Cache and Compilation | [workflows/sd3.5.json](./workflows/sd3.5.json)
| SDXL with First Block Cache | [workflows/sdxl.json](./workflows/sdxl.json)

**NOTE**: The `Compile Model+` node requires your computation to meet some software and hardware requirements, please refer to the [Enhanced `torch.compile`](#enhanced-torchcompile) section for more information.
Expand All @@ -54,17 +55,18 @@ This can significantly reduce the computation cost of the model, achieving a spe
To use first block cache, simply add the `wavespeed->Apply First Block Cache` node to your workflow after your `Load Diffusion Model` node and adjust the `residual_diff_threashold` value to a suitable value for your model, for example: `0.12` for `flux-dev.safetensors` with `fp8_e4m3fn_fast` and 28 steps.
It is expected to see a speedup of 1.5x to 3.0x with acceptable accuracy loss.

It supports many models like `FLUX`, `LTXV (native and non-native)`, `HunyuanVideo (native)`, `SD3.5` and `SDXL`, feel free to try it out and let us know if you have any issues!

Some configurations for different models that you can try:

| Model | Steps | `residual_diff_threashold` |
| - | - | - |
| `flux-dev.safetensors` with `fp8_e4m3fn_fast` | 28 | 0.12 |
| `ltx-video-2b-v0.9.1.safetensors` | 30 | 0.1 |
| `hunyuan_video_720_cfgdistill_fp8_e4m3fn.safetensors` | 20 | 0.1 |
| `sd3.5_large_fp8_scaled.safetensors` | 30 | 0.12 |
| `sd_xl_base_1.0.safetensors` | 25 | 0.2 |

It supports many models like `FLUX`, `LTXV (native and non-native)`, `HunyuanVideo (native)` and `SDXL`, feel free to try it out and let us know if you have any issues!

See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAttention/blob/main/doc/fastest_flux.md#apply-first-block-cache-on-flux1-dev) for more information and detailed comparison on quality and speed.

![Usage of First Block Cache](./assets/usage_fbcache.png)
Expand Down
37 changes: 23 additions & 14 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ def apply_prev_hidden_states_residual(hidden_states,

encoder_hidden_states_residual = get_buffer(
"encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
if encoder_hidden_states_residual is None:
encoder_hidden_states = None
else:
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()

return hidden_states, encoder_hidden_states

Expand Down Expand Up @@ -294,8 +296,9 @@ def forward(self, *args, **kwargs):
txt_arg_name=txt_arg_name,
**kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual",
encoder_hidden_states_residual)
if encoder_hidden_states_residual is not None:
set_buffer("encoder_hidden_states_residual",
encoder_hidden_states_residual)
torch._dynamo.graph_break()

if self.return_hidden_states_only:
Expand Down Expand Up @@ -359,15 +362,19 @@ def call_remaining_transformer_blocks(self,
dim=1)

hidden_states_shape = hidden_states.shape
encoder_hidden_states_shape = encoder_hidden_states.shape

hidden_states = hidden_states.flatten().contiguous().reshape(
hidden_states_shape)
encoder_hidden_states = encoder_hidden_states.flatten().contiguous(
).reshape(encoder_hidden_states_shape)

if encoder_hidden_states is not None:
encoder_hidden_states_shape = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.flatten().contiguous(
).reshape(encoder_hidden_states_shape)

hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
if encoder_hidden_states is None:
encoder_hidden_states_residual = None
else:
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual


Expand Down Expand Up @@ -557,8 +564,8 @@ def create_patch_flux_forward_orig(model,
from torch import Tensor
from comfy.ldm.flux.model import timestep_embedding

def call_remaining_blocks(self, blocks_replace, control, img, txt, vec,
pe, attn_mask):
def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe,
attn_mask):
original_hidden_states = img

for i, block in enumerate(self.double_blocks):
Expand Down Expand Up @@ -725,7 +732,8 @@ def block_wrap(args):
threshold=residual_diff_threshold,
)
if validate_can_use_cache_function is not None:
can_use_cache = validate_can_use_cache_function(can_use_cache)
can_use_cache = validate_can_use_cache_function(
can_use_cache)
if not can_use_cache:
set_buffer("first_hidden_states_residual",
first_hidden_states_residual)
Expand Down Expand Up @@ -756,7 +764,8 @@ def block_wrap(args):

@contextlib.contextmanager
def patch_forward_orig():
with unittest.mock.patch.object(model, "forward_orig", new_forward_orig):
with unittest.mock.patch.object(model, "forward_orig",
new_forward_orig):
yield

return patch_forward_orig
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast."
version = "1.1.3"
version = "1.1.4"
license = {file = "LICENSE"}

[project.urls]
Expand Down
Loading

0 comments on commit e73a08f

Please sign in to comment.