Skip to content

Commit

Permalink
optimize internlm-7b
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jun 11, 2024
1 parent 9831ca2 commit 1b19454
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 115 deletions.
31 changes: 11 additions & 20 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,10 @@ def _optimize_pre(model):
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
from ipex_llm.transformers.models.stablelm import merge_qkv
model.apply(merge_qkv)
# for internlm
if model.config.model_type == "internlm":
from ipex_llm.transformers.models.internlm import merge_qkv
model.apply(merge_qkv)
# for internlm-xcomposer2-vl
if model.config.model_type == "internlmxcomposer2":
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
Expand Down Expand Up @@ -1167,27 +1171,14 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.internlm import internlm_attention_forward
convert_forward(model, module.InternLMAttention, internlm_attention_forward)
convert_forward(model, module.InternLMRMSNorm, llama_rms_norm_forward)
elif model.config.model_type == "internlm2":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.internlm import internlm2_attention_forward
try:
convert_forward(model,
module.InternLM2Attention,
internlm2_attention_forward
)
except:
convert_forward(model,
module.InternLMAttention,
internlm_attention_forward
)
try:
convert_forward(model,
module.InternLM2RMSNorm,
llama_rms_norm_forward
)
except:
convert_forward(model,
module.InternLMRMSNorm,
llama_rms_norm_forward
)
convert_forward(model, module.InternLM2Attention, internlm2_attention_forward)
convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
elif model.config.model_type == "internlmxcomposer2":
modeling_module_name = model.model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down
166 changes: 71 additions & 95 deletions python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,35 @@
import torch
import torch.utils.checkpoint
from torch import nn
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from einops import rearrange
import os


KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def merge_qkv(module: torch.nn.Module):
if module.__class__.__name__ == "InternLMAttention":
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=-1)

qkv_proj = torch.nn.Linear(0, 0, bias=True)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj

del module.q_proj, module.k_proj, module.v_proj


def internlm_attention_forward(
Expand All @@ -68,109 +83,69 @@ def internlm_attention_forward(
use_cache: bool=False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
query_states = self.q_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
key_states = self.k_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
value_states = self.v_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_heads,
self.num_heads], dim=1)

kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]

# IPEX-LLM OPT: fuse rope
if should_use_fuse_rope(hidden_states, position_ids, self.training):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"internlm")
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
position_ids,
"internlm")
# [bsz, nh, t, hd]

if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device
)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v

key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)

elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(
bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device
query_states, key_states, cos, sin, position_ids, "internlm"
)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
)
past_key_value = (key_states, value_states) if use_cache else None

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# IPEX-LLM OPT: sdp
attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
f"but is {attn_weights.size()}"
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
f"but is {attn_output.size()}"
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand Down Expand Up @@ -250,6 +225,7 @@ def internlm2_attention_forward(
past_key_value = (key_states, value_states) if use_cache else None

# IPEX-LLM OPT: sdp
attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
Expand Down

0 comments on commit 1b19454

Please sign in to comment.