From 21140cdf8c43946acd9ea522b4fda66df5d859c9 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Sun, 12 Jan 2025 22:07:04 +0800 Subject: [PATCH] release cache memory after get_output_data --- fbcache_nodes.py | 2 ++ first_block_cache.py | 22 ++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/fbcache_nodes.py b/fbcache_nodes.py index 65ad750..f776492 100644 --- a/fbcache_nodes.py +++ b/fbcache_nodes.py @@ -91,6 +91,8 @@ def patch( if residual_diff_threshold <= 0: return (model, ) + first_block_cache.patch_get_output_data() + using_validation = max_consecutive_cache_hits > 0 or start > 0 or end < 1 if using_validation: model_sampling = model.get_model_object("model_sampling") diff --git a/first_block_cache.py b/first_block_cache.py index 81733ee..f764cba 100644 --- a/first_block_cache.py +++ b/first_block_cache.py @@ -76,6 +76,28 @@ def cache_context(cache_context): _current_cache_context = old_cache_context +def patch_get_output_data(): + import execution + + get_output_data = getattr(execution, "get_output_data", None) + if get_output_data is None: + return + + if getattr(get_output_data, "_patched", False): + return + + def new_get_output_data(*args, **kwargs): + out = get_output_data(*args, **kwargs) + cache_context = get_current_cache_context() + if cache_context is not None: + cache_context.clear_buffers() + set_current_cache_context(None) + return out + + new_get_output_data._patched = True + execution.get_output_data = new_get_output_data + + @torch.compiler.disable() def are_two_tensors_similar(t1, t2, *, threshold): if t1.shape != t2.shape: diff --git a/pyproject.toml b/pyproject.toml index 917760c..7c59fae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "wavespeed" description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast." -version = "1.1.0" +version = "1.1.1" license = {file = "LICENSE"} [project.urls]