From f8d8f48f63b414160f633e5ba6f22842a44f7758 Mon Sep 17 00:00:00 2001 From: s1dlx Date: Wed, 2 Aug 2023 17:56:14 +0100 Subject: [PATCH] will it merge? --- merge_models.py | 3 +++ pyproject.toml | 2 +- sd_meh/__init__.py | 2 +- sd_meh/merge.py | 32 +++++++++++++++++++++++++++----- sd_meh/rebasin.py | 6 ++---- sd_meh/utils.py | 17 +++++++++++------ 6 files changed, 45 insertions(+), 17 deletions(-) diff --git a/merge_models.py b/merge_models.py index 530cf01..524bd0d 100644 --- a/merge_models.py +++ b/merge_models.py @@ -111,6 +111,7 @@ type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False), default="INFO", ) +@click.option("-xl", "--sdxl", "sdxl", is_flag=True) def main( model_a, model_b, @@ -137,6 +138,7 @@ def main( presets_alpha_lambda, presets_beta_lambda, logging_level, + sdxl, ): if logging_level: logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level) @@ -157,6 +159,7 @@ def main( block_weights_preset_beta_b, presets_alpha_lambda, presets_beta_lambda, + sdxl, ) merged = merge_models( diff --git a/pyproject.toml b/pyproject.toml index b2d2cde..93cb8ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sd-meh" -version = "0.9.4" +version = "0.10.0" description = "stable diffusion merging execution helper" authors = ["s1dlx "] license = "MIT" diff --git a/sd_meh/__init__.py b/sd_meh/__init__.py index e94731c..61fb31c 100644 --- a/sd_meh/__init__.py +++ b/sd_meh/__init__.py @@ -1 +1 @@ -__version__ = "0.9.4" +__version__ = "0.10.0" diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 760391e..03b54b2 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -28,6 +28,10 @@ NUM_OUTPUT_BLOCKS = 12 NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS +NUM_INPUT_BLOCKS_XL = 9 +NUM_OUTPUT_BLOCKS_XL = 9 +NUM_TOTAL_BLOCKS_XL = NUM_INPUT_BLOCKS_XL + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS_XL + KEY_POSITION_IDS = ".".join( [ "cond_stage_model", @@ -144,6 +148,11 @@ def merge_models( ) -> Dict: thetas = load_thetas(models, prune, device, precision) + sdxl = ( + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" + in thetas["model_a"].keys() + ) + logging.info(f"start merging with {merge_mode} method") if re_basin: merged = rebasin_merge( @@ -157,6 +166,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) else: merged = simple_merge( @@ -169,6 +179,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) return un_prune_model(merged, thetas, models, device, prune, precision) @@ -221,6 +232,7 @@ def simple_merge( device: str = "cpu", work_device: Optional[str] = None, threads: int = 1, + sdxl: bool = False, ) -> Dict: futures = [] with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress: @@ -238,6 +250,7 @@ def simple_merge( weights_clip, device, work_device, + sdxl, ) futures.append(future) @@ -270,6 +283,7 @@ def rebasin_merge( device="cpu", work_device=None, threads: int = 1, + sdxl: bool = False, ): # WARNING: not sure how this does when 3 models are involved... @@ -299,6 +313,7 @@ def rebasin_merge( device, work_device, threads, + sdxl, ) log_vram("simple merge done") @@ -367,6 +382,7 @@ def merge_key( weights_clip: bool = False, device: str = "cpu", work_device: Optional[str] = None, + sdxl: bool = False, ) -> Optional[Tuple[str, Dict]]: if work_device is None: work_device = device @@ -391,16 +407,22 @@ def merge_key( if "time_embed" in key: weight_index = 0 # before input blocks elif ".out." in key: - weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks + weight_index = ( + NUM_TOTAL_BLOCKS_XL - 1 if sdxl else NUM_TOTAL_BLOCKS - 1 + ) # after output blocks elif m := re_inp.search(key): weight_index = int(m.groups()[0]) elif re_mid.search(key): - weight_index = NUM_INPUT_BLOCKS + weight_index = NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS elif m := re_out.search(key): - weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + int(m.groups()[0]) + weight_index = ( + (NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS) + + NUM_MID_BLOCK + + int(m.groups()[0]) + ) - if weight_index >= NUM_TOTAL_BLOCKS: - raise ValueError(f"illegal block index {key}") + if weight_index >= (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS): + raise ValueError(f"illegal block index {weight_index} for key {key}") if weight_index >= 0: current_bases = {k: w[weight_index] for k, w in weights.items()} diff --git a/sd_meh/rebasin.py b/sd_meh/rebasin.py index 2fbb418..010d67f 100644 --- a/sd_meh/rebasin.py +++ b/sd_meh/rebasin.py @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params): def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): for k in model_a: try: - perm_params = get_permuted_param( - ps, perm, k, model_a - ) + perm_params = get_permuted_param(ps, perm, k, model_a) model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params - except RuntimeError: # dealing with pix2pix and inpainting models + except RuntimeError: # dealing with pix2pix and inpainting models continue return model_a diff --git a/sd_meh/utils.py b/sd_meh/utils.py index f507ae8..f0d2295 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -2,7 +2,7 @@ import logging from sd_meh import merge_methods -from sd_meh.merge import NUM_TOTAL_BLOCKS +from sd_meh.merge import NUM_TOTAL_BLOCKS, NUM_TOTAL_BLOCKS_XL from sd_meh.presets import BLOCK_WEIGHTS_PRESETS MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction)) @@ -13,25 +13,25 @@ ] -def compute_weights(weights, base): +def compute_weights(weights, base, sdxl: bool = False): if not weights: - return [base] * NUM_TOTAL_BLOCKS + return [base] * (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS) if "," not in weights: return weights w_alpha = list(map(float, weights.split(","))) - if len(w_alpha) == NUM_TOTAL_BLOCKS: + if len(w_alpha) == (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS): return w_alpha -def assemble_weights_and_bases(preset, weights, base, greek_letter): +def assemble_weights_and_bases(preset, weights, base, greek_letter, sdxl: bool = False): logging.info(f"Assembling {greek_letter} w&b") if preset: logging.info(f"Using {preset} preset") base, *weights = BLOCK_WEIGHTS_PRESETS[preset] bases = {greek_letter: base} - weights = {greek_letter: compute_weights(weights, base)} + weights = {greek_letter: compute_weights(weights, base, sdxl)} logging.info(f"base_{greek_letter}: {bases[greek_letter]}") logging.info(f"{greek_letter} weights: {weights[greek_letter]}") @@ -70,12 +70,14 @@ def weights_and_bases( block_weights_preset_beta_b, presets_alpha_lambda, presets_beta_lambda, + sdxl: bool = False, ): weights, bases = assemble_weights_and_bases( block_weights_preset_alpha, weights_alpha, base_alpha, "alpha", + sdxl, ) if block_weights_preset_alpha_b: @@ -85,6 +87,7 @@ def weights_and_bases( None, None, "alpha", + sdxl, ) weights, bases = interpolate_presets( weights, @@ -101,6 +104,7 @@ def weights_and_bases( weights_beta, base_beta, "beta", + sdxl, ) if block_weights_preset_beta_b: @@ -110,6 +114,7 @@ def weights_and_bases( None, None, "beta", + sdxl, ) weights, bases = interpolate_presets( weights,