From 10e480ee96f1746392ef9baca93585435f6f9523 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 11 Jun 2024 14:19:19 +0800 Subject: [PATCH] refactor internlm and internlm2 (#11274) --- .../llm/src/ipex_llm/transformers/convert.py | 31 +-- .../ipex_llm/transformers/models/internlm.py | 255 ++++++++---------- 2 files changed, 125 insertions(+), 161 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 29e7c7cd0e2..55565a49bbb 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -719,6 +719,10 @@ def _optimize_pre(model): # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b from ipex_llm.transformers.models.stablelm import merge_qkv model.apply(merge_qkv) + # for internlm + if model.config.model_type == "internlm": + from ipex_llm.transformers.models.internlm import merge_qkv + model.apply(merge_qkv) # for internlm-xcomposer2-vl if model.config.model_type == "internlmxcomposer2": from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp @@ -1167,27 +1171,14 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.internlm import internlm_attention_forward + convert_forward(model, module.InternLMAttention, internlm_attention_forward) + convert_forward(model, module.InternLMRMSNorm, llama_rms_norm_forward) + elif model.config.model_type == "internlm2": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.internlm import internlm2_attention_forward - try: - convert_forward(model, - module.InternLM2Attention, - internlm2_attention_forward - ) - except: - convert_forward(model, - module.InternLMAttention, - internlm_attention_forward - ) - try: - convert_forward(model, - module.InternLM2RMSNorm, - llama_rms_norm_forward - ) - except: - convert_forward(model, - module.InternLMRMSNorm, - llama_rms_norm_forward - ) + convert_forward(model, module.InternLM2Attention, internlm2_attention_forward) + convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward) elif model.config.model_type == "internlmxcomposer2": modeling_module_name = model.model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 7c17a9b306f..7ec7059739b 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -42,20 +42,35 @@ import torch import torch.utils.checkpoint from torch import nn -from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ - append_kv_cache, is_enough_kv_cache_room_4_31 from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from einops import rearrange -import os -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) +def merge_qkv(module: torch.nn.Module): + if module.__class__.__name__ == "InternLMAttention": + new_weight = torch.cat([ + module.q_proj.weight.data, + module.k_proj.weight.data, + module.v_proj.weight.data, + ], dim=0) + new_bias = torch.cat([ + module.q_proj.bias.data, + module.k_proj.bias.data, + module.v_proj.bias.data, + ], dim=-1) + + qkv_proj = torch.nn.Linear(0, 0, bias=True) + qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) + qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) + qkv_proj.in_features = new_weight.size(1) + qkv_proj.out_features = new_weight.size(0) + module.qkv_proj = qkv_proj + + del module.q_proj, module.k_proj, module.v_proj def internlm_attention_forward( @@ -68,109 +83,69 @@ def internlm_attention_forward( use_cache: bool=False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - query_states = self.q_proj(hidden_states) \ - .view(bsz, q_len, self.num_heads, self.head_dim) \ - .transpose(1, 2) - key_states = self.k_proj(hidden_states) \ - .view(bsz, q_len, self.num_heads, self.head_dim) \ - .transpose(1, 2) - value_states = self.v_proj(hidden_states) \ - .view(bsz, q_len, self.num_heads, self.head_dim) \ - .transpose(1, 2) + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) kv_seq_len = key_states.shape[-2] - enough_kv_room = True if past_key_value is not None: - enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] + + # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "internlm") + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - "internlm") - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - cache_k = past_key_value[0] - cache_v = past_key_value[1] - if not enough_kv_room: - # allocate new - new_cache_k, new_cache_v = extend_kv_cache( - bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device - ) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) - - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache( - bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device + query_states, key_states, cos, sin, position_ids, "internlm" ) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states + # IPEX-LLM OPT: kv cache and quantzie kv cache + use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, hidden_states.device + ) past_key_value = (key_states, value_states) if use_cache else None - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # IPEX-LLM OPT: sdp + attn_weights = None + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import xe_addons + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import xe_addons + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, attention_mask) + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, " - f"but is {attn_output.size()}" - ) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -229,62 +204,60 @@ def internlm2_attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + + # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "internlm") + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # query_states, key_states = apply_rotary_pos_emb(query_states, - # key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - "internlm") - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + query_states, key_states, cos, sin, position_ids, "internlm" + ) + # IPEX-LLM OPT: kv cache and quantzie kv cache + use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, hidden_states.device + ) past_key_value = (key_states, value_states) if use_cache else None - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # IPEX-LLM OPT: sdp + attn_weights = None + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import xe_addons + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import xe_addons + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, attention_mask) + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, " - f"but is {attn_output.size()}" - ) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)