diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c419eebf65b..09c69232a5f 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1261,6 +1261,7 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward) from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward) + convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward) from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat model.chat = MethodType(internlm_xcomposser2_chat, model) elif model.config.model_type == "qwen": diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 227293e497d..df0ffd5d810 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -359,12 +359,14 @@ def internlm_xcomposser2_attention_forward( kv_seq_len += past_key_value[0].shape[-2] # IPEX-LLM OPT: fuse rope - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if should_use_fuse_rope(hidden_states, position_ids, self.training): - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu( - query_states, key_states, sin, cos, "internlm", position_ids - ) + # This fuse rope will get wrong result if context_length > max_position_embeddings (32768) + # we assume context_length <= 32768 + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, "internlm")