diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md index 5ad0627b7b7..e3569496184 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md @@ -11,7 +11,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr | Qwen2 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | | Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) | | Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) | -| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) | +| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) | ## 0. Requirements To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. @@ -59,6 +59,9 @@ python baichuan2.py :: to run MiniCPM-1B-sft-bf16 python minicpm.py + +:: to run MiniCPM-2B-sft-bf16 +python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16" ``` Arguments info: diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py index 9fd854898b0..9cd01218852 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py @@ -32,7 +32,7 @@ parser.add_argument( "--repo-id-or-model-path", type=str, - default="openbmb/MiniCPM-1B-sft-bf16", + default="openbmb/MiniCPM-1B-sft-bf16", # or "openbmb/MiniCPM-2B-sft-bf16" help="The huggingface repo id for the MiniCPM model to be downloaded" ", or the path to the huggingface checkpoint folder", ) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 21603467a24..39760cf9eee 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -362,7 +362,7 @@ def convert_llm(model: torch.nn.Module, invalidInputError(False, "Now we only support Llama2 / Llama3 / Baichuan2 / " "Qwen2 / Qwen2.5 / Minicpm for pipeline running.") - if isinstance(model.lm_head, SlicedLMHead): + if hasattr(model, "lm_head") and isinstance(model.lm_head, SlicedLMHead): model.lm_head.get_fused_lm_head() # patch generate function diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index 6f18b579dcf..f39a73a1b39 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -18,8 +18,10 @@ import torch import numpy as np import os -from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead +from .common import update_names_of_IR_and_export_blob from intel_npu_acceleration_library.backend.factory import NNFactory +from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from typing import Sequence class MiniCPMEmbedding(NNFactory): @@ -65,6 +67,68 @@ def __init__( self.compile() +class MiniCPMLMHead(LLMBaseNNFactory): + def __init__( + self, + hidden_shape: Sequence[int], + num_heads: int, + rms_norm_eps: float, + model_norm_weight, + vocab_size: int, + mode: str = "decode", + dtype: np.dtype = np.int8, + max_seq_len: int = 1024, + transpose_value: bool = False, + profile: bool = False, + device: str = "NPU", + ): + super().__init__(max_seq_len=max_seq_len, + transpose_value=transpose_value, + dtype=dtype, + profile=profile, + device=device) + self.max_seq_len = max_seq_len + self.dtype = dtype + self.batch_size, self.seq_len, self.hidden_size = hidden_shape + self.mode = mode + self.rms_norm_eps = rms_norm_eps + self.transpose_value = transpose_value + self.vocab_size = vocab_size + + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_heads + + # define input, the order self.parameter matters + input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) + + hidden_states = input + + # model norm and lm head + model_norm_weight = self.constant(model_norm_weight) + hidden_states = self.layer_norm(hidden_states, model_norm_weight) + if vocab_size == 122753: + # for MiniCPM-2B-sft-bf16 + hidden_states_1 = self.linear( + hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + hidden_states_2 = self.linear( + hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313]) + hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2) + else: + # for MiniCPM-1B-sft-bf16 + hidden_states = self.linear( + hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + + # define outputs + hidden_states = self.convert_to_fp32(hidden_states) + + print("start compiling") + self.compile() + + def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads @@ -72,24 +136,23 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): rms_norm_eps = model.config.rms_norm_eps vocab_size = model.config.vocab_size model_norm = model.model.norm - lm_head = model.lm_head if n_splits_linear == 1: - weights = [(lm_head.weight, lm_head.scale)] + if vocab_size == 122753: + # for MiniCPM-2B-sft-bf16 + weights = [(model.lm_head_0.weight, model.lm_head_0.scale), + (model.lm_head_1.weight, model.lm_head_1.scale)] + else: + # for MiniCPM-1B-sft-bf16 + weights = [(model.lm_head.weight, model.lm_head.scale)] else: - lm_heads = lm_head.lm_heads - lm_head_weights = [] - scales = [] - for i in range(n_splits_linear): - lm_head_weights.append(lm_heads[i].weight) - scales.append(lm_heads[i].scale) - weights = [(torch.stack(lm_head_weights, axis=0), - torch.stack(scales, axis=0))] + # TODO + pass if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear np_dtype = np.float16 - new_lm_head = LowBitLLMLMHead( + new_lm_head = MiniCPMLMHead( [1, 1, num_heads * head_dim], num_heads=num_heads, max_seq_len=1, @@ -99,17 +162,21 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, - n_splits=n_splits_linear ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) # save weights bins files if n_splits_linear == 1: - weight_numpy = [ - lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), - ] + if vocab_size == 122753: + weight_numpy = [model.lm_head_0.weight.data.numpy(), + model.lm_head_0.scale.data.numpy(), + model.lm_head_1.weight.data.numpy(), + model.lm_head_1.scale.data.numpy(), ] + else: + weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ] else: - weight_numpy = [v.numpy() for v in weights[0]] + # TODO + pass for idx, weight in enumerate(weight_numpy): bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")