Skip to content

Commit

Permalink
closes #6; fix LTXV support
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 8, 2025
1 parent 9894a54 commit 93a2387
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
1 change: 1 addition & 0 deletions fbcache_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def patch(
diffusion_model, "single_blocks") else None,
residual_diff_threshold=residual_diff_threshold,
cat_hidden_states_first=diffusion_model.__class__.__name__ == "HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.__name__ == "LTXVModel",
)
])
dummy_single_transformer_blocks = torch.nn.ModuleList()
Expand Down
47 changes: 31 additions & 16 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,15 @@ def __init__(
residual_diff_threshold,
return_hidden_states_first=True,
cat_hidden_states_first=False,
return_hidden_states_only=False,
):
super().__init__()
self.transformer_blocks = transformer_blocks
self.single_transformer_blocks = single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.return_hidden_states_first = return_hidden_states_first
self.cat_hidden_states_first = cat_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only

def forward(self, img, txt=None, *args, context=None, **kwargs):
if context is not None:
Expand All @@ -139,10 +141,12 @@ def forward(self, img, txt=None, *args, context=None, **kwargs):
encoder_hidden_states = txt
if self.residual_diff_threshold <= 0.0:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat(
[hidden_states, encoder_hidden_states]
Expand All @@ -153,16 +157,21 @@ def forward(self, img, txt=None, *args, context=None, **kwargs):
hidden_states = block(hidden_states, *args, **kwargs)
hidden_states = hidden_states[:,
encoder_hidden_states.shape[1]:]
return ((hidden_states, encoder_hidden_states)
if self.return_hidden_states_first else
(encoder_hidden_states, hidden_states))
if self.return_hidden_states_only:
return hidden_states
else:
return ((hidden_states, encoder_hidden_states)
if self.return_hidden_states_first else
(encoder_hidden_states, hidden_states))

original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
hidden_states, encoder_hidden_states = first_transformer_block(
hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states

Expand All @@ -186,26 +195,32 @@ def forward(self, img, txt=None, *args, context=None, **kwargs):
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(hidden_states,
encoder_hidden_states,
*args, **kwargs)
*args,
**kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual",
encoder_hidden_states_residual)
torch._dynamo.graph_break()

return ((hidden_states,
encoder_hidden_states) if self.return_hidden_states_first else
(encoder_hidden_states, hidden_states))
if self.return_hidden_states_only:
return hidden_states
else:
return ((hidden_states,
encoder_hidden_states) if self.return_hidden_states_first else
(encoder_hidden_states, hidden_states))

def call_remaining_transformer_blocks(self, hidden_states,
encoder_hidden_states, *args,
**kwargs):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
for block in self.transformer_blocks[1:]:
hidden_states, encoder_hidden_states = block(
hidden_states = block(
hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat(
[hidden_states, encoder_hidden_states]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = ""
version = "1.0.5"
version = "1.0.6"
license = {file = "LICENSE"}

[project.urls]
Expand Down

0 comments on commit 93a2387

Please sign in to comment.