Skip to content

Commit

Permalink
LLM: enable chatglm3-6b target_model ipex (intel-analytics#10085)
Browse files Browse the repository at this point in the history
* init

* always make casual_mask

* not return last tensor

* update

* optimize_model = False

* enable optimized=False

* enable optimized_model=true

* speed_up ipex target_model

* remove if True

* use group_size

* update python style

* update

* update
  • Loading branch information
hzjane authored Feb 19, 2024
1 parent 5553f43 commit d05c4d6
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 71 deletions.
24 changes: 15 additions & 9 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
f"format......")
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert

# using ipex optimizer before changing to bigdl linear
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
return model

if optimize_model:
model = _optimize_pre(model)

Expand All @@ -543,14 +553,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
# Do nothing here for weights are empty.
pass

_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
return model
if optimize_model:
model = _optimize_post(model, lightweight_bmm)
return model
Expand Down Expand Up @@ -590,13 +592,17 @@ def _optimize_ipex(model):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_ipex_optimize_rmsnorm, _llama_model_forward_4_35
_ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks
)

AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35)
model = model_convert_reference(model)
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
convert_function(model.transformer, "get_masks", GLM_get_masks)
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
_ipex_optimize_rmsnorm(model)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)
Expand Down
43 changes: 43 additions & 0 deletions python/llm/src/bigdl/llm/transformers/convert_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def _ipex_jit(model):
sample_inputs = (
get_dummy_input(model, return_dict=True)
)
if "return_last_logit" in sample_inputs:
del sample_inputs["return_last_logit"]
with torch.no_grad(), torch.cpu.amp.autocast(
enabled=True
):
Expand All @@ -159,6 +161,47 @@ def _ipex_jit(model):
return model.eval()


def convert_function(m, func_name, new_function):
bound_method = new_function.__get__(m, m.__class__)
setattr(m, func_name, bound_method)


def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(
batch_size, seq_length, seq_length, device=input_ids.device
)
full_attention_mask.tril_()
past_length = 0
if past_key_values:
if len(past_key_values[0]) != 4: # not discrete kv cache
past_length = past_key_values[0][0].shape[0]
else: # discrete kv cache
past_length = past_key_values[0][0].shape[-2]

import os
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
# always call for jit
if past_length or _enable_ipex:
full_attention_mask = torch.cat(
(
torch.ones(
batch_size, seq_length, past_length, device=input_ids.device
),
full_attention_mask,
),
dim=-1,
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
# if not past_length and padding_mask is not None:
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask


@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
Expand Down
208 changes: 146 additions & 62 deletions python/llm/src/bigdl/llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# patch GenerationMixin.generate
from transformers import GenerationMixin
original_generate = GenerationMixin.generate

query_group_size = 16
logger = logging.getLogger("bigdl.llm.speculative")


Expand Down Expand Up @@ -131,62 +131,98 @@ def clear_benchmarks(self):
def _prepare_past_key_values_storage_cpu(self, past_key_values,
max_new_tokens, _enable_ipex=False):
past_key_values_storage = []
# init ipex_past_key_values
if _enable_ipex:
ipex_past_key_values = []
cur_len = past_key_values[0][0].size(1)
ipex_past_key_values = [
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
for pkv in past_key_values
]

for i in range(len(past_key_values)):
if not _enable_ipex:
len0 = past_key_values[i][0].size(0)
len1 = past_key_values[i][0].size(1)
len2 = past_key_values[i][0].size(2)
len3 = past_key_values[i][0].size(3)
if self.config.model_type == "chatglm":
len0 = past_key_values[0][1].size(0) # seq max length
len1 = past_key_values[0][1].size(1)
len2 = past_key_values[0][1].size(2)
len3 = past_key_values[0][1].size(3)
for pkv in past_key_values:
key = pkv[1]
value = pkv[2]
key = key.permute(1, 2, 0, 3).unsqueeze(-3)
key = key.expand(-1, -1, query_group_size, -1, -1)
key = key.contiguous().view(len1, len2 * query_group_size,
len0, len3).permute(2, 0, 1, 3)
value = value.permute(1, 2, 0, 3).unsqueeze(-3)
value = value.expand(-1, -1, query_group_size, -1, -1)
value = value.contiguous().view(len1, len2 * query_group_size,
len0, len3).permute(2, 0, 1, 3)
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
ipex_past_key_values.append(list)
else:
len0 = past_key_values[i][1].size(1)
len1 = past_key_values[i][1].size(2)
len2 = past_key_values[i][0].size(2) # seq length
len3 = past_key_values[i][1].size(3)
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.transpose(1, 2)
v0 = v0.transpose(1, 2)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "chatglm":
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
if not _enable_ipex:
ipex_past_key_values = [
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
for pkv in past_key_values
]
if not _enable_ipex:
len0 = past_key_values[0][0].size(0)
len1 = past_key_values[0][0].size(1)
len2 = past_key_values[0][0].size(2)
len3 = past_key_values[0][0].size(3)
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.transpose(1, 2)
v0 = v0.transpose(1, 2)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "chatglm":
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
torch.float32)
else:
len0 = past_key_values[0][1].size(1)
len1 = past_key_values[0][1].size(2)
len2 = past_key_values[0][0].size(2) # seq length
len3 = past_key_values[0][1].size(3)
for i in range(len(past_key_values)):
if self.config.model_type == "chatglm":
k0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len2, :, :, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
Expand All @@ -195,7 +231,8 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
return past_key_values_storage


def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage):
def _prepare_draft_past_key_values_cpu(self, past_key_values,
past_key_values_storage, _enable_ipex):
tmp_past_key_values = []
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
Expand All @@ -204,7 +241,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_st
v0 = past_key_values_storage[i][1][:, :len1, :, :]
tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "chatglm":
len0 = past_key_values[0][0].size(0)
if not _enable_ipex:
len0 = past_key_values[0][0].size(0)
else:
len0 = past_key_values[0][0].size(1)
k0 = past_key_values_storage[i][0][:len0, :, :, :]
v0 = past_key_values_storage[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0))
Expand Down Expand Up @@ -244,15 +284,41 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
else:
size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(1)
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
if self.config.model_type == "chatglm":
size = original_draft_past_key_values[0][0].size(0)
size1 = past_key_values[0][0].size(1)
len0 = past_key_values[0][1].size(0) # seq max_length
len1 = past_key_values[0][1].size(1)
len2 = past_key_values[0][1].size(2)
len3 = past_key_values[0][1].size(3)
key0 = torch.ones(size1-size, len1, len2, len3,
dtype=torch.float32)
value0 = torch.ones(size1-size, len1, len2, len3,
dtype=torch.float32)
key0 = past_key_values[i][1][size:size1, :, :, :]
value0 = past_key_values[i][2][size:size1, :, :, :]
key = key0.permute(1, 2, 0, 3).unsqueeze(-3)
key = key.expand(-1, -1, query_group_size, -1, -1)
key = key.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
key = key.permute(2, 0, 1, 3)
value = value0.permute(1, 2, 0, 3).unsqueeze(-3)
value = value.expand(-1, -1, query_group_size, -1, -1)
value = value.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
value = value.permute(2, 0, 1, 3)
past_key_values_storage[i][0][size:size1, :, :, :] = \
key.to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
value.to(torch.float32)
else:
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)

past_key_values_storage[i][0][:, :, size:size1, :] = \
delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \
delta_past_value.to(torch.float32)
past_key_values_storage[i][0][:, :, size:size1, :] = \
delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \
delta_past_value.to(torch.float32)


@torch.no_grad()
Expand Down Expand Up @@ -372,10 +438,14 @@ def speculative_generate(self,
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
if _enable_ipex:
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
('llama' in self.config.model_type) or
('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or
("mistral" in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
Llama, Baichuan2-13b and Mistral models currently.")
if "chatglm" in self.config.model_type:
global query_group_size
query_group_size = draft_model.config.num_attention_heads // \
draft_model.config.multi_query_group_num

tmp_matchness = 0
e2e_tic = 0.0
Expand Down Expand Up @@ -437,7 +507,7 @@ def speculative_generate(self,
# each iter cut off cur_len kv_cache from past_key_values1
draft_past_key_values = \
_prepare_draft_past_key_values_cpu(self, past_key_values,
past_key_values_storage)
past_key_values_storage, _enable_ipex)
original_draft_past_key_values = draft_past_key_values
else:
draft_past_key_values = past_key_values
Expand All @@ -464,7 +534,10 @@ def speculative_generate(self,
"use_cache": True,
}
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
if _enable_ipex:
past_key_value_len = past_key_values[0][0].shape[1]
else:
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
Expand Down Expand Up @@ -533,6 +606,16 @@ def speculative_generate(self,
position_ids=position_ids,
past_key_values=past_key_values,
)
elif "chatglm" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device).unsqueeze(0)
position_ids = position_ids.repeat(1, 1) + past_key_value_len
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
position_ids=position_ids,
# return_last_logit=torch.tensor(False),
past_key_values=past_key_values,)
elif "mistral" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
seq_len = drafted_input_ids.shape[1]
Expand Down Expand Up @@ -591,7 +674,8 @@ def speculative_generate(self,
self.verify_time.append(toc - tic)
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])

past_key_values = output['past_key_values']
if past_key_values is None:
past_key_values = output['past_key_values']

if generation_config.do_sample:
draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
Expand Down

0 comments on commit d05c4d6

Please sign in to comment.