From 5ee6f97d6fd7a9c78b81a3e3d67272599b988a12 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:46:24 +0800 Subject: [PATCH] [NPU L0] Add layernorm weight as const / input setting (#12322) --- .../npu_pipeline_model/baichuan.py | 19 ++++++++---- .../npu_pipeline_model/convert_pipeline.py | 23 ++++++++++----- .../transformers/npu_pipeline_model/llama.py | 19 ++++++++---- .../npu_pipeline_model/minicpm.py | 19 ++++++++---- .../npu_pipeline_model/pipeline_cpp.py | 9 ++++-- .../transformers/npu_pipeline_model/qwen.py | 29 +++++++++++-------- 6 files changed, 80 insertions(+), 38 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py index 0ceaf93100f..078490925ac 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py @@ -70,7 +70,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const): num_heads = model.model.layers[0].self_attn.num_heads head_dim = model.model.layers[0].self_attn.head_dim intermediate_size = model.config.intermediate_size @@ -106,8 +107,8 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj single_decoder = LowBitBaichuanMultiDecoderlayer( [1, 1, num_heads * head_dim], - input_layernorm_weights=[layer_norm_0], - post_attn_layernorm_weights=[layer_norm_1], + input_layernorm_weights=[layer_norm_0] if layernorm_const else None, + post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, cached_cos=cached_cos, cached_sin=cached_sin, num_heads=num_heads, @@ -123,9 +124,17 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj f"decoder_layer_{layer_idx}", temp_dir) + if layernorm_const: + st_idx = 5 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 7 for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") scale.numpy().tofile(bin_file) del single_decoder diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 05f70dbc863..4e7394d3b40 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -189,6 +189,8 @@ def convert_llm(model: torch.nn.Module, max_prompt_len: int, transpose_value_cache: bool, group_size: int): + # whether to set layernorm weight as const + layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "1") == "1" if group_size == 0: n_splits_linear = 1 n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1 @@ -207,7 +209,8 @@ def convert_llm(model: torch.nn.Module, param_list = [] for layer_idx in range(0, layer_num): param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const)) with Pool() as pool: result = pool.starmap(convert_llama_layer, param_list) @@ -230,7 +233,7 @@ def convert_llm(model: torch.nn.Module, res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", first_blob_path, last_blob_path, - os.path.join(temp_dir, "decoder_layer")) + os.path.join(temp_dir, "decoder_layer"), layernorm_const) except: invalidInputError(False, "False to InitLLMPipeline.") @@ -246,7 +249,8 @@ def convert_llm(model: torch.nn.Module, param_list = [] for layer_idx in range(0, layer_num): param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const)) with Pool() as pool: result = pool.starmap(convert_baichuan_layer, param_list) @@ -270,7 +274,7 @@ def convert_llm(model: torch.nn.Module, res = InitLLMPipeline("baichuan", kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", first_blob_path, last_blob_path, - os.path.join(temp_dir, "decoder_layer")) + os.path.join(temp_dir, "decoder_layer"), layernorm_const) except: invalidInputError(False, "False to InitLLMPipeline.") @@ -286,7 +290,8 @@ def convert_llm(model: torch.nn.Module, param_list = [] for layer_idx in range(0, layer_num): param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const)) with Pool() as pool: result = pool.starmap(convert_minicpm_layer, param_list) @@ -309,11 +314,12 @@ def convert_llm(model: torch.nn.Module, res = InitLLMPipeline("minicpm", kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", first_blob_path, last_blob_path, - os.path.join(temp_dir, "decoder_layer")) + os.path.join(temp_dir, "decoder_layer"), layernorm_const) except: invalidInputError(False, "False to InitLLMPipeline.") elif model.config.model_type == "qwen2": + layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "0") == "1" with tempfile.TemporaryDirectory() as temp_dir: weight_dir = os.path.join(temp_dir, "model_weights") os.mkdir(weight_dir) @@ -325,7 +331,8 @@ def convert_llm(model: torch.nn.Module, param_list = [] for layer_idx in range(0, layer_num): param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const)) with Pool() as pool: result = pool.starmap(convert_qwen_layer, param_list) @@ -349,7 +356,7 @@ def convert_llm(model: torch.nn.Module, res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", first_blob_path, last_blob_path, - os.path.join(temp_dir, "decoder_layer")) + os.path.join(temp_dir, "decoder_layer"), layernorm_const) except: invalidInputError(False, "False to InitLLMPipeline.") diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 1203214c0de..09faad35854 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -85,7 +85,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -146,8 +147,8 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, single_decoder = LowBitLlamaMultiDecoderlayer( [1, 1, num_heads * head_dim], - input_layernorm_weights=[layer_norm_0], - post_attn_layernorm_weights=[layer_norm_1], + input_layernorm_weights=[layer_norm_0] if layernorm_const else None, + post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, cached_cos=cached_cos, cached_sin=cached_sin, num_heads=num_heads, @@ -167,9 +168,17 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, f"decoder_layer_{layer_idx}", temp_dir) + if layernorm_const: + st_idx = 5 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 7 for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") scale.numpy().tofile(bin_file) del single_decoder diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index f39a73a1b39..07017efc588 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -197,7 +197,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -238,8 +239,8 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, single_decoder = LowBitMinicpmMultiDecoderlayer( [1, 1, num_heads * head_dim], - input_layernorm_weights=[layer_norm_0], - post_attn_layernorm_weights=[layer_norm_1], + input_layernorm_weights=[layer_norm_0] if layernorm_const else None, + post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, cached_cos=cached_cos, cached_sin=cached_sin, num_heads=num_heads, @@ -258,9 +259,17 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, f"decoder_layer_{layer_idx}", temp_dir) + if layernorm_const: + st_idx = 5 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 7 for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") scale.numpy().tofile(bin_file) del single_decoder diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py index 41d4f95d854..1ab40ffdbc6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py @@ -43,7 +43,8 @@ def get_shared_lib_info(lib_base_name: str): # Load the library _lib = ctypes.cdll.LoadLibrary(_lib_path) -_lib.InitLLMPipeline.argtypes = [ctypes.c_char_p] + [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5 +_lib.InitLLMPipeline.argtypes = [ctypes.c_char_p] + [ctypes.c_int] * 5 + \ + [ctypes.c_char_p] * 5 + [ctypes.c_bool] _lib.InitLLMPipeline.restype = ctypes.c_int _lib.generate_serve.argtypes = [ctypes.c_int] * 5 + [ctypes.c_bool] + [ctypes.c_int] @@ -52,11 +53,13 @@ def get_shared_lib_info(lib_base_name: str): def InitLLMPipeline(model_type: str, kv_len: int, num_head: int, head_dim: int, num_layers: int, vocab_size: int, model_weight_dir: str, model_name: str, - first_blob_name: str, last_blob_name: str, rest_blob_name: str): + first_blob_name: str, last_blob_name: str, rest_blob_name: str, + layernorm_const: bool): return _lib.InitLLMPipeline(model_type.encode('utf-8'), kv_len, num_head, head_dim, num_layers, vocab_size, model_weight_dir.encode('utf-8'), model_name.encode('utf-8'), first_blob_name.encode('utf-8'), - last_blob_name.encode('utf-8'), rest_blob_name.encode('utf-8')) + last_blob_name.encode('utf-8'), rest_blob_name.encode('utf-8'), + layernorm_const) def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int, diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index be0244e9020..80f15aa49fc 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -86,7 +86,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, - temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -149,8 +150,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, single_decoder = LowBitQwenMultiDecoderlayer( [1, 1, num_heads * head_dim], - input_layernorm_weights=None, - post_attn_layernorm_weights=None, + input_layernorm_weights=[layer_norm_0] if layernorm_const else None, + post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, q_biases=None, k_biases=None, v_biases=None, @@ -174,21 +175,25 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, temp_dir) # 0, 1, 2 are input_embed/attention_mask/position_id - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin") - k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_6.bin") - v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_7.bin") + if layernorm_const: + st_idx = 3 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 5 + q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin") + k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin") + v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin") q_bias.data.numpy().tofile(q_bias_bin_file) k_bias.data.numpy().tofile(k_bias_bin_file) v_bias.data.numpy().tofile(v_bias_bin_file) # 6, 7 are past k/v for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{10+idx*2}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin") weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{10+idx*2+1}.bin") + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin") scale.numpy().tofile(bin_file) del single_decoder