From 50d42c30b56dbc767d76738c6f58d39194786cf4 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Thu, 29 Aug 2024 01:35:27 +0900 Subject: [PATCH] Upgrade LoRA Block Weight - support % syntax --- __init__.py | 2 +- inspire/lora_block_weight.py | 180 +++++++++++++++++++++++++++++++---- pyproject.toml | 2 +- 3 files changed, 166 insertions(+), 18 deletions(-) diff --git a/__init__.py b/__init__.py index 36746e5..a84d356 100644 --- a/__init__.py +++ b/__init__.py @@ -7,7 +7,7 @@ import importlib -version_code = [0, 86, 2] +version_code = [1, 0] version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '') print(f"### Loading: ComfyUI-Inspire-Pack ({version_str})") diff --git a/inspire/lora_block_weight.py b/inspire/lora_block_weight.py index 8d9a7d3..f74a241 100644 --- a/inspire/lora_block_weight.py +++ b/inspire/lora_block_weight.py @@ -1,3 +1,5 @@ +import regex + import folder_paths import comfy.utils import comfy.lora @@ -34,6 +36,13 @@ def load_lbw_preset(filename): return [] +def parse_unet_num(s): + if s[1] == '.': + return int(s[0]) + else: + return int(s) + + class LoraLoaderBlockWeight: def __init__(self): self.loaded_lora = None @@ -55,8 +64,8 @@ def INPUT_TYPES(s): "lora_name": (lora_names, ), "strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - "inverse": ("BOOLEAN", {"default": False, "label_on": "True", "label_off": "False"}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "inverse": ("BOOLEAN", {"default": False, "label_on": "True", "label_off": "False", "tooltip": "Apply the following weights for each block:\nTrue: 1 - weight\nFalse: weight"}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": ""}), "A": ("FLOAT", {"default": 4.0, "min": -10.0, "max": 10.0, "step": 0.01}), "B": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "preset": (preset,), @@ -129,12 +138,151 @@ def norm_value(value): # make to int if 1.0 or 0.0 else: return value + @staticmethod + def block_spec_parser(loaded, spec): + if not spec.startswith("%"): + return spec + else: + items = [x.strip() for x in spec[1:].split(',')] + + input_blocks_set = set() + middle_blocks_set= set() + output_blocks_set = set() + double_blocks_set = set() + single_blocks_set = set() + + for key, v in loaded.items(): + if isinstance(key, tuple): + k = key[0] + else: + k = key + + k_unet = k[len("diffusion_model."):] + + if k_unet.startswith("input_blocks."): + k_unet_num = k_unet[len("input_blocks."):len("input_blocks.")+2] + k_unet_int = parse_unet_num(k_unet_num) + input_blocks_set.add(k_unet_int) + elif k_unet.startswith("middle_block."): + k_unet_num = k_unet[len("middle_block."):len("middle_block.")+2] + k_unet_int = parse_unet_num(k_unet_num) + middle_blocks_set.add(k_unet_int) + elif k_unet.startswith("output_blocks."): + k_unet_num = k_unet[len("output_blocks."):len("output_blocks.")+2] + k_unet_int = parse_unet_num(k_unet_num) + output_blocks_set.add(k_unet_int) + elif k_unet.startswith("double_blocks."): + k_unet_num = k_unet[len("double_blocks."):len("double_blocks.") + 2] + k_unet_int = parse_unet_num(k_unet_num) + double_blocks_set.add(k_unet_int) + elif k_unet.startswith("single_blocks."): + k_unet_num = k_unet[len("single_blocks."):len("single_blocks.") + 2] + k_unet_int = parse_unet_num(k_unet_num) + single_blocks_set.add(k_unet_int) + + pat1 = re.compile(r"(default|base)=([0-9.]+)") + pat2 = re.compile(r"(in|out|mid|double|single)([0-9]+)-([0-9]+)=([0-9.]+)") + pat3 = re.compile(r"(in|out|mid|double|single)([0-9]+)=([0-9.]+)") + pat4 = re.compile(r"(in|out|mid|double|single)=([0-9.]+)") + + base_spec = None + default_spec = 1.0 + + for item in items: + match = pat1.match(item) + if match: + if match[1] == 'base': + base_spec = match[2] + continue + + if match[1] == 'default': + default_spec = match[2] + continue + + if base_spec is None: + base_spec = default_spec + + input_blocks = [default_spec] * len(input_blocks_set) + middle_blocks = [default_spec] * len(middle_blocks_set) + output_blocks = [default_spec] * len(output_blocks_set) + double_blocks = [default_spec] * len(double_blocks_set) + single_blocks = [default_spec] * len(single_blocks_set) + + for item in items: + match = pat2.match(item) + if match: + for x in range(int(match[2])-1, int(match[3])): + value = float(match[4]) + + if x < 0: + continue + + if match[1] == 'in' and len(input_blocks) > x: + input_blocks[x] = value + elif match[1] == 'out' and len(output_blocks) > x: + output_blocks[x] = value + elif match[1] == 'mid' and len(middle_blocks) > x: + middle_blocks[x] = value + elif match[1] == 'double' and len(double_blocks) > x: + double_blocks[x] = value + elif match[1] == 'single' and len(single_blocks) > x: + single_blocks[x] = value + + continue + + match = pat3.match(item) + if match: + value = float(match[3]) + x = int(match[2]) - 1 + + if x < 0: + continue + + if match[1] == 'in' and len(input_blocks) > x: + input_blocks[x] = value + elif match[1] == 'out' and len(output_blocks) > x: + output_blocks[x] = value + elif match[1] == 'mid' and len(middle_blocks) > x: + middle_blocks[x] = value + elif match[1] == 'double' and len(double_blocks) > x: + double_blocks[x] = value + elif match[1] == 'single' and len(single_blocks) > x: + single_blocks[x] = value + + continue + + match = pat4.match(item) + if match: + value = float(match[2]) + + if match[1] == 'in': + input_blocks = [value] * len(input_blocks) + elif match[1] == 'out': + output_blocks = [value] * len(output_blocks) + elif match[1] == 'mid': + middle_blocks = [value] * len(middle_blocks) + elif match[1] == 'double': + double_blocks = [value] * len(double_blocks) + elif match[1] == 'single': + single_blocks = [value] * len(single_blocks) + + continue + + # concat specs + res = [str(base_spec)] + for x in (input_blocks + middle_blocks + output_blocks + double_blocks + single_blocks): + res.append(str(x)) + + return ",".join(res) + @staticmethod def load_lora_for_models(model, clip, lora, strength_model, strength_clip, inverse, seed, A, B, block_vector): key_map = comfy.lora.model_lora_keys_unet(model.model) key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) loaded = comfy.lora.load_lora(lora, key_map) + block_vector = LoraLoaderBlockWeight.block_spec_parser(loaded, block_vector) + block_vector = block_vector.split(":") if len(block_vector) > 1: block_vector = block_vector[1] @@ -153,13 +301,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, inver last_k_unet_num = None new_modelpatcher = model.clone() - populated_ratio = strength_model - - def parse_unet_num(s): - if s[1] == '.': - return int(s[0]) - else: - return int(s) # sort: input, middle, output, others input_blocks = [] @@ -204,6 +345,8 @@ def parse_unet_num(s): np.random.seed(seed % (2**31)) populated_vector_list = [] ratios = [] + ratio = 1.0 + for k, v, k_unet_num, k_unet in (input_blocks + middle_blocks + output_blocks + double_blocks + single_blocks): if last_k_unet_num != k_unet_num and len(vector) > vector_i: ratios = LoraLoaderBlockWeight.convert_vector_value(A, B, vector[vector_i].strip()) @@ -220,6 +363,8 @@ def parse_unet_num(s): else: if len(ratios) > 0: ratio = ratios.pop(0) + else: + pass # use last used ratio if no more user specified ratio is given if inverse: populated_ratio = 1 - ratio @@ -243,25 +388,28 @@ def parse_unet_num(s): if inverse: populated_ratio = 1 - ratio else: - populated_ratio = 1 + populated_ratio = ratio populated_vector_list.insert(0, LoraLoaderBlockWeight.norm_value(populated_ratio)) + new_clip = clip.clone() for k, v, k_unet in others: - new_modelpatcher.add_patches({k: v}, strength_model * populated_ratio) + if 'text' in k_unet: + new_clip.add_patches({k: v}, strength_clip * populated_ratio) + else: + new_modelpatcher.add_patches({k: v}, strength_model * populated_ratio) + # if inverse: # print(f"\t{k_unet} -> inv({ratio}) ") # else: # print(f"\t{k_unet} -> ({ratio}) ") - new_clip = clip.clone() - new_clip.add_patches(loaded, strength_clip) populated_vector = ','.join(map(str, populated_vector_list)) - return (new_modelpatcher, new_clip, populated_vector) + return new_modelpatcher, new_clip, populated_vector def doit(self, model, clip, lora_name, strength_model, strength_clip, inverse, seed, A, B, preset, block_vector, bypass=False, category_filter=None): if strength_model == 0 and strength_clip == 0 or bypass: - return (model, clip, "") + return model, clip, "" lora_path = folder_paths.get_full_path("loras", lora_name) lora = None @@ -278,7 +426,7 @@ def doit(self, model, clip, lora_name, strength_model, strength_clip, inverse, s self.loaded_lora = (lora_path, lora) model_lora, clip_lora, populated_vector = LoraLoaderBlockWeight.load_lora_for_models(model, clip, lora, strength_model, strength_clip, inverse, seed, A, B, block_vector) - return (model_lora, clip_lora, populated_vector) + return model_lora, clip_lora, populated_vector class XY_Capsule_LoraBlockWeight: diff --git a/pyproject.toml b/pyproject.toml index 2f0d779..d8e2d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-inspire-pack" description = "This extension provides various nodes to support Lora Block Weight and the Impact Pack. Provides many easily applicable regional features and applications for Variation Seed." -version = "0.86.2" +version = "1.0" license = { file = "LICENSE" } dependencies = ["matplotlib", "cachetools"]