From 1ba6b5fdd1ad92a69d3c9fe779e386c99b36aa09 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Jan 2025 14:32:24 +0800 Subject: [PATCH] support flux controlnet and fix no cache context --- README.md | 14 +- fbcache_nodes.py | 22 +- first_block_cache.py | 218 +++++++- pyproject.toml | 2 +- workflows/flux_controlnet.json | 888 +++++++++++++++++++++++++++++++++ 5 files changed, 1131 insertions(+), 13 deletions(-) create mode 100644 workflows/flux_controlnet.json diff --git a/README.md b/README.md index ead539c..3f43f0f 100644 --- a/README.md +++ b/README.md @@ -33,13 +33,13 @@ git clone https://github.com/chengzeyi/Comfy-WaveSpeed.git You can find demo workflows in the `workflows` folder. -[FLUX.1-dev with First Block Cache and Compilation](./workflows/flux.json) - -[LTXV with First Block Cache and Compilation](./workflows/ltxv.json) - -[HunyuanVideo with First Block Cache](./workflows/hunyuan_video.json) - -[SDXL with First Block Cache](./workflows/sdxl.json) +| Workflow | Path | +| - | - | +| FLUX.1-dev with First Block Cache and Compilation | [workflows/flux.json](./workflows/flux.json) +| FLUX.1-dev ControlNet with First Block Cache and Compilation | [workflows/flux_controlnet.json](./workflows/flux_controlnet.json) +| LTXV with First Block Cache and Compilation | [workflows/ltxv.json](./workflows/ltxv.json) +| HunyuanVideo with First Block Cache | [workflows/hunyuan_video.json](./workflows/hunyuan_video.json) +| SDXL with First Block Cache | [workflows/sdxl.json](./workflows/sdxl.json) **NOTE**: The `Compile Model+` node requires your computation to meet some software and hardware requirements, please refer to the [Enhanced `torch.compile`](#enhanced-torchcompile) section for more information. If you have problems with the compilation node, you can remove it from the workflow and only use the `Apply First Block Cache` node. diff --git a/fbcache_nodes.py b/fbcache_nodes.py index 47279c0..4aaeab8 100644 --- a/fbcache_nodes.py +++ b/fbcache_nodes.py @@ -121,9 +121,17 @@ def validate_use_cache(use_cached): model = model.clone() diffusion_model = model.get_model_object(object_to_patch) - if diffusion_model.__class__.__name__ == "UNetModel": + if diffusion_model.__class__.__name__ in ("UNetModel", "FLUX"): - patch__forward = first_block_cache.create_patch_unet_model__forward( + if diffusion_model.__class__.__name__ == "UNetModel": + create_patch_function = first_block_cache.create_patch_unet_model__forward + elif diffusion_model.__class__.__name__ == "FLUX": + create_patch_function = first_block_cache.create_patch_flux_forward_orig + else: + raise ValueError( + f"Unsupported model {diffusion_model.__class__.__name__}") + + patch_foward = create_patch_function( diffusion_model, residual_diff_threshold=residual_diff_threshold, validate_can_use_cache_function=validate_use_cache, @@ -144,7 +152,11 @@ def model_unet_function_wrapper(model_function, kwargs): first_block_cache.set_current_cache_context( first_block_cache.create_cache_context()) - with patch__forward(): + if first_block_cache.get_current_cache_context() is None: + first_block_cache.set_current_cache_context( + first_block_cache.create_cache_context()) + + with patch_foward(): return model_function(input, timestep, **c) except model_management.InterruptProcessingException as exc: prev_timestep = None @@ -226,6 +238,10 @@ def model_unet_function_wrapper(model_function, kwargs): first_block_cache.set_current_cache_context( first_block_cache.create_cache_context()) + if first_block_cache.get_current_cache_context() is None: + first_block_cache.set_current_cache_context( + first_block_cache.create_cache_context()) + with unittest.mock.patch.object( diffusion_model, double_blocks_name, diff --git a/first_block_cache.py b/first_block_cache.py index f764cba..783f060 100644 --- a/first_block_cache.py +++ b/first_block_cache.py @@ -371,6 +371,7 @@ def call_remaining_transformer_blocks(self, return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual +# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24 def create_patch_unet_model__forward(model, *, residual_diff_threshold, @@ -379,7 +380,7 @@ def create_patch_unet_model__forward(model, def call_remaining_blocks(self, transformer_options, control, transformer_patches, hs, h, *args, **kwargs): - original_h = h + original_hidden_states = h for id, module in enumerate(self.input_blocks): if id < 2: @@ -421,7 +422,7 @@ def call_remaining_blocks(self, transformer_options, control, output_shape = None h = forward_timestep_embed(module, h, *args, output_shape, **kwargs) - hidden_states_residual = h - original_h + hidden_states_residual = h - original_hidden_states return h, hidden_states_residual def unet_model__forward(self, @@ -546,3 +547,216 @@ def patch__forward(): yield return patch__forward + + +# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24 +def create_patch_flux_forward_orig(model, + *, + residual_diff_threshold, + validate_can_use_cache_function=None): + from torch import Tensor + from comfy.ldm.flux.model import timestep_embedding + + def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, + pe, attn_mask): + original_hidden_states = img + + for i, block in enumerate(self.double_block): + if i < 1: + continue + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"], out["txt"] = block( + img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", + i)]({ + "img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask + }, { + "original_block": block_wrap + }) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", + i)]({ + "img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask + }, { + "original_block": block_wrap + }) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1]:, ...] += add + + img = img[:, txt.shape[1]:, ...] + + img = img.contiguous() + hidden_states_residual = img - original_hidden_states + return img, hidden_states_residual + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + control=None, + transformer_options={}, + attn_mask: Tensor = None, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + if img.ndim != 3 or txt.ndim != 3: + raise ValueError( + "Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in( + timestep_embedding(guidance, 256).to(img.dtype)) + + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if i >= 1: + break + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"], out["txt"] = block( + img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", + i)]({ + "img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask + }, { + "original_block": block_wrap + }) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + if i == 0: + first_hidden_states_residual = img + can_use_cache = get_can_use_cache( + first_hidden_states_residual, + threshold=residual_diff_threshold, + ) + if validate_can_use_cache_function is not None: + can_use_cache = validate_can_use_cache_function(can_use_cache) + if not can_use_cache: + set_buffer("first_hidden_states_residual", + first_hidden_states_residual) + del first_hidden_states_residual + + torch._dynamo.graph_break() + if can_use_cache: + img = apply_prev_hidden_states_residual(img) + else: + img, hidden_states_residual = call_remaining_blocks( + self, + blocks_replace, + control, + img, + txt, + vec, + pe, + attn_mask, + ) + set_buffer("hidden_states_residual", hidden_states_residual) + torch._dynamo.graph_break() + + img = self.final_layer(img, + vec) # (N, T, patch_size ** 2 * out_channels) + return img + + new_forward_orig = forward_orig.__get__(model) + + @contextlib.contextmanager + def patch_forward_orig(): + with unittest.mock.patch.object(model, "forward_orig", new_forward_orig): + yield + + return patch_forward_orig diff --git a/pyproject.toml b/pyproject.toml index 7c59fae..c7a0d02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "wavespeed" description = "The all in one inference optimization solution for ComfyUI, universal, flexible, and fast." -version = "1.1.1" +version = "1.1.2" license = {file = "LICENSE"} [project.urls] diff --git a/workflows/flux_controlnet.json b/workflows/flux_controlnet.json new file mode 100644 index 0000000..b7bf98e --- /dev/null +++ b/workflows/flux_controlnet.json @@ -0,0 +1,888 @@ +{ + "last_node_id": 37, + "last_link_id": 65, + "nodes": [ + { + "id": 3, + "type": "KSampler", + "pos": [ + 1280, + 100 + ], + "size": [ + 315, + 262 + ], + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 65 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 19 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 54 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 0, + "fixed", + 20, + 1, + "euler", + "normal", + 1 + ] + }, + { + "id": 7, + "type": "CLIPTextEncode", + "pos": [ + 212, + 417 + ], + "size": [ + 425.27801513671875, + 180.6060791015625 + ], + "flags": { + "collapsed": true + }, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 59 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 17 + ], + "slot_index": 0 + } + ], + "title": "CLIP Text Encode (Negative Prompt)", + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "" + ], + "color": "#322", + "bgcolor": "#533" + }, + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1620, + 98 + ], + "size": [ + 210, + 46 + ], + "flags": {}, + "order": 16, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 62 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 9 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + }, + "widgets_values": [] + }, + { + "id": 9, + "type": "SaveImage", + "pos": [ + 1865, + 99 + ], + "size": [ + 828.9535522460938, + 893.8475341796875 + ], + "flags": {}, + "order": 17, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 9 + } + ], + "outputs": [], + "properties": {}, + "widgets_values": [ + "ComfyUI" + ] + }, + { + "id": 14, + "type": "ControlNetApplySD3", + "pos": [ + 930, + 100 + ], + "size": [ + 315, + 186 + ], + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "positive", + "type": "CONDITIONING", + "link": 42 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 17 + }, + { + "name": "control_net", + "type": "CONTROL_NET", + "link": 52 + }, + { + "name": "vae", + "type": "VAE", + "link": 60 + }, + { + "name": "image", + "type": "IMAGE", + "link": 50 + } + ], + "outputs": [ + { + "name": "positive", + "type": "CONDITIONING", + "links": [ + 18 + ], + "slot_index": 0, + "shape": 3 + }, + { + "name": "negative", + "type": "CONDITIONING", + "links": [ + 19 + ], + "slot_index": 1, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "ControlNetApplySD3" + }, + "widgets_values": [ + 0.4, + 0, + 1 + ] + }, + { + "id": 15, + "type": "ControlNetLoader", + "pos": [ + 570, + -60 + ], + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CONTROL_NET", + "type": "CONTROL_NET", + "links": [ + 52 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "ControlNetLoader" + }, + "widgets_values": [ + "instantx_flux_canny.safetensors" + ] + }, + { + "id": 17, + "type": "LoadImage", + "pos": [ + 220, + 530 + ], + "size": [ + 315, + 314.0000305175781 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 49 + ], + "slot_index": 0, + "shape": 3 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "girl_in_field.png", + "image" + ] + }, + { + "id": 18, + "type": "Canny", + "pos": [ + 560, + 530 + ], + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 49 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 26, + 50 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "Canny" + }, + "widgets_values": [ + 0.2, + 0.3 + ] + }, + { + "id": 19, + "type": "PreviewImage", + "pos": [ + 900, + 530 + ], + "size": [ + 571.5869140625, + 625.5296020507812 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 26 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 23, + "type": "CLIPTextEncode", + "pos": [ + 210, + 196 + ], + "size": [ + 422.84503173828125, + 164.31304931640625 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 61 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 41 + ], + "slot_index": 0 + } + ], + "title": "CLIP Text Encode (Positive Prompt)", + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "anime girl smiling with long hair standing in a football arena with a single massive sword hanging from her back" + ], + "color": "#232", + "bgcolor": "#353" + }, + { + "id": 26, + "type": "FluxGuidance", + "pos": [ + 570, + 50 + ], + "size": [ + 317.4000244140625, + 58 + ], + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 41 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 42 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "FluxGuidance" + }, + "widgets_values": [ + 3.5 + ] + }, + { + "id": 28, + "type": "EmptySD3LatentImage", + "pos": [ + 930, + 340 + ], + "size": [ + 315, + 106 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 54 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "EmptySD3LatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 32, + "type": "VAELoader", + "pos": [ + -180, + 230 + ], + "size": [ + 311.81634521484375, + 60.429901123046875 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 60, + 62 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 33, + "type": "Note", + "pos": [ + -180, + 380 + ], + "size": [ + 336, + 288 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [], + "properties": { + "text": "" + }, + "widgets_values": [ + "If you get an error in any of the nodes above make sure the files are in the correct directories.\n\nSee the top of the examples page for the links : https://comfyanonymous.github.io/ComfyUI_examples/flux/\n\nflux1-dev.safetensors goes in: ComfyUI/models/unet/\n\nt5xxl_fp16.safetensors and clip_l.safetensors go in: ComfyUI/models/clip/\n\nae.safetensors goes in: ComfyUI/models/vae/\n\n\nTip: You can set the weight_dtype above to one of the fp8 types if you have memory issues." + ], + "color": "#432", + "bgcolor": "#653" + }, + { + "id": 34, + "type": "UNETLoader", + "pos": [ + -180, + -60 + ], + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 63 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev.safetensors", + "fp8_e4m3fn_fast" + ], + "color": "#223", + "bgcolor": "#335" + }, + { + "id": 35, + "type": "DualCLIPLoader", + "pos": [ + -180, + 90 + ], + "size": [ + 315, + 106 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 59, + 61 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "t5xxl_fp8_e4m3fn.safetensors", + "clip_l.safetensors", + "flux", + "default" + ] + }, + { + "id": 36, + "type": "ApplyFBCacheOnModel", + "pos": [ + 200, + -160 + ], + "size": [ + 315, + 154 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 63 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 64 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFBCacheOnModel" + }, + "widgets_values": [ + "diffusion_model", + 0.12, + 0, + 1, + -1 + ] + }, + { + "id": 37, + "type": "EnhancedCompileModel", + "pos": [ + 560, + -410 + ], + "size": [ + 400, + 294 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "*", + "link": 64 + } + ], + "outputs": [ + { + "name": "*", + "type": "*", + "links": [ + 65 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EnhancedCompileModel" + }, + "widgets_values": [ + true, + "diffusion_model", + "torch.compile", + false, + false, + "", + "", + false, + "inductor" + ] + } + ], + "links": [ + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 9, + 8, + 0, + 9, + 0, + "IMAGE" + ], + [ + 17, + 7, + 0, + 14, + 1, + "CONDITIONING" + ], + [ + 18, + 14, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 19, + 14, + 1, + 3, + 2, + "CONDITIONING" + ], + [ + 26, + 18, + 0, + 19, + 0, + "IMAGE" + ], + [ + 41, + 23, + 0, + 26, + 0, + "CONDITIONING" + ], + [ + 42, + 26, + 0, + 14, + 0, + "CONDITIONING" + ], + [ + 49, + 17, + 0, + 18, + 0, + "IMAGE" + ], + [ + 50, + 18, + 0, + 14, + 4, + "IMAGE" + ], + [ + 52, + 15, + 0, + 14, + 2, + "CONTROL_NET" + ], + [ + 54, + 28, + 0, + 3, + 3, + "LATENT" + ], + [ + 59, + 35, + 0, + 7, + 0, + "CLIP" + ], + [ + 60, + 32, + 0, + 14, + 3, + "VAE" + ], + [ + 61, + 35, + 0, + 23, + 0, + "CLIP" + ], + [ + 62, + 32, + 0, + 8, + 1, + "VAE" + ], + [ + 63, + 34, + 0, + 36, + 0, + "MODEL" + ], + [ + 64, + 36, + 0, + 37, + 0, + "*" + ], + [ + 65, + 37, + 0, + 3, + 0, + "MODEL" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.5131581182307068, + "offset": [ + 230.9978013084971, + 284.1700529197747 + ] + } + }, + "version": 0.4 +} \ No newline at end of file