Skip to content

Commit

Permalink
PixArt LoRA changes
Browse files Browse the repository at this point in the history
  • Loading branch information
city96 committed Jun 11, 2024
1 parent 0f83907 commit 9da9978
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
53 changes: 41 additions & 12 deletions PixArt/diffusers_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,42 @@ def convert_state_dict(state_dict):
return new_state_dict

# Same as above but for LoRA weights:
def convert_lora_state_dict(state_dict):
# peft
rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
def convert_lora_state_dict(state_dict, peft=True):
# koyha
rep_ak = lambda x: x.replace(".weight", ".lora_down.weight")
rep_bk = lambda x: x.replace(".weight", ".lora_up.weight")

prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}
rep_pk = lambda x: x.replace(".weight", ".alpha")
if peft: # peft
rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
rep_pp = lambda x: x.replace(".weight", ".alpha")

prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}
else: # OneTrainer
rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight"
rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight"
rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha"

prefix = "lora_transformer_"
t5_marker = "lora_te_encoder"
t5_keys = []
for key in list(state_dict.keys()):
if key.startswith(prefix):
state_dict[key[len(prefix):]] = state_dict.pop(key)
elif t5_marker in key:
t5_keys.append(state_dict.pop(key))
if len(t5_keys) > 0:
print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys")

cmap = []
cmap_unet = conversion_map + conversion_map_ms # todo: 512 model
for k, v in cmap_unet:
if not v.endswith(".weight"):
continue
cmap.append((rep_ak(k), rep_ap(v)))
cmap.append((rep_bk(k), rep_bp(v)))
if v.endswith(".weight"):
cmap.append((rep_ak(k), rep_ap(v)))
cmap.append((rep_bk(k), rep_bp(v)))
if not peft:
cmap.append((rep_pk(k), rep_pp(v)))

missing = [k for k,v in cmap if v not in state_dict]
new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing}
Expand All @@ -134,7 +152,12 @@ def convert_lora_state_dict(state_dict):
new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat((
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
), dim=0)

matched += [key('q'), key('k'), key('v')]
if not peft:
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")]
matched += [akey('q'), akey('k'), akey('v')]

# Cross-attention (linear)
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
Expand All @@ -143,13 +166,19 @@ def convert_lora_state_dict(state_dict):
state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if not peft:
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")]
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")]
matched += [akey('q'), akey('k'), akey('v')]


if len(matched) < len(state_dict):
print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
print(list( set(state_dict.keys()) - set(matched) ))

if len(missing) > 0:
print(f"PixArt: LoRA conversion has missing keys!")
print(f"PixArt: LoRA conversion has missing keys! (probably)")
print(missing)

return new_state_dict
5 changes: 3 additions & 2 deletions PixArt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def replace_model_patcher(model):

n.object_patches = model.object_patches.copy()
n.model_options = copy.deepcopy(model.model_options)
n.model_keys = model.model_keys
return n

def find_peft_alpha(path):
Expand Down Expand Up @@ -120,9 +119,11 @@ def load_pixart_lora(model, lora, lora_path, strength):
k_back = lambda x: x.replace(".lora_up.weight", "")
# need to convert the actual weights for this to work.
if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")):
lora = convert_lora_state_dict(lora)
lora = convert_lora_state_dict(lora, peft=True)
alpha = find_peft_alpha(lora_path)
lora.update({f"{k_back(x)}.alpha":torch.tensor(alpha) for x in lora.keys() if "lora_up" in x})
else: # OneTrainer
lora = convert_lora_state_dict(lora, peft=False)

key_map = {k_back(x):f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake

Expand Down

0 comments on commit 9da9978

Please sign in to comment.