From 93a23878af8f60efce1d163fd5cd0a700cd76dd7 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Wed, 8 Jan 2025 17:48:36 +0800 Subject: [PATCH] closes #6; fix LTXV support --- fbcache_nodes.py | 1 + first_block_cache.py | 47 +++++++++++++++++++++++++++++--------------- pyproject.toml | 2 +- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/fbcache_nodes.py b/fbcache_nodes.py index fd40c07..a51652b 100644 --- a/fbcache_nodes.py +++ b/fbcache_nodes.py @@ -55,6 +55,7 @@ def patch( diffusion_model, "single_blocks") else None, residual_diff_threshold=residual_diff_threshold, cat_hidden_states_first=diffusion_model.__class__.__name__ == "HunyuanVideo", + return_hidden_states_only=diffusion_model.__class__.__name__ == "LTXVModel", ) ]) dummy_single_transformer_blocks = torch.nn.ModuleList() diff --git a/first_block_cache.py b/first_block_cache.py index caa01c7..b95cef0 100644 --- a/first_block_cache.py +++ b/first_block_cache.py @@ -124,6 +124,7 @@ def __init__( residual_diff_threshold, return_hidden_states_first=True, cat_hidden_states_first=False, + return_hidden_states_only=False, ): super().__init__() self.transformer_blocks = transformer_blocks @@ -131,6 +132,7 @@ def __init__( self.residual_diff_threshold = residual_diff_threshold self.return_hidden_states_first = return_hidden_states_first self.cat_hidden_states_first = cat_hidden_states_first + self.return_hidden_states_only = return_hidden_states_only def forward(self, img, txt=None, *args, context=None, **kwargs): if context is not None: @@ -139,10 +141,12 @@ def forward(self, img, txt=None, *args, context=None, **kwargs): encoder_hidden_states = txt if self.residual_diff_threshold <= 0.0: for block in self.transformer_blocks: - hidden_states, encoder_hidden_states = block( + hidden_states = block( hidden_states, encoder_hidden_states, *args, **kwargs) - if not self.return_hidden_states_first: - hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if not self.return_hidden_states_only: + hidden_states, encoder_hidden_states = hidden_states + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states if self.single_transformer_blocks is not None: hidden_states = torch.cat( [hidden_states, encoder_hidden_states] @@ -153,16 +157,21 @@ def forward(self, img, txt=None, *args, context=None, **kwargs): hidden_states = block(hidden_states, *args, **kwargs) hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:] - return ((hidden_states, encoder_hidden_states) - if self.return_hidden_states_first else - (encoder_hidden_states, hidden_states)) + if self.return_hidden_states_only: + return hidden_states + else: + return ((hidden_states, encoder_hidden_states) + if self.return_hidden_states_first else + (encoder_hidden_states, hidden_states)) original_hidden_states = hidden_states first_transformer_block = self.transformer_blocks[0] - hidden_states, encoder_hidden_states = first_transformer_block( + hidden_states = first_transformer_block( hidden_states, encoder_hidden_states, *args, **kwargs) - if not self.return_hidden_states_first: - hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if not self.return_hidden_states_only: + hidden_states, encoder_hidden_states = hidden_states + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states first_hidden_states_residual = hidden_states - original_hidden_states del original_hidden_states @@ -186,15 +195,19 @@ def forward(self, img, txt=None, *args, context=None, **kwargs): encoder_hidden_states_residual, ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, - *args, **kwargs) + *args, + **kwargs) set_buffer("hidden_states_residual", hidden_states_residual) set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual) torch._dynamo.graph_break() - return ((hidden_states, - encoder_hidden_states) if self.return_hidden_states_first else - (encoder_hidden_states, hidden_states)) + if self.return_hidden_states_only: + return hidden_states + else: + return ((hidden_states, + encoder_hidden_states) if self.return_hidden_states_first else + (encoder_hidden_states, hidden_states)) def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, @@ -202,10 +215,12 @@ def call_remaining_transformer_blocks(self, hidden_states, original_hidden_states = hidden_states original_encoder_hidden_states = encoder_hidden_states for block in self.transformer_blocks[1:]: - hidden_states, encoder_hidden_states = block( + hidden_states = block( hidden_states, encoder_hidden_states, *args, **kwargs) - if not self.return_hidden_states_first: - hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if not self.return_hidden_states_only: + hidden_states, encoder_hidden_states = hidden_states + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states if self.single_transformer_blocks is not None: hidden_states = torch.cat( [hidden_states, encoder_hidden_states] diff --git a/pyproject.toml b/pyproject.toml index 76e6879..1d41bdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "wavespeed" description = "" -version = "1.0.5" +version = "1.0.6" license = {file = "LICENSE"} [project.urls]