From e8e3de8968b9a6eb1769d60a0cb90d34b1653907 Mon Sep 17 00:00:00 2001 From: chengzeyi <ichengzeyi@gmail.com> Date: Wed, 8 Jan 2025 19:38:00 +0800 Subject: [PATCH] fix lora with compile --- README.md | 4 +++- misc_nodes.py | 2 ++ pyproject.toml | 2 +- utils.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d5955c8..f8effec 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,11 @@ See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAtt To use the Enhanced `torch.compile`, simply add the `wavespeed->Compile Model+` node to your workflow after your `Load Diffusion Model` node or `Apply First Block Cache` node. The compilation process happens the first time you run the workflow, and it takes quite a long time, but it will be cached for future runs. You can pass different `mode` values to make it runs faster, for example `max-autotune` or `max-autotune-no-cudagraphs`. +One of the advantages of this node over the original `TorchCompileModel` node is that it works with LoRA. **NOTE**: `torch.compile` might not be able to work with model offloading well, you could try passing `--gpu-only` when launching your `ComfyUI` to disable model offloading. -**NOTE**: `torch.compile` does not work on Windows offcially and has problems working with LoRAs, you should not use this node if you are facing these issues. + +**NOTE**: `torch.compile` does not work on Windows offcially, you should not use this node if you are facing these issues, or search on Google to find out how to make it work. ![Usage of Enhanced `torch.compile`](./assets/usage_compile.png) diff --git a/misc_nodes.py b/misc_nodes.py index 5da28b7..75e4ce1 100644 --- a/misc_nodes.py +++ b/misc_nodes.py @@ -109,6 +109,8 @@ def patch( disable, backend, ): + utils.patch_optimized_module() + import_path, function_name = compiler.rsplit(".", 1) module = importlib.import_module(import_path) compile_function = getattr(module, function_name) diff --git a/pyproject.toml b/pyproject.toml index 1d41bdf..8d5ecc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "wavespeed" description = "" -version = "1.0.6" +version = "1.0.7" license = {file = "LICENSE"} [project.urls] diff --git a/utils.py b/utils.py index 383bc9f..c6ffee8 100644 --- a/utils.py +++ b/utils.py @@ -57,3 +57,48 @@ def foo(*args, **kwargs): with unittest.mock.patch.object(model_management, "load_models_gpu", foo): yield + + +def patch_optimized_module(): + try: + from torch._dynamo.eval_frame import OptimizedModule + except ImportError: + return + + if getattr(OptimizedModule, "_patched", False): + return + + def __getattribute__(self, name): + if name == "_orig_mod": + return object.__getattribute__(self, "_modules")[name] + if name in ( + "__class__", + "_modules", + "state_dict", + "load_state_dict", + "parameters", + "named_parameters", + "buffers", + "named_buffers", + "children", + "named_children", + "modules", + "named_modules", + ): + return getattr(object.__getattribute__(self, "_orig_mod"), name) + return object.__getattribute__(self, name) + + def __delattr__(self, name): + # unload_lora_weights() wants to del peft_config + return delattr(self._orig_mod, name) + + @classmethod + def __instancecheck__(cls, instance): + return isinstance(instance, OptimizedModule) or issubclass( + object.__getattribute__(instance, "__class__"), cls + ) + + OptimizedModule.__getattribute__ = __getattribute__ + OptimizedModule.__delattr__ = __delattr__ + OptimizedModule.__instancecheck__ = __instancecheck__ + OptimizedModule._patched = True