diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 50abe9f9112f..e4344984fdbe 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -180,18 +180,18 @@ def llama_attention_forward_4_31( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len) - if attn_weights.size() != attn_weights_size: - invalidInputError(False, - f"Attention weights should be of size {attn_weights_size}, " - f"but is {attn_weights.size()}") + # attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len) + # if attn_weights.size() != attn_weights_size: + # invalidInputError(False, + # f"Attention weights should be of size {attn_weights_size}, " + # f"but is {attn_weights.size()}") if attention_mask is not None: - attn_mask_size = (bsz, 1, q_len, kv_seq_len) - if attention_mask.size() != attn_mask_size: - invalidInputError(False, - f"Attention mask should be of size {attn_mask_size}, " - f"but is {attention_mask.size()}") + # attn_mask_size = (bsz, 1, q_len, kv_seq_len) + # if attention_mask.size() != attn_mask_size: + # invalidInputError(False, + # f"Attention mask should be of size {attn_mask_size}, " + # f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask # upcast attention to fp32 @@ -200,10 +200,10 @@ def llama_attention_forward_4_31( attn_output = torch.matmul(attn_weights, value_states) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) - if attn_output.size() != attn_output_size: - invalidInputError(False, - f"`attn_output` should be of size {attn_output_size}," - f" but is {attn_output.size()}") + # if attn_output.size() != attn_output_size: + # invalidInputError(False, + # f"`attn_output` should be of size {attn_output_size}," + # f" but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)