Skip to content

Commit

Permalink
remove check
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Sep 27, 2023
1 parent 6d79c53 commit 3c8e52a
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,18 @@ def llama_attention_forward_4_31(
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len)
if attn_weights.size() != attn_weights_size:
invalidInputError(False,
f"Attention weights should be of size {attn_weights_size}, "
f"but is {attn_weights.size()}")
# attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len)
# if attn_weights.size() != attn_weights_size:
# invalidInputError(False,
# f"Attention weights should be of size {attn_weights_size}, "
# f"but is {attn_weights.size()}")

if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
# attn_mask_size = (bsz, 1, q_len, kv_seq_len)
# if attention_mask.size() != attn_mask_size:
# invalidInputError(False,
# f"Attention mask should be of size {attn_mask_size}, "
# f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
Expand All @@ -200,10 +200,10 @@ def llama_attention_forward_4_31(
attn_output = torch.matmul(attn_weights, value_states)

attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")
# if attn_output.size() != attn_output_size:
# invalidInputError(False,
# f"`attn_output` should be of size {attn_output_size},"
# f" but is {attn_output.size()}")

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand Down

0 comments on commit 3c8e52a

Please sign in to comment.