Skip to content

Commit

Permalink
Upgrade LoRA Block Weight
Browse files Browse the repository at this point in the history
- support % syntax
  • Loading branch information
ltdrdata committed Aug 28, 2024
1 parent 89e24b7 commit 50d42c3
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 18 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import importlib

version_code = [0, 86, 2]
version_code = [1, 0]
version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '')
print(f"### Loading: ComfyUI-Inspire-Pack ({version_str})")

Expand Down
180 changes: 164 additions & 16 deletions inspire/lora_block_weight.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import regex

import folder_paths
import comfy.utils
import comfy.lora
Expand Down Expand Up @@ -34,6 +36,13 @@ def load_lbw_preset(filename):
return []


def parse_unet_num(s):
if s[1] == '.':
return int(s[0])
else:
return int(s)


class LoraLoaderBlockWeight:
def __init__(self):
self.loaded_lora = None
Expand All @@ -55,8 +64,8 @@ def INPUT_TYPES(s):
"lora_name": (lora_names, ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"inverse": ("BOOLEAN", {"default": False, "label_on": "True", "label_off": "False"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"inverse": ("BOOLEAN", {"default": False, "label_on": "True", "label_off": "False", "tooltip": "Apply the following weights for each block:\nTrue: 1 - weight\nFalse: weight"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": ""}),
"A": ("FLOAT", {"default": 4.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"B": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"preset": (preset,),
Expand Down Expand Up @@ -129,12 +138,151 @@ def norm_value(value): # make to int if 1.0 or 0.0
else:
return value

@staticmethod
def block_spec_parser(loaded, spec):
if not spec.startswith("%"):
return spec
else:
items = [x.strip() for x in spec[1:].split(',')]

input_blocks_set = set()
middle_blocks_set= set()
output_blocks_set = set()
double_blocks_set = set()
single_blocks_set = set()

for key, v in loaded.items():
if isinstance(key, tuple):
k = key[0]
else:
k = key

k_unet = k[len("diffusion_model."):]

if k_unet.startswith("input_blocks."):
k_unet_num = k_unet[len("input_blocks."):len("input_blocks.")+2]
k_unet_int = parse_unet_num(k_unet_num)
input_blocks_set.add(k_unet_int)
elif k_unet.startswith("middle_block."):
k_unet_num = k_unet[len("middle_block."):len("middle_block.")+2]
k_unet_int = parse_unet_num(k_unet_num)
middle_blocks_set.add(k_unet_int)
elif k_unet.startswith("output_blocks."):
k_unet_num = k_unet[len("output_blocks."):len("output_blocks.")+2]
k_unet_int = parse_unet_num(k_unet_num)
output_blocks_set.add(k_unet_int)
elif k_unet.startswith("double_blocks."):
k_unet_num = k_unet[len("double_blocks."):len("double_blocks.") + 2]
k_unet_int = parse_unet_num(k_unet_num)
double_blocks_set.add(k_unet_int)
elif k_unet.startswith("single_blocks."):
k_unet_num = k_unet[len("single_blocks."):len("single_blocks.") + 2]
k_unet_int = parse_unet_num(k_unet_num)
single_blocks_set.add(k_unet_int)

pat1 = re.compile(r"(default|base)=([0-9.]+)")
pat2 = re.compile(r"(in|out|mid|double|single)([0-9]+)-([0-9]+)=([0-9.]+)")
pat3 = re.compile(r"(in|out|mid|double|single)([0-9]+)=([0-9.]+)")
pat4 = re.compile(r"(in|out|mid|double|single)=([0-9.]+)")

base_spec = None
default_spec = 1.0

for item in items:
match = pat1.match(item)
if match:
if match[1] == 'base':
base_spec = match[2]
continue

if match[1] == 'default':
default_spec = match[2]
continue

if base_spec is None:
base_spec = default_spec

input_blocks = [default_spec] * len(input_blocks_set)
middle_blocks = [default_spec] * len(middle_blocks_set)
output_blocks = [default_spec] * len(output_blocks_set)
double_blocks = [default_spec] * len(double_blocks_set)
single_blocks = [default_spec] * len(single_blocks_set)

for item in items:
match = pat2.match(item)
if match:
for x in range(int(match[2])-1, int(match[3])):
value = float(match[4])

if x < 0:
continue

if match[1] == 'in' and len(input_blocks) > x:
input_blocks[x] = value
elif match[1] == 'out' and len(output_blocks) > x:
output_blocks[x] = value
elif match[1] == 'mid' and len(middle_blocks) > x:
middle_blocks[x] = value
elif match[1] == 'double' and len(double_blocks) > x:
double_blocks[x] = value
elif match[1] == 'single' and len(single_blocks) > x:
single_blocks[x] = value

continue

match = pat3.match(item)
if match:
value = float(match[3])
x = int(match[2]) - 1

if x < 0:
continue

if match[1] == 'in' and len(input_blocks) > x:
input_blocks[x] = value
elif match[1] == 'out' and len(output_blocks) > x:
output_blocks[x] = value
elif match[1] == 'mid' and len(middle_blocks) > x:
middle_blocks[x] = value
elif match[1] == 'double' and len(double_blocks) > x:
double_blocks[x] = value
elif match[1] == 'single' and len(single_blocks) > x:
single_blocks[x] = value

continue

match = pat4.match(item)
if match:
value = float(match[2])

if match[1] == 'in':
input_blocks = [value] * len(input_blocks)
elif match[1] == 'out':
output_blocks = [value] * len(output_blocks)
elif match[1] == 'mid':
middle_blocks = [value] * len(middle_blocks)
elif match[1] == 'double':
double_blocks = [value] * len(double_blocks)
elif match[1] == 'single':
single_blocks = [value] * len(single_blocks)

continue

# concat specs
res = [str(base_spec)]
for x in (input_blocks + middle_blocks + output_blocks + double_blocks + single_blocks):
res.append(str(x))

return ",".join(res)

@staticmethod
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, inverse, seed, A, B, block_vector):
key_map = comfy.lora.model_lora_keys_unet(model.model)
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)

block_vector = LoraLoaderBlockWeight.block_spec_parser(loaded, block_vector)

block_vector = block_vector.split(":")
if len(block_vector) > 1:
block_vector = block_vector[1]
Expand All @@ -153,13 +301,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, inver

last_k_unet_num = None
new_modelpatcher = model.clone()
populated_ratio = strength_model

def parse_unet_num(s):
if s[1] == '.':
return int(s[0])
else:
return int(s)

# sort: input, middle, output, others
input_blocks = []
Expand Down Expand Up @@ -204,6 +345,8 @@ def parse_unet_num(s):
np.random.seed(seed % (2**31))
populated_vector_list = []
ratios = []
ratio = 1.0

for k, v, k_unet_num, k_unet in (input_blocks + middle_blocks + output_blocks + double_blocks + single_blocks):
if last_k_unet_num != k_unet_num and len(vector) > vector_i:
ratios = LoraLoaderBlockWeight.convert_vector_value(A, B, vector[vector_i].strip())
Expand All @@ -220,6 +363,8 @@ def parse_unet_num(s):
else:
if len(ratios) > 0:
ratio = ratios.pop(0)
else:
pass # use last used ratio if no more user specified ratio is given

if inverse:
populated_ratio = 1 - ratio
Expand All @@ -243,25 +388,28 @@ def parse_unet_num(s):
if inverse:
populated_ratio = 1 - ratio
else:
populated_ratio = 1
populated_ratio = ratio

populated_vector_list.insert(0, LoraLoaderBlockWeight.norm_value(populated_ratio))

new_clip = clip.clone()
for k, v, k_unet in others:
new_modelpatcher.add_patches({k: v}, strength_model * populated_ratio)
if 'text' in k_unet:
new_clip.add_patches({k: v}, strength_clip * populated_ratio)
else:
new_modelpatcher.add_patches({k: v}, strength_model * populated_ratio)

# if inverse:
# print(f"\t{k_unet} -> inv({ratio}) ")
# else:
# print(f"\t{k_unet} -> ({ratio}) ")

new_clip = clip.clone()
new_clip.add_patches(loaded, strength_clip)
populated_vector = ','.join(map(str, populated_vector_list))
return (new_modelpatcher, new_clip, populated_vector)
return new_modelpatcher, new_clip, populated_vector

def doit(self, model, clip, lora_name, strength_model, strength_clip, inverse, seed, A, B, preset, block_vector, bypass=False, category_filter=None):
if strength_model == 0 and strength_clip == 0 or bypass:
return (model, clip, "")
return model, clip, ""

lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
Expand All @@ -278,7 +426,7 @@ def doit(self, model, clip, lora_name, strength_model, strength_clip, inverse, s
self.loaded_lora = (lora_path, lora)

model_lora, clip_lora, populated_vector = LoraLoaderBlockWeight.load_lora_for_models(model, clip, lora, strength_model, strength_clip, inverse, seed, A, B, block_vector)
return (model_lora, clip_lora, populated_vector)
return model_lora, clip_lora, populated_vector


class XY_Capsule_LoraBlockWeight:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-inspire-pack"
description = "This extension provides various nodes to support Lora Block Weight and the Impact Pack. Provides many easily applicable regional features and applications for Variation Seed."
version = "0.86.2"
version = "1.0"
license = { file = "LICENSE" }
dependencies = ["matplotlib", "cachetools"]

Expand Down

0 comments on commit 50d42c3

Please sign in to comment.