Skip to content

Commit

Permalink
Rename ModelPatcher -> LayerPatcher to avoid conflicts with another M…
Browse files Browse the repository at this point in the history
…odelPatcher definition.
  • Loading branch information
RyanJDick authored and psychedelicious committed Dec 17, 2024
1 parent e01d799 commit c407a25
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 33 deletions.
6 changes: 3 additions & 3 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
Expand Down Expand Up @@ -82,7 +82,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_model_patches(
LayerPatcher.apply_model_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
Expand Down Expand Up @@ -179,7 +179,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_model_patches(
LayerPatcher.apply_model_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
Expand Down Expand Up @@ -1003,7 +1003,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_model_patches(
LayerPatcher.apply_model_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -306,7 +306,7 @@ def _run_diffusion(
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
ModelPatcher.apply_model_patches(
LayerPatcher.apply_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
Expand All @@ -321,7 +321,7 @@ def _run_diffusion(
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
ModelPatcher.apply_model_sidecar_patches(
LayerPatcher.apply_model_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo


Expand Down Expand Up @@ -111,7 +111,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
ModelPatcher.apply_model_patches(
LayerPatcher.apply_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo

# The SD3 T5 Max Sequence Length set based on the default in diffusers.
Expand Down Expand Up @@ -150,7 +150,7 @@ def _clip_encode(
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
ModelPatcher.apply_model_patches(
LayerPatcher.apply_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
Expand Down Expand Up @@ -207,7 +207,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
with (
ExitStack() as exit_stack,
unet_info as unet,
ModelPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
Expand Down
22 changes: 11 additions & 11 deletions invokeai/backend/patches/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage


class ModelPatcher:
class LayerPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
Expand All @@ -37,7 +37,7 @@ def apply_model_patches(
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
ModelPatcher.apply_model_patch(
LayerPatcher.apply_model_patch(
model=model,
prefix=prefix,
patch=patch,
Expand Down Expand Up @@ -85,11 +85,11 @@ def apply_model_patch(
if not layer_key.startswith(prefix):
continue

module_key, module = ModelPatcher._get_submodule(
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)

ModelPatcher._apply_model_layer_patch(
LayerPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
Expand Down Expand Up @@ -169,7 +169,7 @@ def apply_model_sidecar_patches(
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
ModelPatcher._apply_model_sidecar_patch(
LayerPatcher._apply_model_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
Expand All @@ -182,9 +182,9 @@ def apply_model_sidecar_patches(
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = ModelPatcher._split_parent_key(module_key)
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
ModelPatcher._set_submodule(parent_module, module_name, orig_module)
LayerPatcher._set_submodule(parent_module, module_name, orig_module)

@staticmethod
def _apply_model_sidecar_patch(
Expand Down Expand Up @@ -212,11 +212,11 @@ def _apply_model_sidecar_patch(
if not layer_key.startswith(prefix):
continue

module_key, module = ModelPatcher._get_submodule(
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)

ModelPatcher._apply_model_layer_wrapper_patch(
LayerPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
Expand All @@ -242,9 +242,9 @@ def _apply_model_layer_wrapper_patch(
if not isinstance(module_to_patch, BaseSidecarWrapper):
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = ModelPatcher._split_parent_key(module_to_patch_key)
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
ModelPatcher._set_submodule(module_parent, module_name, wrapped_module)
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
else:
assert module_to_patch_key in original_modules
wrapped_module = module_to_patch
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/stable_diffusion/extensions/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diffusers import UNet2DConditionModel

from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase

if TYPE_CHECKING:
Expand All @@ -31,7 +31,7 @@ def __init__(
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, ModelPatchRaw)
ModelPatcher.apply_model_patch(
LayerPatcher.apply_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,
Expand Down
12 changes: 6 additions & 6 deletions tests/backend/patches/test_lora_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher


class DummyModule(torch.nn.Module):
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)

with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_apply_lora_patches_change_device():

orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()

with ModelPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
output_before_patch = model(input)

# Patch the model and run inference during the patch.
with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)

# Run inference after unpatching.
Expand Down Expand Up @@ -186,10 +186,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):

input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)

with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)

with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)

# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
Expand Down

0 comments on commit c407a25

Please sign in to comment.