Skip to content

Commit

Permalink
Rename ModelPatcher methods to reflect that they are general model pa…
Browse files Browse the repository at this point in the history
…tching methods and are not LoRA-specific.
  • Loading branch information
RyanJDick committed Dec 17, 2024
1 parent c604a09 commit b820862
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_patches(
ModelPatcher.apply_model_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
Expand Down Expand Up @@ -179,7 +179,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_patches(
ModelPatcher.apply_model_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_patches(
ModelPatcher.apply_model_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _run_diffusion(
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
ModelPatcher.apply_lora_patches(
ModelPatcher.apply_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
Expand All @@ -321,7 +321,7 @@ def _run_diffusion(
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
ModelPatcher.apply_lora_sidecar_patches(
ModelPatcher.apply_model_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
ModelPatcher.apply_lora_patches(
ModelPatcher.apply_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _clip_encode(
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
ModelPatcher.apply_lora_patches(
ModelPatcher.apply_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
with (
ExitStack() as exit_stack,
unet_info as unet,
ModelPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
ModelPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
Expand Down
20 changes: 10 additions & 10 deletions invokeai/backend/patches/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ModelPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
def apply_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
Expand All @@ -37,7 +37,7 @@ def apply_lora_patches(
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
ModelPatcher.apply_lora_patch(
ModelPatcher.apply_model_patch(
model=model,
prefix=prefix,
patch=patch,
Expand All @@ -54,7 +54,7 @@ def apply_lora_patches(

@staticmethod
@torch.no_grad()
def apply_lora_patch(
def apply_model_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
Expand Down Expand Up @@ -89,7 +89,7 @@ def apply_lora_patch(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)

ModelPatcher._apply_lora_layer_patch(
ModelPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
Expand All @@ -99,7 +99,7 @@ def apply_lora_patch(

@staticmethod
@torch.no_grad()
def _apply_lora_layer_patch(
def _apply_model_layer_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
Expand Down Expand Up @@ -146,7 +146,7 @@ def _apply_lora_layer_patch(
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
def apply_model_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
Expand All @@ -169,7 +169,7 @@ def apply_lora_sidecar_patches(
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
ModelPatcher._apply_lora_sidecar_patch(
ModelPatcher._apply_model_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
Expand All @@ -187,7 +187,7 @@ def apply_lora_sidecar_patches(
ModelPatcher._set_submodule(parent_module, module_name, orig_module)

@staticmethod
def _apply_lora_sidecar_patch(
def _apply_model_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
Expand Down Expand Up @@ -216,7 +216,7 @@ def _apply_lora_sidecar_patch(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)

ModelPatcher._apply_lora_layer_wrapper_patch(
ModelPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
Expand All @@ -228,7 +228,7 @@ def _apply_lora_sidecar_patch(

@staticmethod
@torch.no_grad()
def _apply_lora_layer_wrapper_patch(
def _apply_model_layer_wrapper_patch(
model: torch.nn.Module,
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/stable_diffusion/extensions/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, LoRAModelRaw)
ModelPatcher.apply_lora_patch(
ModelPatcher.apply_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,
Expand Down
10 changes: 5 additions & 5 deletions tests/backend/patches/test_lora_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)

with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_apply_lora_patches_change_device():

orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()

with ModelPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
with ModelPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
output_before_patch = model(input)

# Patch the model and run inference during the patch.
with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)

# Run inference after unpatching.
Expand Down Expand Up @@ -186,10 +186,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):

input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)

with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)

with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)

# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
Expand Down

0 comments on commit b820862

Please sign in to comment.