diff --git a/README.md b/README.md index 3698704..0527713 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ Some configurations for different models that you can try: | `flux-dev.safetensors` with `fp8_e4m3fn_fast` | 28 | 0.12 | | `ltx-video-2b-v0.9.1.safetensors` | 30 | 0.1 | -It supports many models like `FLUX`, `LTXV (native)` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues! +It supports many models like `FLUX`, `LTXV (native and non-native)` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues! See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAttention/blob/main/doc/fastest_flux.md#apply-first-block-cache-on-flux1-dev) for more information and detailed comparison on quality and speed. diff --git a/fbcache_nodes.py b/fbcache_nodes.py index e112a34..8a78900 100644 --- a/fbcache_nodes.py +++ b/fbcache_nodes.py @@ -23,35 +23,52 @@ def INPUT_TYPES(s): "residual_diff_threshold": ( "FLOAT", { - "default": 0.0, - "min": 0.0, - "max": 1.0, - "step": 0.001, - "tooltip": "Controls the tolerance for caching with lower values being more strict. Setting this to 0 disables the FBCache effect.", + "default": + 0.0, + "min": + 0.0, + "max": + 1.0, + "step": + 0.001, + "tooltip": + "Controls the tolerance for caching with lower values being more strict. Setting this to 0 disables the FBCache effect.", }, ), "start": ( - "FLOAT", { - "default": 0.0, - "step": 0.01, - "max": 1.0, - "min": 0.0, - "tooltip": "Start time as a percentage of sampling where the FBCache effect can apply. Example: 0.0 would signify 0% (the beginning of sampling), 0.5 would signify 50%.", + "FLOAT", + { + "default": + 0.0, + "step": + 0.01, + "max": + 1.0, + "min": + 0.0, + "tooltip": + "Start time as a percentage of sampling where the FBCache effect can apply. Example: 0.0 would signify 0% (the beginning of sampling), 0.5 would signify 50%.", }, ), - "end": ( - "FLOAT", { - "default": 1.0, - "step": 0.01, - "max": 1.0, - "min": 0.0, - "tooltip": "End time as a percentage of sampling where the FBCache effect can apply. Example: 1.0 would signify 100% (the end of sampling), 0.5 would signify 50%.", - } - ), + "end": ("FLOAT", { + "default": + 1.0, + "step": + 0.01, + "max": + 1.0, + "min": + 0.0, + "tooltip": + "End time as a percentage of sampling where the FBCache effect can apply. Example: 1.0 would signify 100% (the end of sampling), 0.5 would signify 50%.", + }), "max_consecutive_cache_hits": ( - "INT", { - "default": -1, - "tooltip": "Allows limiting how many cached results can be used in a row. For example, setting this to 1 will mean there will be at least one full model call after each cached result. Set to 0 or lower to disable cache limiting.", + "INT", + { + "default": + -1, + "tooltip": + "Allows limiting how many cached results can be used in a row. For example, setting this to 1 will mean there will be at least one full model call after each cached result. Set to 0 or lower to disable cache limiting.", }, ), } @@ -72,7 +89,7 @@ def patch( end=1.0, ): if residual_diff_threshold <= 0: - return (model,) + return (model, ) prev_timestep = None current_timestep = None consecutive_cache_hits = 0 @@ -80,6 +97,11 @@ def patch( model = model.clone() diffusion_model = model.get_model_object(object_to_patch) + is_non_native_ltxv = False + if diffusion_model.__class__.__name__ == "LTXVTransformer3D": + is_non_native_ltxv = True + diffusion_model = diffusion_model.transformer + double_blocks_name = None single_blocks_name = None if hasattr(diffusion_model, "transformer_blocks"): @@ -89,25 +111,41 @@ def patch( elif hasattr(diffusion_model, "joint_blocks"): double_blocks_name = "joint_blocks" else: - raise ValueError("No transformer blocks found") + raise ValueError("No double blocks found") if hasattr(diffusion_model, "single_blocks"): single_blocks_name = "single_blocks" + if is_non_native_ltxv: + original_create_skip_layer_mask = getattr( + diffusion_model, "create_skip_layer_mask", None) + if original_create_skip_layer_mask is not None: + original_double_blocks = getattr(diffusion_model, + double_blocks_name) + + def new_create_skip_layer_mask(self, *args, **kwargs): + with unittest.mock.patch.object(self, double_blocks_name, + original_double_blocks): + return original_create_skip_layer_mask(*args, **kwargs) + return original_create_skip_layer_mask(*args, **kwargs) + + diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__( + diffusion_model) + using_validation = max_consecutive_cache_hits > 0 or start > 0 or end < 1 if using_validation: model_sampling = model.get_model_object("model_sampling") - start_sigma, end_sigma = ( - float(model_sampling.percent_to_sigma(pct)) - for pct in (start, end) - ) + start_sigma, end_sigma = (float( + model_sampling.percent_to_sigma(pct)) for pct in (start, end)) del model_sampling @torch.compiler.disable() def validate_use_cache(use_cached): nonlocal consecutive_cache_hits use_cached = use_cached and end_sigma <= current_timestep <= start_sigma - use_cached = use_cached and (max_consecutive_cache_hits < 1 or consecutive_cache_hits < max_consecutive_cache_hits) + use_cached = use_cached and (max_consecutive_cache_hits < 1 + or consecutive_cache_hits + < max_consecutive_cache_hits) consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0 return use_cached else: @@ -123,8 +161,8 @@ def validate_use_cache(use_cached): validate_can_use_cache_function=validate_use_cache, cat_hidden_states_first=diffusion_model.__class__.__name__ == "HunyuanVideo", - return_hidden_states_only=diffusion_model.__class__.__name__ == - "LTXVModel", + return_hidden_states_only=diffusion_model.__class__.__name__ + == "LTXVModel" or is_non_native_ltxv, clone_original_hidden_states=diffusion_model.__class__.__name__ == "LTXVModel", return_hidden_states_first=diffusion_model.__class__.__name__ @@ -158,7 +196,8 @@ def model_unet_function_wrapper(model_function, kwargs): diffusion_model, single_blocks_name, dummy_single_transformer_blocks, - ) if single_blocks_name is not None else contextlib.nullcontext(): + ) if single_blocks_name is not None else contextlib.nullcontext( + ): return model_function(input, timestep, **c) except model_management.InterruptProcessingException as exc: prev_timestep = None diff --git a/first_block_cache.py b/first_block_cache.py index 6a94fe2..f6b9e30 100644 --- a/first_block_cache.py +++ b/first_block_cache.py @@ -143,38 +143,54 @@ def __init__( self.clone_original_hidden_states = clone_original_hidden_states def forward(self, *args, **kwargs): + img_arg_name = None + if "img" in kwargs: + img_arg_name = "img" + elif "hidden_states" in kwargs: + img_arg_name = "hidden_states" + txt_arg_name = None + if "txt" in kwargs: + txt_arg_name = "txt" + elif "context" in kwargs: + txt_arg_name = "context" + elif "encoder_hidden_states" in kwargs: + txt_arg_name = "encoder_hidden_states" if self.accept_hidden_states_first: if args: img = args[0] args = args[1:] else: - img = kwargs.pop("img") + img = kwargs.pop(img_arg_name) if args: txt = args[0] args = args[1:] else: - txt = kwargs.pop("txt" if "txt" in kwargs else "context") + txt = kwargs.pop(txt_arg_name) else: if args: txt = args[0] args = args[1:] else: - txt = kwargs.pop("txt" if "txt" in kwargs else "context") + txt = kwargs.pop(txt_arg_name) if args: img = args[0] args = args[1:] else: - img = kwargs.pop("img") + img = kwargs.pop(img_arg_name) hidden_states = img encoder_hidden_states = txt if self.residual_diff_threshold <= 0.0: for block in self.transformer_blocks: - if self.accept_hidden_states_first: + if txt_arg_name == "encoder_hidden_states": hidden_states = block( - hidden_states, encoder_hidden_states, *args, **kwargs) + hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs) else: - hidden_states = block( - encoder_hidden_states, hidden_states, *args, **kwargs) + if self.accept_hidden_states_first: + hidden_states = block( + hidden_states, encoder_hidden_states, *args, **kwargs) + else: + hidden_states = block( + encoder_hidden_states, hidden_states, *args, **kwargs) if not self.return_hidden_states_only: hidden_states, encoder_hidden_states = hidden_states if not self.return_hidden_states_first: @@ -200,12 +216,16 @@ def forward(self, *args, **kwargs): if self.clone_original_hidden_states: original_hidden_states = original_hidden_states.clone() first_transformer_block = self.transformer_blocks[0] - if self.accept_hidden_states_first: + if txt_arg_name == "encoder_hidden_states": hidden_states = first_transformer_block( - hidden_states, encoder_hidden_states, *args, **kwargs) + hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs) else: - hidden_states = first_transformer_block( - encoder_hidden_states, hidden_states, *args, **kwargs) + if self.accept_hidden_states_first: + hidden_states = first_transformer_block( + hidden_states, encoder_hidden_states, *args, **kwargs) + else: + hidden_states = first_transformer_block( + encoder_hidden_states, hidden_states, *args, **kwargs) if not self.return_hidden_states_only: hidden_states, encoder_hidden_states = hidden_states if not self.return_hidden_states_first: @@ -237,6 +257,7 @@ def forward(self, *args, **kwargs): ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, + txt_arg_name=txt_arg_name, **kwargs) set_buffer("hidden_states_residual", hidden_states_residual) set_buffer("encoder_hidden_states_residual", @@ -252,6 +273,7 @@ def forward(self, *args, **kwargs): def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, + txt_arg_name=None, **kwargs): original_hidden_states = hidden_states original_encoder_hidden_states = encoder_hidden_states @@ -259,12 +281,16 @@ def call_remaining_transformer_blocks(self, hidden_states, original_hidden_states = original_hidden_states.clone() original_encoder_hidden_states = original_encoder_hidden_states.clone() for block in self.transformer_blocks[1:]: - if self.accept_hidden_states_first: + if txt_arg_name == "encoder_hidden_states": hidden_states = block( - hidden_states, encoder_hidden_states, *args, **kwargs) + hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs) else: - hidden_states = block( - encoder_hidden_states, hidden_states, *args, **kwargs) + if self.accept_hidden_states_first: + hidden_states = block( + hidden_states, encoder_hidden_states, *args, **kwargs) + else: + hidden_states = block( + encoder_hidden_states, hidden_states, *args, **kwargs) if not self.return_hidden_states_only: hidden_states, encoder_hidden_states = hidden_states if not self.return_hidden_states_first: diff --git a/pyproject.toml b/pyproject.toml index a0dbac9..a773d11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "wavespeed" description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast." -version = "1.0.18" +version = "1.0.19" license = {file = "LICENSE"} [project.urls]