From b4c1ece071f7df6cb9e9289608e57e97000381b1 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 27 Oct 2023 10:05:09 +0800 Subject: [PATCH 01/17] rebase kai's code --- .../llm/src/bigdl/llm/transformers/convert.py | 5 +- .../bigdl/llm/transformers/models/chatglm2.py | 104 ++++++++++++++++-- 2 files changed, 100 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2acab799108..8775948da57 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -266,7 +266,7 @@ def _optimize_post(model): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c - from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward + from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward, chatglm2_model_forward convert_forward(model, module.SelfAttention, chatglm2_attention_forward_8eb45c @@ -274,6 +274,9 @@ def _optimize_post(model): convert_forward(model, module.CoreAttention, core_attn_forward_8eb45c) + convert_forward(model, + module.ChatGLMModel, + chatglm2_model_forward) convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index fa54ea3e152..1393ed6803b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -18,9 +18,10 @@ # import torch -from typing import Optional, Tuple, Union, List, Callable, Dict, Any +from typing import Optional, Tuple, List import torch.nn.functional as F -from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache, apply_rotary_pos_emb, apply_rotary_pos_emb_no_cache_xpu KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -54,7 +55,7 @@ def split_tensor_along_last_dim( @torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) rot_dim = rope_cache.shape[-2] * 2 @@ -86,6 +87,64 @@ def chatglm_rms_norm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) return hidden_states +def chatglm2_model_forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # use_fuse_rope = input_ids.device.type == "xpu" + # use_fuse_rope = use_fuse_rope and not self.training + use_fuse_rope = True + + if use_fuse_rope: + rotary_pos_emb = position_ids + else: + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + def chatglm2_attention_forward_8eb45c( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True @@ -132,19 +191,47 @@ def chatglm2_attention_forward_8eb45c( # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + if len(rotary_pos_emb.shape) == 2: # use_fuse_rope, actually it is position_ids + rot_dim = rotary_pos_emb.shape[-2] * 2 + query_layer = query_layer.permute(1, 2, 0, 3) + query_layer, query_layer_pass = query_layer[..., :rot_dim], query_layer[..., rot_dim:] + key_layer = key_layer.permute(1, 2, 0, 3) + key_layer, key_layer_pass = key_layer[..., :rot_dim], key_layer[..., rot_dim:] + dummy_q = torch.empty(query_layer.shape, dtype=query_layer.dtype, device=query_layer.device) + dummy_k = torch.empty(key_layer.shape, dtype=key_layer.dtype, device=key_layer.device) + _, query_layer = apply_rotary_pos_emb_no_cache_xpu(dummy_q, + query_layer, + rotary_pos_emb, + "llama") + _, key_layer = apply_rotary_pos_emb_no_cache_xpu(dummy_k, + key_layer, + rotary_pos_emb, + "llama") + if query_layer_pass.shape[-1] > 0: + query_layer = torch.cat((query_layer, query_layer_pass), dim=-1) + query_layer = query_layer.permute(2, 0, 1, 3) + if key_layer_pass.shape[-1] > 0: + key_layer = torch.cat((key_layer, key_layer_pass), dim=-1) + key_layer = key_layer.permute(2, 0, 1, 3) + else: + query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) if self.multi_query_attention: key_length = key_layer.size(0) query_group_size = self.num_attention_heads_per_partition // \ self.num_multi_query_groups_per_partition + # if rotary_pos_emb is not None and use_fuse_rope: + # key_layer = key_layer.unsqueeze(-3) + # else: + # key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) + # print("----key", key_layer.shape) key_layer = key_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, key_length, @@ -200,7 +287,8 @@ def chatglm2_attention_forward_8eb45c( # ================================== # core attention computation # ================================== - + # print("----query", query_layer.shape) + # print("----key", key_layer.shape) context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= From 03ee5bab2336aab2488ee493804448db181a8206 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 08:26:38 +0800 Subject: [PATCH 02/17] update --- .../bigdl/llm/transformers/models/chatglm2.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 1393ed6803b..2ca8167cc51 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -54,7 +54,7 @@ def split_tensor_along_last_dim( return tensor_list -@torch.jit.script +# @torch.jit.script def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) @@ -116,7 +116,7 @@ def chatglm2_model_forward( # use_fuse_rope = input_ids.device.type == "xpu" # use_fuse_rope = use_fuse_rope and not self.training - use_fuse_rope = True + use_fuse_rope = False if use_fuse_rope: rotary_pos_emb = position_ids @@ -218,6 +218,20 @@ def chatglm2_attention_forward_8eb45c( key_layer = torch.cat((key_layer, key_layer_pass), dim=-1) key_layer = key_layer.permute(2, 0, 1, 3) else: + rot_dim = rotary_pos_emb.shape[-2] * 2 + #query_states = query_layer.permute(1, 2, 0, 3).contiguous() + query_states = query_layer.transpose(0, 1)#.contiguous() + query_states, query_states_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] + key_states = key_layer.permute(1, 2, 0, 3).contiguous() + key_states = key_layer.transpose(0, 1)#.contiguous() + key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] + cos, sin = rotary_pos_emb.split([1, 1], -1) + cos = cos.squeeze().unsqueeze(0)#.unsqueeze(1) + sin = sin.squeeze().unsqueeze(0)#.unsqueeze(1) + position_ids = torch.range(0, cur_length-1, dtype=torch.long).reshape((1, cur_length)) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "gptj") + query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) From c3f02c6a2b3964c2edd4ee1038b2dff6c1b1db9d Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 08:29:31 +0800 Subject: [PATCH 03/17] update --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 2ca8167cc51..f703ed96d73 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -222,7 +222,7 @@ def chatglm2_attention_forward_8eb45c( #query_states = query_layer.permute(1, 2, 0, 3).contiguous() query_states = query_layer.transpose(0, 1)#.contiguous() query_states, query_states_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] - key_states = key_layer.permute(1, 2, 0, 3).contiguous() + # key_states = key_layer.permute(1, 2, 0, 3).contiguous() key_states = key_layer.transpose(0, 1)#.contiguous() key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] cos, sin = rotary_pos_emb.split([1, 1], -1) From 1a86e760baa22989d7dbb4dfe7e10a4b5f5eb29b Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 08:35:41 +0800 Subject: [PATCH 04/17] update --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index f703ed96d73..4c48a2357d3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -235,6 +235,9 @@ def chatglm2_attention_forward_8eb45c( query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) + assert(torch.equal(query_states, query_layer.transpose(0, 1)[..., :64])) + assert(torch.equal(key_states, key_layer.transpose(0, 1)[..., :64])) + if self.multi_query_attention: key_length = key_layer.size(0) query_group_size = self.num_attention_heads_per_partition // \ From bfaf886cf6756c644f0592b90c82f4a9622cdd59 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 15:57:30 +0800 Subject: [PATCH 05/17] update --- .../bigdl/llm/transformers/models/chatglm2.py | 49 ++++++++++--------- .../bigdl/llm/transformers/models/utils.py | 4 ++ 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 4c48a2357d3..7ecbf1b557e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -127,7 +127,15 @@ def chatglm2_model_forward( rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + # rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) + cos = cos.squeeze(-1) + sin = sin.squeeze(-1) + cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) + sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) + rotary_pos_emb = (cos, sin, rotary_pos_emb) + self.cossin = rotary_pos_emb + self.position_ids = position_ids # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( @@ -195,7 +203,24 @@ def chatglm2_attention_forward_8eb45c( # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: - if len(rotary_pos_emb.shape) == 2: # use_fuse_rope, actually it is position_ids + if len(rotary_pos_emb) == 3: + cos, sin, cossin = rotary_pos_emb + rot_dim = cos.shape[-1] + #query_states = query_layer.permute(1, 2, 0, 3).contiguous() + query_states = query_layer.transpose(0, 1)#.contiguous() + query_states, query_states_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] + # key_states = key_layer.permute(1, 2, 0, 3).contiguous() + key_states = key_layer.transpose(0, 1)#.contiguous() + key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, None, "chatglm2") + cossin = cossin.transpose(0, 1).contiguous() + query_layer = apply_rotary_pos_emb_chatglm(query_layer, cossin) + key_layer = apply_rotary_pos_emb_chatglm(key_layer, cossin) + + assert(torch.equal(query_states, query_layer.transpose(0, 1)[..., :64])) + assert(torch.equal(key_states, key_layer.transpose(0, 1)[..., :64])) + elif len(rotary_pos_emb.shape) == 2: # use_fuse_rope, actually it is position_ids rot_dim = rotary_pos_emb.shape[-2] * 2 query_layer = query_layer.permute(1, 2, 0, 3) query_layer, query_layer_pass = query_layer[..., :rot_dim], query_layer[..., rot_dim:] @@ -217,26 +242,6 @@ def chatglm2_attention_forward_8eb45c( if key_layer_pass.shape[-1] > 0: key_layer = torch.cat((key_layer, key_layer_pass), dim=-1) key_layer = key_layer.permute(2, 0, 1, 3) - else: - rot_dim = rotary_pos_emb.shape[-2] * 2 - #query_states = query_layer.permute(1, 2, 0, 3).contiguous() - query_states = query_layer.transpose(0, 1)#.contiguous() - query_states, query_states_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] - # key_states = key_layer.permute(1, 2, 0, 3).contiguous() - key_states = key_layer.transpose(0, 1)#.contiguous() - key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] - cos, sin = rotary_pos_emb.split([1, 1], -1) - cos = cos.squeeze().unsqueeze(0)#.unsqueeze(1) - sin = sin.squeeze().unsqueeze(0)#.unsqueeze(1) - position_ids = torch.range(0, cur_length-1, dtype=torch.long).reshape((1, cur_length)) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "gptj") - - query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) - - assert(torch.equal(query_states, query_layer.transpose(0, 1)[..., :64])) - assert(torch.equal(key_states, key_layer.transpose(0, 1)[..., :64])) if self.multi_query_attention: key_length = key_layer.size(0) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index b5888319e1e..abd9535dda9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -86,6 +86,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): q_embed = (q * cos) + (rotate_every_two(q) * sin) k_embed = (k * cos) + (rotate_every_two(k) * sin) return q_embed, k_embed + elif model_family == "chatglm2": + q_embed = (q * cos) + (rotate_every_two(q) * sin) + k_embed = (k * cos) + (rotate_every_two(k) * sin) + return q_embed, k_embed else: invalidInputError(False, f"{model_family} is not supported.") From 93aec0f7e600d3a9e66f83916d071b9b280b4d09 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 16:21:57 +0800 Subject: [PATCH 06/17] update --- .../bigdl/llm/transformers/models/chatglm2.py | 69 +++++++------------ 1 file changed, 23 insertions(+), 46 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 7ecbf1b557e..e422984f79a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -114,28 +114,30 @@ def chatglm2_model_forward( if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - # use_fuse_rope = input_ids.device.type == "xpu" - # use_fuse_rope = use_fuse_rope and not self.training - use_fuse_rope = False + use_fuse_rope = input_ids.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not self.training + # Rotary positional embeddings if use_fuse_rope: - rotary_pos_emb = position_ids - else: - # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] - # rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) cos = cos.squeeze(-1) sin = sin.squeeze(-1) cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) - rotary_pos_emb = (cos, sin, rotary_pos_emb) - self.cossin = rotary_pos_emb - self.position_ids = position_ids + rotary_pos_emb = (cos, sin) + else: + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( @@ -203,45 +205,20 @@ def chatglm2_attention_forward_8eb45c( # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: - if len(rotary_pos_emb) == 3: - cos, sin, cossin = rotary_pos_emb + if len(rotary_pos_emb) == 2: # use_fuse_rope + cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] - #query_states = query_layer.permute(1, 2, 0, 3).contiguous() - query_states = query_layer.transpose(0, 1)#.contiguous() + query_states = query_layer.transpose(0, 1) query_states, query_states_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] - # key_states = key_layer.permute(1, 2, 0, 3).contiguous() - key_states = key_layer.transpose(0, 1)#.contiguous() + key_states = key_layer.transpose(0, 1) key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, None, "chatglm2") - cossin = cossin.transpose(0, 1).contiguous() - query_layer = apply_rotary_pos_emb_chatglm(query_layer, cossin) - key_layer = apply_rotary_pos_emb_chatglm(key_layer, cossin) - - assert(torch.equal(query_states, query_layer.transpose(0, 1)[..., :64])) - assert(torch.equal(key_states, key_layer.transpose(0, 1)[..., :64])) - elif len(rotary_pos_emb.shape) == 2: # use_fuse_rope, actually it is position_ids - rot_dim = rotary_pos_emb.shape[-2] * 2 - query_layer = query_layer.permute(1, 2, 0, 3) - query_layer, query_layer_pass = query_layer[..., :rot_dim], query_layer[..., rot_dim:] - key_layer = key_layer.permute(1, 2, 0, 3) - key_layer, key_layer_pass = key_layer[..., :rot_dim], key_layer[..., rot_dim:] - dummy_q = torch.empty(query_layer.shape, dtype=query_layer.dtype, device=query_layer.device) - dummy_k = torch.empty(key_layer.shape, dtype=key_layer.dtype, device=key_layer.device) - _, query_layer = apply_rotary_pos_emb_no_cache_xpu(dummy_q, - query_layer, - rotary_pos_emb, - "llama") - _, key_layer = apply_rotary_pos_emb_no_cache_xpu(dummy_k, - key_layer, - rotary_pos_emb, - "llama") - if query_layer_pass.shape[-1] > 0: - query_layer = torch.cat((query_layer, query_layer_pass), dim=-1) - query_layer = query_layer.permute(2, 0, 1, 3) - if key_layer_pass.shape[-1] > 0: - key_layer = torch.cat((key_layer, key_layer_pass), dim=-1) - key_layer = key_layer.permute(2, 0, 1, 3) + torch.ops.torch_ipex.apply_rotary_embedding(query_states, sin, cos, query_states) + torch.ops.torch_ipex.apply_rotary_embedding(key_states, sin, cos, key_states) + query_layer = torch.cat((query_states, query_states_pass), dim=-1).transpose(0, 1) + key_layer = torch.cat((key_states, key_states_pass), dim=-1).transpose(0, 1) + else: + query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) if self.multi_query_attention: key_length = key_layer.size(0) From 77be988b8299ac8710d05b6dccf279f572decf55 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 18:56:54 +0800 Subject: [PATCH 07/17] code cleanup --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 9 +-------- python/llm/src/bigdl/llm/transformers/models/utils.py | 4 ---- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index e422984f79a..619e4afc85f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -21,7 +21,7 @@ from typing import Optional, Tuple, List import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast -from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache, apply_rotary_pos_emb, apply_rotary_pos_emb_no_cache_xpu +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -224,13 +224,8 @@ def chatglm2_attention_forward_8eb45c( key_length = key_layer.size(0) query_group_size = self.num_attention_heads_per_partition // \ self.num_multi_query_groups_per_partition - # if rotary_pos_emb is not None and use_fuse_rope: - # key_layer = key_layer.unsqueeze(-3) - # else: - # key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) - # print("----key", key_layer.shape) key_layer = key_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, key_length, @@ -286,8 +281,6 @@ def chatglm2_attention_forward_8eb45c( # ================================== # core attention computation # ================================== - # print("----query", query_layer.shape) - # print("----key", key_layer.shape) context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index abd9535dda9..b5888319e1e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -86,10 +86,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): q_embed = (q * cos) + (rotate_every_two(q) * sin) k_embed = (k * cos) + (rotate_every_two(k) * sin) return q_embed, k_embed - elif model_family == "chatglm2": - q_embed = (q * cos) + (rotate_every_two(q) * sin) - k_embed = (k * cos) + (rotate_every_two(k) * sin) - return q_embed, k_embed else: invalidInputError(False, f"{model_family} is not supported.") From 1fa88c052378209ffdc40a529277a0648e2d8faa Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 19:19:16 +0800 Subject: [PATCH 08/17] update --- .../llm/src/bigdl/llm/transformers/convert.py | 3 ++- .../bigdl/llm/transformers/models/chatglm2.py | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 8775948da57..ae7df6b3d46 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -266,7 +266,8 @@ def _optimize_post(model): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c - from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward, chatglm2_model_forward + from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward + from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward convert_forward(model, module.SelfAttention, chatglm2_attention_forward_8eb45c diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 619e4afc85f..ef71509036e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -87,6 +87,7 @@ def chatglm_rms_norm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) return hidden_states + def chatglm2_model_forward( self, input_ids, @@ -100,7 +101,8 @@ def chatglm2_model_forward( return_dict: Optional[bool] = None, ): output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -111,8 +113,11 @@ def chatglm2_model_forward( inputs_embeds = self.embedding(input_ids) if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + if (attention_mask is not None and not attention_mask.all()) or + (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, + past_key_values, + padding_mask=attention_mask) use_fuse_rope = input_ids.device.type == "xpu" use_fuse_rope = use_fuse_rope and not self.training @@ -146,7 +151,8 @@ def chatglm2_model_forward( ) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -209,7 +215,8 @@ def chatglm2_attention_forward_8eb45c( cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] query_states = query_layer.transpose(0, 1) - query_states, query_states_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] + query_states = query_states[..., :rot_dim] + query_states_pass = query_states[..., rot_dim:] key_states = key_layer.transpose(0, 1) key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] torch.ops.torch_ipex.apply_rotary_embedding(query_states, sin, cos, query_states) From 9dc26012311b0fe0d341c05c5fa64ae64922eea7 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 19:28:38 +0800 Subject: [PATCH 09/17] chatglm2 --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index ef71509036e..61196ff4a80 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -94,7 +94,7 @@ def chatglm2_model_forward( position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor,torch.Tensor],...]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -102,7 +102,7 @@ def chatglm2_model_forward( ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None - else self.config.output_hidden_states + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -114,7 +114,7 @@ def chatglm2_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or - (past_key_values and seq_length != 1): + (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) From 4fbe3354100004ee7cfe7cadd23e1dbba91298e4 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 19:38:18 +0800 Subject: [PATCH 10/17] fix --- .../bigdl/llm/transformers/models/chatglm2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 61196ff4a80..177cfe0bc19 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -91,14 +91,14 @@ def chatglm_rms_norm_forward(self, hidden_states): def chatglm2_model_forward( self, input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor,torch.Tensor],...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + position_ids: Optional[torch.Tensor]=None, + attention_mask: Optional[torch.BoolTensor]=None, + full_attention_mask: Optional[torch.BoolTensor]=None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None, + inputs_embeds: Optional[torch.Tensor]=None, + use_cache: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, + return_dict: Optional[bool]=None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None @@ -114,7 +114,7 @@ def chatglm2_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or - (past_key_values and seq_length != 1): + (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) From 8ba2819c4015648bb30597f100aa7e5d824ff5fe Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 19:44:04 +0800 Subject: [PATCH 11/17] fix style --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 177cfe0bc19..09b21fd3973 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -113,8 +113,8 @@ def chatglm2_model_forward( inputs_embeds = self.embedding(input_ids) if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or - (past_key_values and seq_length != 1): + if (attention_mask is not None and not attention_mask.all()) or ( + past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) @@ -152,7 +152,7 @@ def chatglm2_model_forward( if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] - if v is not None) + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, From 90e82c66cc9b95712bef88e9c61bccfe477d8256 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 19:56:09 +0800 Subject: [PATCH 12/17] fix --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 09b21fd3973..db26f67aafd 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -215,8 +215,7 @@ def chatglm2_attention_forward_8eb45c( cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] query_states = query_layer.transpose(0, 1) - query_states = query_states[..., :rot_dim] - query_states_pass = query_states[..., rot_dim:] + query_states = query_states[..., :rot_dim], query_states[..., rot_dim:] key_states = key_layer.transpose(0, 1) key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] torch.ops.torch_ipex.apply_rotary_embedding(query_states, sin, cos, query_states) From 9b7f1dbe92cd762894da7ef1970f548b8ba09fcb Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 20:06:12 +0800 Subject: [PATCH 13/17] fix --- .../bigdl/llm/transformers/models/chatglm2.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index db26f67aafd..6e12efa376a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -123,12 +123,13 @@ def chatglm2_model_forward( use_fuse_rope = use_fuse_rope and not self.training # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] if use_fuse_rope: - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] + # repeat cos sin here. cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) cos = cos.squeeze(-1) sin = sin.squeeze(-1) @@ -136,12 +137,6 @@ def chatglm2_model_forward( sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) rotary_pos_emb = (cos, sin) else: - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() # Run encoder. @@ -214,14 +209,16 @@ def chatglm2_attention_forward_8eb45c( if len(rotary_pos_emb) == 2: # use_fuse_rope cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] - query_states = query_layer.transpose(0, 1) - query_states = query_states[..., :rot_dim], query_states[..., rot_dim:] - key_states = key_layer.transpose(0, 1) - key_states, key_states_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] - torch.ops.torch_ipex.apply_rotary_embedding(query_states, sin, cos, query_states) - torch.ops.torch_ipex.apply_rotary_embedding(key_states, sin, cos, key_states) - query_layer = torch.cat((query_states, query_states_pass), dim=-1).transpose(0, 1) - key_layer = torch.cat((key_states, key_states_pass), dim=-1).transpose(0, 1) + query_layer = query_layer.transpose(0, 1) + query_layer_pass = query_layer[..., rot_dim:] + query_layer = query_layer[..., :rot_dim] + key_layer = key_layer.transpose(0, 1) + key_layer_pass = key_layer[..., rot_dim:] + key_layer = key_layer[..., :rot_dim] + torch.ops.torch_ipex.apply_rotary_embedding(query_layer, sin, cos, query_layer) + torch.ops.torch_ipex.apply_rotary_embedding(key_layer, sin, cos, key_layer) + query_layer = torch.cat((query_layer, query_layer_pass), dim=-1).transpose(0, 1) + key_layer = torch.cat((key_layer, key_layer_pass), dim=-1).transpose(0, 1) else: query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) From 41eee8519a81fc51716f93b526474a96c013d695 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Fri, 3 Nov 2023 20:15:46 +0800 Subject: [PATCH 14/17] cleanup --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 6e12efa376a..73d813d36d0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -54,7 +54,7 @@ def split_tensor_along_last_dim( return tensor_list -# @torch.jit.script +@torch.jit.script def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) From 3ea9e10c3c8073f8306b7652f51fdd1167faa246 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Mon, 6 Nov 2023 11:17:14 +0800 Subject: [PATCH 15/17] update --- .../bigdl/llm/transformers/models/chatglm2.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 73d813d36d0..632e17d641c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -129,7 +129,8 @@ def chatglm2_model_forward( else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] if use_fuse_rope: - # repeat cos sin here. + # repeat cos sin here, call only once for each token. + # if put this to attension forward, it will generate too many times. cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) cos = cos.squeeze(-1) sin = sin.squeeze(-1) @@ -206,19 +207,19 @@ def chatglm2_attention_forward_8eb45c( # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: - if len(rotary_pos_emb) == 2: # use_fuse_rope + if len(rotary_pos_emb) == 2: # use_fuse_rope, see chatglm2_model_forward cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] query_layer = query_layer.transpose(0, 1) - query_layer_pass = query_layer[..., rot_dim:] - query_layer = query_layer[..., :rot_dim] key_layer = key_layer.transpose(0, 1) - key_layer_pass = key_layer[..., rot_dim:] - key_layer = key_layer[..., :rot_dim] - torch.ops.torch_ipex.apply_rotary_embedding(query_layer, sin, cos, query_layer) - torch.ops.torch_ipex.apply_rotary_embedding(key_layer, sin, cos, key_layer) - query_layer = torch.cat((query_layer, query_layer_pass), dim=-1).transpose(0, 1) - key_layer = torch.cat((key_layer, key_layer_pass), dim=-1).transpose(0, 1) + query_layer_cur = query_layer[..., :rot_dim] + key_layer_cur = key_layer[..., :rot_dim] + # ipex's apply_rotary_embedding can change the origin storage, so query_layer will get + # the result directly. + torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur) + torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur) + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) else: query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) From fc547d21f29db5766b6d55e3353fcfcf78314235 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Mon, 6 Nov 2023 11:23:55 +0800 Subject: [PATCH 16/17] change comments --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 632e17d641c..6216a8e7e8d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -129,8 +129,9 @@ def chatglm2_model_forward( else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] if use_fuse_rope: - # repeat cos sin here, call only once for each token. - # if put this to attension forward, it will generate too many times. + # Repeat cos sin here, call only once for each token. + # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two. + # If put this to attension forward, it will generate too many times. cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) cos = cos.squeeze(-1) sin = sin.squeeze(-1) From 60a434979b2f03685fe96505dee6d7d6dcf9af91 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Mon, 6 Nov 2023 11:32:26 +0800 Subject: [PATCH 17/17] fix style check --- python/llm/src/bigdl/llm/transformers/models/chatglm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 6216a8e7e8d..0373f2aad29 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -129,7 +129,7 @@ def chatglm2_model_forward( else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] if use_fuse_rope: - # Repeat cos sin here, call only once for each token. + # Repeat cos sin here, call only once for each token. # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two. # If put this to attension forward, it will generate too many times. cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)