From efb04327a7bcf5c88bb939835632de6e123e3667 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:22:45 -0800 Subject: [PATCH] corrected types for strides in triton FA (#274) (#276) Co-authored-by: Aleksandr Malyshev (cherry picked from commit 9a46e97c1e63cbb5223a10a86705063b00e55576) --- vllm/attention/backends/rocm_flash_attn.py | 3 +- vllm/attention/ops/triton_flash_attention.py | 40 ++++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7d2d87176800c..e5df445d8449b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -619,7 +619,8 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if key is not None and value is not None: + if key is not None and value is not None \ + and attn_type != AttentionType.ENCODER_DECODER: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index f94211116a746..2019ed184e5a1 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -314,26 +314,26 @@ def attn_fwd( sm_scale, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, cu_seqlens_q, cu_seqlens_k, dropout_p,