diff --git a/PixArt/diffusers_convert.py b/PixArt/diffusers_convert.py index 2bab19e..7209476 100644 --- a/PixArt/diffusers_convert.py +++ b/PixArt/diffusers_convert.py @@ -104,24 +104,42 @@ def convert_state_dict(state_dict): return new_state_dict # Same as above but for LoRA weights: -def convert_lora_state_dict(state_dict): - # peft - rep_ap = lambda x: x.replace(".weight", ".lora_A.weight") - rep_bp = lambda x: x.replace(".weight", ".lora_B.weight") +def convert_lora_state_dict(state_dict, peft=True): # koyha rep_ak = lambda x: x.replace(".weight", ".lora_down.weight") rep_bk = lambda x: x.replace(".weight", ".lora_up.weight") - - prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight") - state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} + rep_pk = lambda x: x.replace(".weight", ".alpha") + if peft: # peft + rep_ap = lambda x: x.replace(".weight", ".lora_A.weight") + rep_bp = lambda x: x.replace(".weight", ".lora_B.weight") + rep_pp = lambda x: x.replace(".weight", ".alpha") + + prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight") + state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} + else: # OneTrainer + rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight" + rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight" + rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha" + + prefix = "lora_transformer_" + t5_marker = "lora_te_encoder" + t5_keys = [] + for key in list(state_dict.keys()): + if key.startswith(prefix): + state_dict[key[len(prefix):]] = state_dict.pop(key) + elif t5_marker in key: + t5_keys.append(state_dict.pop(key)) + if len(t5_keys) > 0: + print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys") cmap = [] cmap_unet = conversion_map + conversion_map_ms # todo: 512 model for k, v in cmap_unet: - if not v.endswith(".weight"): - continue - cmap.append((rep_ak(k), rep_ap(v))) - cmap.append((rep_bk(k), rep_bp(v))) + if v.endswith(".weight"): + cmap.append((rep_ak(k), rep_ap(v))) + cmap.append((rep_bk(k), rep_bp(v))) + if not peft: + cmap.append((rep_pk(k), rep_pp(v))) missing = [k for k,v in cmap if v not in state_dict] new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} @@ -134,7 +152,12 @@ def convert_lora_state_dict(state_dict): new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat(( state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] ), dim=0) + matched += [key('q'), key('k'), key('v')] + if not peft: + akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") + new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")] + matched += [akey('q'), akey('k'), akey('v')] # Cross-attention (linear) key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") @@ -143,13 +166,19 @@ def convert_lora_state_dict(state_dict): state_dict[key('k')], state_dict[key('v')] ), dim=0) matched += [key('q'), key('k'), key('v')] + if not peft: + akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") + new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")] + new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")] + matched += [akey('q'), akey('k'), akey('v')] + if len(matched) < len(state_dict): print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") print(list( set(state_dict.keys()) - set(matched) )) if len(missing) > 0: - print(f"PixArt: LoRA conversion has missing keys!") + print(f"PixArt: LoRA conversion has missing keys! (probably)") print(missing) return new_state_dict diff --git a/PixArt/lora.py b/PixArt/lora.py index 285fba6..70fa704 100644 --- a/PixArt/lora.py +++ b/PixArt/lora.py @@ -87,7 +87,6 @@ def replace_model_patcher(model): n.object_patches = model.object_patches.copy() n.model_options = copy.deepcopy(model.model_options) - n.model_keys = model.model_keys return n def find_peft_alpha(path): @@ -120,9 +119,11 @@ def load_pixart_lora(model, lora, lora_path, strength): k_back = lambda x: x.replace(".lora_up.weight", "") # need to convert the actual weights for this to work. if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")): - lora = convert_lora_state_dict(lora) + lora = convert_lora_state_dict(lora, peft=True) alpha = find_peft_alpha(lora_path) lora.update({f"{k_back(x)}.alpha":torch.tensor(alpha) for x in lora.keys() if "lora_up" in x}) + else: # OneTrainer + lora = convert_lora_state_dict(lora, peft=False) key_map = {k_back(x):f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake