diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index bb08b1abea5..32841838d6d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -30,3 +30,13 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: new_linear.in_features = new_weight.size(1) new_linear.out_features = new_weight.size(0) return new_linear + + +def reshape_lm_head_input(x): + if x.dim() > 3: + x = x.reshape([-1, x.shape[-2], x.shape[-1]]) + shape = list(x.size()) + if shape[1] > 10: + shape[1] = 1 + x = x[:, -1, :].view(shape) + return x diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 150788be4ec..7056f1f9923 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -54,6 +54,9 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, LlamaModel, llama_model_forward) + from transformers.models.llama.modeling_llama import LlamaForCausalLM + from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward + convert_forward(model, LlamaForCausalLM, llama2_casullm_forward) elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960: # for qwen2-1.5B from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward @@ -77,3 +80,6 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, Qwen2Model, qwen2_model_forward) + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward + convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 0e6d113cae3..46c4236f2f1 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -39,6 +39,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from ipex_llm.transformers.npu_models.common import reshape_lm_head_input +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.nn import CrossEntropyLoss class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): @@ -944,3 +947,79 @@ def llama_fused_model_forward( ) return llama_fused_model_forward + + +def llama2_casullm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # ipex-llm change start + hidden_states = reshape_lm_head_input(hidden_states) + # ipex-llm change end + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, + dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 7a61ad9d24b..ec5e701fd4b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -39,6 +39,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from ipex_llm.transformers.npu_models.common import reshape_lm_head_input +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.nn import CrossEntropyLoss class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): @@ -981,3 +984,72 @@ def qwen2_fused_model_forward( ) return qwen2_fused_model_forward + + +def qwen2_casullm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # cache_position=cache_position, + ) + + hidden_states = outputs[0] + # ipex-llm change start + hidden_states = reshape_lm_head_input(hidden_states) + # ipex-llm change end + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )