Skip to content

Commit

Permalink
Support minicpm-1B in level0 pipeline (intel-analytics#12297)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Oct 30, 2024
1 parent 46d8300 commit 41b8064
Show file tree
Hide file tree
Showing 7 changed files with 435 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) |

## 0. Requirements
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down Expand Up @@ -47,6 +48,9 @@ python llama3.py
:: to run Baichuan2-7B-Chat
python baichuan2.py
:: to run MiniCPM-1B-sft-bf16
python minicpm.py
```

Arguments info:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import torch
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils import logging
import os

logger = logging.get_logger(__name__)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Predict Tokens using `generate()` API for npu model"
)
parser.add_argument(
"--repo-id-or-model-path",
type=str,
default="openbmb/MiniCPM-1B-sft-bf16",
help="The huggingface repo id for the MiniCPM model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path

if not args.lowbit_path or not os.path.exists(args.lowbit_path):
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
pipeline=True,
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

print("-" * 80)
print("done")
with torch.inference_mode():
print("finish to load")
for i in range(5):
prompt = "<用户>{}<AI>".format(args.prompt)
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
)
end = time.time()
print(f"Inference time: {end-st} s")
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str)

print("-" * 80)
print("done")
print("success shut down")
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

print("finish to load")
for i in range(5):
_input_ids = tokenizer.encode("<用户>{}".format(args.prompt), return_tensors="pt")
_input_ids = tokenizer.encode("<用户>{}<AI>".format(args.prompt), return_tensors="pt")
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
Expand Down
84 changes: 48 additions & 36 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,46 @@ def convert_baichuan(
convert_forward(model, module.BaichuanModel, baichuan_model_forward)


def convert_minicpm(
model: torch.nn.Module,
max_output_len=1024,
max_prompt_len=1024,
decoder=False,
inter_pp=None,
intra_pp=None,
transpose_value_cache=True,
):
from ipex_llm.transformers.npu_models.minicpm_mp import gen_minicpm_fused_model_forward
from ipex_llm.transformers.npu_models.minicpm_mp import DecodeRunner, PrefillRunner
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

if decoder:
decode_runner = DecodeRunner(
model,
max_seq_len=max_output_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
else:
decode_runner = None
prefill_runner = PrefillRunner(
model,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
minicpm_model_forward = gen_minicpm_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
if model.config.num_hidden_layers == 40:
# for minicpm-2b
from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)


def optimize_llm(
model: torch.nn.Module,
max_context_len=1024,
Expand Down Expand Up @@ -291,41 +331,13 @@ def optimize_llm(
intra_pp = 2
if inter_pp is None:
inter_pp = 2

from ipex_llm.transformers.npu_models.minicpm_mp import gen_minicpm_fused_model_forward
from ipex_llm.transformers.npu_models.minicpm_mp import DecodeRunner, PrefillRunner

modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

if model.config.num_hidden_layers == 52:
# for minicpm-1b
transpose_cache = transpose_value_cache
elif model.config.num_hidden_layers == 40:
# for minicpm-2b
transpose_cache = False

decode_runner = DecodeRunner(
model,
max_seq_len=max_context_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_cache,
)
prefill_runner = PrefillRunner(
model,
max_output_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_cache,
)
minicpm_model_forward = gen_minicpm_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
if model.config.num_hidden_layers == 40:
# for minicpm-2b
from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)
convert_minicpm(model,
max_output_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
decoder=True,
transpose_value_cache=transpose_value_cache)
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
# for Baichuan2-7B
if intra_pp is None:
Expand All @@ -339,7 +351,7 @@ def optimize_llm(
intra_pp=intra_pp,
decoder=True,
transpose_value_cache=transpose_value_cache)
if isinstance(model.lm_head, SlicedLMHead):
if hasattr(model, 'lm_head') and isinstance(model.lm_head, SlicedLMHead):
model.lm_head.get_fused_lm_head()


Expand Down
73 changes: 39 additions & 34 deletions python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from torch.nn import CrossEntropyLoss


class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
def __init__(
self,
# batch_size: int,
Expand Down Expand Up @@ -118,31 +118,13 @@ def __init__(

# Self Attention
if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1))
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64)
else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len))
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
dtype=np.int64)

position_ids = self.create_input_op((self.batch_size, self.seq_len))
past_keys = []
past_values = []
if mode == "decode":
for i in range(num_layers):
past_key = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
if transpose_value:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
)
else:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
past_keys.append(past_key)
past_values.append(past_value)
else:
past_keys = [None] * num_layers
past_values = [None] * num_layers
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)

if input_layernorm_weights is None:
input_layernorm_weights = []
Expand All @@ -168,6 +150,27 @@ def __init__(
input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights]
post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]

past_keys = []
past_values = []
if mode == "decode":
for i in range(num_layers):
past_key = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
if transpose_value:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
)
else:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
past_keys.append(past_key)
past_values.append(past_value)
else:
past_keys = [None] * num_layers
past_values = [None] * num_layers

hidden_states = input

curr_key_values = []
Expand Down Expand Up @@ -297,7 +300,7 @@ def __init__(
start, end = self.layer_ranges[i]
lm_0 = input_laynorm_weights[start:end]
lm_1 = post_attn_layernorm_weights[start:end]
decoder = LowBitLlamaMultiDecoderlayer(
decoder = LowBitMinicpmMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=lm_0,
post_attn_layernorm_weights=lm_1,
Expand Down Expand Up @@ -334,15 +337,15 @@ def forward(

inputs = (
hidden_states.to(torch.float16),
attention_mask,
position_ids.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64),
)

for i in range(self.intra_stages):
start, end = self.layer_ranges[i]
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])

hidden_states, new_keys, new_values = LowBitLlamaMultiDecoderlayer.run_decoders(
hidden_states, new_keys, new_values = LowBitMinicpmMultiDecoderlayer.run_decoders(
inputs,
decoders=self.backend_decoders)

Expand Down Expand Up @@ -403,7 +406,7 @@ def __init__(
np_dtype = np.float16

self.backend_cls_prefill = partial(
LowBitLlamaMultiDecoderlayer,
LowBitMinicpmMultiDecoderlayer,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
num_layers=1,
Expand Down Expand Up @@ -445,7 +448,9 @@ def forward(
seq_len = hidden_states.shape[1]

backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
Expand Down Expand Up @@ -578,9 +583,9 @@ def run_decode(

pad_mask = (0, pad_len)
padded_causal_mask = F.pad(
causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
)
padded_causal_mask[:, :, :, -1] = 0.0
padded_causal_mask[:, :, :, -1] = 0
dist.recv(hidden_states, src=rank - 1)
layer_outputs = multi_decoder(
hidden_states,
Expand Down Expand Up @@ -831,9 +836,9 @@ def forward(
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad(
attention_mask.to(torch.float16),
attention_mask.to(torch.int64),
(0, pad_len, 0, pad_len),
value=torch.finfo(torch.float16).min,
value=torch.iinfo(torch.int64).min,
)

args = (hidden_states, position_ids, attention_mask, past_key_value)
Expand Down
Loading

0 comments on commit 41b8064

Please sign in to comment.