From f36c23664f7961cad100746ebb7e599324f93ada Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:56:30 +0800 Subject: [PATCH] [NPU] Fix abnormal output with latest driver (#12530) --- .../llm/src/ipex_llm/transformers/npu_models/mp_models_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 42cf72e353c..610cbc1e189 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -471,7 +471,7 @@ def layer_norm(self, hidden_states, layernorm_weight): ) eps = self.constant(self.rms_norm_eps) hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - layernorm_weight = self.convert_to_fp32(layernorm_weight) + hidden_states = self.convert_to_fp16(hidden_states) hidden_states = self.eltwise_mul(layernorm_weight, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) return hidden_states