Skip to content

Commit

Permalink
[LLM] Replace Embedding layer to fix it on CPU (#9254)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Nov 1, 2023
1 parent 3290ede commit 4981c79
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
8 changes: 6 additions & 2 deletions python/llm/src/bigdl/llm/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def load_low_bit(model, model_path):
return model


def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None):
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None,
replace_embedding=False):
"""
A method to optimize any pytorch model.
Expand All @@ -202,6 +203,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
:param optimize_llm: Whether to further optimize llm model.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped
when conducting model optimizations. Default to be None.
:param replace_embedding: Whether to replace the Embedding layer, may need to set it
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
:return: The optimized model.
Expand All @@ -227,7 +230,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
model = ggml_convert_low_bit(model,
qtype=qtype,
optimize_model=optimize_llm,
modules_to_not_convert=modules_to_not_convert)
modules_to_not_convert=modules_to_not_convert,
replace_embedding=replace_embedding)
# add save_low_bit to pretrained model dynamically
import types
model._bigdl_config = dict()
Expand Down
23 changes: 20 additions & 3 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# limitations under the License.


import platform
import torch
import torch.nn as nn
from accelerate import init_empty_weights
Expand Down Expand Up @@ -82,8 +83,10 @@ def is_linear_module(module):


def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False):
current_key_name=None, convert_shape_only=False,
replace_embedding=False):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding
has_been_replaced = False

for name, module in model.named_children():
Expand Down Expand Up @@ -147,6 +150,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
model._modules[name].requires_grad_(False)

module.weight = None
elif replace_embedding and type(module) == nn.Embedding:
# skip user-defined Embedding layer
if platform.system().lower() == 'windows':
model._modules[name] = LLMEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse,
_weight=module.weight.data,
)

# Remove the last key for recursion
if len(list(module.children())) > 0:
Expand All @@ -156,6 +172,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
modules_to_not_convert,
current_key_name,
convert_shape_only,
replace_embedding,
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
Expand Down Expand Up @@ -185,7 +202,7 @@ def _optimize_pre(model):

def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None):
modules_to_not_convert=None, replace_embedding=False):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
Expand All @@ -196,7 +213,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,

model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
None, convert_shape_only,
None, convert_shape_only, replace_embedding,
)
if not has_been_replaced:
warnings.warn(
Expand Down
25 changes: 25 additions & 0 deletions python/llm/src/bigdl/llm/transformers/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import torch
from torch import Tensor


class LLMEmbedding(torch.nn.Embedding):
def forward(self, x: Tensor):
x_shape = x.shape
return self.weight[x.reshape(-1)].reshape(*x_shape, -1)
10 changes: 8 additions & 2 deletions python/llm/src/bigdl/llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def from_pretrained(cls,
Default to be True.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
conducting model optimizations. Default to be None.
:param replace_embedding: Whether to replace the Embedding layer, may need to set it
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
:return: a model instance
"""
Expand Down Expand Up @@ -118,6 +120,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
# `from_pretrained`` may pop items out in dict
# and lead to args missing.
modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
replace_embedding = kwargs.pop("replace_embedding", False)
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
try:
Expand All @@ -130,7 +133,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
model.config.update({"bigdl_lcmu_enabled": False})
model = model.to("cpu")
model = ggml_convert_low_bit(model, qtype, optimize_model,
modules_to_not_convert=modules_to_not_convert)
modules_to_not_convert=modules_to_not_convert,
replace_embedding=replace_embedding)
model.config.update({"bigdl_transformers_low_bit": q_k})
model.config.update({"tie_word_embeddings": False})

Expand Down Expand Up @@ -167,6 +171,7 @@ def load_low_bit(cls,
import os

modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
replace_embedding = kwargs.pop("replace_embedding", False)
# Autofactory
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs_orig = copy.deepcopy(kwargs)
Expand Down Expand Up @@ -277,7 +282,8 @@ def load_low_bit(cls,
# Loading args may differ based on their usage
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
modules_to_not_convert=modules_to_not_convert)
modules_to_not_convert=modules_to_not_convert,
replace_embedding=replace_embedding)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
Expand Down

0 comments on commit 4981c79

Please sign in to comment.