Skip to content

Commit

Permalink
fix LTXV support
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 9, 2025
1 parent 27b8ecd commit 3b3eac4
Show file tree
Hide file tree
Showing 5 changed files with 744 additions and 4 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ git clone https://github.com/chengzeyi/Comfy-WaveSpeed.git

# Usage

## Demo Workflow
## Demo Workflows

You can find demo workflows in the `workflows` folder.

[FLUX.1-dev with First Block Cache and Compilation](./workflows/flux.json)

[LTXV with First Block Cache and Compilation](./workflows/ltxv.json)

## Dynamic Caching ([First Block Cache](https://github.com/chengzeyi/ParaAttention?tab=readme-ov-file#first-block-cache-our-dynamic-caching))

Inspired by TeaCache and other denoising caching algorithms, we introduce [First Block Cache (FBCache)](https://github.com/chengzeyi/ParaAttention?tab=readme-ov-file#first-block-cache-our-dynamic-caching) to use the residual output of the first transformer block as the cache indicator.
Expand All @@ -42,7 +46,13 @@ This can significantly reduce the computation cost of the model, achieving a spe
To use first block cache, simply add the `wavespeed->Apply First Block Cache` node to your workflow after your `Load Diffusion Model` node and adjust the `residual_diff_threashold` value to a suitable value for your model, for example: `0.07` for `flux-dev.safetensors` with `fp8_e4m3fn_fast` and 28 steps.
It is expected to see a speedup of 1.5x to 3.0x with acceptable accuracy loss.

It supports many models like `FLUX`, `LTXV` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues!
Some configurations for different models that you can try:

| Model | Steps | `residual_diff_threashold` |
| `flux-dev.safetensors` with `fp8_e4m3fn_fast` | 28 | 0.07 |
| `ltx-video-2b-v0.9.1.safetensors` | 30 | 0.051 |

It supports many models like `FLUX`, `LTXV (native)` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues!

See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAttention/blob/main/doc/fastest_flux.md#apply-first-block-cache-on-flux1-dev) for more information and detailed comparison on quality and speed.

Expand Down
2 changes: 1 addition & 1 deletion fbcache_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import unittest
import torch

from . import utils
from . import first_block_cache


Expand Down Expand Up @@ -56,6 +55,7 @@ def patch(
residual_diff_threshold=residual_diff_threshold,
cat_hidden_states_first=diffusion_model.__class__.__name__ == "HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.__name__ == "LTXVModel",
clone_original_hidden_states=diffusion_model.__class__.__name__ == "LTXVModel",
)
])
dummy_single_transformer_blocks = torch.nn.ModuleList()
Expand Down
7 changes: 7 additions & 0 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
return_hidden_states_first=True,
cat_hidden_states_first=False,
return_hidden_states_only=False,
clone_original_hidden_states=False,
):
super().__init__()
self.transformer_blocks = transformer_blocks
Expand All @@ -133,6 +134,7 @@ def __init__(
self.return_hidden_states_first = return_hidden_states_first
self.cat_hidden_states_first = cat_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only
self.clone_original_hidden_states = clone_original_hidden_states

def forward(self, img, txt=None, *args, context=None, **kwargs):
if context is not None:
Expand Down Expand Up @@ -165,6 +167,8 @@ def forward(self, img, txt=None, *args, context=None, **kwargs):
(encoder_hidden_states, hidden_states))

original_hidden_states = hidden_states
if self.clone_original_hidden_states:
original_hidden_states = original_hidden_states.clone()
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs)
Expand Down Expand Up @@ -214,6 +218,9 @@ def call_remaining_transformer_blocks(self, hidden_states,
**kwargs):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
if self.clone_original_hidden_states:
original_hidden_states = original_hidden_states.clone()
original_encoder_hidden_states = original_encoder_hidden_states.clone()
for block in self.transformer_blocks[1:]:
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast."
version = "1.0.10"
version = "1.0.11"
license = {file = "LICENSE"}

[project.urls]
Expand Down
Loading

0 comments on commit 3b3eac4

Please sign in to comment.