From 9942a4ba6920bc1cb6248b55afaf4ea44d7767c5 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Wed, 15 May 2024 18:07:00 +0800 Subject: [PATCH] [WIP] Support llama2 with transformers==4.38.0 (#11024) * support llama2 with transformers==4.38.0 * add supprot for quantize_qkv * add original support for 4.38.0 now * code style fix --- .../llm/src/ipex_llm/transformers/convert.py | 26 ++- .../src/ipex_llm/transformers/models/llama.py | 155 +++++++++++------- .../src/ipex_llm/transformers/models/utils.py | 6 + 3 files changed, 123 insertions(+), 64 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c1a3d5e0fc4..e58da1a3918 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -961,16 +961,24 @@ def _optimize_post(model, lightweight_bmm=False): llama_decoder_forward) if version.parse(trans_version) >= version.parse("4.36.0"): # transformers version >= 4.36.0 - from ipex_llm.transformers.models.llama import llama_attention_forward_4_36 + from ipex_llm.transformers.models.llama import llama_attention_forward_4_38 from ipex_llm.transformers.models.llama import llama_model_forward_4_36 - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_36, ) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, - llama_model_forward_4_36) + if version.parse(trans_version) >= version.parse("4.38.0"): + from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original + # Todo: support llama_model_forward with transformers version >= 4.38.0 + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_forward_4_38_original) + else: + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaModel, + llama_model_forward_4_36) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_forward_4_38) else: # transformers version between 4.31.0 - 4.35.2 convert_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 45c0c4f4b7b..92333c34ef7 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -333,6 +333,7 @@ def llama_attention_forward_4_31( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if use_quantize_kv_cache(self.q_proj, hidden_states): @@ -348,6 +349,7 @@ def llama_attention_forward_4_31( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + cache_position=cache_position, kwargs=kwargs ) @@ -361,6 +363,7 @@ def llama_attention_forward_4_31_quantized( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, hidden_size = hidden_states.size() @@ -437,7 +440,8 @@ def llama_attention_forward_4_31_quantized( repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output, attn_weights = native_sdp(query_states, repeated_key_states, - repeated_value_states, attention_mask, + repeated_value_states, + attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads, output_attentions) if use_cache: @@ -462,7 +466,7 @@ def llama_attention_forward_4_31_quantized( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, + attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads, output_attentions) else: @@ -498,6 +502,7 @@ def llama_attention_forward_4_31_original( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, hidden_size = hidden_states.size() @@ -683,7 +688,7 @@ def llama_attention_forward_4_31_original( value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, + attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) @@ -919,20 +924,21 @@ def llama_attention_selective_batching_forward_4_31( return attn_output.to(original_dtype), attn_weights, updated_past_key_values -def llama_attention_forward_4_36( +def llama_attention_forward_4_38( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_value: Optional[List[torch.FloatTensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: if use_quantize_kv_cache(self.q_proj, hidden_states): - forward_function = llama_attention_forward_4_36_quantized + forward_function = llama_attention_forward_4_38_quantized else: - forward_function = llama_attention_forward_4_36_original + forward_function = llama_attention_forward_4_38_original return forward_function( self=self, hidden_states=hidden_states, @@ -941,20 +947,22 @@ def llama_attention_forward_4_36( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, kwargs=kwargs ) -def llama_attention_forward_4_36_quantized( +def llama_attention_forward_4_38_quantized( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_value: Optional[List[torch.FloatTensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -1026,9 +1034,15 @@ def llama_attention_forward_4_36_quantized( "llama", rope_theta=rope_theta) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") + if cache_position is not None: + # for transformers 4.38.0 + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama2") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") kv_seq_len = key_states.shape[-2] if len(past_key_value.key_cache) <= self.layer_idx: @@ -1037,7 +1051,8 @@ def llama_attention_forward_4_36_quantized( if should_split_qkv_tensor(query_states, bsz, self.num_heads, q_len, kv_seq_len, output_attentions): attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, - repeated_value_states, attention_mask, + repeated_value_states, + attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads) else: @@ -1053,13 +1068,17 @@ def llama_attention_forward_4_36_quantized( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if cache_position is not None: + # for transformers 4.38.0 + causal_mask = attention_mask[:, :, cache_position, : kv_seq_len] + attn_weights = attn_weights + causal_mask + else: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask if kv_seq_len >= 2048 or bsz >= 64: # for memory considerations, do not upcast attention to fp32 @@ -1097,13 +1116,17 @@ def llama_attention_forward_4_36_quantized( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if cache_position is not None: + # for transformers 4.38.0 + causal_mask = attention_mask[:, :, cache_position, : kv_seq_len] + attn_weights = attn_weights + causal_mask + else: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask if kv_seq_len >= 2048 or bsz >= 64: # for memory considerations, do not upcast attention to fp32 @@ -1146,16 +1169,17 @@ def llama_attention_forward_4_36_quantized( return attn_output, attn_weights, past_key_value -def llama_attention_forward_4_36_original( +def llama_attention_forward_4_38_original( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_value: Optional[List[torch.FloatTensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -1293,9 +1317,15 @@ def llama_attention_forward_4_36_original( "llama", rope_theta=rope_theta) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") + if cache_position is not None: + # for transformers 4.38.0 + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama2") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") if past_key_value is not None: # update the number of seen tokens @@ -1335,8 +1365,13 @@ def llama_attention_forward_4_36_original( past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states + if cache_position is not None: + new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len] + else: + new_attention_mask = attention_mask + if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): + use_flash_attention(query_states, key_states, new_attention_mask): # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -1349,7 +1384,7 @@ def llama_attention_forward_4_36_original( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_q4_0 - attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: @@ -1359,7 +1394,7 @@ def llama_attention_forward_4_36_original( # otherwise, use native attention if query_states.device.type == "xpu": attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, + new_attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads, output_attentions) else: @@ -1369,16 +1404,16 @@ def llama_attention_forward_4_36_original( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=new_attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with # AttentionMaskConverter.to_causal_4d that # does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=self.is_causal and new_attention_mask is None and q_len > 1, ) else: attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, + new_attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads, output_attentions) @@ -1407,7 +1442,7 @@ def llama_attention_forward_4_36_original( return attn_output.to(original_dtype), attn_weights, past_key_value -def native_sdp(query, key, value, attention_mask, +def native_sdp(query, key, value, attention_mask, cache_position, bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions): if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions): return native_sdp_split_qkv_tensor(query, key, value, attention_mask, @@ -1423,12 +1458,17 @@ def native_sdp(query, key, value, attention_mask, f"but is {attn_weights.size()}") if attention_mask is not None: - attn_mask_size = (bsz, 1, q_len, kv_seq_len) - if attention_mask.size() != attn_mask_size: - invalidInputError(False, - f"Attention mask should be of size {attn_mask_size}, " - f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask + if cache_position is not None: + # for transformers 4.38.0 + causal_mask = attention_mask[:, :, cache_position, : kv_seq_len] + attn_weights = attn_weights + causal_mask + else: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask if kv_seq_len >= 2048 or bsz >= 64: # for memory considerations, do not upcast attention to fp32 @@ -1442,7 +1482,7 @@ def native_sdp(query, key, value, attention_mask, return attn_output, attn_weights -def native_sdp_split_qkv_tensor(query, key, value, attention_mask, +def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position, bsz, q_len, kv_seq_len, head_dim, num_heads): block_size = 8 query_split = torch.split(query.to(key.dtype), block_size, dim=1) @@ -1459,12 +1499,17 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, f"{attn_weights_split_size}, but is {attn_weights_split.size()}") if attention_mask is not None: - attn_mask_size = (bsz, 1, q_len, kv_seq_len) - if attention_mask.size() != attn_mask_size: - invalidInputError(False, - f"Attention mask should be of size {attn_mask_size}, " - f"but is {attention_mask.size()}") - attn_weights_split = attn_weights_split + attention_mask + if cache_position is not None: + # for transformers 4.38.0 + causal_mask = attention_mask[:, :, cache_position, : kv_seq_len] + attn_weights = attn_weights + causal_mask + else: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) attn_outputs.append(torch.matmul(attn_weights_split, v)) attn_output = torch.cat(attn_outputs, dim=1) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 6bf69b4cd26..b899370dfbe 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -178,6 +178,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + elif model_family == "llama2": + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed elif model_family == "gptj": q_embed = (q * cos) + (rotate_every_two(q) * sin) k_embed = (k * cos) + (rotate_every_two(k) * sin)