Skip to content

Commit

Permalink
LLM: GPU QLoRA update to bf16 to accelerate gradient checkpointing (i…
Browse files Browse the repository at this point in the history
…ntel-analytics#9499)

* update to bf16 to accelerate gradient checkpoint

* add utils and fix ut
  • Loading branch information
rnwang04 authored Nov 21, 2023
1 parent daa8e9a commit dfd0871
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
from bigdl.llm.transformers import AutoModelForCausalLM

# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
cast_lora_weight

def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
Expand All @@ -76,6 +77,7 @@ def train(
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams
bf16: bool = True, # default to bf16
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3,
Expand Down Expand Up @@ -301,6 +303,9 @@ def generate_and_tokenize_prompt(data_point):
# model.is_parallelizable = True
# model.model_parallel = True

if bf16:
cast_lora_weight(model, torch.bfloat16)

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,11 @@
0
].self_attn.q_proj.weight

assert torch.allclose(first_weight_old, first_weight)

# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()

lora_model.train(False)

# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)

lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
Expand Down
7 changes: 4 additions & 3 deletions python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import transformers
from transformers import LlamaTokenizer

from peft import LoraConfig
import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \
cast_lora_weight
from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset
import argparse
Expand Down Expand Up @@ -61,6 +61,8 @@
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
cast_lora_weight(model, torch.bfloat16)

tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
trainer = transformers.Trainer(
Expand All @@ -73,7 +75,6 @@
max_steps=200,
learning_rate=2e-5,
save_steps=100,
# fp16=True,
bf16=True, # bf16 is more stable in training
logging_steps=20,
output_dir="outputs",
Expand Down
19 changes: 17 additions & 2 deletions python/llm/src/bigdl/llm/transformers/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,18 @@ def __init__(

def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x)
if autocast_dtype is not None:
if x.device.type == "xpu":
# force to use bf16 on gpu
x = x.to(torch.bfloat16)
elif autocast_dtype is not None:
x = x.to(autocast_dtype)
result = super().forward(x)

if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
result = result.clone()
if autocast_dtype is None:
if autocast_dtype is None and x.device.type == "cpu":
expected_dtype = result.dtype
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
output = (
Expand Down Expand Up @@ -357,3 +360,15 @@ def _setup_devices(self) -> "torch.device":
# patch transformer for xpu DDP traing
from transformers import TrainingArguments
TrainingArguments._setup_devices = _setup_devices


def cast_lora_weight(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(dtype)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if module.weight.dtype == torch.float32:
module = module.to(dtype)

0 comments on commit dfd0871

Please sign in to comment.