From 08219dcedb2ffb2cfbd17fbe5acd79743653f866 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:37:35 +0000 Subject: [PATCH] rewrite hipified split-k decoder invocation to ck-tile style --- .../hip_decoder/attention_forward_splitk.cpp | 1080 ++--------------- .../ck_attention_forward_decoder_splitk.h | 919 +++++--------- 2 files changed, 440 insertions(+), 1559 deletions(-) diff --git a/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp b/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp index fd70436a36..2452204840 100644 --- a/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp +++ b/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp @@ -4,6 +4,10 @@ #include #include +#include +#include +#include + #include "ck_attention_forward_decoder_splitk.h" namespace { @@ -50,6 +54,40 @@ struct c10_to_data_t { namespace { +template +void instantiate_and_launch_kernels( + typename ck_tile::ForwardDecoderSplitKArgument arg, + dim3 attn_grid_size, + dim3 attn_block_size, + int32_t lds_bytes, + dim3 reduce_grid_size, + dim3 reduce_block_size, + hipStream_t stream) { + auto attn_kernel_impl = ck_tile::ForwardDecoderSplitKAttnKernelImpl< + ck_data_t, + vec_size, + kLoopUnroll, + kLoopUnrollTail, + kMaxKVSequenceLength, + compute_t>{}; + auto reduce_kernel_impl = ck_tile:: + ForwardDecoderSplitKReduceKernelImpl{}; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, /* benchmark */ false}, + ck_tile::make_kernel( + attn_kernel_impl, attn_grid_size, attn_block_size, lds_bytes, arg)); + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, /* benchmark */ false}, + ck_tile::make_kernel( + reduce_kernel_impl, + reduce_grid_size, + reduce_block_size, + 0 /* lds_bytes */, + arg)); +} + template < int32_t ThreadsPerWavefront, int32_t WavefrontsPerBlock> @@ -58,8 +96,8 @@ at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] at::optional seq_kv_lens, // [B] - double qk_scale, - int64_t split_k, + float qk_scale, + int32_t split_k, at::Tensor& split_max, at::Tensor& split_sumexp, at::Tensor& split_O, @@ -83,19 +121,24 @@ at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( auto M = XQ.size(1); auto G = XQ.size(2); auto H = XQ.size(3); + auto HDim = XQ.size(4); TORCH_CHECK(B <= 1024); TORCH_CHECK(M <= 1024); TORCH_CHECK(H <= 1024); - dim3 blocks(B * H * M * G, split_k); - dim3 threads(ThreadsPerWavefront, WavefrontsPerBlock); + const dim3 attn_grid_size(B * H * M * G, split_k); + const dim3 attn_block_size(ThreadsPerWavefront, WavefrontsPerBlock); + + const dim3 reduce_grid_size = {attn_grid_size.x}; + const dim3 reduce_block_size = {attn_block_size.x}; int32_t smem_softmax = kMaxKVSequenceLength * sizeof(compute_t) + WavefrontsPerBlock * sizeof(compute_t); int32_t smem_output = kMaxHeadDimension * sizeof(compute_t) * - threads.y; // 4 * threadsPerBlock * sizeof(float) == sizeof(O[b][0][h][:]) - const size_t lds_bytes = max(smem_softmax, smem_output); + WavefrontsPerBlock; // 4 * threadsPerBlock * sizeof(float) == + // sizeof(O[b][0][h][:]) + const size_t attn_lds_bytes = max(smem_softmax, smem_output); auto stream = at::cuda::getCurrentHIPStream().stream(); AT_DISPATCH_SWITCH_3( @@ -106,14 +149,6 @@ at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( "efficient_attention_forward_decoder_splitk_ck", [&] { using ck_data_t = c10_to_data_t::type; - using device_op_t = - ck::tensor_operation::device::FMHADecoderSplitKDeviceOp< - ck_data_t, - kMaxKVSequenceLength, - kLoopUnroll, - kLoopUnrollTail, - compute_t>; - auto op = device_op_t{}; auto XQ_acc = XQ.packed_accessor32(); @@ -136,7 +171,7 @@ at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( auto split_sumexp_acc = split_sumexp .packed_accessor32(); - auto arg = device_op_t::Argument( + auto arg = ck_tile::ForwardDecoderSplitKArgument{ reinterpret_cast(XQ_acc.data()), reinterpret_cast(K_acc.data()), reinterpret_cast(V_acc.data()), @@ -154,20 +189,59 @@ at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( K_acc.stride(2), K_acc.stride(3), split_O_acc.stride(0), - XQ_acc.size(1), - XQ_acc.size(2), - XQ_acc.size(3), - XQ_acc.size(4), - K_acc.size(1), + static_cast(XQ_acc.size(1)), + static_cast(XQ_acc.size(2)), + static_cast(XQ_acc.size(3)), + static_cast(XQ_acc.size(4)), + static_cast(K_acc.size(1)), K_acc.size(3) == 1, qk_scale, - split_k, - blocks, - threads, - lds_bytes); + split_k}; - auto invoker = device_op_t::Invoker{}; - (void)invoker.Run(&arg, {stream}); + auto required_vec_size = 0; + + for (auto vec_size : {4, 2, 1}) { + if (arg.Q_size_k <= vec_size * ThreadsPerWavefront) { + required_vec_size = vec_size; + } + } + + TORCH_CHECK(required_vec_size > 0); + + switch (required_vec_size) { + case 4: + instantiate_and_launch_kernels( + arg, + attn_grid_size, + attn_block_size, + attn_lds_bytes, + reduce_grid_size, + reduce_block_size, + stream); + break; + case 2: + instantiate_and_launch_kernels( + arg, + attn_grid_size, + attn_block_size, + attn_lds_bytes, + reduce_grid_size, + reduce_block_size, + stream); + break; + case 1: + instantiate_and_launch_kernels( + arg, + attn_grid_size, + attn_block_size, + attn_lds_bytes, + reduce_grid_size, + reduce_block_size, + stream); + break; + default: + break; + } }); return O; @@ -179,8 +253,8 @@ at::Tensor efficient_attention_forward_decoder_splitk_ck_impl( const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, H or 1, D] at::optional seq_kv_lens, // [B] - double qk_scale, - int64_t split_k) { + float qk_scale, + int32_t split_k) { auto O = at::empty_like(XQ); constexpr auto rank = 5; @@ -226,7 +300,12 @@ at::Tensor efficient_attention_forward_decoder_splitk_ck( return efficient_attention_forward_decoder_splitk_ck_impl< kThreadsPerWavefront, kWavefrontsPerBlock>( - XQ, cache_K, cache_V, seq_kv_lens, qk_scale, split_k); + XQ, + cache_K, + cache_V, + seq_kv_lens, + static_cast(qk_scale), + static_cast(split_k)); } } // namespace @@ -237,948 +316,5 @@ TORCH_LIBRARY_IMPL(xformers, CUDA, m) { TORCH_FN(efficient_attention_forward_decoder_splitk_ck)); } -#ifdef ATTN_FWD_SPLITK_DECODER_MAIN - -#include - -// clang-format off - -/* - -(1) hipify - > pip install -e /xformers - - For obtaining the executed build commands, add `--verbose`. - For efficient utilization of CPU cores for compilation use MAX_JOBS env variable. - -(2) compile - > mkdir build - > cd build - > cmake /xformers/xformers/csrc/attention/hip_fmha/ \ - -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_BUILD_TYPE=Debug \ - -D GPU_TARGETS="native" - > make - -(3a) run correctness check - > ./attention_forward_splitk_decoder_main - -(3b) run specific input shape - > ./attention_forward_splitk_decoder_main n_keys padding batch_size n_heads is_multiquery dtype n_wavefronts_per_block -*/ - -// clang-format on - -static std::tuple split_attention_torch( - const at::Tensor& Q, - const at::Tensor& K, - const at::Tensor& V, - const at::Tensor& k_seqlens, - const int32_t split_k, - const int32_t block_size) { - auto Q_scaled = at::div(Q, sqrt(Q.size(-1))); - - std::vector O_splits; - std::vector m_splits; - std::vector l_splits; - - for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { - std::vector O_batch; - std::vector m_batch; - std::vector l_batch; - - for (size_t b = 0; b < k_seqlens.numel(); ++b) { - auto seqlen = k_seqlens[b].item(); - const int64_t t_low = - split_idx * (seqlen / split_k / block_size) * block_size; - const int64_t t_high = (split_idx + 1 < split_k) - ? (1 + split_idx) * (seqlen / split_k / block_size) * block_size - : seqlen; - - const bool empty = t_low == t_high; - - auto S = at::einsum( - "mghk, nghk -> mghn", - {Q_scaled[b], - at::slice(K[b], /*dim*/ 0, /*start*/ t_low, /*end*/ t_high)}, - /* einsum eval path */ at::nullopt); - auto m = empty - ? at::empty_like(S) - : std::get<0>(at::max(S, /* dim */ -1, /* keepdim */ true)); - auto s = at::exp(at::sub(S, m)); - auto l = at::sum(s, /* dim */ -1, /* keepdim */ true); - auto O = at::einsum( - "mghn, nghk -> mghk", - {s, at::slice(V[b], /*dim*/ 0, /*start*/ t_low, /*end*/ t_high)}, - /* einsum eval path */ at::nullopt); - if (empty) { - m = at::empty_like(at::slice(O, -1, 0, 1)); - l = at::zeros_like(m); - m.fill_(ck::NumericLimits::Lowest()); - } - O_batch.push_back(O); - m_batch.push_back(m); - l_batch.push_back(l); - } - - auto O_cat = at::stack(O_batch); - auto m_cat = at::stack(m_batch); - auto l_cat = at::stack(l_batch); - - O_splits.push_back(O_cat); - m_splits.push_back(m_cat); - l_splits.push_back(l_cat); - } - - auto O_cat = at::stack(O_splits); - auto m_cat = at::transpose(at::stack(m_splits), 0, -1); - auto l_cat = at::transpose(at::stack(l_splits), 0, -1); - - return std::make_tuple(O_cat, m_cat, l_cat); -} - -static at::Tensor split_reduce_torch( - const at::Tensor& O_splits, - const at::Tensor& m_splits, - const at::Tensor& l_splits, - int32_t split_k) { - auto O = at::zeros_like(at::slice(O_splits, 0, 0, 1)); - auto global_max = - at::empty_like(at::slice(m_splits, -1, 0, 1)).fill_(-65535.); - auto global_sumexp = at::zeros_like(global_max); - - for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { - auto local_O = at::slice(O_splits, 0, split_idx, split_idx + 1); - auto local_max = at::slice(m_splits, -1, split_idx, split_idx + 1); - auto local_sumexp = at::slice(l_splits, -1, split_idx, split_idx + 1); - - auto log_alpha = at::neg(at::abs(at::sub(local_max, global_max))); - auto alpha = at::exp(log_alpha); - alpha.nan_to_num_(1.); - - auto pick_new = at::less(local_max, global_max); - auto pick_current_coef = at::where(pick_new, 1., alpha); - auto pick_new_coef = at::where(pick_new, alpha, 1.); - - O = at::add(at::mul(pick_current_coef, O), at::mul(pick_new_coef, local_O)); - global_sumexp = at::add( - at::mul(pick_current_coef, global_sumexp), - at::mul(pick_new_coef, local_sumexp)); - global_max = at::max(local_max, global_max); - } - - return at::div(O, global_sumexp); -} - -static at::Tensor efficient_attention_forward_decoder_splitk_torch( - const at::Tensor& XQ, // [B, 1, G, H, D] - const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] - const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] - at::optional seq_kv_lens, // [B] - double qk_scale, - int32_t split_k, - int32_t block_size) { - auto [O_split, m, l] = split_attention_torch( - XQ, cache_K, cache_V, *seq_kv_lens, split_k, block_size); - auto O = split_reduce_torch(O_split, m, l, split_k); - return O.reshape_as(XQ); -} - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct FMHADecoderSplitAttentionDeviceOp : public BaseOperator { - using DeviceOp = FMHADecoderSplitAttentionDeviceOp; - struct Argument : public BaseArgument { - const scalar_t* __restrict__ XQ; - const scalar_t* __restrict__ cache_K; - const scalar_t* __restrict__ cache_V; - scalar_t* __restrict__ O; - scalar_t* __restrict__ split_O; - compute_t* __restrict__ split_max; - compute_t* __restrict__ split_sumexp; - const int32_t* __restrict__ seq_kv_lens; - const ptrdiff_t XQ_stride_b; - const ptrdiff_t XQ_stride_m; - const ptrdiff_t XQ_stride_g; - const ptrdiff_t XQ_stride_h; - const ptrdiff_t K_stride_b; - const ptrdiff_t K_stride_m; - const ptrdiff_t K_stride_g; - const ptrdiff_t K_stride_h; - const ptrdiff_t O_stride_split; - const int32_t Q_size_m; - const int32_t Q_size_g; - const int32_t Q_size_h; - const int32_t Q_size_k; - const int32_t K_size_m; - const bool multiquery; - const float qk_scale; - const int32_t split_k; - - const dim3 grid_dim; - const dim3 block_dim; - const size_t lds_bytes; - - Argument( - const scalar_t* __restrict__ XQ, - const scalar_t* __restrict__ cache_K, - const scalar_t* __restrict__ cache_V, - scalar_t* __restrict__ O, - scalar_t* __restrict__ split_O, - compute_t* __restrict__ split_max, - compute_t* __restrict__ split_sumexp, - const int32_t* __restrict__ seq_kv_lens, - const ptrdiff_t XQ_stride_b, - const ptrdiff_t XQ_stride_m, - const ptrdiff_t XQ_stride_g, - const ptrdiff_t XQ_stride_h, - const ptrdiff_t K_stride_b, - const ptrdiff_t K_stride_m, - const ptrdiff_t K_stride_g, - const ptrdiff_t K_stride_h, - const ptrdiff_t O_stride_split, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const int32_t K_size_m, - const bool multiquery, - const float qk_scale, - const int32_t split_k, - // launch params - const dim3 grid_dim, - const dim3 block_dim, - const size_t lds_bytes) - : XQ(XQ), - cache_K(cache_K), - cache_V(cache_V), - O(O), - split_O(split_O), - split_max(split_max), - split_sumexp(split_sumexp), - seq_kv_lens(seq_kv_lens), - XQ_stride_b(XQ_stride_b), - XQ_stride_m(XQ_stride_m), - XQ_stride_g(XQ_stride_g), - XQ_stride_h(XQ_stride_h), - K_stride_b(K_stride_b), - K_stride_m(K_stride_m), - K_stride_g(K_stride_g), - K_stride_h(K_stride_h), - O_stride_split(O_stride_split), - Q_size_m(Q_size_m), - Q_size_g(Q_size_g), - Q_size_h(Q_size_h), - Q_size_k(Q_size_k), - K_size_m(K_size_m), - multiquery(multiquery), - qk_scale(qk_scale), - split_k(split_k), - // launch params - grid_dim(grid_dim), - block_dim(block_dim), - lds_bytes(lds_bytes) {} - - std::string str() const { - std::ostringstream oss; - oss << "Argument { " << std::endl - << " XQ: " << XQ << std::endl - << " cache_K: " << cache_K << std::endl - << " cache_V: " << cache_V << std::endl - << " O: " << O << std::endl - << " split_O: " << split_O << std::endl - << " split_max: " << split_max << std::endl - << " split_sumexp: " << split_sumexp << std::endl - << " seq_kv_lens: " << seq_kv_lens << std::endl - << " XQ_stride_b: " << XQ_stride_b << std::endl - << " XQ_stride_m: " << XQ_stride_m << std::endl - << " XQ_stride_g: " << XQ_stride_g << std::endl - << " XQ_stride_h: " << XQ_stride_h << std::endl - << " K_stride_b: " << K_stride_b << std::endl - << " K_stride_m: " << K_stride_m << std::endl - << " K_stride_g: " << K_stride_g << std::endl - << " K_stride_h: " << K_stride_h << std::endl - << " O_stride_split: " << O_stride_split << std::endl - << " Q_size_m: " << Q_size_m << std::endl - << " Q_size_g: " << Q_size_g << std::endl - << " Q_size_h: " << Q_size_h << std::endl - << " Q_size_k: " << Q_size_k << std::endl - << " K_size_m: " << K_size_m << std::endl - << " multiquery: " << multiquery << std::endl - << " qk_scale: " << qk_scale << std::endl - << " split_k: " << split_k << std::endl - << std::endl - << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." - << grid_dim.z << std::endl - << " block_dim: " << block_dim.x << "." << block_dim.y << "." - << block_dim.z << std::endl - << " lds_bytes: " << lds_bytes << std::endl - << "}"; - return oss.str(); - } - }; - - struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run( - const Argument& arg, - const StreamConfig& stream_config = StreamConfig{}) { - auto threads_per_wavefront = arg.block_dim.x; - auto Q_size_k_alignment_necessary = 0; - - for (auto vec_size : {4, 2, 1}) { - if (arg.Q_size_k <= vec_size * threads_per_wavefront) { - Q_size_k_alignment_necessary = vec_size; - } - } - - if (!Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported Q_size_k"); - } - - if (arg.Q_size_k % Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported alignment for Q_size_k"); - } - - float split_attention_result = launch_and_time_kernel( - stream_config, - Q_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - 4, - kLoopUnroll, - kLoopUnrollTail, - kMaxKVSequenceLength, - compute_t> - : Q_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - 2, - kLoopUnroll, - kLoopUnrollTail, - kMaxKVSequenceLength, - compute_t> - : Q_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - 1, - kLoopUnroll, - kLoopUnrollTail, - kMaxKVSequenceLength, - compute_t> - : nullptr, - arg.grid_dim, - arg.block_dim, - arg.lds_bytes, - arg.XQ, - arg.cache_K, - arg.cache_V, - arg.split_O, - arg.split_max, - arg.split_sumexp, - arg.seq_kv_lens, - arg.XQ_stride_b, - arg.XQ_stride_m, - arg.XQ_stride_g, - arg.XQ_stride_h, - arg.K_stride_b, - arg.K_stride_m, - arg.K_stride_g, - arg.K_stride_h, - arg.O_stride_split, - arg.Q_size_m, - arg.Q_size_g, - arg.Q_size_h, - arg.Q_size_k, - arg.K_size_m, - arg.multiquery, - arg.qk_scale, - arg.split_k); - - return split_attention_result; - } - }; -}; - -template -struct FMHADecoderSplitReduceDeviceOp : public BaseOperator { - using DeviceOp = FMHADecoderSplitReduceDeviceOp; - struct Argument : public BaseArgument { - const scalar_t* __restrict__ split_O; - const compute_t* __restrict__ split_max; - const compute_t* __restrict__ split_sumexp; - scalar_t* __restrict__ O; - - const int32_t O_size_m; - const int32_t O_size_g; - const int32_t O_size_h; - const int32_t O_size_k; - - const ptrdiff_t O_stride_split; - const ptrdiff_t O_stride_b; - const ptrdiff_t O_stride_m; - const ptrdiff_t O_stride_g; - const ptrdiff_t O_stride_h; - - const int32_t split_k; - - const dim3 grid_dim; - const dim3 block_dim; - const size_t lds_bytes; - - Argument( - const scalar_t* __restrict__ split_O, - const compute_t* __restrict__ split_max, - const compute_t* __restrict__ split_sumexp, - scalar_t* __restrict__ O, - const int32_t O_size_m, - const int32_t O_size_g, - const int32_t O_size_h, - const int32_t O_size_k, - const ptrdiff_t O_stride_split, - const ptrdiff_t O_stride_b, - const ptrdiff_t O_stride_m, - const ptrdiff_t O_stride_g, - const ptrdiff_t O_stride_h, - const int32_t split_k, - // launch params - const dim3 grid_dim, - const dim3 block_dim, - const size_t lds_bytes) - : split_O(split_O), - split_max(split_max), - split_sumexp(split_sumexp), - O(O), - O_size_m(O_size_m), - O_size_g(O_size_g), - O_size_h(O_size_h), - O_size_k(O_size_k), - O_stride_split(O_stride_split), - O_stride_b(O_stride_b), - O_stride_m(O_stride_m), - O_stride_g(O_stride_g), - O_stride_h(O_stride_h), - split_k(split_k), - // launch params - grid_dim(grid_dim), - block_dim(block_dim), - lds_bytes(lds_bytes) {} - - std::string str() const { - std::ostringstream oss; - oss << "Argument { " << std::endl - << " O: " << O << std::endl - << " split_O: " << split_O << std::endl - << " split_max: " << split_max << std::endl - << " split_sumexp: " << split_sumexp << std::endl - << " O_stride_b: " << O_stride_b << std::endl - << " O_stride_m: " << O_stride_m << std::endl - << " O_stride_g: " << O_stride_g << std::endl - << " O_stride_h: " << O_stride_h << std::endl - << " O_stride_split: " << O_stride_split << std::endl - << " O_size_m: " << O_size_m << std::endl - << " O_size_g: " << O_size_g << std::endl - << " O_size_h: " << O_size_h << std::endl - << " O_size_k: " << O_size_k << std::endl - << " split_k: " << split_k << std::endl - << std::endl - << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." - << grid_dim.z << std::endl - << " block_dim: " << block_dim.x << "." << block_dim.y << "." - << block_dim.z << std::endl - << " lds_bytes: " << lds_bytes << std::endl - << "}"; - return oss.str(); - } - }; - - struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run( - const Argument& arg, - const StreamConfig& stream_config = StreamConfig{}) { - auto threads_per_wavefront = arg.block_dim.x; - auto O_size_k_alignment_necessary = 0; - - for (auto vec_size : {4, 2, 1}) { - if (arg.O_size_k <= vec_size * threads_per_wavefront) { - O_size_k_alignment_necessary = vec_size; - } - } - - if (!O_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported O_size_k"); - } - - if (arg.O_size_k % O_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported alignment for O_size_k"); - } - - const dim3 reduce_gridsize = {arg.grid_dim.x}; - const dim3 reduce_blocksize = {arg.block_dim.x}; - constexpr int32_t reduce_lds_bytes = 0; - float reduce_result = launch_and_time_kernel( - stream_config, - O_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 4> - : O_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 2> - : O_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 1> - : nullptr, - reduce_gridsize, - reduce_blocksize, - reduce_lds_bytes, - arg.split_O, - arg.split_max, - arg.split_sumexp, - arg.O, - arg.O_size_m, - arg.O_size_g, - arg.O_size_h, - arg.O_size_k, - arg.O_stride_split, - arg.O_stride_b, - arg.O_stride_m, - arg.O_stride_g, - arg.O_stride_h, - arg.split_k); - return reduce_result; - } - }; -}; -} // namespace device -} // namespace tensor_operation -} // namespace ck - -static std::tuple split_attention_hip( - const at::Tensor& XQ, - const at::Tensor& K, - const at::Tensor& V, - const at::Tensor& seqlen, - const int32_t split_k, - const int32_t wavefronts_per_block) { - at::OptionalDeviceGuard guard(XQ.device()); - - auto B = XQ.size(0); - auto M = XQ.size(1); - auto G = XQ.size(2); - auto H = XQ.size(3); - auto D = XQ.size(4); - - double qk_scale = 1. / sqrt(D); - - auto O = at::empty_like(XQ); - constexpr auto rank = 5; - auto split_O = at::zeros({split_k, B, M, G, H, D}, XQ.options()); - auto split_max = - at::empty({B, M, G, H, split_k}, XQ.options().dtype(at::kFloat)) - .fill_(ck::NumericLimits::Lowest()); - auto split_sumexp = at::zeros_like(split_max); - - dim3 blocks(B * H * M * G, split_k); - dim3 threads(kThreadsPerWavefront, wavefronts_per_block); - - int32_t smem_softmax = - kMaxKVSequenceLength * sizeof(float) + threads.y * sizeof(float); - int32_t smem_output = kMaxHeadDimension * sizeof(float) * - wavefronts_per_block; // 4 * threadsPerBlock * sizeof(float) == - // sizeof(O[b][0][h][:]) - const size_t lds_bytes = max(smem_softmax, smem_output); - auto stream = at::cuda::getCurrentHIPStream().stream(); - - AT_DISPATCH_SWITCH_3( - at::ScalarType::Half, - at::ScalarType::BFloat16, - at::ScalarType::Float, - XQ.scalar_type(), - "efficient_attention_forward_decoder_split_attention_ck_test", - [&] { - using ck_data_t = c10_to_data_t::type; - using device_op_t = - ck::tensor_operation::device::FMHADecoderSplitAttentionDeviceOp< - ck_data_t>; - auto op = device_op_t{}; - - auto XQ_acc = - XQ.packed_accessor32(); - auto K_acc = - K.packed_accessor64(); - auto V_acc = - V.packed_accessor64(); - auto split_O_acc = - split_O - .packed_accessor32(); - auto O_acc = - O.packed_accessor32(); - auto seq_acc = - seqlen.packed_accessor32(); - auto split_max_acc = - split_max.packed_accessor32(); - auto split_sumexp_acc = - split_sumexp - .packed_accessor32(); - auto arg = device_op_t::Argument( - reinterpret_cast(XQ_acc.data()), - reinterpret_cast(K_acc.data()), - reinterpret_cast(V_acc.data()), - reinterpret_cast(O_acc.data()), - reinterpret_cast(split_O_acc.data()), - split_max_acc.data(), - split_sumexp_acc.data(), - seq_acc.data(), - XQ_acc.stride(0), - XQ_acc.stride(1), - XQ_acc.stride(2), - XQ_acc.stride(3), - K_acc.stride(0), - K_acc.stride(1), - K_acc.stride(2), - K_acc.stride(3), - split_O_acc.stride(0), - XQ_acc.size(1), - XQ_acc.size(2), - XQ_acc.size(3), - XQ_acc.size(4), - K_acc.size(1), - K_acc.size(3) == 1, - qk_scale, - split_k, - blocks, - threads, - lds_bytes); - - auto invoker = device_op_t::Invoker{}; - (void)invoker.Run(arg, {stream}); - }); - return std::make_tuple(split_O, split_max, split_sumexp); -} - -static at::Tensor split_reduce_hip( - const at::Tensor& split_O, - const at::Tensor& split_max, - const at::Tensor& split_sumexp, - const int32_t split_k) { - at::OptionalDeviceGuard guard(split_O.device()); - - auto B = split_O.size(1); - auto M = split_O.size(2); - auto G = split_O.size(3); - auto H = split_O.size(4); - auto D = split_O.size(5); - - TORCH_CHECK_EQ(split_k, split_O.size(0)); - TORCH_CHECK_EQ(split_k, split_max.size(-1)); - TORCH_CHECK_EQ(split_k, split_sumexp.size(-1)); - - constexpr auto rank = 5; - - TORCH_CHECK_EQ(split_O.dim(), 1 + rank); - TORCH_CHECK_EQ(split_max.dim(), rank); - TORCH_CHECK_EQ(split_sumexp.dim(), rank); - - auto O = at::zeros({B, M, G, H, D}, split_O.options()); - - auto stream = at::cuda::getCurrentHIPStream().stream(); - auto lds_bytes = 0; - - dim3 blocks(B * H * M * G); - dim3 threads(kThreadsPerWavefront); - - AT_DISPATCH_SWITCH_3( - at::ScalarType::Half, - at::ScalarType::BFloat16, - at::ScalarType::Float, - O.scalar_type(), - "efficient_attention_forward_decoder_split_reduce_ck_test", - [&] { - using ck_data_t = c10_to_data_t::type; - using device_op_t = - ck::tensor_operation::device::FMHADecoderSplitReduceDeviceOp< - ck_data_t>; - auto op = device_op_t{}; - - auto split_O_acc = - split_O - .packed_accessor32(); - auto O_acc = - O.packed_accessor32(); - auto split_max_acc = - split_max.packed_accessor32(); - auto split_sumexp_acc = - split_sumexp - .packed_accessor32(); - auto arg = device_op_t::Argument( - reinterpret_cast(split_O_acc.data()), - split_max_acc.data(), - split_sumexp_acc.data(), - reinterpret_cast(O_acc.data()), - O_acc.size(1), - O_acc.size(2), - O_acc.size(3), - O_acc.size(4), - split_O_acc.stride(0), - O_acc.stride(0), - O_acc.stride(1), - O_acc.stride(2), - O_acc.stride(3), - split_k, - blocks, - threads, - lds_bytes); - - auto invoker = device_op_t::Invoker{}; - (void)invoker.Run(arg, {stream}); - }); - return O; -} - -std::tuple generate_inputs( - const int32_t padding, - const int32_t B, - const int32_t Hq, - const int32_t Hkv, - const decltype(torch::kFloat32) dtype = torch::kFloat32) { - const int32_t D = 4 * kThreadsPerWavefront; - const int32_t G = Hq / Hkv; - const int32_t num_queries = 1; - - at::manual_seed(1); - - auto options = torch::TensorOptions() - .dtype(dtype) - .layout(torch::kStrided) - .device(torch::kCUDA, 1) - .requires_grad(false); - auto int_options = options.dtype(torch::kInt); - auto XQ = at::randn({B, num_queries, G, Hq, D}, options); - auto K = (G == 1) ? at::randn({B, padding, G, Hkv, D}, options) - : at::randn({B, padding, G, 1, D}, options) - .expand({B, padding, G, Hq, D}); - auto V = at::randn_like(K); - auto seqlen = at::randint(num_queries, padding + 1, {B}, int_options); - - return std::make_tuple(XQ, K, V, seqlen); -} - -static float percent_mismatch(const at::Tensor& a, const at::Tensor& b) { - auto mask = - at::isclose(a, b, /*atol*/ 1e-3, /*rtol*/ 1e-5, /*equal_nan*/ false); - auto percent_match = at::sum(mask.to(torch::kFloat32)) / mask.numel(); - return 1. - percent_match.item(); -} - -static void test_split_attention( - int32_t padding, - int32_t batch_size, - int32_t Hq, - int32_t Hkv, - int32_t split_k) { - auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); - - auto [O_ref, m_ref, l_ref] = split_attention_torch( - XQ, K, V, seqlen, split_k, /* block_size */ kWavefrontsPerBlock * 16); - - auto [O_hip, m_hip, l_hip] = - split_attention_hip(XQ, K, V, seqlen, split_k, kWavefrontsPerBlock); - - auto O_percent_mismatch = percent_mismatch(O_ref, O_hip); - auto m_percent_mismatch = percent_mismatch(m_ref, m_hip); - auto l_percent_mismatch = percent_mismatch(l_ref, l_hip); - - printf( - "[Test split attention] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched split_O " - "elements percentage: %.2f Mismatched split_max elements percentage: %.2f Mismatched " - "split_sumexp elements percentage: %.2f\n", - padding, - batch_size, - Hq, - Hkv, - split_k, - O_percent_mismatch, - m_percent_mismatch, - l_percent_mismatch); -} - -static void test_split_reduce( - int32_t padding, - int32_t batch_size, - int32_t Hq, - int32_t Hkv, - int32_t split_k) { - auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); - - auto [O_ref, m_ref, l_ref] = - split_attention_hip(XQ, K, V, seqlen, split_k, kWavefrontsPerBlock); - - auto O_torch = split_reduce_torch( - O_ref, m_ref.unsqueeze(0), l_ref.unsqueeze(0), split_k); - auto O_hip = split_reduce_hip(O_ref, m_ref, l_ref, split_k); - - auto hip_torch_mismatch = percent_mismatch(O_hip, O_torch); - printf( - "[Test split reduce] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched elements " - "percentage: %.2f \n", - padding, - batch_size, - Hq, - Hkv, - split_k, - hip_torch_mismatch); -} - -static void test_splitk_decoder_e2e_correctness( - int32_t padding, - int32_t batch_size, - int32_t Hq, - int32_t Hkv, - int32_t split_k) { - auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); - - double qk_scale = 1. / sqrt(XQ.size(-1)); - - auto result = efficient_attention_forward_decoder_splitk_ck_impl< - kThreadsPerWavefront, - kWavefrontsPerBlock>(XQ, K, V, seqlen, qk_scale, split_k); - auto gold_result = efficient_attention_forward_decoder_splitk_torch( - XQ, K, V, seqlen, qk_scale, /* split_k */ 1, /* block_size */ 1); - auto e2e_mismatch = percent_mismatch(result, gold_result); - printf( - "[Test e2e split-k decoder] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched " - "elements percentage: %.2f\n", - padding, - batch_size, - Hq, - Hkv, - split_k, - e2e_mismatch); -} - -int main(int argc, char** argv) { - if (argc == 1) { - for (auto padding : {32, 4096}) { - for (auto batch_size : {1, 8}) { - for (auto Hq : {16}) { - for (auto Hkv : {16}) { - for (auto split_k : {1, 2, 4, 8, 16}) { - test_splitk_decoder_e2e_correctness( - padding, batch_size, Hq, Hkv, split_k); - } - } - } - } - } - - for (auto padding : {32, 4096}) { - for (auto batch_size : {1, 8}) { - for (auto Hq : {16}) { - for (auto Hkv : {16}) { - for (auto split_k : {1, 2, 4, 8, 16}) { - test_split_attention(padding, batch_size, Hq, Hkv, split_k); - } - } - } - } - } - - for (auto padding : {32, 4096}) { - for (auto batch_size : {1, 8}) { - for (auto Hq : {16}) { - for (auto Hkv : {16}) { - for (auto split_k : {1, 2}) { - test_split_reduce(padding, batch_size, Hq, Hkv, split_k); - } - } - } - } - } - } else { - const auto args = std::vector(argv + 1, argv + argc); - if (args.size() != 6) { - std::cout << "Usage: ./a.out padding batch_size nq_heads nkv_heads dtype " - "n_wavefronts_per_block" - << std::endl; - return 0; - } - const int32_t padding = std::stoi(args[0]); - const int32_t batch_size = std::stoi(args[1]); - const int32_t nq_heads = std::stoi(args[2]); - const int32_t nkv_heads = std::stoi(args[3]); - const auto dtype = (args[4] == "f32") ? torch::kFloat32 - : (args[4] == "f16") ? torch::kFloat16 - : torch::kBFloat16; - const int32_t n_wavefronts_per_block = std::stoi(args[5]); - - auto [Q, K, V, seq] = - generate_inputs(padding, batch_size, nq_heads, nkv_heads, dtype); - auto O = at::empty_like(Q); - - constexpr auto splitk_dim = 0; - constexpr auto split_k = 1; - auto O_splits = at::stack(O, splitk_dim); - - auto split_max = at::empty( - {batch_size, padding, Q.size(2), Q.size(3), split_k}, - Q.options().dtype(at::kFloat)); - auto split_sumexp = at::empty_like(split_max); - - const double qk_scale = 1. / sqrt(Q.size(-1)); - auto call_ptr = - decltype(&efficient_attention_forward_decoder_splitk_ck_out_impl< - kThreadsPerWavefront, - kWavefrontsPerBlock>){}; - -#define SWITCH_CASE_SET_CALLPTR(n) \ - case (n): \ - call_ptr = &efficient_attention_forward_decoder_splitk_ck_out_impl< \ - kThreadsPerWavefront, \ - (n)>; \ - break; - - switch (n_wavefronts_per_block) { - SWITCH_CASE_SET_CALLPTR(1); - SWITCH_CASE_SET_CALLPTR(2); - SWITCH_CASE_SET_CALLPTR(4); - SWITCH_CASE_SET_CALLPTR(8); - SWITCH_CASE_SET_CALLPTR(16); - - default: - call_ptr = nullptr; - break; - } -#undef SWITCH_CASE_SET_CALLPTR - - if (call_ptr) { - call_ptr( - Q, - K, - V, - seq, - qk_scale, - split_k, - split_max, - split_sumexp, - O_splits, - O); - } else { - std::cout << "Warning: no kernel was found for wavefronts_per_block=" - << n_wavefronts_per_block << std::endl; - } - } - return 0; -} - -#endif // MAIN - #undef AT_DISPATCH_CASE_3 #undef AT_DISPATCH_SWITCH_3 diff --git a/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h b/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h index e4d575a588..5389affacc 100644 --- a/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h +++ b/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h @@ -1,8 +1,5 @@ #pragma once -#include -#include -#include #include #include @@ -58,98 +55,125 @@ __forceinline__ __device__ void store_v( *(reinterpret_cast(data_ptr) + vector_offset) = value; } +} // namespace + +namespace ck_tile { +template +struct ForwardDecoderSplitKArgument { + const scalar_t* __restrict__ XQ; + const scalar_t* __restrict__ cache_K; + const scalar_t* __restrict__ cache_V; + scalar_t* __restrict__ O; + scalar_t* __restrict__ split_O; + compute_t* __restrict__ split_max; + compute_t* __restrict__ split_sumexp; + const int32_t* __restrict__ seq_kv_lens; + const ptrdiff_t XQ_stride_b; + const ptrdiff_t XQ_stride_m; + const ptrdiff_t XQ_stride_g; + const ptrdiff_t XQ_stride_h; + const ptrdiff_t K_stride_b; + const ptrdiff_t K_stride_m; + const ptrdiff_t K_stride_g; + const ptrdiff_t K_stride_h; + const ptrdiff_t O_stride_split; + const int32_t Q_size_m; + const int32_t Q_size_g; + const int32_t Q_size_h; + const int32_t Q_size_k; + const int32_t K_size_m; + const bool multiquery; + const float qk_scale; + const int32_t split_k; +}; + template -__global__ void efficient_attention_forward_decoder_splitk_reduce_ck_kernel( - const scalar_t* __restrict__ O_splits, - const compute_t* __restrict__ split_max, - const compute_t* __restrict__ split_sumexp, - scalar_t* __restrict__ O, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const ptrdiff_t O_stride_split, - const ptrdiff_t O_stride_b, - const ptrdiff_t O_stride_m, - const ptrdiff_t O_stride_g, - const ptrdiff_t O_stride_h, - const int32_t split_k) { - // Each block handles a single batch and head and query and group - const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); - const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; - const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; - const int32_t h = blockIdx.x % Q_size_h; - - using data_t = scalar_t; - using data_vec_t = typename ck::vector_type::type; - using compute_vec_t = typename ck::vector_type::type; +struct ForwardDecoderSplitKReduceKernelImpl { + CK_TILE_DEVICE void operator()( + ForwardDecoderSplitKArgument arg) { + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (arg.Q_size_m * arg.Q_size_g * arg.Q_size_h); + const int32_t m = + (blockIdx.x / (arg.Q_size_g * arg.Q_size_h)) % arg.Q_size_m; + const int32_t g = (blockIdx.x / arg.Q_size_h) % arg.Q_size_g; + const int32_t h = blockIdx.x % arg.Q_size_h; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_vec_t = typename ck::vector_type::type; - union { - data_vec_t vec; - data_t arr[vec_size]; - } O_split_data; - union { - compute_vec_t vec; - compute_t arr[vec_size]; - } O_split_compute; - union { - data_vec_t vec; - data_t arr[vec_size]; - } global_O_data; - union { - compute_vec_t vec; - compute_t arr[vec_size]; - } global_O_compute; + union { + data_vec_t vec; + data_t arr[vec_size]; + } O_split_data; + union { + compute_vec_t vec; + compute_t arr[vec_size]; + } O_split_compute; + union { + data_vec_t vec; + data_t arr[vec_size]; + } global_O_data; + union { + compute_vec_t vec; + compute_t arr[vec_size]; + } global_O_compute; - global_O_compute.vec = 0; + global_O_compute.vec = 0; - const int32_t lane_idx = threadIdx.x; - const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; + const int32_t lane_idx = threadIdx.x; + const bool lane_active_for_io = lane_idx * vec_size < arg.Q_size_k; - if (!lane_active_for_io) { - return; - } + if (!lane_active_for_io) { + return; + } - compute_t global_sumexp = 0; - compute_t global_max = ck::NumericLimits::Lowest(); + compute_t global_sumexp = 0; + compute_t global_max = ck::NumericLimits::Lowest(); - for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { - load_v( - O_splits + b * O_stride_b + m * O_stride_m + g * O_stride_g + - h * O_stride_h + split_idx * O_stride_split, - lane_idx, - &O_split_data.vec); + for (int32_t split_idx = 0; split_idx < arg.split_k; ++split_idx) { + load_v( + arg.split_O + b * arg.XQ_stride_b + m * arg.XQ_stride_m + + g * arg.XQ_stride_g + h * arg.XQ_stride_h + + split_idx * arg.O_stride_split, + lane_idx, + &O_split_data.vec); #pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - O_split_compute.arr[i] = ck::type_convert(O_split_data.arr[i]); + for (int32_t i = 0; i < vec_size; ++i) { + O_split_compute.arr[i] = + ck::type_convert(O_split_data.arr[i]); + } + compute_t local_max = + *(arg.split_max + blockIdx.x * arg.split_k + split_idx); + compute_t local_sumexp = + *(arg.split_sumexp + blockIdx.x * arg.split_k + split_idx); + + compute_t log_alpha = -std::abs(local_max - global_max); + compute_t alpha = + isnan(log_alpha) ? compute_t{1.} : ck::math::exp(log_alpha); + + bool pick_new = local_max < global_max; + compute_t pick_current_coef = pick_new ? 1. : alpha; + compute_t pick_new_coef = pick_new ? alpha : 1.; + + global_sumexp = + pick_current_coef * global_sumexp + pick_new_coef * local_sumexp; + global_O_compute.vec = pick_current_coef * global_O_compute.vec + + pick_new_coef * O_split_compute.vec; + global_max = ck::math::max(local_max, global_max); } - compute_t local_max = *(split_max + blockIdx.x * split_k + split_idx); - compute_t local_sumexp = *(split_sumexp + blockIdx.x * split_k + split_idx); - - compute_t log_alpha = -std::abs(local_max - global_max); - compute_t alpha = - isnan(log_alpha) ? compute_t{1.} : ck::math::exp(log_alpha); - - bool pick_new = local_max < global_max; - compute_t pick_current_coef = pick_new ? 1. : alpha; - compute_t pick_new_coef = pick_new ? alpha : 1.; - - global_sumexp = - pick_current_coef * global_sumexp + pick_new_coef * local_sumexp; - global_O_compute.vec = pick_current_coef * global_O_compute.vec + - pick_new_coef * O_split_compute.vec; - global_max = ck::math::max(local_max, global_max); - } - global_O_compute.vec /= global_sumexp; + global_O_compute.vec /= global_sumexp; #pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - global_O_data.arr[i] = ck::type_convert(global_O_compute.arr[i]); + for (int32_t i = 0; i < vec_size; ++i) { + global_O_data.arr[i] = ck::type_convert(global_O_compute.arr[i]); + } + store_v( + arg.O + b * arg.XQ_stride_b + m * arg.XQ_stride_m + + g * arg.XQ_stride_g + h * arg.XQ_stride_h, + lane_idx, + global_O_data.vec); } - store_v( - O + b * O_stride_b + m * O_stride_m + g * O_stride_g + h * O_stride_h, - lane_idx, - global_O_data.vec); -} +}; template < typename scalar_t, @@ -158,556 +182,277 @@ template < int32_t n_loop_unroll_tail, int32_t KV_M_MAX, typename compute_t> -__global__ void efficient_attention_forward_decoder_splitk_ck_kernel( - const scalar_t* __restrict__ XQ, - const scalar_t* __restrict__ cache_K, - const scalar_t* __restrict__ cache_V, - scalar_t* __restrict__ O_splits, - compute_t* __restrict__ split_max, - compute_t* __restrict__ split_sumexp, - const int32_t* __restrict__ seq_kv_lens, - const ptrdiff_t XQ_stride_b, - const ptrdiff_t XQ_stride_m, - const ptrdiff_t XQ_stride_g, - const ptrdiff_t XQ_stride_h, - const ptrdiff_t K_stride_b, - const ptrdiff_t K_stride_m, - const ptrdiff_t K_stride_g, - const ptrdiff_t K_stride_h, - const ptrdiff_t O_stride_split, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const int32_t K_size_m, - const bool multiquery, - const float qk_scale, - const int32_t split_k) { - static_assert( - n_loop_unroll_tail < n_loop_unroll || n_loop_unroll_tail == 1, - "tail unroll must be smaller than main loop untoll; pragma unroll 0 is illegal " - "(and tail is no-op)"); - - // Each block handles a single batch and head and query and group - const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); - const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; - const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; - const int32_t h = blockIdx.x % Q_size_h; - const int32_t split_idx = blockIdx.y; - - // Note: this is decoding case where we attend to current and all previous - // tokens. - const int32_t t_max = seq_kv_lens ? seq_kv_lens[b] : K_size_m; - - const int32_t lane_idx = threadIdx.x; - const int32_t wavefront_idx = threadIdx.y; - // TODO: `threads_per_wavefront` and `wavefronts_per_block` may be compile - // time constants; investigate when optimizing - const int32_t threads_per_wavefront = blockDim.x; - const int32_t wavefronts_per_block = blockDim.y; - const int32_t threads_per_block = - threads_per_wavefront * wavefronts_per_block; - const int32_t thread_linear_idx = - lane_idx + wavefront_idx * threads_per_wavefront; - // const auto* q_ = &(XQ_acc[b][m][g][h][0]); - const auto XQO_base_offset = - b * XQ_stride_b + m * XQ_stride_m + g * XQ_stride_g + h * XQ_stride_h; - const auto* __restrict__ q_ = XQ + XQO_base_offset; - - const auto cache_KV_base_offset = b * K_stride_b + 0 * K_stride_m + - g * K_stride_g + (multiquery ? 0 : h * K_stride_h); - const auto* __restrict__ cache_K_base = cache_K + cache_KV_base_offset; - const auto* __restrict__ cache_V_base = cache_V + cache_KV_base_offset; - - using data_t = scalar_t; - using data_vec_t = typename ck::vector_type::type; - using compute_vec_t = typename ck::vector_type::type; - - const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; - - extern __shared__ __align__(16) compute_t smem[]; - - data_vec_t q_thread = 0; - // Load Q into registers in all wavefronts. - // Each thread handles `vec_size` D dimensions - if (lane_active_for_io) { - load_v(q_, lane_idx, &q_thread); - } - - compute_t max_qk_acc = ck::NumericLimits::Lowest(); - - // Compute S[0:t_max] = - // ``` - // for t in range(t_max): - // S[t] = dot(Q, K[t]) - // ``` - // Split the 0:t_max range across wavefronts in a block, - // unroll loads to expose more parallelism. - // Reduce the dot product with cross-lane operation; - // Q and K[t] are in the registers of threads in a single wavefront. - - data_vec_t k_loads[n_loop_unroll] = {}; - - const auto dtt = wavefronts_per_block * n_loop_unroll; - // only last split gets the tail. - // the first (split_k - 1) splits have a number of iterations divisible by - // `dtt` - const auto n_unrolled_loops = t_max / dtt / split_k; // +1? - const int32_t tt_low = - wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * split_idx; - const int32_t tt_high = - wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * (split_idx + 1); - const int32_t dtt_tail = wavefronts_per_block * n_loop_unroll_tail; - const int32_t tt_tail_low = wavefront_idx * n_loop_unroll_tail + - n_unrolled_loops * dtt * (split_idx + 1); - const int32_t tt_tail_high = (split_idx == split_k - 1) ? t_max : tt_tail_low; - - for (auto tt = tt_low; tt < tt_high; tt += dtt) { +struct ForwardDecoderSplitKAttnKernelImpl { + CK_TILE_DEVICE void operator()( + ForwardDecoderSplitKArgument arg) { + static_assert( + n_loop_unroll_tail < n_loop_unroll || n_loop_unroll_tail == 1, + "tail unroll must be smaller than main loop untoll; pragma unroll 0 is illegal " + "(and tail is no-op)"); + + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (arg.Q_size_m * arg.Q_size_g * arg.Q_size_h); + const int32_t m = + (blockIdx.x / (arg.Q_size_g * arg.Q_size_h)) % arg.Q_size_m; + const int32_t g = (blockIdx.x / arg.Q_size_h) % arg.Q_size_g; + const int32_t h = blockIdx.x % arg.Q_size_h; + const int32_t split_idx = blockIdx.y; + + // Note: this is decoding case where we attend to current and all previous + // tokens. + const int32_t t_max = arg.seq_kv_lens ? arg.seq_kv_lens[b] : arg.K_size_m; + + const int32_t lane_idx = threadIdx.x; + const int32_t wavefront_idx = threadIdx.y; + // TODO: `threads_per_wavefront` and `wavefronts_per_block` may be compile + // time constants; investigate when optimizing + const int32_t threads_per_wavefront = blockDim.x; + const int32_t wavefronts_per_block = blockDim.y; + const int32_t threads_per_block = + threads_per_wavefront * wavefronts_per_block; + const int32_t thread_linear_idx = + lane_idx + wavefront_idx * threads_per_wavefront; + // const auto* q_ = &(XQ_acc[b][m][g][h][0]); + const auto XQO_base_offset = b * arg.XQ_stride_b + m * arg.XQ_stride_m + + g * arg.XQ_stride_g + h * arg.XQ_stride_h; + const auto* __restrict__ q_ = arg.XQ + XQO_base_offset; + + const auto cache_KV_base_offset = b * arg.K_stride_b + 0 * arg.K_stride_m + + g * arg.K_stride_g + (arg.multiquery ? 0 : h * arg.K_stride_h); + const auto* __restrict__ cache_K_base = arg.cache_K + cache_KV_base_offset; + const auto* __restrict__ cache_V_base = arg.cache_V + cache_KV_base_offset; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_vec_t = typename ck::vector_type::type; + + const bool lane_active_for_io = lane_idx * vec_size < arg.Q_size_k; + + extern __shared__ __align__(16) compute_t smem[]; + + data_vec_t q_thread = 0; + // Load Q into registers in all wavefronts. + // Each thread handles `vec_size` D dimensions if (lane_active_for_io) { -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - const int32_t t = tt + ttt; - // load the K[b][t][g][h|0][:] row into registers - load_v( - cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); - } + load_v(q_, lane_idx, &q_thread); } -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - compute_t qk_acc = 0; - ck::inner_product( - q_thread, k_loads[ttt], qk_acc); - qk_acc *= qk_scale; - - qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); - max_qk_acc = ck::math::max(qk_acc, max_qk_acc); - if (lane_idx == 0) { - smem[tt + ttt - n_unrolled_loops * dtt * split_idx] = qk_acc; - } - } - } - for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { - if (lane_active_for_io) { -#pragma unroll n_loop_unroll_tail - for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { - const int32_t t = tt + ttt; - if (t < t_max) { + compute_t max_qk_acc = ck::NumericLimits::Lowest(); + + // Compute S[0:t_max] = + // ``` + // for t in range(t_max): + // S[t] = dot(Q, K[t]) + // ``` + // Split the 0:t_max range across wavefronts in a block, + // unroll loads to expose more parallelism. + // Reduce the dot product with cross-lane operation; + // Q and K[t] are in the registers of threads in a single wavefront. + + data_vec_t k_loads[n_loop_unroll] = {}; + + const auto dtt = wavefronts_per_block * n_loop_unroll; + // only last split gets the tail. + // the first (split_k - 1) splits have a number of iterations divisible by + // `dtt` + const auto n_unrolled_loops = t_max / dtt / arg.split_k; // +1? + const int32_t tt_low = + wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * split_idx; + const int32_t tt_high = wavefront_idx * n_loop_unroll + + n_unrolled_loops * dtt * (split_idx + 1); + const int32_t dtt_tail = wavefronts_per_block * n_loop_unroll_tail; + const int32_t tt_tail_low = wavefront_idx * n_loop_unroll_tail + + n_unrolled_loops * dtt * (split_idx + 1); + const int32_t tt_tail_high = + (split_idx == arg.split_k - 1) ? t_max : tt_tail_low; + + for (auto tt = tt_low; tt < tt_high; tt += dtt) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; // load the K[b][t][g][h|0][:] row into registers load_v( - cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + cache_K_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); } } - } -#pragma unroll n_loop_unroll_tail - for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { - compute_t qk_acc = 0; - const int32_t t = tt + ttt; - if (t < t_max) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + compute_t qk_acc = 0; ck::inner_product( q_thread, k_loads[ttt], qk_acc); - qk_acc *= qk_scale; + qk_acc *= arg.qk_scale; qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); max_qk_acc = ck::math::max(qk_acc, max_qk_acc); - - // write accumulated sums to smem. if (lane_idx == 0) { - smem[t - n_unrolled_loops * dtt * split_idx] = qk_acc; + smem[tt + ttt - n_unrolled_loops * dtt * split_idx] = qk_acc; } } } - } - - // Use shared reduction to compute max and compute softmax on shared memory. - // write max acc - if (lane_idx == 0) { - smem[KV_M_MAX + wavefront_idx] = max_qk_acc; - } - __syncthreads(); - if (lane_idx < wavefronts_per_block) { - max_qk_acc = ck::math::max(max_qk_acc, smem[KV_M_MAX + lane_idx]); - } - // shared across all threads in block - max_qk_acc = - wavefrontReduce(max_qk_acc, [](auto a, auto b) { return a > b ? a : b; }); - - if (wavefront_idx == 0 && lane_idx == 0) { - split_max[blockIdx.x * split_k + split_idx] = max_qk_acc; - } - // each wavefront computes partial sum of exp. - { // softmax reduce begin - compute_t softmax_denominator = 0.0f; - const int32_t t_low = n_unrolled_loops * dtt * split_idx; - const int32_t t_high = (split_idx + 1 < split_k) - ? n_unrolled_loops * dtt * (split_idx + 1) - : t_max; - for (int32_t t = t_low + thread_linear_idx; t < t_high; - t += threads_per_block) { - const auto s = ck::math::exp(smem[t - t_low] - max_qk_acc); - softmax_denominator += s; - smem[t - t_low] = s; + for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); + } + } + } +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + compute_t qk_acc = 0; + const int32_t t = tt + ttt; + if (t < t_max) { + ck::inner_product( + q_thread, k_loads[ttt], qk_acc); + qk_acc *= arg.qk_scale; + + qk_acc = + wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_acc, max_qk_acc); + + // write accumulated sums to smem. + if (lane_idx == 0) { + smem[t - n_unrolled_loops * dtt * split_idx] = qk_acc; + } + } + } } - softmax_denominator = wavefrontReduce( - softmax_denominator, [](auto a, auto b) { return a + b; }); + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc if (lane_idx == 0) { - smem[KV_M_MAX + wavefront_idx] = softmax_denominator; + smem[KV_M_MAX + wavefront_idx] = max_qk_acc; } __syncthreads(); - - // now, compute sum of exp(x - max(x)) over all intermediate results. - softmax_denominator = 0.0; if (lane_idx < wavefronts_per_block) { - softmax_denominator = smem[KV_M_MAX + lane_idx]; + max_qk_acc = ck::math::max(max_qk_acc, smem[KV_M_MAX + lane_idx]); } - softmax_denominator = wavefrontReduce( - softmax_denominator, [](auto a, auto b) { return a + b; }); + // shared across all threads in block + max_qk_acc = wavefrontReduce( + max_qk_acc, [](auto a, auto b) { return a > b ? a : b; }); if (wavefront_idx == 0 && lane_idx == 0) { - split_sumexp[blockIdx.x * split_k + split_idx] = softmax_denominator; + arg.split_max[blockIdx.x * arg.split_k + split_idx] = max_qk_acc; } - } // softmax reduce end - // Split T across wavefronts in a block - // each wavefront compute sum(t_subset) P[t] * V[t_subset, d] - // outputs are of size float[D] + // each wavefront computes partial sum of exp. + { // softmax reduce begin + compute_t softmax_denominator = 0.0f; + const int32_t t_low = n_unrolled_loops * dtt * split_idx; + const int32_t t_high = (split_idx + 1 < arg.split_k) + ? n_unrolled_loops * dtt * (split_idx + 1) + : t_max; + for (int32_t t = t_low + thread_linear_idx; t < t_high; + t += threads_per_block) { + const auto s = ck::math::exp(smem[t - t_low] - max_qk_acc); + softmax_denominator += s; + smem[t - t_low] = s; + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); - compute_t ps[n_loop_unroll] = {}; - compute_vec_t o_acc = 0; - if (lane_active_for_io) { - for (auto tt = tt_low; tt < tt_high; tt += dtt) { -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - const int32_t t = tt + ttt; - // load the V[b][t][g][h|0][:] row into registers, reusing K register - // storage - load_v( - cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); - ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = softmax_denominator; } + __syncthreads(); -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - o_acc = - scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + // now, compute sum of exp(x - max(x)) over all intermediate results. + softmax_denominator = 0.0; + if (lane_idx < wavefronts_per_block) { + softmax_denominator = smem[KV_M_MAX + lane_idx]; } - } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); - for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { -#pragma unroll n_loop_unroll_tail - for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { - const int32_t t = tt + ttt; - if (t < t_max) { + if (wavefront_idx == 0 && lane_idx == 0) { + arg.split_sumexp[blockIdx.x * arg.split_k + split_idx] = + softmax_denominator; + } + } // softmax reduce end + + // Split T across wavefronts in a block + // each wavefront compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + compute_t ps[n_loop_unroll] = {}; + compute_vec_t o_acc = 0; + if (lane_active_for_io) { + for (auto tt = tt_low; tt < tt_high; tt += dtt) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; // load the V[b][t][g][h|0][:] row into registers, reusing K register // storage load_v( - cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + cache_V_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + } + +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { o_acc = scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); } } - } - } - __syncthreads(); - - // NB: needs sizeof(smem) >= `vec_size` * (sizeof(float)==4) * threadsPerBlock - if (lane_active_for_io) { - store_v(&smem[0], thread_linear_idx, o_acc); - } - __syncthreads(); - // sum up partial D rows from other wavefronts - if (wavefront_idx == 0 && lane_active_for_io) { - union { - compute_vec_t vec = 0; - compute_t arr[vec_size]; - } r; - for (int32_t w = 0; w < wavefronts_per_block; ++w) { - compute_vec_t partial_r; - load_v( - smem, w * threads_per_wavefront + lane_idx, &partial_r); - r.vec += partial_r; - } - // elementwise convert from compute_t result to data_t out to be written - union { - data_vec_t vec; - data_t arr[vec_size]; - } bf_r; -#pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - bf_r.arr[i] = ck::type_convert(r.arr[i]); + for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the V[b][t][g][h|0][:] row into registers, reusing K + // register storage + load_v( + cache_V_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + o_acc = scalar_scale_acc( + o_acc, k_loads[ttt], ps[ttt]); + } + } + } } - // write output row O[b][m][g][h][:] - data_t* __restrict__ o_ = - O_splits + XQO_base_offset + split_idx * O_stride_split; - store_v(o_, lane_idx, bf_r.vec); - } -} - -} // namespace + __syncthreads(); -namespace ck { -namespace tensor_operation { -namespace device { -template < - typename scalar_t, - int32_t KV_M_MAX, - int32_t n_loop_unroll, - int32_t n_loop_unroll_tail, - typename compute_t> -struct FMHADecoderSplitKDeviceOp : public BaseOperator { - using DeviceOp = FMHADecoderSplitKDeviceOp; - struct Argument : public BaseArgument { - const scalar_t* __restrict__ XQ; - const scalar_t* __restrict__ cache_K; - const scalar_t* __restrict__ cache_V; - scalar_t* __restrict__ O; - scalar_t* __restrict__ split_O; - compute_t* __restrict__ split_max; - compute_t* __restrict__ split_sumexp; - const int32_t* __restrict__ seq_kv_lens; - const ptrdiff_t XQ_stride_b; - const ptrdiff_t XQ_stride_m; - const ptrdiff_t XQ_stride_g; - const ptrdiff_t XQ_stride_h; - const ptrdiff_t K_stride_b; - const ptrdiff_t K_stride_m; - const ptrdiff_t K_stride_g; - const ptrdiff_t K_stride_h; - const ptrdiff_t O_stride_split; - const int32_t Q_size_m; - const int32_t Q_size_g; - const int32_t Q_size_h; - const int32_t Q_size_k; - const int32_t K_size_m; - const bool multiquery; - const float qk_scale; - const int32_t split_k; - - const dim3 grid_dim; - const dim3 block_dim; - const size_t lds_bytes; - - Argument( - const scalar_t* __restrict__ XQ, - const scalar_t* __restrict__ cache_K, - const scalar_t* __restrict__ cache_V, - scalar_t* __restrict__ O, - scalar_t* __restrict__ split_O, - compute_t* __restrict__ split_max, - compute_t* __restrict__ split_sumexp, - const int32_t* __restrict__ seq_kv_lens, - const ptrdiff_t XQ_stride_b, - const ptrdiff_t XQ_stride_m, - const ptrdiff_t XQ_stride_g, - const ptrdiff_t XQ_stride_h, - const ptrdiff_t K_stride_b, - const ptrdiff_t K_stride_m, - const ptrdiff_t K_stride_g, - const ptrdiff_t K_stride_h, - const ptrdiff_t O_stride_split, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const int32_t K_size_m, - const bool multiquery, - const float qk_scale, - const int32_t split_k, - // launch params - const dim3 grid_dim, - const dim3 block_dim, - const size_t lds_bytes) - : XQ(XQ), - cache_K(cache_K), - cache_V(cache_V), - O(O), - split_O(split_O), - split_max(split_max), - split_sumexp(split_sumexp), - seq_kv_lens(seq_kv_lens), - XQ_stride_b(XQ_stride_b), - XQ_stride_m(XQ_stride_m), - XQ_stride_g(XQ_stride_g), - XQ_stride_h(XQ_stride_h), - K_stride_b(K_stride_b), - K_stride_m(K_stride_m), - K_stride_g(K_stride_g), - K_stride_h(K_stride_h), - O_stride_split(O_stride_split), - Q_size_m(Q_size_m), - Q_size_g(Q_size_g), - Q_size_h(Q_size_h), - Q_size_k(Q_size_k), - K_size_m(K_size_m), - multiquery(multiquery), - qk_scale(qk_scale), - split_k(split_k), - // launch params - grid_dim(grid_dim), - block_dim(block_dim), - lds_bytes(lds_bytes) {} - - std::string str() const { - std::ostringstream oss; - oss << "Argument { " << std::endl - << " XQ: " << XQ << std::endl - << " cache_K: " << cache_K << std::endl - << " cache_V: " << cache_V << std::endl - << " O: " << O << std::endl - << " split_O: " << split_O << std::endl - << " split_max: " << split_max << std::endl - << " split_sumexp: " << split_sumexp << std::endl - << " seq_kv_lens: " << seq_kv_lens << std::endl - << " XQ_stride_b: " << XQ_stride_b << std::endl - << " XQ_stride_m: " << XQ_stride_m << std::endl - << " XQ_stride_g: " << XQ_stride_g << std::endl - << " XQ_stride_h: " << XQ_stride_h << std::endl - << " K_stride_b: " << K_stride_b << std::endl - << " K_stride_m: " << K_stride_m << std::endl - << " K_stride_g: " << K_stride_g << std::endl - << " K_stride_h: " << K_stride_h << std::endl - << " O_stride_split: " << O_stride_split << std::endl - << " Q_size_m: " << Q_size_m << std::endl - << " Q_size_g: " << Q_size_g << std::endl - << " Q_size_h: " << Q_size_h << std::endl - << " Q_size_k: " << Q_size_k << std::endl - << " K_size_m: " << K_size_m << std::endl - << " multiquery: " << multiquery << std::endl - << " qk_scale: " << qk_scale << std::endl - << " split_k: " << split_k << std::endl - << std::endl - << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." - << grid_dim.z << std::endl - << " block_dim: " << block_dim.x << "." << block_dim.y << "." - << block_dim.z << std::endl - << " lds_bytes: " << lds_bytes << std::endl - << "}"; - return oss.str(); + // NB: needs sizeof(smem) >= `vec_size` * (sizeof(float)==4) * + // threadsPerBlock + if (lane_active_for_io) { + store_v(&smem[0], thread_linear_idx, o_acc); } - }; - - struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run( - const BaseArgument* argp_, - const StreamConfig& stream_config = StreamConfig{}) { - const Argument* argp = dynamic_cast(argp_); - - auto threads_per_wavefront = argp->block_dim.x; - auto Q_size_k_alignment_necessary = 0; - - for (auto vec_size : {4, 2, 1}) { - if (argp->Q_size_k <= vec_size * threads_per_wavefront) { - Q_size_k_alignment_necessary = vec_size; - } - } - if (!Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported Q_size_k"); + __syncthreads(); + // sum up partial D rows from other wavefronts + if (wavefront_idx == 0 && lane_active_for_io) { + union { + compute_vec_t vec = 0; + compute_t arr[vec_size]; + } r; + for (int32_t w = 0; w < wavefronts_per_block; ++w) { + compute_vec_t partial_r; + load_v( + smem, w * threads_per_wavefront + lane_idx, &partial_r); + r.vec += partial_r; } - - if (argp->Q_size_k % Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported alignment for Q_size_k"); + // elementwise convert from compute_t result to data_t out to be written + union { + data_vec_t vec; + data_t arr[vec_size]; + } bf_r; +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + bf_r.arr[i] = ck::type_convert(r.arr[i]); } - - float split_attention_result = launch_and_time_kernel( - stream_config, - Q_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - /* vec_size */ 4, - n_loop_unroll, - n_loop_unroll_tail, - KV_M_MAX, - compute_t> - : Q_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - /* vec_size */ 2, - n_loop_unroll, - n_loop_unroll_tail, - KV_M_MAX, - compute_t> - : Q_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - /* vec_size */ 1, - n_loop_unroll, - n_loop_unroll_tail, - KV_M_MAX, - compute_t> - : nullptr, - argp->grid_dim, - argp->block_dim, - argp->lds_bytes, - argp->XQ, - argp->cache_K, - argp->cache_V, - argp->split_O, - argp->split_max, - argp->split_sumexp, - argp->seq_kv_lens, - argp->XQ_stride_b, - argp->XQ_stride_m, - argp->XQ_stride_g, - argp->XQ_stride_h, - argp->K_stride_b, - argp->K_stride_m, - argp->K_stride_g, - argp->K_stride_h, - argp->O_stride_split, - argp->Q_size_m, - argp->Q_size_g, - argp->Q_size_h, - argp->Q_size_k, - argp->K_size_m, - argp->multiquery, - argp->qk_scale, - argp->split_k); - - const dim3 reduce_gridsize = {argp->grid_dim.x}; - const dim3 reduce_blocksize = {argp->block_dim.x}; - constexpr int32_t reduce_lds_bytes = 0; - float reduce_result = launch_and_time_kernel( - stream_config, - Q_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 4> - : Q_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 2> - : Q_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 1> - : nullptr, - reduce_gridsize, - reduce_blocksize, - reduce_lds_bytes, - argp->split_O, - argp->split_max, - argp->split_sumexp, - argp->O, - argp->Q_size_m, - argp->Q_size_g, - argp->Q_size_h, - argp->Q_size_k, - argp->O_stride_split, - argp->XQ_stride_b, - argp->XQ_stride_m, - argp->XQ_stride_g, - argp->XQ_stride_h, - argp->split_k); - return split_attention_result + reduce_result; + // write output row O[b][m][g][h][:] + data_t* __restrict__ o_ = + arg.split_O + XQO_base_offset + split_idx * arg.O_stride_split; + store_v(o_, lane_idx, bf_r.vec); } - }; + } }; -} // namespace device -} // namespace tensor_operation -} // namespace ck + +} // namespace ck_tile