Skip to content

Commit

Permalink
support SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 11, 2025
1 parent 805b67c commit 7b8f76f
Show file tree
Hide file tree
Showing 5 changed files with 1,049 additions and 134 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ You can find demo workflows in the `workflows` folder.

[LTXV with First Block Cache and Compilation](./workflows/ltxv.json)

[SDXL with First Block Cache](./workflows/sdxl.json)

**NOTE**: The compilation node requires your computation to meet some software and hardware requirements, please refer to the [Enhanced `torch.compile`](Enhanced `torch.compile`) section for more information.

## Dynamic Caching ([First Block Cache](https://github.com/chengzeyi/ParaAttention?tab=readme-ov-file#first-block-cache-our-dynamic-caching))

Inspired by TeaCache and other denoising caching algorithms, we introduce [First Block Cache (FBCache)](https://github.com/chengzeyi/ParaAttention?tab=readme-ov-file#first-block-cache-our-dynamic-caching) to use the residual output of the first transformer block as the cache indicator.
Expand All @@ -52,6 +56,7 @@ Some configurations for different models that you can try:
| - | - | - |
| `flux-dev.safetensors` with `fp8_e4m3fn_fast` | 28 | 0.12 |
| `ltx-video-2b-v0.9.1.safetensors` | 30 | 0.1 |
| `sd_xl_base_1.0.safetensors` | 25 | 0.2 |

It supports many models like `FLUX`, `LTXV (native and non-native)` and `HunyuanVideo (native)`, feel free to try it out and let us know if you have any issues!

Expand Down
218 changes: 124 additions & 94 deletions fbcache_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,49 +90,6 @@ def patch(
):
if residual_diff_threshold <= 0:
return (model, )
prev_timestep = None
current_timestep = None
consecutive_cache_hits = 0

model = model.clone()
diffusion_model = model.get_model_object(object_to_patch)

is_non_native_ltxv = False
if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
is_non_native_ltxv = True
diffusion_model = diffusion_model.transformer

double_blocks_name = None
single_blocks_name = None
if hasattr(diffusion_model, "transformer_blocks"):
double_blocks_name = "transformer_blocks"
elif hasattr(diffusion_model, "double_blocks"):
double_blocks_name = "double_blocks"
elif hasattr(diffusion_model, "joint_blocks"):
double_blocks_name = "joint_blocks"
else:
raise ValueError("No double blocks found")

if hasattr(diffusion_model, "single_blocks"):
single_blocks_name = "single_blocks"

if is_non_native_ltxv:
original_create_skip_layer_mask = getattr(
diffusion_model, "create_skip_layer_mask", None)
if original_create_skip_layer_mask is not None:
# original_double_blocks = getattr(diffusion_model,
# double_blocks_name)

def new_create_skip_layer_mask(self, *args, **kwargs):
# with unittest.mock.patch.object(self, double_blocks_name,
# original_double_blocks):
# return original_create_skip_layer_mask(*args, **kwargs)
# return original_create_skip_layer_mask(*args, **kwargs)
raise RuntimeError(
"STG is not supported with FBCache yet")

diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__(
diffusion_model)

using_validation = max_consecutive_cache_hits > 0 or start > 0 or end < 1
if using_validation:
Expand All @@ -153,57 +110,130 @@ def validate_use_cache(use_cached):
else:
validate_use_cache = None

cached_transformer_blocks = torch.nn.ModuleList([
first_block_cache.CachedTransformerBlocks(
None if double_blocks_name is None else getattr(
diffusion_model, double_blocks_name),
None if single_blocks_name is None else getattr(
diffusion_model, single_blocks_name),
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
cat_hidden_states_first=diffusion_model.__class__.__name__ ==
"HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.__name__
== "LTXVModel" or is_non_native_ltxv,
clone_original_hidden_states=diffusion_model.__class__.__name__
== "LTXVModel",
return_hidden_states_first=diffusion_model.__class__.__name__
!= "OpenAISignatureMMDITWrapper",
accept_hidden_states_first=diffusion_model.__class__.__name__
!= "OpenAISignatureMMDITWrapper",
)
])
dummy_single_transformer_blocks = torch.nn.ModuleList()

def model_unet_function_wrapper(model_function, kwargs):
nonlocal prev_timestep, current_timestep, consecutive_cache_hits

try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
current_timestep = t = timestep[0].item()

if prev_timestep is None or t >= prev_timestep:
prev_timestep = t
consecutive_cache_hits = 0
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

with unittest.mock.patch.object(
diffusion_model,
double_blocks_name,
cached_transformer_blocks,
), unittest.mock.patch.object(
diffusion_model,
single_blocks_name,
dummy_single_transformer_blocks,
) if single_blocks_name is not None else contextlib.nullcontext(
):
return model_function(input, timestep, **c)
except model_management.InterruptProcessingException as exc:
prev_timestep = None
raise exc from None
prev_timestep = None
current_timestep = None
consecutive_cache_hits = 0

model = model.clone()
diffusion_model = model.get_model_object(object_to_patch)

if diffusion_model.__class__.__name__ == "UNetModel":

def model_unet_function_wrapper(model_function, kwargs):
nonlocal prev_timestep, current_timestep, consecutive_cache_hits

try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
current_timestep = t = timestep[0].item()

if prev_timestep is None or t >= prev_timestep:
prev_timestep = t
consecutive_cache_hits = 0
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

with first_block_cache.patch_unet_model__forward(
diffusion_model,
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache
):
return model_function(input, timestep, **c)
except model_management.InterruptProcessingException as exc:
prev_timestep = None
raise exc from None
else:
is_non_native_ltxv = False
if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
is_non_native_ltxv = True
diffusion_model = diffusion_model.transformer

double_blocks_name = None
single_blocks_name = None
if hasattr(diffusion_model, "transformer_blocks"):
double_blocks_name = "transformer_blocks"
elif hasattr(diffusion_model, "double_blocks"):
double_blocks_name = "double_blocks"
elif hasattr(diffusion_model, "joint_blocks"):
double_blocks_name = "joint_blocks"
else:
raise ValueError(
f"No double blocks found for {diffusion_model.__class__.__name__}"
)

if hasattr(diffusion_model, "single_blocks"):
single_blocks_name = "single_blocks"

if is_non_native_ltxv:
original_create_skip_layer_mask = getattr(
diffusion_model, "create_skip_layer_mask", None)
if original_create_skip_layer_mask is not None:
# original_double_blocks = getattr(diffusion_model,
# double_blocks_name)

def new_create_skip_layer_mask(self, *args, **kwargs):
# with unittest.mock.patch.object(self, double_blocks_name,
# original_double_blocks):
# return original_create_skip_layer_mask(*args, **kwargs)
# return original_create_skip_layer_mask(*args, **kwargs)
raise RuntimeError(
"STG is not supported with FBCache yet")

diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__(
diffusion_model)

cached_transformer_blocks = torch.nn.ModuleList([
first_block_cache.CachedTransformerBlocks(
None if double_blocks_name is None else getattr(
diffusion_model, double_blocks_name),
None if single_blocks_name is None else getattr(
diffusion_model, single_blocks_name),
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
cat_hidden_states_first=diffusion_model.__class__.__name__
== "HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.
__name__ == "LTXVModel" or is_non_native_ltxv,
clone_original_hidden_states=diffusion_model.__class__.
__name__ == "LTXVModel",
return_hidden_states_first=diffusion_model.__class__.
__name__ != "OpenAISignatureMMDITWrapper",
accept_hidden_states_first=diffusion_model.__class__.
__name__ != "OpenAISignatureMMDITWrapper",
)
])
dummy_single_transformer_blocks = torch.nn.ModuleList()

def model_unet_function_wrapper(model_function, kwargs):
nonlocal prev_timestep, current_timestep, consecutive_cache_hits

try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
current_timestep = t = timestep[0].item()

if prev_timestep is None or t >= prev_timestep:
prev_timestep = t
consecutive_cache_hits = 0
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())

with unittest.mock.patch.object(
diffusion_model,
double_blocks_name,
cached_transformer_blocks,
), unittest.mock.patch.object(
diffusion_model,
single_blocks_name,
dummy_single_transformer_blocks,
) if single_blocks_name is not None else contextlib.nullcontext(
):
return model_function(input, timestep, **c)
except model_management.InterruptProcessingException as exc:
prev_timestep = None
raise exc from None

model.set_model_unet_function_wrapper(model_unet_function_wrapper)
return (model, )
Loading

0 comments on commit 7b8f76f

Please sign in to comment.