diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index 9696723f127..b310b1d277a 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -191,37 +191,66 @@ def qwen2_vision_attention_forward( ).permute(1, 0, 2, 3).unbind(0) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + # q, k, v: [seq_length, num_heads, head_dim] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) + seq_lens = cu_seqlens.tolist() + invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length, + "unexpected input") - if len(cu_seqlens) == 2 and cu_seqlens.tolist() == [0, seq_length]: - attention_mask = None + if use_sdp_non_causal(self.head_dim, q.device, q.dtype): + import xe_addons + image_num = len(seq_lens) - 1 + image_size = seq_lens[1] - seq_lens[0] + guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size, + dtype=cu_seqlens.dtype, device=cu_seqlens.device) + if (guessed_seq_lens == cu_seqlens).all(): + q = q.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = k.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + # q, k, v: [image_num, num_heads, image_size, head_dim] + + attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None) + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() + attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim) + # attn_output: [seq_length, num_heads, head_dim] + else: + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0).contiguous() + v = v.transpose(0, 1).unsqueeze(0).contiguous() + # q, k, v: [1, num_heads, seq_length, head_dim] + + attn_outputs = [] + for i in range(image_num): + start_idx = seq_lens[i] + end_idx = seq_lens[i + 1] + tmp_q = q[:, :, start_idx:end_idx, :] + tmp_k = k[:, :, start_idx:end_idx, :] + tmp_v = v[:, :, start_idx:end_idx, :] + attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None) + attn_output = attn_output.permute(0, 2, 1, 3) + # attn_output: [1, seq_length, num_heads, head_dim] + attn_outputs.append(attn_output) + attn_output = torch.cat(attn_outputs, dim=1).squeeze(0) + # attn_output: [seq_length, num_heads, head_dim] else: attention_mask = torch.full( [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = 0 + for i in range(1, len(seq_lens)): + attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + # q, k, v: [num_heads, seq_length, head_dim] - if use_sdp_non_causal(self.head_dim, q.device, q.dtype): - import xe_addons - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(0) - attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), attention_mask) - attn_output = attn_output.squeeze(0) - else: attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + attn_weights = attn_weights + attention_mask attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.transpose(0, 1) + # attn_output: [seq_length, num_heads, head_dim] + attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output