From 8477d175804ed947a4c50249ef3da692f3bed530 Mon Sep 17 00:00:00 2001 From: yan ma Date: Wed, 22 Jan 2025 18:54:25 +0800 Subject: [PATCH] fix --- vllm/model_executor/models/mllama.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6162433c5c547..f15318fbd98ff 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -910,7 +910,7 @@ def _attention_with_mask( attention_mask = attention_mask.view(1, 1, q_len, kv_len) if current_platform.is_hpu(): from habana_frameworks.torch.hpex.kernels import FusedSDPA - output = FusedSDPA.apply(q, k, v, attention_mask, 0.0) + output = FusedSDPA.apply(q, k, v, attention_mask) output = output.permute(2, 0, 1, 3).reshape( q_len, self.num_local_heads * self.head_dim) return output @@ -919,7 +919,7 @@ def _attention_with_mask( k, v, attn_mask=attention_mask, - dropout_p=0.0) + is_causal=False) output = output.permute(2, 0, 1, 3).reshape( q_len, self.num_local_heads * self.head_dim) return output @@ -987,8 +987,10 @@ def forward( # TODO: Change input_tokens tensor at the beginning of model execution # to 2D tensor to align with public vllm input_tokens shape. But this # will face the graph building failure issue, still need to investigate. - if len(hidden_states.shape) == 3: - full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( + assert len(residual.shape) == 3 + if len(hidden_states.shape)==2: + hidden_states = hidden_states.view(residual.size(0), residual.size(1), residual.size(2)) + full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( hidden_states.size(0), -1, 1) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh(