diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2acab799108..ae7df6b3d46 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -267,6 +267,7 @@ def _optimize_post(model): from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward + from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward convert_forward(model, module.SelfAttention, chatglm2_attention_forward_8eb45c @@ -274,6 +275,9 @@ def _optimize_post(model): convert_forward(model, module.CoreAttention, core_attn_forward_8eb45c) + convert_forward(model, + module.ChatGLMModel, + chatglm2_model_forward) convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index fa54ea3e152..0373f2aad29 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -18,8 +18,9 @@ # import torch -from typing import Optional, Tuple, Union, List, Callable, Dict, Any +from typing import Optional, Tuple, List import torch.nn.functional as F +from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache @@ -54,7 +55,7 @@ def split_tensor_along_last_dim( @torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) rot_dim = rope_cache.shape[-2] * 2 @@ -87,6 +88,77 @@ def chatglm_rms_norm_forward(self, hidden_states): return hidden_states +def chatglm2_model_forward( + self, + input_ids, + position_ids: Optional[torch.Tensor]=None, + attention_mask: Optional[torch.BoolTensor]=None, + full_attention_mask: Optional[torch.BoolTensor]=None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None, + inputs_embeds: Optional[torch.Tensor]=None, + use_cache: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, + return_dict: Optional[bool]=None, +): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or ( + past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, + past_key_values, + padding_mask=attention_mask) + + use_fuse_rope = input_ids.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not self.training + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + if use_fuse_rope: + # Repeat cos sin here, call only once for each token. + # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two. + # If put this to attension forward, it will generate too many times. + cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) + cos = cos.squeeze(-1) + sin = sin.squeeze(-1) + cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) + sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) + rotary_pos_emb = (cos, sin) + else: + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] + if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def chatglm2_attention_forward_8eb45c( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): @@ -132,12 +204,26 @@ def chatglm2_attention_forward_8eb45c( # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + if len(rotary_pos_emb) == 2: # use_fuse_rope, see chatglm2_model_forward + cos, sin = rotary_pos_emb + rot_dim = cos.shape[-1] + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) + query_layer_cur = query_layer[..., :rot_dim] + key_layer_cur = key_layer[..., :rot_dim] + # ipex's apply_rotary_embedding can change the origin storage, so query_layer will get + # the result directly. + torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur) + torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur) + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) + else: + query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) if self.multi_query_attention: key_length = key_layer.size(0) @@ -200,7 +286,6 @@ def chatglm2_attention_forward_8eb45c( # ================================== # core attention computation # ================================== - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # =================