Skip to content

Commit

Permalink
Rename LoRAModelRaw to ModelPatchRaw.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Dec 17, 2024
1 parent b820862 commit 7fad4c9
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 50 deletions.
10 changes: 5 additions & 5 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
Expand Down Expand Up @@ -66,10 +66,10 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)

def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
Expand Down Expand Up @@ -162,11 +162,11 @@ def run_clip_compel(
c_pooled = None
return c, c_pooled

def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
assert isinstance(lora_model, ModelPatchRaw)
yield (lora_model, lora.weight)
del lora_info
return
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
Expand Down Expand Up @@ -987,10 +987,10 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
Expand Down Expand Up @@ -715,15 +715,15 @@ def _prep_ip_adapter_extensions(

return pos_ip_adapter_extensions, neg_ip_adapter_extensions

def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
# applied last.
loras.append(self.control_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo

Expand Down Expand Up @@ -130,9 +130,9 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
6 changes: 3 additions & 3 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo

Expand Down Expand Up @@ -193,9 +193,9 @@ def _clip_encode(

def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
Expand Down Expand Up @@ -194,10 +194,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/load/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
Expand All @@ -43,7 +43,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
(
TextualInversionModelRaw,
IPAdapter,
LoRAModelRaw,
ModelPatchRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
Expand Down Expand Up @@ -43,7 +43,7 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
)


def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
Expand Down Expand Up @@ -81,4 +81,4 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
else:
raise ValueError(f"{layer_key} not expected")

return LoRAModelRaw(layers=layers)
return ModelPatchRaw(layers=layers)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw


def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
Expand All @@ -30,7 +30,9 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present


def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
Expand Down Expand Up @@ -215,7 +217,7 @@ def add_qkv_lora_layer_if_present(

layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}

return LoRAModelRaw(layers=layers_with_prefix)
return ModelPatchRaw(layers=layers_with_prefix)


def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
# Example keys:
Expand Down Expand Up @@ -39,7 +39,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
)


def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
Expand Down Expand Up @@ -71,7 +71,7 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)

# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
return ModelPatchRaw(layers=layers)


T = TypeVar("T")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw


def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)

layers: dict[str, BaseLayerPatch] = {}
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)

return LoRAModelRaw(layers=layers)
return ModelPatchRaw(layers=layers)


def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from invokeai.backend.raw_model import RawModel


class LoRAModelRaw(RawModel):
class ModelPatchRaw(RawModel):
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
self.layers = layers

Expand Down
10 changes: 5 additions & 5 deletions invokeai/backend/patches/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
Expand All @@ -19,7 +19,7 @@ class ModelPatcher:
@contextmanager
def apply_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
Expand Down Expand Up @@ -57,7 +57,7 @@ def apply_model_patches(
def apply_model_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch: ModelPatchRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
Expand Down Expand Up @@ -148,7 +148,7 @@ def _apply_model_layer_patch(
@contextmanager
def apply_model_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
dtype: torch.dtype,
):
Expand Down Expand Up @@ -189,7 +189,7 @@ def apply_model_sidecar_patches(
@staticmethod
def _apply_model_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch: ModelPatchRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/stable_diffusion/extensions/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from diffusers import UNet2DConditionModel

from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase

Expand All @@ -30,7 +30,7 @@ def __init__(
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, LoRAModelRaw)
assert isinstance(lora_model, ModelPatchRaw)
ModelPatcher.apply_model_patch(
model=unet,
prefix="lora_unet_",
Expand Down
Loading

0 comments on commit 7fad4c9

Please sign in to comment.