diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 4e7c3807a74..7e99e00742e 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -742,10 +742,15 @@ def _optimize_post(model, lightweight_bmm=False): if version.parse(trans_version) >= version.parse("4.36.0"): # transformers version >= 4.36.0 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36 + from bigdl.llm.transformers.models.llama import llama_model_forward_4_36 convert_forward( model, transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_forward_4_36, ) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaModel, + llama_model_forward_4_36) else: # transformers version between 4.31.0 - 4.35.2 convert_forward( diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 826cdb4cd4d..29d6fe358d8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -49,6 +49,7 @@ from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaModel from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError @@ -84,6 +85,37 @@ def get_ipex_version(): return _ipex_version +def llama_model_forward_4_36( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + from bigdl.llm.transformers.kv import DynamicFp8Cache + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + return LlamaModel.forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 @@ -906,6 +938,212 @@ def llama_attention_forward_4_36( output_attentions: bool = False, use_cache: bool = False, **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if use_quantize_kv_cache(self.q_proj, hidden_states): + forward_function = llama_attention_forward_4_36_quantized + else: + forward_function = llama_attention_forward_4_36_original + return forward_function( + self=self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + kwargs=kwargs + ) + + +def llama_attention_forward_4_36_quantized( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) + qtype = getattr(self.q_proj, "qtype", None) + qtype_check = qtype in [SYM_INT4, FP8E5] + no_tp = not self.config.pretraining_tp > 1 + decoding_fast_path = (no_tp and qtype_check and use_fuse_rope + and enough_kv_room and bsz * q_len == 1) + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + tmp_cache_k, tmp_cache_v = init_kv_cache( + bsz, + self.num_key_value_heads, + self.head_dim, + 0, + 1, + dtype=hidden_states.dtype, + device=device + ) + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + tmp_cache_k, tmp_cache_v, + self.q_proj.weight.qtype, + 0, + self.head_dim) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError( + False, + f"The cache structure has changed since version v4.36." + f" If you are using {self.__class__.__name__} " + f"for auto-regressive decoding with k/v caching," + f" please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") + kv_seq_len = key_states.shape[-2] + + if len(past_key_value.key_cache) <= self.layer_idx: + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + if use_cache: + cache_kwargs = None + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + else: + cache_kwargs = None # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + kv_seq_len = key_states.shape[-2] + if query_states.size(2) != 1 or query_states.device.type != 'xpu': + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + key_states = repeat_kv(key_states, self.num_key_value_groups)\ + .to(device, dtype=query_states.dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups)\ + .to(device, dtype=query_states.dtype) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + else: + import linear_q4_0 + attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) + attn_weights = attn_weights / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," + f" but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + + if query_states.size(2) != 1 or query_states.device.type != 'xpu': + attn_output = torch.matmul(attn_weights, value_states) + else: + import linear_q4_0 + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, + value_states.transpose(-1, -2)) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError( + False, + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}," + f" but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size + // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], + o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def llama_attention_forward_4_36_original( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn(