Skip to content

Commit

Permalink
[NPU] Optimize Qwen2 lm_head to use INT4 (#12072)
Browse files Browse the repository at this point in the history
* temp save

* update

* fix

* fix

* Split lm_head into 7 parts & remove int8 for lm_head when sym_int4

* Simlify and add condition to code

* Small fix

* refactor some code

* fix style

* fix style

* fix style

* fix

* fix

* temp sav e

* refactor

* fix style

* further refactor

* simplify code

* meet code review

* fix style

---------

Co-authored-by: Yuwen Hu <[email protected]>
  • Loading branch information
rnwang04 and Oscilloscope98 authored Sep 14, 2024
1 parent 18714ce commit 081af41
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 8 deletions.
16 changes: 15 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import os
import torch
import importlib
import numpy as np
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead


def convert_forward(m, target_m, new_forward):
Expand Down Expand Up @@ -85,9 +87,16 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):

if model.config.model_type == "qwen2":
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward
model.apply(split_mlp_down_proj)

# for Qwen2-7B-Insturct, divide lm_head into 7 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
not cpu_lm_head:
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7,
bias=model.lm_head.bias)
del model.lm_head
model.lm_head = new_lm_head

# lm_head to cpu optimization
if cpu_lm_head:
# disable the optimization by default
Expand Down Expand Up @@ -182,6 +191,11 @@ def optimize_llm(
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)

# for Qwen2-7B-Insturct, divide lm_head into 7 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
isinstance(model.lm_head, SlicedLMHead):
model.lm_head.get_fused_lm_head()
elif model.config.model_type == "minicpm":
# for minicpm-1b
if intra_pp is None:
Expand Down
11 changes: 4 additions & 7 deletions python/llm/src/ipex_llm/transformers/npu_models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@
#

from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
from intel_npu_acceleration_library.backend import run_matmul
from intel_npu_acceleration_library.dtypes import NPUDtype
from typing import Optional, Union
import os
import torch
from torch.nn import Parameter
import uuid
import math

from intel_npu_acceleration_library.backend import run_matmul
from typing import Optional, Union
from ipex_llm.utils.common import invalidInputError


Expand All @@ -52,7 +50,6 @@ def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None
self.outC, self.inC = self.weight.shape
self.op_id = str(uuid.uuid4())
self._mm = AutogradMatMul.apply

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Torch module forward method.
Expand Down Expand Up @@ -147,7 +144,7 @@ def __init__(
"""
super().__init__()

self.weight = Parameter(weight, requires_grad=False)
self.weight = Parameter(weight, requires_grad=False).contiguous()
if self.weight.dtype not in (torch.int8, torch.uint8):
invalidInputError(
False,
Expand All @@ -163,7 +160,6 @@ def __init__(
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
self.bias = bias
self.op_id = str(uuid.uuid4())
self._mm = AutogradMatMul.apply

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Torch module forward method.
Expand Down Expand Up @@ -194,6 +190,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"Use `.eval()` to do inference only"
)
)

out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)

if self.bias is None:
Expand Down
152 changes: 152 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import nn
import numpy as np
from intel_npu_acceleration_library.backend import NNFactory
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib


class LMHeadLinear(NNFactory):
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
with weights prefetching."""

def __init__(
self,
inC: int,
outC: int,
batch: int,
split_num: int = 2,
profile: bool = False,
device: str = "NPU",
dtype: np.dtype = np.int8,
):
"""Initialize the LMHeadLinear class.
Args:
inC (int): input channels
outC (int): output channels
batch (int): batch
split_num (int): split in_features of lm_head to how many parts
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
"""
super().__init__(profile, device)
self.inC, self.outC = inC, outC
self.batch = batch

input = self.parameter((self.batch, self.inC))

self.split_num = split_num
split_size = self.inC // split_num // 2 * 2

for i in range(self.split_num):
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype)
if i == 0:
res = linear_slice
else:
res += linear_slice

print("start compiling lm_head")
self.compile()
print("end compiling lm_head")

def run(
self, X: np.ndarray
) -> np.ndarray:
"""Run the layer: $X * (W * S)^T$ .
Args:
X (np.ndarray): activation
Raises:
RuntimeError: Input, weights or scale shape mismatch
Returns:
np.ndarray: result
"""
self.prefetchWeights(1, verify_size=False)
self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:
return self.out[0]
return self.out


class SlicedLMHead(nn.Module):
def __init__(self, weight, bias, split_num):
super().__init__()
self.split_num = split_num
self.outC, self.inC = weight.shape
split_size = weight.size(1) // split_num // 2 * 2
self.lm_heads = nn.Sequential()
for i in range(split_num):
new_linear = torch.nn.Linear(0, 0, bias=False)
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < split_num - 1 else weight.size(1)
new_weight = torch.nn.Parameter(weight[:, start_idx:end_idx],
requires_grad=False)
new_linear.weight = new_weight
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
self.lm_heads.append(new_linear)
self.bias = bias

def forward(self, hidden_states):
if hidden_states.size(0) * hidden_states.size(1) == 1:
original_shape = hidden_states.shape
x_2d = hidden_states.view(-1, hidden_states.shape[-1])
target_shape = tuple(list(original_shape[:-1]) + [self.outC])

out = self.fused_lm_head.run(x_2d.numpy())
logits = torch.from_numpy(out)
logits = logits.view(target_shape)
else:
split_size = hidden_states.size(-1) // self.split_num // 2 * 2
logits = None
for i in range(self.split_num):
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
hidden_states_slice = hidden_states[:, :, start_idx:end_idx]
logits_slice = self.lm_heads[i](hidden_states_slice)
if logits is None:
logits = logits_slice
else:
logits += logits_slice

if self.bias is None:
return logits
return logits + self.bias

def get_weight_dtype(self):
return self.lm_heads[0].weight.dtype

def get_fused_lm_head(self):
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
False, "NPU", dtype=np_dtype)
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)]
self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id,
*fused_lm_head_weights)

0 comments on commit 081af41

Please sign in to comment.