Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add boft support in stable-diffusion #1295

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ slow_tests_custom_file_input: test_installs
# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
python -m pip install peft==0.10.0
python -m pip install peft==0.12.0
python -m pytest tests/test_peft_inference.py
python -m pytest tests/test_pipeline.py

Expand All @@ -96,7 +96,7 @@ slow_tests_deepspeed: test_installs
slow_tests_diffusers: test_installs
python -m pytest tests/test_diffusers.py -v -s -k "test_no_"
python -m pytest tests/test_diffusers.py -v -s -k "test_textual_inversion"
python -m pip install peft==0.7.0
python -m pip install peft==0.12.0
python -m pytest tests/test_diffusers.py -v -s -k "test_train_text_to_image_"
python -m pytest tests/test_diffusers.py -v -s -k "test_train_controlnet"
python -m pytest tests/test_diffusers.py -v -s -k "test_deterministic_image_generation"
Expand Down
41 changes: 36 additions & 5 deletions examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,19 +464,50 @@ def main():
)

if args.unet_adapter_name_or_path is not None:
from peft import PeftModel
from peft import PeftModel, tuners
from peft.utils import PeftType

from optimum.habana.peft.layer import GaudiBoftGetDeltaWeight

tuners.boft.layer.Linear.get_delta_weight = GaudiBoftGetDeltaWeight
tuners.boft.layer.Conv2d.get_delta_weight = GaudiBoftGetDeltaWeight
tuners.boft.layer._FBD_CUDA = False

pipeline.unet = PeftModel.from_pretrained(pipeline.unet, args.unet_adapter_name_or_path)
pipeline.unet = pipeline.unet.merge_and_unload()
if pipeline.unet.peft_type in [PeftType.OFT, PeftType.BOFT]:
# WA torch.inverse issue in Synapse AI 1.17 for oft and boft
if args.bf16:
pipeline.unet = pipeline.unet.to(torch.float32)
pipeline.unet = pipeline.unet.merge_and_unload()
if args.bf16:
pipeline.unet = pipeline.unet.to(torch.bfloat16)
else:
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16):
pipeline.unet = pipeline.unet.merge_and_unload()

if args.text_encoder_adapter_name_or_path is not None:
from peft import PeftModel
from peft import PeftModel, tuners
from peft.utils import PeftType

from optimum.habana.peft.layer import GaudiBoftGetDeltaWeight

tuners.boft.layer.Linear.get_delta_weight = GaudiBoftGetDeltaWeight
tuners.boft.layer.Conv2d.get_delta_weight = GaudiBoftGetDeltaWeight
tuners.boft.layer._FBD_CUDA = False

pipeline.text_encoder = PeftModel.from_pretrained(
pipeline.text_encoder, args.text_encoder_adapter_name_or_path
)
pipeline.text_encoder = pipeline.text_encoder.merge_and_unload()

if pipeline.text_encoder.peft_type in [PeftType.OFT, PeftType.BOFT]:
# WA torch.inverse issue in Synapse AI 1.17 for oft and boft
if args.bf16:
pipeline.text_encoder = pipeline.text_encoder.to(torch.float32)
pipeline.text_encoder = pipeline.text_encoder.merge_and_unload()
if args.bf16:
pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
else:
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16):
pipeline.text_encoder = pipeline.text_encoder.merge_and_unload()
else:
# SD LDM3D use-case
from optimum.habana.diffusers import GaudiStableDiffusionLDM3DPipeline as GaudiStableDiffusionPipeline
Expand Down
4 changes: 2 additions & 2 deletions examples/stable-diffusion/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,15 +355,15 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \
lora --unet_r 8 --unet_alpha 8

```
Similar command could be applied to loha, lokr, oft.
Similar command could be applied to loha, lokr, oft, boft.
You could check each adapter specific args by "--help", like you could use following command to check oft specific args.

```bash
python3 train_dreambooth.py oft --help

```

**___Note: oft could not work with hpu graphs mode. since "torch.inverse" need to fallback to cpu.
**___Note: boft/oft could not work with hpu graphs mode. since "torch.inverse" "torch.linalg.solve" need to fallback to cpu.
there's error like "cpu fallback is not supported during hpu graph capturing"___**


Expand Down
2 changes: 1 addition & 1 deletion examples/stable-diffusion/training/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
imagesize
peft == 0.10.0
peft == 0.12.0
75 changes: 72 additions & 3 deletions examples/stable-diffusion/training/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from diffusers.utils.torch_utils import is_compiled_module
from habana_frameworks.torch.hpu import memory_stats
from huggingface_hub import HfApi
from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model
from peft import BOFTConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model, tuners
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
Expand Down Expand Up @@ -108,7 +108,9 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")


def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]:
def create_unet_adapter_config(
args: argparse.Namespace,
) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig, BOFTConfig]:
if args.adapter == "full":
raise ValueError("Cannot create unet adapter config for full parameter")

Expand Down Expand Up @@ -152,6 +154,21 @@ def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, Lo
coft=args.unet_use_coft,
eps=args.unet_eps,
)
elif args.adapter == "boft":
config = BOFTConfig(
boft_block_size=args.unet_block_size,
boft_block_num=args.unet_block_num,
boft_n_butterfly_factor=args.unet_n_butterfly_factor,
target_modules=UNET_TARGET_MODULES,
boft_dropout=args.unet_dropout,
bias=args.unet_bias,
)
from optimum.habana.peft.layer import GaudiBoftConv2dForward, GaudiBoftLinearForward

tuners.boft.layer.Linear.forward = GaudiBoftLinearForward
tuners.boft.layer.Conv2d.forward = GaudiBoftConv2dForward
tuners.boft.layer._FBD_CUDA = False

else:
raise ValueError(f"Unknown adapter type {args.adapter}")

Expand All @@ -160,7 +177,7 @@ def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, Lo

def create_text_encoder_adapter_config(
args: argparse.Namespace,
) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]:
) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig, BOFTConfig]:
if args.adapter == "full":
raise ValueError("Cannot create text_encoder adapter config for full parameter")

Expand Down Expand Up @@ -202,6 +219,20 @@ def create_text_encoder_adapter_config(
coft=args.te_use_coft,
eps=args.te_eps,
)
elif args.adapter == "boft":
config = BOFTConfig(
boft_block_size=args.te_block_size,
boft_block_num=args.te_block_num,
boft_n_butterfly_factor=args.te_n_butterfly_factor,
target_modules=TEXT_ENCODER_TARGET_MODULES,
boft_dropout=args.te_dropout,
bias=args.te_bias,
)
from optimum.habana.peft.layer import GaudiBoftConv2dForward, GaudiBoftLinearForward

tuners.boft.layer.Linear.forward = GaudiBoftLinearForward
tuners.boft.layer.Conv2d.forward = GaudiBoftConv2dForward
tuners.boft.layer._FBD_CUDA = False
else:
raise ValueError(f"Unknown adapter type {args.adapter}")

Expand Down Expand Up @@ -632,6 +663,44 @@ def parse_args(input_args=None):
help="The control strength of COFT for text_encoder, only used if `train_text_encoder` is True",
)

# boft adapter
boft = subparsers.add_parser("boft", help="Use Boft adapter")
boft.add_argument("--unet_block_size", type=int, default=8, help="Boft block_size for unet")
boft.add_argument("--unet_block_num", type=int, default=0, help="Boft block_num for unet")
boft.add_argument("--unet_n_butterfly_factor", type=int, default=1, help="Boft n_butterfly_factor for unet")
boft.add_argument("--unet_dropout", type=float, default=0.1, help="Boft dropout for unet")
boft.add_argument("--unet_bias", type=str, default="boft_only", help="Boft bias for unet")
boft.add_argument(
"--te_block_size",
type=int,
default=8,
help="Boft block_size for text_encoder,only used if `train_text_encoder` is True",
)
boft.add_argument(
"--te_block_num",
type=int,
default=0,
help="Boft block_num for text_encoder,only used if `train_text_encoder` is True",
)
boft.add_argument(
"--te_n_butterfly_factor",
type=int,
default=1,
help="Boft n_butterfly_factor for text_encoder,only used if `train_text_encoder` is True",
)
boft.add_argument(
"--te_dropout",
type=float,
default=0.1,
help="Boft dropout for text_encoder,only used if `train_text_encoder` is True",
)
boft.add_argument(
"--te_bias",
type=str,
default="boft_only",
help="Boft bias for text_encoder, only used if `train_text_encoder` is True",
)

if input_args is not None:
args = parser.parse_args(input_args)
else:
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
GaudiAdaloraLayerSVDLinearForward,
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
GaudiBoftConv2dForward,
GaudiBoftGetDeltaWeight,
GaudiBoftLinearForward,
GaudiPolyLayerLinearForward,
)
from .peft_model import gaudi_generate, gaudi_prepare_inputs_for_generation
171 changes: 171 additions & 0 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,174 @@ def GaudiAdaptedAttention_getattr(self, name: str):
# This is necessary as e.g. causal models have various methods that we
# don't want to re-implement here.
return getattr(self.model, name)


def GaudiBoftConv2dForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""
Copied from Conv2d.forward: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/boft/layer.py#L899
The only differences are:
- torch.block_diag operate in cpu, or else lazy mode will hang
- delete fbd_cuda_available logic,
"""
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
boft_rotation = torch.eye(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
device=x.device,
dtype=x.dtype,
)
boft_scale = torch.ones((1, int(self.out_features)), device=x.device, dtype=x.dtype)

for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
continue
boft_R = self.boft_R[active_adapter]
boft_s = self.boft_s[active_adapter]
dropout = self.boft_dropout[active_adapter]

N, D, H, _ = boft_R.shape
boft_R = boft_R.view(N * D, H, H)
orth_rotate_butterfly = self.cayley_batch(boft_R)
orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H)
orth_rotate_butterfly = dropout(orth_rotate_butterfly)
orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0).cpu()
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

boft_P = self.boft_P.to(x)
block_diagonal_butterfly = block_diagonal_butterfly.to(x)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]

for i in range(1, butterfly_oft_mat_batch.shape[0]):
butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat

boft_rotation = butterfly_oft_mat @ boft_rotation
boft_scale = boft_s * boft_scale

x = x.to(self.base_layer.weight.data.dtype)

orig_weight = self.base_layer.weight.data
orig_weight = orig_weight.view(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
self.out_features,
)
rotated_weight = torch.mm(boft_rotation, orig_weight)

scaled_rotated_weight = rotated_weight * boft_scale

scaled_rotated_weight = scaled_rotated_weight.view(
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
)
result = F.conv2d(
input=x,
weight=scaled_rotated_weight,
bias=self.base_layer.bias,
padding=self.base_layer.padding[0],
stride=self.base_layer.stride[0],
)

result = result.to(previous_dtype)
return result


def GaudiBoftLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""
Copied from Linear.forward: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/boft/layer.py#L587
The only differences are:
- torch.block_diag operate in cpu, or else lazy mode will hang
- delete fbd_cuda_available logic,
"""
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
boft_rotation = torch.eye(self.in_features, device=x.device)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device)

for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
continue
boft_R = self.boft_R[active_adapter]
boft_s = self.boft_s[active_adapter]
dropout = self.boft_dropout[active_adapter]

N, D, H, _ = boft_R.shape
boft_R = boft_R.view(N * D, H, H)
orth_rotate_butterfly = self.cayley_batch(boft_R)
orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H)
orth_rotate_butterfly = dropout(orth_rotate_butterfly)
orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0).cpu()
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

# The BOFT author's cayley_batch, dropout and FastBlockDiag ONLY return fp32 outputs.
boft_P = self.boft_P.to(x)
block_diagonal_butterfly = block_diagonal_butterfly.to(x)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]

for i in range(1, butterfly_oft_mat_batch.shape[0]):
butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat

boft_rotation = butterfly_oft_mat @ boft_rotation
boft_scale = boft_s * boft_scale

x = x.to(self.get_base_layer().weight.data.dtype)

orig_weight = self.get_base_layer().weight.data
orig_weight = torch.transpose(orig_weight, 0, 1)
rotated_weight = torch.mm(boft_rotation, orig_weight)
rotated_weight = torch.transpose(rotated_weight, 0, 1)

scaled_rotated_weight = rotated_weight * boft_scale

result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias)

result = result.to(previous_dtype)
return result


def GaudiBoftGetDeltaWeight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
"""
Copied from Linear.get_delta_weight: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/boft/layer.py#L555
The only differences are:
- torch.block_diag operate in cpu, or else lazy mode will hang
- delete fbd_cuda_available logic,
"""

boft_R = self.boft_R[adapter]
boft_s = self.boft_s[adapter]

N, D, H, _ = boft_R.shape
boft_R = boft_R.view(N * D, H, H)
orth_rotate_butterfly = self.cayley_batch(boft_R)
orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H)
orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0).cpu()
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

boft_P = self.boft_P
block_diagonal_butterfly = block_diagonal_butterfly.to(boft_P)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]

for i in range(1, butterfly_oft_mat_batch.shape[0]):
butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat

return butterfly_oft_mat, boft_s
Loading
Loading