From e8e3de8968b9a6eb1769d60a0cb90d34b1653907 Mon Sep 17 00:00:00 2001
From: chengzeyi <ichengzeyi@gmail.com>
Date: Wed, 8 Jan 2025 19:38:00 +0800
Subject: [PATCH] fix lora with compile

---
 README.md      |  4 +++-
 misc_nodes.py  |  2 ++
 pyproject.toml |  2 +-
 utils.py       | 45 +++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 51 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index d5955c8..f8effec 100644
--- a/README.md
+++ b/README.md
@@ -51,9 +51,11 @@ See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAtt
 To use the Enhanced `torch.compile`, simply add the `wavespeed->Compile Model+` node to your workflow after your `Load Diffusion Model` node or `Apply First Block Cache` node.
 The compilation process happens the first time you run the workflow, and it takes quite a long time, but it will be cached for future runs.
 You can pass different `mode` values to make it runs faster, for example `max-autotune` or `max-autotune-no-cudagraphs`.
+One of the advantages of this node over the original `TorchCompileModel` node is that it works with LoRA.
 
 **NOTE**: `torch.compile` might not be able to work with model offloading well, you could try passing `--gpu-only` when launching your `ComfyUI` to disable model offloading.
-**NOTE**: `torch.compile` does not work on Windows offcially and has problems working with LoRAs, you should not use this node if you are facing these issues.
+
+**NOTE**: `torch.compile` does not work on Windows offcially, you should not use this node if you are facing these issues, or search on Google to find out how to make it work.
 
 ![Usage of Enhanced `torch.compile`](./assets/usage_compile.png)
 
diff --git a/misc_nodes.py b/misc_nodes.py
index 5da28b7..75e4ce1 100644
--- a/misc_nodes.py
+++ b/misc_nodes.py
@@ -109,6 +109,8 @@ def patch(
         disable,
         backend,
     ):
+        utils.patch_optimized_module()
+
         import_path, function_name = compiler.rsplit(".", 1)
         module = importlib.import_module(import_path)
         compile_function = getattr(module, function_name)
diff --git a/pyproject.toml b/pyproject.toml
index 1d41bdf..8d5ecc1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
 [project]
 name = "wavespeed"
 description = ""
-version = "1.0.6"
+version = "1.0.7"
 license = {file = "LICENSE"}
 
 [project.urls]
diff --git a/utils.py b/utils.py
index 383bc9f..c6ffee8 100644
--- a/utils.py
+++ b/utils.py
@@ -57,3 +57,48 @@ def foo(*args, **kwargs):
 
     with unittest.mock.patch.object(model_management, "load_models_gpu", foo):
         yield
+
+
+def patch_optimized_module():
+    try:
+        from torch._dynamo.eval_frame import OptimizedModule
+    except ImportError:
+        return
+
+    if getattr(OptimizedModule, "_patched", False):
+        return
+
+    def __getattribute__(self, name):
+        if name == "_orig_mod":
+            return object.__getattribute__(self, "_modules")[name]
+        if name in (
+            "__class__",
+            "_modules",
+            "state_dict",
+            "load_state_dict",
+            "parameters",
+            "named_parameters",
+            "buffers",
+            "named_buffers",
+            "children",
+            "named_children",
+            "modules",
+            "named_modules",
+        ):
+            return getattr(object.__getattribute__(self, "_orig_mod"), name)
+        return object.__getattribute__(self, name)
+
+    def __delattr__(self, name):
+        # unload_lora_weights() wants to del peft_config
+        return delattr(self._orig_mod, name)
+
+    @classmethod
+    def __instancecheck__(cls, instance):
+        return isinstance(instance, OptimizedModule) or issubclass(
+            object.__getattribute__(instance, "__class__"), cls
+        )
+
+    OptimizedModule.__getattribute__ = __getattribute__
+    OptimizedModule.__delattr__ = __delattr__
+    OptimizedModule.__instancecheck__ = __instancecheck__
+    OptimizedModule._patched = True