Skip to content

Commit

Permalink
LLM: fix unlora module in qlora finetune (intel#9621)
Browse files Browse the repository at this point in the history
* fix unlora module

* split train and inference
  • Loading branch information
rnwang04 authored Dec 7, 2023
1 parent 820e575 commit 2618be1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
21 changes: 19 additions & 2 deletions python/llm/src/bigdl/llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from operator import mul
from functools import reduce
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from bigdl.llm.transformers.utils import get_autocast_dtype

T = TypeVar("T", bound="torch.nn.Module")

Expand Down Expand Up @@ -433,8 +434,17 @@ def __init__(self, input_features, output_features, qtype, bias=True,
self.qtype = qtype
self.conver_to_half = conver_to_half
self.mp_group = mp_group
self.compute_dtype = None # only for training

def forward(self, x: torch.Tensor):
if self.training:
# below logic is only for training
autocast_dtype = get_autocast_dtype(x)
if self.compute_dtype is not None and x.device.type == "xpu":
x = x.to(self.compute_dtype) # solve GC issue for unlora module
elif autocast_dtype is not None:
x = x.to(autocast_dtype)

if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

Expand All @@ -457,9 +467,16 @@ def forward(self, x: torch.Tensor):
x_2d = x_2d.contiguous()

input_seq_size = x_shape[1]
if self.training and x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
if self.training:
# training path
if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
else:
result = linear_q4_0.forward_new(x_2d, self.weight.data,
self.weight.qtype,
input_seq_size)
else:
# inference path
# current workaround to reduce first token latency of fp32 input
# sometimes fp16 cause nan and training instability
# disable the conversion when training
Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/bigdl/llm/transformers/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def _setup_devices(self) -> "torch.device":

def cast_lora_weight(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
if isinstance(module, LowBitLinear):
module.compute_dtype = dtype
if isinstance(module, LoraLayer):
module = module.to(dtype)
if 'norm' in name:
Expand Down

0 comments on commit 2618be1

Please sign in to comment.