From 4981c792b19950141956868c55a73abd784143b8 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 1 Nov 2023 13:58:10 +0800 Subject: [PATCH] [LLM] Replace Embedding layer to fix it on CPU (#9254) --- python/llm/src/bigdl/llm/optimize.py | 8 ++++-- .../llm/src/bigdl/llm/transformers/convert.py | 23 ++++++++++++++--- .../src/bigdl/llm/transformers/embedding.py | 25 +++++++++++++++++++ .../llm/src/bigdl/llm/transformers/model.py | 10 ++++++-- 4 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 python/llm/src/bigdl/llm/transformers/embedding.py diff --git a/python/llm/src/bigdl/llm/optimize.py b/python/llm/src/bigdl/llm/optimize.py index 880629a2520..6396f7faa6b 100644 --- a/python/llm/src/bigdl/llm/optimize.py +++ b/python/llm/src/bigdl/llm/optimize.py @@ -192,7 +192,8 @@ def load_low_bit(model, model_path): return model -def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None): +def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None, + replace_embedding=False): """ A method to optimize any pytorch model. @@ -202,6 +203,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ :param optimize_llm: Whether to further optimize llm model. :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when conducting model optimizations. Default to be None. + :param replace_embedding: Whether to replace the Embedding layer, may need to set it + to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. :return: The optimized model. @@ -227,7 +230,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ model = ggml_convert_low_bit(model, qtype=qtype, optimize_model=optimize_llm, - modules_to_not_convert=modules_to_not_convert) + modules_to_not_convert=modules_to_not_convert, + replace_embedding=replace_embedding) # add save_low_bit to pretrained model dynamically import types model._bigdl_config = dict() diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2acab799108..8f864862a74 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -35,6 +35,7 @@ # limitations under the License. +import platform import torch import torch.nn as nn from accelerate import init_empty_weights @@ -82,8 +83,10 @@ def is_linear_module(module): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, - current_key_name=None, convert_shape_only=False): + current_key_name=None, convert_shape_only=False, + replace_embedding=False): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear + from bigdl.llm.transformers.embedding import LLMEmbedding has_been_replaced = False for name, module in model.named_children(): @@ -147,6 +150,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model._modules[name].requires_grad_(False) module.weight = None + elif replace_embedding and type(module) == nn.Embedding: + # skip user-defined Embedding layer + if platform.system().lower() == 'windows': + model._modules[name] = LLMEmbedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + _weight=module.weight.data, + ) # Remove the last key for recursion if len(list(module.children())) > 0: @@ -156,6 +172,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, modules_to_not_convert, current_key_name, convert_shape_only, + replace_embedding, ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -185,7 +202,7 @@ def _optimize_pre(model): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", - modules_to_not_convert=None): + modules_to_not_convert=None, replace_embedding=False): logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"format......") @@ -196,7 +213,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, - None, convert_shape_only, + None, convert_shape_only, replace_embedding, ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py new file mode 100644 index 00000000000..2764d01ee48 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -0,0 +1,25 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +from torch import Tensor + + +class LLMEmbedding(torch.nn.Embedding): + def forward(self, x: Tensor): + x_shape = x.shape + return self.weight[x.reshape(-1)].reshape(*x_shape, -1) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index a46631455fb..98a54ee13f6 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -68,6 +68,8 @@ def from_pretrained(cls, Default to be True. :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when conducting model optimizations. Default to be None. + :param replace_embedding: Whether to replace the Embedding layer, may need to set it + to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. :return: a model instance """ @@ -118,6 +120,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): # `from_pretrained`` may pop items out in dict # and lead to args missing. modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) + replace_embedding = kwargs.pop("replace_embedding", False) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) try: @@ -130,7 +133,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): model.config.update({"bigdl_lcmu_enabled": False}) model = model.to("cpu") model = ggml_convert_low_bit(model, qtype, optimize_model, - modules_to_not_convert=modules_to_not_convert) + modules_to_not_convert=modules_to_not_convert, + replace_embedding=replace_embedding) model.config.update({"bigdl_transformers_low_bit": q_k}) model.config.update({"tie_word_embeddings": False}) @@ -167,6 +171,7 @@ def load_low_bit(cls, import os modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) + replace_embedding = kwargs.pop("replace_embedding", False) # Autofactory trust_remote_code = kwargs.pop("trust_remote_code", None) kwargs_orig = copy.deepcopy(kwargs) @@ -277,7 +282,8 @@ def load_low_bit(cls, # Loading args may differ based on their usage quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, - modules_to_not_convert=modules_to_not_convert) + modules_to_not_convert=modules_to_not_convert, + replace_embedding=replace_embedding) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]