Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Jan 22, 2025
1 parent 6049ef4 commit 8477d17
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ def _attention_with_mask(
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
if current_platform.is_hpu():
from habana_frameworks.torch.hpex.kernels import FusedSDPA
output = FusedSDPA.apply(q, k, v, attention_mask, 0.0)
output = FusedSDPA.apply(q, k, v, attention_mask)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
Expand All @@ -919,7 +919,7 @@ def _attention_with_mask(
k,
v,
attn_mask=attention_mask,
dropout_p=0.0)
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
Expand Down Expand Up @@ -987,8 +987,10 @@ def forward(
# TODO: Change input_tokens tensor at the beginning of model execution
# to 2D tensor to align with public vllm input_tokens shape. But this
# will face the graph building failure issue, still need to investigate.
if len(hidden_states.shape) == 3:
full_text_row_masked_out_mask = full_text_row_masked_out_mask.view(
assert len(residual.shape) == 3
if len(hidden_states.shape)==2:
hidden_states = hidden_states.view(residual.size(0), residual.size(1), residual.size(2))
full_text_row_masked_out_mask = full_text_row_masked_out_mask.view(
hidden_states.size(0), -1, 1)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
Expand Down

0 comments on commit 8477d17

Please sign in to comment.