diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md index fcb79e5afc6..98f2c161070 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md @@ -8,6 +8,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr |------------|----------------------------------------------------------------| | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | +| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) | ## 0. Requirements To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. @@ -43,6 +44,9 @@ python llama2.py :: to run Meta-Llama-3-8B-Instruct python llama3.py + +:: to run Baichuan2-7B-Chat +python baichuan2.py ``` Arguments info: diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py new file mode 100644 index 00000000000..04e4a0ff8b6 --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py @@ -0,0 +1,99 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +import time +import argparse +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def get_prompt(message: str, chat_history: list[tuple[str, str]], + system_prompt: str) -> str: + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + # The first user input is _not_ stripped + do_strip = False + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Predict Tokens using `generate()` API for npu model" + ) + parser.add_argument( + "--repo-id-or-model-path", + type=str, + default="baichuan-inc/Baichuan2-7B-Chat", + help="The huggingface repo id for the Baichuan2 model to be downloaded" + ", or the path to the huggingface checkpoint folder", + ) + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") + parser.add_argument("--max-context-len", type=int, default=1024) + parser.add_argument("--max-prompt-len", type=int, default=960) + parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + model = AutoModelForCausalLM.from_pretrained(model_path, + optimize_model=True, + pipeline=True, + max_context_len=args.max_context_len, + max_prompt_len=args.max_prompt_len, + torch_dtype=torch.float16, + attn_implementation="eager", + transpose_value_cache=not args.disable_transpose_value_cache, + trust_remote_code=True) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + DEFAULT_SYSTEM_PROMPT = """\ + """ + + print("-" * 80) + print("done") + with torch.inference_mode(): + print("finish to load") + for i in range(5): + prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + _input_ids = tokenizer.encode(prompt, return_tensors="pt") + print("input length:", len(_input_ids[0])) + st = time.time() + output = model.generate( + _input_ids, max_new_tokens=args.n_predict, do_print=True + ) + end = time.time() + print(f"Inference time: {end-st} s") + input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False) + print("-" * 20, "Input", "-" * 20) + print(input_str) + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print("-" * 20, "Output", "-" * 20) + print(output_str) + + print("-" * 80) + print("done") + print("success shut down") diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 453aee0e966..c8d64c1e5cf 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -112,32 +112,14 @@ def __init__( # Self Attention if mode == "decode": - attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) + attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), + dtype=np.int64) else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), + dtype=np.int64) - position_ids = self.create_input_op((self.batch_size, self.seq_len)) + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) # self.num_key_value_heads = num_key_value_heads - past_keys = [] - past_values = [] - if mode == "decode": - for i in range(num_layers): - past_key = self.create_cache_op( - (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) - ) - if transpose_value: - past_value = self.create_cache_op( - (self.batch_size, self.num_heads, self.head_dim, self.max_seq_len) - ) - else: - past_value = self.create_cache_op( - (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) - ) - past_keys.append(past_key) - past_values.append(past_value) - else: - past_keys = [None] * num_layers - past_values = [None] * num_layers if input_layernorm_weights is None: input_layernorm_weights = [] @@ -163,6 +145,27 @@ def __init__( input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights] post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] + past_keys = [] + past_values = [] + if mode == "decode": + for i in range(num_layers): + past_key = self.create_cache_op( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + if transpose_value: + past_value = self.create_cache_op( + (self.batch_size, self.num_heads, self.head_dim, self.max_seq_len) + ) + else: + past_value = self.create_cache_op( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + past_keys.append(past_key) + past_values.append(past_value) + else: + past_keys = [None] * num_layers + past_values = [None] * num_layers + hidden_states = input curr_key_values = [] @@ -251,6 +254,7 @@ def attention(self, attn_weight = self.matmul(query_states, key_states, False, True) / ( math.sqrt(self.head_dim)) + attention_mask = self.convert_to_fp16(attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) @@ -395,8 +399,8 @@ def forward( inputs = ( hidden_states.to(torch.float16), - attention_mask, - position_ids.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64), ) for i in range(self.intra_stages): @@ -502,7 +506,9 @@ def forward( seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16)) + inputs = (hidden_states.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 @@ -625,9 +631,9 @@ def run_decode( pad_mask = (0, pad_len) padded_causal_mask = F.pad( - attention_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min + attention_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min ) - padded_causal_mask[:, :, :, -1] = 0.0 + padded_causal_mask[:, :, :, -1] = 0 dist.recv(hidden_states, src=rank - 1) layer_outputs = multi_decoder( hidden_states, @@ -869,9 +875,9 @@ def forward( hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) position_ids = F.pad(position_ids, (0, pad_len), value=0) attention_mask = F.pad( - attention_mask.to(torch.float16), + attention_mask.to(torch.int64), (0, pad_len, 0, pad_len), - value=torch.finfo(torch.float16).min, + value=torch.iinfo(torch.int64).min, ) args = (hidden_states, position_ids, attention_mask, past_key_value) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index cb4e94320cb..39999ce77f3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -192,6 +192,41 @@ def convert_llama( convert_forward(model, LlamaForCausalLM, llama2_casullm_forward) +def convert_baichuan( + model: torch.nn.Module, + max_output_len=1024, + max_prompt_len=1024, + decoder=False, + inter_pp=None, + intra_pp=None, + transpose_value_cache=True, +): + from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward + from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner + if decoder: + decode_runner = DecodeRunner( + model, + max_seq_len=max_output_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + transpose_value_cache=transpose_value_cache, + ) + else: + decode_runner = None + prefill_runner = PrefillRunner( + model, + max_output_len=max_output_len, + max_prompt_len=max_prompt_len, + transpose_value_cache=transpose_value_cache, + ) + baichuan_model_forward = gen_baichuan_fused_model_forward( + prefill_runner=prefill_runner, decode_runner=decode_runner + ) + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward(model, module.BaichuanModel, baichuan_model_forward) + + def optimize_llm( model: torch.nn.Module, max_context_len=1024, @@ -297,28 +332,13 @@ def optimize_llm( intra_pp = 2 if inter_pp is None: inter_pp = 2 - from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward - from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner - decode_runner = DecodeRunner( - model, - max_seq_len=max_context_len, - inter_pp=inter_pp, - intra_pp=intra_pp, - transpose_value_cache=transpose_value_cache, - ) - prefill_runner = PrefillRunner( - model, - max_output_len=max_context_len, - max_prompt_len=max_prompt_len, - transpose_value_cache=transpose_value_cache, - ) - baichuan_model_forward = gen_baichuan_fused_model_forward( - prefill_runner=prefill_runner, decode_runner=decode_runner - ) - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - convert_forward(model, module.BaichuanModel, baichuan_model_forward) - + convert_baichuan(model, + max_output_len=max_context_len, + max_prompt_len=max_prompt_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + decoder=True, + transpose_value_cache=transpose_value_cache) if isinstance(model.lm_head, SlicedLMHead): model.lm_head.get_fused_lm_head() diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py new file mode 100644 index 00000000000..0ceaf93100f --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py @@ -0,0 +1,131 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +import numpy as np +import os +from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead + + +def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): + num_heads = model.model.layers[0].self_attn.num_heads + head_dim = model.model.layers[0].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + vocab_size = model.config.vocab_size + model_norm = model.model.norm + lm_head = model.lm_head + weights = [(lm_head.weight, lm_head.scale)] + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + new_lm_head = LowBitLLMLMHead( + [1, 1, num_heads * head_dim], + num_heads=num_heads, + max_seq_len=1, + rms_norm_eps=rms_norm_eps, + mode="decode", + transpose_value=False, + dtype=np_dtype, + model_norm_weight=model_norm.weight.to(torch.float16), + vocab_size=vocab_size, + ) + last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) + + # save weights bins files + weight_numpy = [ + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), + ] + + for idx, weight in enumerate(weight_numpy): + bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") + weight.tofile(bin_file) + + embedding_layer = model.model.embed_tokens + new_embedding = LLMEmbedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + dtype=np.float16, + ) + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir) + return first_blob_path, last_blob_path + + +def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + num_heads = model.model.layers[0].self_attn.num_heads + head_dim = model.model.layers[0].self_attn.head_dim + intermediate_size = model.config.intermediate_size + rms_norm_eps = model.config.rms_norm_eps + + from ipex_llm.transformers.npu_models.baichuan_mp import LowBitBaichuanMultiDecoderlayer + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [] + if n_splits_linear == 1: + weights = [ + (attn_layer.W_pack.weight, attn_layer.W_pack.scale), + (attn_layer.o_proj.weight, attn_layer.o_proj.scale), + (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), + (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), + (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), + ] + else: + # TODO + pass + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + single_decoder = LowBitBaichuanMultiDecoderlayer( + [1, 1, num_heads * head_dim], + input_layernorm_weights=[layer_norm_0], + post_attn_layernorm_weights=[layer_norm_1], + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_layers=1, + max_seq_len=kv_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + transpose_value=transpose_value_cache, + dtype=np_dtype, + ) + rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, + f"decoder_layer_{layer_idx}", + temp_dir) + + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + del single_decoder diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py index 0e3da6e62ad..3cccb9fd422 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -17,6 +17,10 @@ from openvino.runtime import Core, serialize import os +from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from typing import Sequence +from intel_npu_acceleration_library.backend.factory import NNFactory +import numpy as np def update_names_of_IR_and_export_blob(model, model_name, dir): @@ -52,3 +56,101 @@ def update_names_of_IR_and_export_blob(model, model_name, dir): os.remove(new_ir_path) return blob_path + + +class LowBitLLMLMHead(LLMBaseNNFactory): + def __init__( + self, + hidden_shape: Sequence[int], + num_heads: int, + rms_norm_eps: float, + model_norm_weight, + vocab_size: int, + mode: str = "decode", + dtype: np.dtype = np.int8, + max_seq_len: int = 1024, + transpose_value: bool = False, + profile: bool = False, + device: str = "NPU", + n_splits: int = 1, + ): + super().__init__(max_seq_len=max_seq_len, + transpose_value=transpose_value, + dtype=dtype, + profile=profile, + device=device) + self.max_seq_len = max_seq_len + self.dtype = dtype + self.batch_size, self.seq_len, self.hidden_size = hidden_shape + self.mode = mode + self.rms_norm_eps = rms_norm_eps + self.transpose_value = transpose_value + self.vocab_size = vocab_size + + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_heads + + # define input, the order self.parameter matters + input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) + + hidden_states = input + + # model norm and lm head + model_norm_weight = self.constant(model_norm_weight) + hidden_states = self.layer_norm(hidden_states, model_norm_weight) + if n_splits == 1: + hidden_states = self.linear( + hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + else: + hidden_states = self.dq_split_linear( + hidden_states, self.vocab_size, self.hidden_size, n_splits, + wt_dtype=dtype, scale_factor=False + ) + + # define outputs + hidden_states = self.convert_to_fp32(hidden_states) + + print("start compiling") + self.compile() + + +class LLMEmbedding(NNFactory): + def __init__( + self, + vocab_size, + embedding_dim, + embedding_weight, + padding_idx, + dtype, # fp16 + device: str = "NPU", + ): + super().__init__(False, device) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.dtype = dtype + + # define input + weight = self.constant(embedding_weight) + input = self.parameter((1, 1), dtype=np.int32) + + if padding_idx == -1: + padding_idx += vocab_size + + axis_node = self.constant(np.array([0], dtype=np.int64)) + if padding_idx is not None: + masked_embeddings = np.ones(weight.shape, dtype=np.float16) + masked_embeddings[padding_idx, :] = 0.0 # mask + + node_mask = self.constant(masked_embeddings) + node_masked_w = self.eltwise_mul(weight, node_mask) + res = self.gather(node_masked_w, input, axis_node, 0) + else: + res = self.gather(weight, input, axis_node, 0) + + # define outputs + res = self.convert_to_fp16(res) + + print("start compiling") + self.compile() diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index e5db1bb2ee5..3eacdd6bee0 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -122,6 +122,7 @@ def generate( thread = threading.Thread(target=generate_serve, args=(self.kv_len, self.num_head, self.head_dim, self.num_layers, + self.vocab_size, self.transpose_value_cache, new_tokens - 2)) thread.start() @@ -163,11 +164,11 @@ def generate( break token = int.from_bytes(data, sys.byteorder) idx += 1 + if token == eos: + break output_tokens.append(torch.tensor([token])) if streamer is not None: streamer.put(torch.tensor([token])) - if token == eos: - break output = torch.stack(output_tokens, dim=1) output = torch.cat((inputs, output), dim=1) @@ -231,7 +232,47 @@ def convert_llm(model: torch.nn.Module, model.transpose_value_cache = transpose_value_cache try: - res = InitLLMPipeline(kv_len, model.num_head, model.head_dim, layer_num, + res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num, + model.vocab_size, weight_dir, "model", + first_blob_path, last_blob_path, + os.path.join(temp_dir, "decoder_layer")) + except: + invalidInputError(False, + "False to InitLLMPipeline.") + elif model.config.model_type == "baichuan": + with tempfile.TemporaryDirectory() as temp_dir: + weight_dir = os.path.join(temp_dir, "model_weights") + os.mkdir(weight_dir) + layer_num = len(model.model.layers) + from .baichuan import convert_baichuan_layer, convert_lm_head_and_embedding + first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear, + temp_dir, weight_dir) + + param_list = [] + for layer_idx in range(0, layer_num): + param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + with Pool() as pool: + result = pool.starmap(convert_baichuan_layer, param_list) + + # Prefill Runner + from ipex_llm.transformers.npu_models.convert_mp import convert_baichuan + convert_baichuan(model, + max_output_len=kv_len, + max_prompt_len=max_prompt_len, + decoder=False, + transpose_value_cache=transpose_value_cache) + + # patch attrs for generate + model.kv_len = kv_len + model.num_head = model.model.layers[0].self_attn.num_heads + model.head_dim = model.model.layers[0].self_attn.head_dim + model.num_layers = layer_num + model.transpose_value_cache = transpose_value_cache + model.vocab_size = model.config.vocab_size + + try: + res = InitLLMPipeline("baichuan", kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", first_blob_path, last_blob_path, os.path.join(temp_dir, "decoder_layer")) @@ -240,7 +281,7 @@ def convert_llm(model: torch.nn.Module, "False to InitLLMPipeline.") else: invalidInputError(False, - "Now we only support Llama2 for pipeline running.") + "Now we only support Llama2 / Llama3 / Baichuan2 for pipeline running.") if isinstance(model.lm_head, SlicedLMHead): model.lm_head.get_fused_lm_head() diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 9392c8470fd..1203214c0de 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -17,112 +17,8 @@ import torch import numpy as np -from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory -from typing import Sequence -from intel_npu_acceleration_library.backend.factory import NNFactory import os -from .common import update_names_of_IR_and_export_blob - - -class LowBitLlamaLMHead(LLMBaseNNFactory): - def __init__( - self, - hidden_shape: Sequence[int], - num_heads: int, - num_key_value_heads: int, - rms_norm_eps: float, - model_norm_weight, - vocab_size: int, - mode: str = "decode", - dtype: np.dtype = np.int8, - max_seq_len: int = 1024, - transpose_value: bool = False, - profile: bool = False, - device: str = "NPU", - n_splits: int = 1, - ): - super().__init__(max_seq_len=max_seq_len, - transpose_value=transpose_value, - dtype=dtype, - profile=profile, - device=device) - self.max_seq_len = max_seq_len - self.dtype = dtype - self.batch_size, self.seq_len, self.hidden_size = hidden_shape - self.mode = mode - self.rms_norm_eps = rms_norm_eps - self.transpose_value = transpose_value - self.vocab_size = vocab_size - - self.num_heads = num_heads - self.num_key_value_heads = num_key_value_heads - - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - # define input, the order self.parameter matters - input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) - - hidden_states = input - - # model norm and lm head - model_norm_weight = self.constant(model_norm_weight) - hidden_states = self.layer_norm(hidden_states, model_norm_weight) - if n_splits == 1: - hidden_states = self.linear( - hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - else: - hidden_states = self.dq_split_linear( - hidden_states, self.vocab_size, self.hidden_size, n_splits, - wt_dtype=dtype, scale_factor=False - ) - - # define outputs - hidden_states = self.convert_to_fp32(hidden_states) - - print("start compiling") - self.compile() - - -class LlamaEmbedding(NNFactory): - def __init__( - self, - vocab_size, - embedding_dim, - embedding_weight, - padding_idx, - dtype, # fp16 - device: str = "NPU", - ): - super().__init__(False, device) - self.vocab_size = vocab_size - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.dtype = dtype - - # define input - weight = self.constant(embedding_weight) - input = self.parameter((1, 1), dtype=np.int32) - - if padding_idx == -1: - padding_idx += vocab_size - - if padding_idx is not None: - masked_embeddings = np.ones(weight.shape, dtype='int64') - masked_embeddings[padding_idx, :] = 0 # mask - - node_mask = self.constant(masked_embeddings) - node_masked_w = self.matmul(weight, node_mask, False, True) - - axis_node = self.constant(np.array([0], dtype=np.int64)) - res = self.gather(node_masked_w if padding_idx else weight, input, axis_node, 0) - - # define outputs - res = self.convert_to_fp16(res) - - print("start compiling") - self.compile() +from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): @@ -149,10 +45,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): else: # FP16 Linear np_dtype = np.float16 - new_lm_head = LowBitLlamaLMHead( + new_lm_head = LowBitLLMLMHead( [1, 1, num_heads * head_dim], num_heads=num_heads, - num_key_value_heads=num_key_value_heads, max_seq_len=1, rms_norm_eps=rms_norm_eps, mode="decode", @@ -177,7 +72,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): weight.tofile(bin_file) embedding_layer = model.model.embed_tokens - new_embedding = LlamaEmbedding( + new_embedding = LLMEmbedding( vocab_size=model.config.vocab_size, embedding_dim=model.config.hidden_size, embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py index 366ed744b31..41d4f95d854 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py @@ -43,23 +43,23 @@ def get_shared_lib_info(lib_base_name: str): # Load the library _lib = ctypes.cdll.LoadLibrary(_lib_path) -_lib.InitLLMPipeline.argtypes = [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5 +_lib.InitLLMPipeline.argtypes = [ctypes.c_char_p] + [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5 _lib.InitLLMPipeline.restype = ctypes.c_int -_lib.generate_serve.argtypes = [ctypes.c_int] * 4 + [ctypes.c_bool] + [ctypes.c_int] +_lib.generate_serve.argtypes = [ctypes.c_int] * 5 + [ctypes.c_bool] + [ctypes.c_int] _lib.generate_serve.restype = ctypes.c_int -def InitLLMPipeline(kv_len: int, num_head: int, head_dim: int, num_layers: int, vocab_size: int, - model_weight_dir: str, model_name: str, +def InitLLMPipeline(model_type: str, kv_len: int, num_head: int, head_dim: int, num_layers: int, + vocab_size: int, model_weight_dir: str, model_name: str, first_blob_name: str, last_blob_name: str, rest_blob_name: str): - return _lib.InitLLMPipeline(kv_len, num_head, head_dim, num_layers, vocab_size, - model_weight_dir.encode('utf-8'), model_name.encode('utf-8'), - first_blob_name.encode('utf-8'), last_blob_name.encode('utf-8'), - rest_blob_name.encode('utf-8')) + return _lib.InitLLMPipeline(model_type.encode('utf-8'), kv_len, num_head, head_dim, num_layers, + vocab_size, model_weight_dir.encode('utf-8'), + model_name.encode('utf-8'), first_blob_name.encode('utf-8'), + last_blob_name.encode('utf-8'), rest_blob_name.encode('utf-8')) def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int, - transpose_value_cache: bool, param_n_output: int): + vocab_size: int, transpose_value_cache: bool, param_n_output: int): _lib.generate_serve(kv_len, num_head, head_dim, num_layers, - transpose_value_cache, param_n_output) + vocab_size, transpose_value_cache, param_n_output)