From 081af41defcc52c520654021413f90d23a064f2d Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Sat, 14 Sep 2024 00:26:46 -0700 Subject: [PATCH] [NPU] Optimize Qwen2 lm_head to use INT4 (#12072) * temp save * update * fix * fix * Split lm_head into 7 parts & remove int8 for lm_head when sym_int4 * Simlify and add condition to code * Small fix * refactor some code * fix style * fix style * fix style * fix * fix * temp sav e * refactor * fix style * further refactor * simplify code * meet code review * fix style --------- Co-authored-by: Yuwen Hu --- .../transformers/npu_models/convert_mp.py | 16 +- .../transformers/npu_models/linear.py | 11 +- .../transformers/npu_models/lm_head.py | 152 ++++++++++++++++++ 3 files changed, 171 insertions(+), 8 deletions(-) create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/lm_head.py diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index bf55a9945ac..237ceb0077a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -16,7 +16,9 @@ import os import torch import importlib +import numpy as np from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params +from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead def convert_forward(m, target_m, new_forward): @@ -85,9 +87,16 @@ def optimize_llm_pre(model: torch.nn.Module, qtype): if model.config.model_type == "qwen2": from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj - from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward model.apply(split_mlp_down_proj) + # for Qwen2-7B-Insturct, divide lm_head into 7 parts + if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ + not cpu_lm_head: + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7, + bias=model.lm_head.bias) + del model.lm_head + model.lm_head = new_lm_head + # lm_head to cpu optimization if cpu_lm_head: # disable the optimization by default @@ -182,6 +191,11 @@ def optimize_llm( from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) + + # for Qwen2-7B-Insturct, divide lm_head into 7 parts + if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ + isinstance(model.lm_head, SlicedLMHead): + model.lm_head.get_fused_lm_head() elif model.config.model_type == "minicpm": # for minicpm-1b if intra_pp is None: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/linear.py b/python/llm/src/ipex_llm/transformers/npu_models/linear.py index d38b1e43707..804751d2a9b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/linear.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/linear.py @@ -22,16 +22,14 @@ # from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4 -from intel_npu_acceleration_library.nn.autograd import AutogradMatMul -from intel_npu_acceleration_library.backend import run_matmul from intel_npu_acceleration_library.dtypes import NPUDtype -from typing import Optional, Union import os import torch from torch.nn import Parameter import uuid import math - +from intel_npu_acceleration_library.backend import run_matmul +from typing import Optional, Union from ipex_llm.utils.common import invalidInputError @@ -52,7 +50,6 @@ def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None self.outC, self.inC = self.weight.shape self.op_id = str(uuid.uuid4()) - self._mm = AutogradMatMul.apply def forward(self, x: torch.Tensor) -> torch.Tensor: """Torch module forward method. @@ -147,7 +144,7 @@ def __init__( """ super().__init__() - self.weight = Parameter(weight, requires_grad=False) + self.weight = Parameter(weight, requires_grad=False).contiguous() if self.weight.dtype not in (torch.int8, torch.uint8): invalidInputError( False, @@ -163,7 +160,6 @@ def __init__( self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False) self.bias = bias self.op_id = str(uuid.uuid4()) - self._mm = AutogradMatMul.apply def forward(self, x: torch.Tensor) -> torch.Tensor: """Torch module forward method. @@ -194,6 +190,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "Use `.eval()` to do inference only" ) ) + out = run_matmul(x, self.weight.data, self.scale.data, self.op_id) if self.bias is None: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py new file mode 100644 index 00000000000..357eddb5e88 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py @@ -0,0 +1,152 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch import nn +import numpy as np +from intel_npu_acceleration_library.backend import NNFactory +from intel_npu_acceleration_library.backend.bindings import lib as backend_lib + + +class LMHeadLinear(NNFactory): + """Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication + with weights prefetching.""" + + def __init__( + self, + inC: int, + outC: int, + batch: int, + split_num: int = 2, + profile: bool = False, + device: str = "NPU", + dtype: np.dtype = np.int8, + ): + """Initialize the LMHeadLinear class. + + Args: + inC (int): input channels + outC (int): output channels + batch (int): batch + split_num (int): split in_features of lm_head to how many parts + profile (bool): Enable/Disable profiling. Defaults to False. + device (str): Target device, default to "NPU". + dtype (np.dtype): weights datatype. Defaults to np.int8. + + """ + super().__init__(profile, device) + self.inC, self.outC = inC, outC + self.batch = batch + + input = self.parameter((self.batch, self.inC)) + + self.split_num = split_num + split_size = self.inC // split_num // 2 * 2 + + for i in range(self.split_num): + start_idx = i * split_size + end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC + input_slice = self.slice(input, begin=[0, start_idx], + end=[self.batch, end_idx]) + linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype) + if i == 0: + res = linear_slice + else: + res += linear_slice + + print("start compiling lm_head") + self.compile() + print("end compiling lm_head") + + def run( + self, X: np.ndarray + ) -> np.ndarray: + """Run the layer: $X * (W * S)^T$ . + + Args: + X (np.ndarray): activation + + Raises: + RuntimeError: Input, weights or scale shape mismatch + + Returns: + np.ndarray: result + """ + self.prefetchWeights(1, verify_size=False) + self.set_input_tensor(X, 0) + self.elapsed = backend_lib.run(self._mm) + if len(self.out) == 1: + return self.out[0] + return self.out + + +class SlicedLMHead(nn.Module): + def __init__(self, weight, bias, split_num): + super().__init__() + self.split_num = split_num + self.outC, self.inC = weight.shape + split_size = weight.size(1) // split_num // 2 * 2 + self.lm_heads = nn.Sequential() + for i in range(split_num): + new_linear = torch.nn.Linear(0, 0, bias=False) + start_idx = i * split_size + end_idx = (i + 1) * split_size if i < split_num - 1 else weight.size(1) + new_weight = torch.nn.Parameter(weight[:, start_idx:end_idx], + requires_grad=False) + new_linear.weight = new_weight + new_linear.in_features = new_weight.size(1) + new_linear.out_features = new_weight.size(0) + self.lm_heads.append(new_linear) + self.bias = bias + + def forward(self, hidden_states): + if hidden_states.size(0) * hidden_states.size(1) == 1: + original_shape = hidden_states.shape + x_2d = hidden_states.view(-1, hidden_states.shape[-1]) + target_shape = tuple(list(original_shape[:-1]) + [self.outC]) + + out = self.fused_lm_head.run(x_2d.numpy()) + logits = torch.from_numpy(out) + logits = logits.view(target_shape) + else: + split_size = hidden_states.size(-1) // self.split_num // 2 * 2 + logits = None + for i in range(self.split_num): + start_idx = i * split_size + end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC + hidden_states_slice = hidden_states[:, :, start_idx:end_idx] + logits_slice = self.lm_heads[i](hidden_states_slice) + if logits is None: + logits = logits_slice + else: + logits += logits_slice + + if self.bias is None: + return logits + return logits + self.bias + + def get_weight_dtype(self): + return self.lm_heads[0].weight.dtype + + def get_fused_lm_head(self): + np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8 + self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num, + False, "NPU", dtype=np_dtype) + fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(), + self.lm_heads[i].scale.data.numpy()) + for i in range(self.split_num)] + self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id, + *fused_lm_head_weights)