Skip to content

Commit

Permalink
support non-native LTXV model for first block cache
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 10, 2025
1 parent 0a3c320 commit f5dfdb6
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 51 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Some configurations for different models that you can try:
| `flux-dev.safetensors` with `fp8_e4m3fn_fast` | 28 | 0.12 |
| `ltx-video-2b-v0.9.1.safetensors` | 30 | 0.1 |

It supports many models like `FLUX`, `LTXV (native)` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues!
It supports many models like `FLUX`, `LTXV (native and non-native)` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues!

See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAttention/blob/main/doc/fastest_flux.md#apply-first-block-cache-on-flux1-dev) for more information and detailed comparison on quality and speed.

Expand Down
105 changes: 72 additions & 33 deletions fbcache_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,52 @@ def INPUT_TYPES(s):
"residual_diff_threshold": (
"FLOAT",
{
"default": 0.0,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Controls the tolerance for caching with lower values being more strict. Setting this to 0 disables the FBCache effect.",
"default":
0.0,
"min":
0.0,
"max":
1.0,
"step":
0.001,
"tooltip":
"Controls the tolerance for caching with lower values being more strict. Setting this to 0 disables the FBCache effect.",
},
),
"start": (
"FLOAT", {
"default": 0.0,
"step": 0.01,
"max": 1.0,
"min": 0.0,
"tooltip": "Start time as a percentage of sampling where the FBCache effect can apply. Example: 0.0 would signify 0% (the beginning of sampling), 0.5 would signify 50%.",
"FLOAT",
{
"default":
0.0,
"step":
0.01,
"max":
1.0,
"min":
0.0,
"tooltip":
"Start time as a percentage of sampling where the FBCache effect can apply. Example: 0.0 would signify 0% (the beginning of sampling), 0.5 would signify 50%.",
},
),
"end": (
"FLOAT", {
"default": 1.0,
"step": 0.01,
"max": 1.0,
"min": 0.0,
"tooltip": "End time as a percentage of sampling where the FBCache effect can apply. Example: 1.0 would signify 100% (the end of sampling), 0.5 would signify 50%.",
}
),
"end": ("FLOAT", {
"default":
1.0,
"step":
0.01,
"max":
1.0,
"min":
0.0,
"tooltip":
"End time as a percentage of sampling where the FBCache effect can apply. Example: 1.0 would signify 100% (the end of sampling), 0.5 would signify 50%.",
}),
"max_consecutive_cache_hits": (
"INT", {
"default": -1,
"tooltip": "Allows limiting how many cached results can be used in a row. For example, setting this to 1 will mean there will be at least one full model call after each cached result. Set to 0 or lower to disable cache limiting.",
"INT",
{
"default":
-1,
"tooltip":
"Allows limiting how many cached results can be used in a row. For example, setting this to 1 will mean there will be at least one full model call after each cached result. Set to 0 or lower to disable cache limiting.",
},
),
}
Expand All @@ -72,14 +89,19 @@ def patch(
end=1.0,
):
if residual_diff_threshold <= 0:
return (model,)
return (model, )
prev_timestep = None
current_timestep = None
consecutive_cache_hits = 0

model = model.clone()
diffusion_model = model.get_model_object(object_to_patch)

is_non_native_ltxv = False
if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
is_non_native_ltxv = True
diffusion_model = diffusion_model.transformer

double_blocks_name = None
single_blocks_name = None
if hasattr(diffusion_model, "transformer_blocks"):
Expand All @@ -89,25 +111,41 @@ def patch(
elif hasattr(diffusion_model, "joint_blocks"):
double_blocks_name = "joint_blocks"
else:
raise ValueError("No transformer blocks found")
raise ValueError("No double blocks found")

if hasattr(diffusion_model, "single_blocks"):
single_blocks_name = "single_blocks"

if is_non_native_ltxv:
original_create_skip_layer_mask = getattr(
diffusion_model, "create_skip_layer_mask", None)
if original_create_skip_layer_mask is not None:
original_double_blocks = getattr(diffusion_model,
double_blocks_name)

def new_create_skip_layer_mask(self, *args, **kwargs):
with unittest.mock.patch.object(self, double_blocks_name,
original_double_blocks):
return original_create_skip_layer_mask(*args, **kwargs)
return original_create_skip_layer_mask(*args, **kwargs)

diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__(
diffusion_model)

using_validation = max_consecutive_cache_hits > 0 or start > 0 or end < 1
if using_validation:
model_sampling = model.get_model_object("model_sampling")
start_sigma, end_sigma = (
float(model_sampling.percent_to_sigma(pct))
for pct in (start, end)
)
start_sigma, end_sigma = (float(
model_sampling.percent_to_sigma(pct)) for pct in (start, end))
del model_sampling

@torch.compiler.disable()
def validate_use_cache(use_cached):
nonlocal consecutive_cache_hits
use_cached = use_cached and end_sigma <= current_timestep <= start_sigma
use_cached = use_cached and (max_consecutive_cache_hits < 1 or consecutive_cache_hits < max_consecutive_cache_hits)
use_cached = use_cached and (max_consecutive_cache_hits < 1
or consecutive_cache_hits
< max_consecutive_cache_hits)
consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0
return use_cached
else:
Expand All @@ -123,8 +161,8 @@ def validate_use_cache(use_cached):
validate_can_use_cache_function=validate_use_cache,
cat_hidden_states_first=diffusion_model.__class__.__name__ ==
"HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.__name__ ==
"LTXVModel",
return_hidden_states_only=diffusion_model.__class__.__name__
== "LTXVModel" or is_non_native_ltxv,
clone_original_hidden_states=diffusion_model.__class__.__name__
== "LTXVModel",
return_hidden_states_first=diffusion_model.__class__.__name__
Expand Down Expand Up @@ -158,7 +196,8 @@ def model_unet_function_wrapper(model_function, kwargs):
diffusion_model,
single_blocks_name,
dummy_single_transformer_blocks,
) if single_blocks_name is not None else contextlib.nullcontext():
) if single_blocks_name is not None else contextlib.nullcontext(
):
return model_function(input, timestep, **c)
except model_management.InterruptProcessingException as exc:
prev_timestep = None
Expand Down
58 changes: 42 additions & 16 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,38 +143,54 @@ def __init__(
self.clone_original_hidden_states = clone_original_hidden_states

def forward(self, *args, **kwargs):
img_arg_name = None
if "img" in kwargs:
img_arg_name = "img"
elif "hidden_states" in kwargs:
img_arg_name = "hidden_states"
txt_arg_name = None
if "txt" in kwargs:
txt_arg_name = "txt"
elif "context" in kwargs:
txt_arg_name = "context"
elif "encoder_hidden_states" in kwargs:
txt_arg_name = "encoder_hidden_states"
if self.accept_hidden_states_first:
if args:
img = args[0]
args = args[1:]
else:
img = kwargs.pop("img")
img = kwargs.pop(img_arg_name)
if args:
txt = args[0]
args = args[1:]
else:
txt = kwargs.pop("txt" if "txt" in kwargs else "context")
txt = kwargs.pop(txt_arg_name)
else:
if args:
txt = args[0]
args = args[1:]
else:
txt = kwargs.pop("txt" if "txt" in kwargs else "context")
txt = kwargs.pop(txt_arg_name)
if args:
img = args[0]
args = args[1:]
else:
img = kwargs.pop("img")
img = kwargs.pop(img_arg_name)
hidden_states = img
encoder_hidden_states = txt
if self.residual_diff_threshold <= 0.0:
for block in self.transformer_blocks:
if self.accept_hidden_states_first:
if txt_arg_name == "encoder_hidden_states":
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs)
else:
hidden_states = block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if self.accept_hidden_states_first:
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
else:
hidden_states = block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
Expand All @@ -200,12 +216,16 @@ def forward(self, *args, **kwargs):
if self.clone_original_hidden_states:
original_hidden_states = original_hidden_states.clone()
first_transformer_block = self.transformer_blocks[0]
if self.accept_hidden_states_first:
if txt_arg_name == "encoder_hidden_states":
hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs)
hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs)
else:
hidden_states = first_transformer_block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if self.accept_hidden_states_first:
hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs)
else:
hidden_states = first_transformer_block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
Expand Down Expand Up @@ -237,6 +257,7 @@ def forward(self, *args, **kwargs):
) = self.call_remaining_transformer_blocks(hidden_states,
encoder_hidden_states,
*args,
txt_arg_name=txt_arg_name,
**kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual",
Expand All @@ -252,19 +273,24 @@ def forward(self, *args, **kwargs):

def call_remaining_transformer_blocks(self, hidden_states,
encoder_hidden_states, *args,
txt_arg_name=None,
**kwargs):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
if self.clone_original_hidden_states:
original_hidden_states = original_hidden_states.clone()
original_encoder_hidden_states = original_encoder_hidden_states.clone()
for block in self.transformer_blocks[1:]:
if self.accept_hidden_states_first:
if txt_arg_name == "encoder_hidden_states":
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs)
else:
hidden_states = block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if self.accept_hidden_states_first:
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
else:
hidden_states = block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast."
version = "1.0.18"
version = "1.0.19"
license = {file = "LICENSE"}

[project.urls]
Expand Down

0 comments on commit f5dfdb6

Please sign in to comment.