Skip to content

Commit

Permalink
release cache memory after get_output_data
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 12, 2025
1 parent e92335e commit 21140cd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
2 changes: 2 additions & 0 deletions fbcache_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def patch(
if residual_diff_threshold <= 0:
return (model, )

first_block_cache.patch_get_output_data()

using_validation = max_consecutive_cache_hits > 0 or start > 0 or end < 1
if using_validation:
model_sampling = model.get_model_object("model_sampling")
Expand Down
22 changes: 22 additions & 0 deletions first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ def cache_context(cache_context):
_current_cache_context = old_cache_context


def patch_get_output_data():
import execution

get_output_data = getattr(execution, "get_output_data", None)
if get_output_data is None:
return

if getattr(get_output_data, "_patched", False):
return

def new_get_output_data(*args, **kwargs):
out = get_output_data(*args, **kwargs)
cache_context = get_current_cache_context()
if cache_context is not None:
cache_context.clear_buffers()
set_current_cache_context(None)
return out

new_get_output_data._patched = True
execution.get_output_data = new_get_output_data


@torch.compiler.disable()
def are_two_tensors_similar(t1, t2, *, threshold):
if t1.shape != t2.shape:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast."
version = "1.1.0"
version = "1.1.1"
license = {file = "LICENSE"}

[project.urls]
Expand Down

0 comments on commit 21140cd

Please sign in to comment.