diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d0d990410bc6e..f95a1f488bcf7 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,7 +9,7 @@ from vllm._custom_C import paged_attention_custom from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random -NUM_BLOCKS = 1024 +NUM_BLOCKS = 1024 * 1024 PARTITION_SIZE = 256 @@ -176,7 +176,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=100, profile=False) + latency = run_benchmark(num_iters=1000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index 0401bd21f7784..e359c98fa4b19 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -27,6 +28,14 @@ typedef float16x4 _Half4; typedef struct _Half8 { _Half4 xy[2]; } _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + ////// Non temporal load stores /////// #if 1 @@ -118,6 +127,102 @@ __device__ __forceinline__ void store(T value, T* addr) { #endif +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) @@ -168,16 +273,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; - _Half8 Qlocal[QHLOOP]; + _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); constexpr int KHELOOP = HEAD_SIZE / x; - _Half8 Klocal[KHELOOP]; + _B16x8 Klocal[KHELOOP]; constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens - _Half8 Vlocal[VHELOOP][VTLOOP]; + _B16x8 Vlocal[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; #pragma unroll @@ -211,16 +316,28 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - + // fetch block number for q and k // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; - const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; #pragma unroll for (int h = 0; h < QHLOOP - 1; h++) { @@ -238,12 +355,12 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; - const _Half8* k_ptrh8 = reinterpret_cast(k_ptr); const int physical_block_offset = local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; @@ -260,21 +377,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; - // fetch vphysical block numbers - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - const int vblock_idx = warp_start_block_idx + b; - const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; - vphysical_blocks[b] = block_table[vblock_idx_ctx]; - } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _Half8* v_ptrh8 = reinterpret_cast(v_ptr); + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block #pragma unroll for (int b = 0; b < VBLOCKS; b++) { @@ -282,13 +386,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // kv_block_stride const int64_t vphysical_block_number = static_cast(vphysical_blocks[b]); - const _Half8* v_ptrh8b = + const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; - const _Half8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { @@ -299,71 +403,71 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #pragma unroll for (int h = 0; h < QHLOOP; h++) { - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[0].xy[0], dout[h], 4, 0, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[0].xy[1], dout[h], 4, 0, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[1].xy[0], dout[h], 4, 1, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[1].xy[1], dout[h], 4, 1, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[2].xy[0], dout[h], 4, 2, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[2].xy[1], dout[h], 4, 2, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[3].xy[0], dout[h], 4, 3, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[3].xy[1], dout[h], 4, 3, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[4].xy[0], dout[h], 4, 4, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[4].xy[1], dout[h], 4, 4, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[5].xy[0], dout[h], 4, 5, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[5].xy[1], dout[h], 4, 5, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[6].xy[0], dout[h], 4, 6, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[6].xy[1], dout[h], 4, 6, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[7].xy[0], dout[h], 4, 7, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[7].xy[1], dout[h], 4, 7, 0); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); if constexpr (KHELOOP > 8) { - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, - 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, - 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, - 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, - 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, - 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, - 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, - 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, - 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, - 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, - 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, - 15, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, - 15, 0); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); } // KHELOOP>8 dout[h] *= scale; } @@ -469,16 +573,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp - float16x4 logits[QHLOOP]; + _B16x4 logits[QHLOOP]; #pragma unroll for (int h = 0; h < QHLOOP; h++) { - #pragma unroll - for (int i = 0; i < 4; i++) { - logits[h][i] = (scalar_t)dout[h][i]; - } + logits[h] = from_floatx4(dout[h]); } - __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; if (warp_start_token_idx >= context_len) { // warp out of context #pragma unroll @@ -497,28 +598,39 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( for (int vh = 0; vh < VHELOOP; vh++) { floatx4 acc = {0}; // iterate over tokens - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][0].xy[0], acc, 4, 0, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][0].xy[1], acc, 4, 1, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][1].xy[0], acc, 4, 2, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][1].xy[1], acc, 4, 3, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][2].xy[0], acc, 4, 4, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][2].xy[1], acc, 4, 5, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][3].xy[0], acc, 4, 6, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][3].xy[1], acc, 4, 7, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][4].xy[0], acc, 4, 8, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][4].xy[1], acc, 4, 9, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][5].xy[0], acc, 4, 10, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][5].xy[1], acc, 4, 11, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][6].xy[0], acc, 4, 12, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][6].xy[1], acc, 4, 13, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][7].xy[0], acc, 4, 14, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][7].xy[1], acc, 4, 15, 0); - float16x4 tmp; - #pragma unroll - for (int i = 0; i < 4; i++) { - tmp[i] = (scalar_t)acc[i]; - } - vout_shared[qh][vh][laneid][warpid] = tmp; + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); } } } // warp in context @@ -526,7 +638,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __syncthreads(); if (warpid == 0) { - float16x4 vout[QHLOOP][VHELOOP]; + _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads scalar_t* out_ptr; int out_num_partitions; @@ -546,18 +658,18 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vout[qh][vh] = {0}; #pragma unroll for (int w = 0; w < NWARPS; w++) { - vout[qh][vh] += vout_shared[qh][vh][laneid][w]; + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); } const int head_size_elem = vh * WARP_SIZE + laneid; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); #pragma unroll for (int i = 0; i < 4; i++) { const int head_idx = 4 * qh + i; if (head_idx < GQA_RATIO) { - // out_ptr[(wg_start_head_idx + head_idx) * max_num_partitions * - // HEAD_SIZE + head_size_elem] = vout[qh][vh][i]; - out_ptr[(wg_start_head_idx + head_idx) * out_num_partitions * - HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; } } } @@ -663,9 +775,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; constexpr int MAX_NPAR = 64; scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; #pragma unroll for (int j = 0; j < MAX_NPAR; j++) { - tmps[j] = 0.0f; + tmps[j] = from_float(dzero); } const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; const int num_partition_offset = (num_partitions)*HEAD_SIZE; @@ -709,17 +822,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( float acc = 0.0f; #pragma unroll for (int j = 0; j < JCHUNK; j++) { - acc += tmps[j] * shared_exp_sums[j]; + acc += to_float(tmps[j]) * shared_exp_sums[j]; } if (num_partitions > JCHUNK) { #pragma unroll for (int j = JCHUNK; j < 2 * JCHUNK; j++) { - acc += tmps[j] * shared_exp_sums[j]; + acc += to_float(tmps[j]) * shared_exp_sums[j]; } if (num_partitions > 2 * JCHUNK) { #pragma unroll for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { - acc += tmps[j] * shared_exp_sums[j]; + acc += to_float(tmps[j]) * shared_exp_sums[j]; } } } @@ -738,17 +851,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #pragma unroll for (int j = 0; j < MAX_NPAR; j++) { - acc += tmps[j] * shared_exp_sums[j + MAX_NPAR]; + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; } } const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - // from_float(out_ptr[threadIdx.x], acc); scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = (scalar_t)acc; + out_ptr[threadIdx.x] = from_float(acc); } #else // !defined(__HIP__MI300__) TODO: Add NAVI support @@ -941,9 +1053,6 @@ void paged_attention_custom_launcher( #define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ switch (block_size) { \ - case 8: \ - CALL_CUSTOM_LAUNCHER(T, 8, HEAD_SIZE); \ - break; \ case 16: \ CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ break; \ @@ -989,9 +1098,12 @@ void paged_attention_custom( #endif const c10::optional& alibi_slopes, const std::string& kv_cache_dtype) { + assert(kv_cache_dtype == "auto"); const int head_size = query.size(2); if (query.dtype() == at::ScalarType::Half) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py index d9b53ed7bd0d9..6ecc348e017e9 100644 --- a/tests/kernels/test_attention_custom.py +++ b/tests/kernels/test_attention_custom.py @@ -4,35 +4,28 @@ import pytest import torch -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm._custom_C import paged_attention_custom -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import is_hip from .allclose_default import get_default_atol, get_default_rtol -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# This will change depending on the compute capability. -# - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = 32 * 1024 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. -NUM_BLOCKS = 4321 # Arbitrary values for testing +NUM_BLOCKS = 128 * 1024 + 4321 # Arbitrary values for testing PARTITION_SIZE = 256 -# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [torch.half, torch.bfloat16, torch.float - ] if not is_hip() else [torch.half] -NUM_GEN_SEQS = [1, 17, 64] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.half] +NUM_GEN_SEQS = [1, 17] # Arbitrary values for testing NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing -# FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -USE_ALIBI = [False, True] +HEAD_SIZES = [64, 128] +BLOCK_SIZES = [16, 32] +USE_ALIBI = [True, False] KV_CACHE_DTYPE = ["auto"] -SEEDS = [0] +SEEDS = [37] CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) ] @@ -255,14 +248,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) @@ -286,8 +279,14 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - atol, rtol = 1e-3, 1e-5 - atol = 5e-3 + atol, rtol = 1e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 2e-4, 1e-5 + if use_alibi: + if dtype == torch.half: + atol, rtol = 5e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-5 if kv_cache_dtype == "fp8": atol, rtol = 1e-2, 1e-5 assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index ee2e83f6b272c..d78aa975f61ff 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -117,8 +117,11 @@ def forward_decode( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape gqa_ratio = num_heads // num_kv_heads - use_custom = (custom_attn_available and query.dtype == torch.half - and head_size == 128 and block_size == 16 + use_custom = (custom_attn_available + and (query.dtype == torch.half + or query.dtype == torch.bfloat16) + and (head_size == 128 or head_size == 64) + and (block_size == 16 or block_size == 32) and kv_cache_dtype == "auto" and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)