From a4a758656aaf5fb1a52452f77c7ca2f836995465 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 16 Oct 2024 17:40:28 +0800 Subject: [PATCH] refactor gemma to reduce old fuse rope usage (#12215) --- .../llm/src/ipex_llm/transformers/convert.py | 34 +- .../src/ipex_llm/transformers/models/gemma.py | 454 ++++++------------ 2 files changed, 145 insertions(+), 343 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 28b927b2d5b..17c7978e9e9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1017,6 +1017,10 @@ def _optimize_pre(model, qtype=None): model.apply(pre_process_attn_and_mlp) if model.config.model_type == "internvl_chat": _optimize_pre(model.language_model, qtype=qtype) + if model.config.model_type == "gemma": + from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq + model.apply(merge_qkv) + model.apply(pre_compute_inv_freq) if model.config.model_type == "gemma2": from ipex_llm.transformers.models.gemma2 import merge_qkv model.apply(merge_qkv) @@ -1846,32 +1850,16 @@ def _optimize_post(model, lightweight_bmm=False): module.MistralMLP, llama_mlp_forward) elif model.config.model_type == "gemma": - invalidInputError(version.parse(trans_version) >= version.parse("4.38.0"), - "Please upgrade transformers to 4.38.0 or higher version " - "to run Mixtral models.") modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - if version.parse(trans_version) >= version.parse("4.39.0"): - from ipex_llm.transformers.models.gemma import gemma_attention_forward_4_39 - convert_forward(model, - module.GemmaAttention, - gemma_attention_forward_4_39 - ) - else: - from ipex_llm.transformers.models.gemma import gemma_attention_forward - convert_forward(model, - module.GemmaAttention, - gemma_attention_forward, - ) + from ipex_llm.transformers.models.gemma import gemma_model_forward + from ipex_llm.transformers.models.gemma import gemma_attention_forward from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward - from ipex_llm.transformers.models.gemma import gemma_mlp_forward - convert_forward(model, - module.GemmaRMSNorm, - gemma_rms_norm_forward) - convert_forward(model, - module.GemmaMLP, - gemma_mlp_forward) - + from ipex_llm.transformers.models.common import mlp_gelu_forward + convert_forward(model, module.GemmaModel, gemma_model_forward) + convert_forward(model, module.GemmaAttention, gemma_attention_forward) + convert_forward(model, module.GemmaRMSNorm, gemma_rms_norm_forward) + convert_forward(model, module.GemmaMLP, mlp_gelu_forward) elif model.config.model_type == "gemma2": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index f542a4290b3..1731266b27d 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -31,50 +31,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu -from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU -from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half -from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 -from ipex_llm.transformers.models.utils import use_decoding_fast_path +from ipex_llm.transformers.kv import DynamicNormalCache +from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax +from ipex_llm.transformers.models.utils import should_use_fuse_rope -import os +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.gemma.modeling_gemma import apply_rotary_pos_emb, repeat_kv +from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, GemmaAttention -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) +def merge_qkv(module: torch.nn.Module): + merge_qkv_base(module, GemmaAttention) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def should_use_fuse_rope(self, hidden_states, position_ids): - use_fuse_rope = hidden_states.device.type == "xpu" - use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad) - use_fuse_rope = use_fuse_rope and position_ids is not None - return use_fuse_rope +def pre_compute_inv_freq(module: torch.nn.Module): + if isinstance(module, GemmaRotaryEmbedding): + module.inv_freq = 1.0 / ( + module.base ** + (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim) + ) def gemma_rms_norm_forward(self, hidden_states): @@ -91,185 +72,113 @@ def gemma_rms_norm_forward(self, hidden_states): return (1 + self.weight) * hidden_states.to(input_dtype) -def gemma_mlp_forward( +def gemma_model_forward( self, - x: torch.Tensor, - residual=None -) -> torch.Tensor: - x_2d = x.view(-1, x.shape[-1]) - bsz, hidden_size = x_2d.shape - qtype = getattr(self.gate_proj, "qtype", None) - if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: - import xe_linear - if not x_2d.is_contiguous(): - x_2d = x_2d.contiguous() - out = self.down_proj(xe_linear.mlp_forward_xpu( - x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, - x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, - GELU, qtype - )) - else: - out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - if residual is not None: - return out + residual - else: - return out - - -def gemma_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor]=None, - position_ids: Optional[torch.LongTensor]=None, - past_key_value: Optional[Tuple[torch.Tensor]]=None, - output_attentions: bool=False, - use_cache: bool=False, - cache_position: Optional[torch.Tensor]=None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, hidden_size = hidden_states.size() - device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype - - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - enough_kv_room, - bsz * q_len) - - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - kv_seq_len = cache_k.shape[-2] - - import xe_linear - query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - position_ids, - cache_k, cache_v, - self.q_proj.weight.qtype, - self.v_proj.weight.qtype, - kv_seq_len, - self.head_dim) - kv_seq_len += 1 - - # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: - past_key_value.seen_tokens = kv_seq_len - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - - if past_key_value is not None: - if self.layer_idx is None: - invalidInputError(False, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} for " - "auto-regressive decodingwith k/v caching, please make sure " - "to initialize the attention class with a layer index.") - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - if use_fuse_rope: - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "gemma") - else: - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, None) - - if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, - key_states, value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # IPEX-LLM OPT start: kv cache and quantize kv cache + if use_cache and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + # IPEX-LLM OPT end + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + invalidInputError((input_ids is None) ^ (inputs_embeds is None), + "You cannot specify both input_ids and inputs_embeds at the same time, " + "and must specify either one") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # IPEX-LLM changes start: support both transformers 4.38.1 and 4.39 + try: + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + causal_mask = causal_mask[:, :, cache_position, :] + except TypeError as _e: + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + # IPEX-LLM changes end + + # embed positions + hidden_states = inputs_embeds + + # normalized + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) - if attention_mask is not None: # no matter the length, we just slice it - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] - else: - causal_mask = attention_mask - attn_weights = attn_weights + causal_mask + hidden_states = layer_outputs[0] - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + if output_attentions: + all_self_attns += (layer_outputs[1],) - attn_output = attn_output.transpose(1, 2).contiguous() + hidden_states = self.norm(hidden_states) - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - if not output_attentions: - attn_weights = None + next_cache = next_decoder_cache if use_cache else None - return attn_output.to(original_dtype), attn_weights, past_key_value + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) -def gemma_attention_forward_4_39( +def gemma_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]=None, @@ -279,111 +188,27 @@ def gemma_attention_forward_4_39( use_cache: bool=False, cache_position: Optional[torch.Tensor]=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, hidden_size = hidden_states.size() - device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype - - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - enough_kv_room, - bsz * q_len) - - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - kv_seq_len = cache_k.shape[-2] - - import xe_linear - query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - position_ids, - cache_k, cache_v, - self.q_proj.weight.qtype, - self.v_proj.weight.qtype, - kv_seq_len, - self.head_dim) - kv_seq_len += 1 - - # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: - past_key_value._seen_tokens = kv_seq_len - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) + + if should_use_fuse_rope(hidden_states, position_ids, self.training): + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - - if past_key_value is not None: - if self.layer_idx is None: - invalidInputError(False, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} for " - "auto-regressive decodingwith k/v caching, please make sure " - "to initialize the attention class with a layer index.") - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - if use_fuse_rope: - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "gemma") - else: - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, None) - - if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value._seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, - key_states, value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, None) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -391,26 +216,15 @@ def gemma_attention_forward_4_39( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - if cache_position is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - else: - causal_mask = attention_mask + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) + attn_weights = attention_softmax(attn_weights, self.training) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) @@ -419,4 +233,4 @@ def gemma_attention_forward_4_39( if not output_attentions: attn_weights = None - return attn_output.to(original_dtype), attn_weights, past_key_value + return attn_output, attn_weights, past_key_value