Skip to content

Commit

Permalink
fix lora with compile
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 8, 2025
1 parent 93a2387 commit e8e3de8
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ See [Apply First Block Cache on FLUX.1-dev](https://github.com/chengzeyi/ParaAtt
To use the Enhanced `torch.compile`, simply add the `wavespeed->Compile Model+` node to your workflow after your `Load Diffusion Model` node or `Apply First Block Cache` node.
The compilation process happens the first time you run the workflow, and it takes quite a long time, but it will be cached for future runs.
You can pass different `mode` values to make it runs faster, for example `max-autotune` or `max-autotune-no-cudagraphs`.
One of the advantages of this node over the original `TorchCompileModel` node is that it works with LoRA.

**NOTE**: `torch.compile` might not be able to work with model offloading well, you could try passing `--gpu-only` when launching your `ComfyUI` to disable model offloading.
**NOTE**: `torch.compile` does not work on Windows offcially and has problems working with LoRAs, you should not use this node if you are facing these issues.

**NOTE**: `torch.compile` does not work on Windows offcially, you should not use this node if you are facing these issues, or search on Google to find out how to make it work.

![Usage of Enhanced `torch.compile`](./assets/usage_compile.png)

Expand Down
2 changes: 2 additions & 0 deletions misc_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def patch(
disable,
backend,
):
utils.patch_optimized_module()

import_path, function_name = compiler.rsplit(".", 1)
module = importlib.import_module(import_path)
compile_function = getattr(module, function_name)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "wavespeed"
description = ""
version = "1.0.6"
version = "1.0.7"
license = {file = "LICENSE"}

[project.urls]
Expand Down
45 changes: 45 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,48 @@ def foo(*args, **kwargs):

with unittest.mock.patch.object(model_management, "load_models_gpu", foo):
yield


def patch_optimized_module():
try:
from torch._dynamo.eval_frame import OptimizedModule
except ImportError:
return

if getattr(OptimizedModule, "_patched", False):
return

def __getattribute__(self, name):
if name == "_orig_mod":
return object.__getattribute__(self, "_modules")[name]
if name in (
"__class__",
"_modules",
"state_dict",
"load_state_dict",
"parameters",
"named_parameters",
"buffers",
"named_buffers",
"children",
"named_children",
"modules",
"named_modules",
):
return getattr(object.__getattribute__(self, "_orig_mod"), name)
return object.__getattribute__(self, name)

def __delattr__(self, name):
# unload_lora_weights() wants to del peft_config
return delattr(self._orig_mod, name)

@classmethod
def __instancecheck__(cls, instance):
return isinstance(instance, OptimizedModule) or issubclass(
object.__getattribute__(instance, "__class__"), cls
)

OptimizedModule.__getattribute__ = __getattribute__
OptimizedModule.__delattr__ = __delattr__
OptimizedModule.__instancecheck__ = __instancecheck__
OptimizedModule._patched = True

0 comments on commit e8e3de8

Please sign in to comment.