Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chatglm2 rope optimization on xpu #9350

Merged
merged 17 commits into from
Nov 6, 2023
4 changes: 4 additions & 0 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,17 @@ def _optimize_post(model):
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward
convert_forward(model,
module.SelfAttention,
chatglm2_attention_forward_8eb45c
)
convert_forward(model,
module.CoreAttention,
core_attn_forward_8eb45c)
convert_forward(model,
module.ChatGLMModel,
chatglm2_model_forward)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
Expand Down
99 changes: 92 additions & 7 deletions python/llm/src/bigdl/llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
#

import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from typing import Optional, Tuple, List
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache


Expand Down Expand Up @@ -54,7 +55,7 @@ def split_tensor_along_last_dim(


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
Expand Down Expand Up @@ -87,6 +88,77 @@ def chatglm_rms_norm_forward(self, hidden_states):
return hidden_states


def chatglm2_model_forward(
self,
input_ids,
position_ids: Optional[torch.Tensor]=None,
attention_mask: Optional[torch.BoolTensor]=None,
full_attention_mask: Optional[torch.BoolTensor]=None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
inputs_embeds: Optional[torch.Tensor]=None,
use_cache: Optional[bool]=None,
output_hidden_states: Optional[bool]=None,
return_dict: Optional[bool]=None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (
past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids,
past_key_values,
padding_mask=attention_mask)

use_fuse_rope = input_ids.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not self.training
Comment on lines +122 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can combine this to if input_ids.device.type == "xpu" and not self.training


# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
if use_fuse_rope:
# Repeat cos sin here, call only once for each token.
# Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
# If put this to attension forward, it will generate too many times.
cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
cos = cos.squeeze(-1)
sin = sin.squeeze(-1)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
hkvision marked this conversation as resolved.
Show resolved Hide resolved
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
rotary_pos_emb = (cos, sin)
else:
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


def chatglm2_attention_forward_8eb45c(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
Expand Down Expand Up @@ -132,12 +204,26 @@ def chatglm2_attention_forward_8eb45c(
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
if len(rotary_pos_emb) == 2: # use_fuse_rope, see chatglm2_model_forward
cos, sin = rotary_pos_emb
rot_dim = cos.shape[-1]
query_layer = query_layer.transpose(0, 1)
key_layer = key_layer.transpose(0, 1)
query_layer_cur = query_layer[..., :rot_dim]
key_layer_cur = key_layer[..., :rot_dim]
# ipex's apply_rotary_embedding can change the origin storage, so query_layer will get
# the result directly.
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
query_layer = query_layer.transpose(0, 1)
key_layer = key_layer.transpose(0, 1)
else:
query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)

if self.multi_query_attention:
key_length = key_layer.size(0)
Expand Down Expand Up @@ -200,7 +286,6 @@ def chatglm2_attention_forward_8eb45c(
# ==================================
# core attention computation
# ==================================

context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

# =================
Expand Down
Loading