Skip to content

Commit

Permalink
Move FLUX_LORA_TRANSFORMER_PREFIX and FLUX_LORA_CLIP_PREFIX to a shar…
Browse files Browse the repository at this point in the history
…ed location.
  • Loading branch information
RyanJDick authored and hipsterusername committed Oct 1, 2024
1 parent 68dbe45 commit 807f458
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 20 deletions.
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
pack,
unpack,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
Expand Down Expand Up @@ -209,7 +209,7 @@ def _run_diffusion(
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
cached_weights=cached_weights,
)
)
Expand All @@ -220,7 +220,7 @@ def _run_diffusion(
LoRAPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
Expand Down Expand Up @@ -101,7 +101,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_KOHYA_CLIP_PREFIX,
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
Expand Down Expand Up @@ -190,7 +190,7 @@ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0

layers_with_prefix = {f"{FLUX_KOHYA_TRANFORMER_PREFIX}{k}": v for k, v in layers.items()}
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}

return LoRAModelRaw(layers=layers_with_prefix)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
Expand All @@ -23,11 +24,6 @@
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"


# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"


def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
Expand Down Expand Up @@ -67,9 +63,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)

# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
Expand Down
3 changes: 3 additions & 0 deletions invokeai/backend/lora/conversions/flux_lora_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_KOHYA_TRANFORMER_PREFIX) for k in model.layers.keys())
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())


def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_PREFIX,
FLUX_KOHYA_TRANFORMER_PREFIX,
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
Expand Down Expand Up @@ -95,8 +94,8 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
expected_layer_keys: set[str] = set()
for k in sd_keys:
# Replace prefixes.
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
k = k.replace("lora_unet_", FLUX_LORA_TRANSFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_LORA_CLIP_PREFIX)
# Remove suffixes.
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
Expand Down

0 comments on commit 807f458

Please sign in to comment.