From b820862eabbb439c30c579c666208414cd413d15 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 14 Dec 2024 15:37:26 +0000 Subject: [PATCH] Rename ModelPatcher methods to reflect that they are general model patching methods and are not LoRA-specific. --- invokeai/app/invocations/compel.py | 4 ++-- invokeai/app/invocations/denoise_latents.py | 2 +- invokeai/app/invocations/flux_denoise.py | 4 ++-- invokeai/app/invocations/flux_text_encoder.py | 2 +- invokeai/app/invocations/sd3_text_encoder.py | 2 +- .../tiled_multi_diffusion_denoise_latents.py | 2 +- invokeai/backend/patches/model_patcher.py | 20 +++++++++---------- .../stable_diffusion/extensions/lora.py | 2 +- tests/backend/patches/test_lora_patcher.py | 10 +++++----- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 82cef172fc1..93523d0052b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -82,7 +82,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora_patches( + ModelPatcher.apply_model_patches( model=text_encoder, patches=_lora_loader(), prefix="lora_te_", @@ -179,7 +179,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora_patches( + ModelPatcher.apply_model_patches( text_encoder, patches=_lora_loader(), prefix=lora_prefix, diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 19ca0172a34..791dc0868c3 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1003,7 +1003,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ModelPatcher.apply_freeu(unet, self.unet.freeu_config), SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_patches( + ModelPatcher.apply_model_patches( model=unet, patches=_lora_loader(), prefix="lora_unet_", diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 3aef3eb24a7..60e103d148d 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -306,7 +306,7 @@ def _run_diffusion( if config.format in [ModelFormat.Checkpoint]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - ModelPatcher.apply_lora_patches( + ModelPatcher.apply_model_patches( model=transformer, patches=self._lora_iterator(context), prefix=FLUX_LORA_TRANSFORMER_PREFIX, @@ -321,7 +321,7 @@ def _run_diffusion( # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference, # than directly patching the weights, but is agnostic to the quantization format. exit_stack.enter_context( - ModelPatcher.apply_lora_sidecar_patches( + ModelPatcher.apply_model_sidecar_patches( model=transformer, patches=self._lora_iterator(context), prefix=FLUX_LORA_TRANSFORMER_PREFIX, diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 0d4c417ae03..424863704f9 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -111,7 +111,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor: if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - ModelPatcher.apply_lora_patches( + ModelPatcher.apply_model_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context), prefix=FLUX_LORA_CLIP_PREFIX, diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index a23af31e1f6..43c9b41fb29 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -150,7 +150,7 @@ def _clip_encode( if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - ModelPatcher.apply_lora_patches( + ModelPatcher.apply_model_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context, clip_model), prefix=FLUX_LORA_CLIP_PREFIX, diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 4910300e1a2..b1c92ccf858 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -207,7 +207,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: with ( ExitStack() as exit_stack, unet_info as unet, - ModelPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"), + ModelPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/patches/model_patcher.py b/invokeai/backend/patches/model_patcher.py index 0ae60c2d7a8..36738793ffc 100644 --- a/invokeai/backend/patches/model_patcher.py +++ b/invokeai/backend/patches/model_patcher.py @@ -17,7 +17,7 @@ class ModelPatcher: @staticmethod @torch.no_grad() @contextmanager - def apply_lora_patches( + def apply_model_patches( model: torch.nn.Module, patches: Iterable[Tuple[LoRAModelRaw, float]], prefix: str, @@ -37,7 +37,7 @@ def apply_lora_patches( original_weights = OriginalWeightsStorage(cached_weights) try: for patch, patch_weight in patches: - ModelPatcher.apply_lora_patch( + ModelPatcher.apply_model_patch( model=model, prefix=prefix, patch=patch, @@ -54,7 +54,7 @@ def apply_lora_patches( @staticmethod @torch.no_grad() - def apply_lora_patch( + def apply_model_patch( model: torch.nn.Module, prefix: str, patch: LoRAModelRaw, @@ -89,7 +89,7 @@ def apply_lora_patch( model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened ) - ModelPatcher._apply_lora_layer_patch( + ModelPatcher._apply_model_layer_patch( module_to_patch=module, module_to_patch_key=module_key, patch=layer, @@ -99,7 +99,7 @@ def apply_lora_patch( @staticmethod @torch.no_grad() - def _apply_lora_layer_patch( + def _apply_model_layer_patch( module_to_patch: torch.nn.Module, module_to_patch_key: str, patch: BaseLayerPatch, @@ -146,7 +146,7 @@ def _apply_lora_layer_patch( @staticmethod @torch.no_grad() @contextmanager - def apply_lora_sidecar_patches( + def apply_model_sidecar_patches( model: torch.nn.Module, patches: Iterable[Tuple[LoRAModelRaw, float]], prefix: str, @@ -169,7 +169,7 @@ def apply_lora_sidecar_patches( original_modules: dict[str, torch.nn.Module] = {} try: for patch, patch_weight in patches: - ModelPatcher._apply_lora_sidecar_patch( + ModelPatcher._apply_model_sidecar_patch( model=model, prefix=prefix, patch=patch, @@ -187,7 +187,7 @@ def apply_lora_sidecar_patches( ModelPatcher._set_submodule(parent_module, module_name, orig_module) @staticmethod - def _apply_lora_sidecar_patch( + def _apply_model_sidecar_patch( model: torch.nn.Module, patch: LoRAModelRaw, patch_weight: float, @@ -216,7 +216,7 @@ def _apply_lora_sidecar_patch( model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened ) - ModelPatcher._apply_lora_layer_wrapper_patch( + ModelPatcher._apply_model_layer_wrapper_patch( model=model, module_to_patch=module, module_to_patch_key=module_key, @@ -228,7 +228,7 @@ def _apply_lora_sidecar_patch( @staticmethod @torch.no_grad() - def _apply_lora_layer_wrapper_patch( + def _apply_model_layer_wrapper_patch( model: torch.nn.Module, module_to_patch: torch.nn.Module, module_to_patch_key: str, diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 6d5549560b7..27ed0ed7b74 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -31,7 +31,7 @@ def __init__( def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): lora_model = self._node_context.models.load(self._model_id).model assert isinstance(lora_model, LoRAModelRaw) - ModelPatcher.apply_lora_patch( + ModelPatcher.apply_model_patch( model=unet, prefix="lora_unet_", patch=lora_model, diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_lora_patcher.py index f748145863d..dd486588d92 100644 --- a/tests/backend/patches/test_lora_patcher.py +++ b/tests/backend/patches/test_lora_patcher.py @@ -53,7 +53,7 @@ def test_apply_lora_patches(device: str, num_layers: int): orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers) - with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""): + with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): # After patching, all LoRA layer weights should have been moved back to the cpu. for lora, _ in lora_models: assert lora.layers["linear_layer_1"].up.device.type == "cpu" @@ -93,7 +93,7 @@ def test_apply_lora_patches_change_device(): orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() - with ModelPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""): + with ModelPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""): # After patching, all LoRA layer weights should have been moved back to the cpu. assert lora_layers["linear_layer_1"].up.device.type == "cpu" assert lora_layers["linear_layer_1"].down.device.type == "cpu" @@ -146,7 +146,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int): output_before_patch = model(input) # Patch the model and run inference during the patch. - with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): output_during_patch = model(input) # Run inference after unpatching. @@ -186,10 +186,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype) - with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""): + with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): output_lora_patches = model(input) - with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): output_lora_sidecar_patches = model(input) # Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical