diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index ca3afc48f..c2a3c8039 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -18,7 +18,6 @@ #include #include -#include #include #include "ck/ck.hpp" @@ -27,7 +26,7 @@ namespace fbgemm_gpu { -template +template using RowwiseGroupedKernel = std::function; using EDataType = ck::bhalf_t; -template -RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { +template +RowwiseGroupedKernel +rowwise_grouped_heuristic_dispatch(int M, int N, int K) { // We use shape heuristics to find the best kernel. // To do this, we divide by the size of M and find the best // option within that grouping. if (M <= 16) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; + return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1< + InputType, + OutputType>; } if (K <= 8192) { - return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2< + InputType, + OutputType>; } - return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2< + InputType, + OutputType>; } if (M <= 32) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2< + InputType, + OutputType>; } if (K <= 8192) { - return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2< + InputType, + OutputType>; } - return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2< + InputType, + OutputType>; } if (M <= 64) { - return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + InputType, + OutputType>; } if (M <= 128) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + InputType, + OutputType>; } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + InputType, + OutputType>; } if (M <= 256) { - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + InputType, + OutputType>; } if (M <= 512) { if (K <= 8192) { - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1< + InputType, + OutputType>; } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + InputType, + OutputType>; } // Default kernel for all other shapes. - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1< + InputType, + OutputType>; } __global__ void set_kernel_args_kernel( @@ -139,9 +165,10 @@ void set_static_kernel_args( if constexpr (std::is_same_v>) { // Output is a list of tensors and we can access each individually. output_ptr = reinterpret_cast(output[i].data_ptr()); - } else{ + } else { // Output is a single contiguous tensor and must be accessed via offset. - output_ptr = reinterpret_cast(output.data_ptr()) + output_offset; + output_ptr = + reinterpret_cast(output.data_ptr()) + output_offset; output_offset += M * N; } @@ -165,7 +192,6 @@ void set_static_kernel_args( M, N, K); - } } @@ -180,8 +206,7 @@ __global__ void set_kernel_args_fixed_nk_kernel( int M, int N, int K, - int group_count, - const int BLOCK_SIZE) { + int group_count) { int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; // Each thread is responsible for setting up the arguments for one group. if (thread_idx < group_count) { @@ -203,33 +228,21 @@ __global__ void set_kernel_args_fixed_nk_kernel( kernel_args[thread_idx] = kernel_group_args; } - // We also fuse in initialization of the output tensor. - // We write in chunks of 2 bfloats at a time for efficiency. - for (int i = 0; i < BLOCK_SIZE / 2; i++) { - // Figure out where in memory we are. - int output_offset = (thread_idx * BLOCK_SIZE) + (i * 2); - int current_group = output_offset / (M * N); - // Skip if outside of valid groups. - if (current_group < group_count) { - int nonzeros = prepad_M[current_group]; - int current_M = output_offset / N; - // Only write if this block needs initialization. - // Avoid writing to final element if number of elements is odd. - if (current_M >= nonzeros && output_offset < (M * N * group_count) - 1) { - __hip_bfloat162* output_block = - reinterpret_cast<__hip_bfloat162*>(output + output_offset); - *output_block = __hip_bfloat162(0, 0); - } + // Figure out where in memory we are. + // Each thread sets one float 4 which corresponds to 8 bf16 values. + int output_offset = (thread_idx * 8); + int current_group = output_offset / (M * N); + // Skip if outside of valid groups. + if (current_group < group_count) { + int nonzeros = prepad_M[current_group]; + int current_M = (output_offset % (M * N)) / N; + // Only write zeros if we're currently in a sparse row. + if (current_M >= nonzeros) { + // Write out a block of 8 output values via vectorized float4. + float4* output_block = reinterpret_cast(output + output_offset); + *output_block = {0, 0, 0, 0}; } } - // Handle case where there are an odd number of total elements. - if (((M * N * group_count) % 2) != 0 && - ((M * N * group_count) - (thread_idx * BLOCK_SIZE) < BLOCK_SIZE)) { - // Write out the final element. - __hip_bfloat16* output_block = - reinterpret_cast<__hip_bfloat16*>(output + (M * N * group_count) - 1); - *output_block = __hip_bfloat16(0); - } } void set_dynamic_kernel_args( @@ -261,9 +274,12 @@ void set_dynamic_kernel_args( int N = WQ.size(1); // Launch a kernel that sets kernel argument memory. + // Each thread sets one float4 which corresponds to 8 bf16 values. const int BLOCK_SIZE = 8; + TORCH_CHECK( + N % BLOCK_SIZE == 0, "N must be divisible 8 for dynamic grouped gemm."); int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE); - int blockSize = std::min(1024, block_factor); + int blockSize = std::min(512, block_factor); int numBlocks = (block_factor + blockSize - 1) / blockSize; set_kernel_args_fixed_nk_kernel<<>>( reinterpret_cast(kernel_args.data_ptr()), @@ -276,8 +292,7 @@ void set_dynamic_kernel_args( M, N, K, - group_count, - BLOCK_SIZE); + group_count); } template @@ -347,7 +362,7 @@ OutputType _f8f8bf16_rowwise_grouped( Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16))); } } - // Now handle single tensor output. + // Now handle single tensor output. } else { // Compute total M across groups. int total_M = 0; @@ -355,14 +370,17 @@ OutputType _f8f8bf16_rowwise_grouped( for (int i = 0; i < group_count; i++) { total_M += XQ[i].size(0); // Also make sure N is constant across shapes. - TORCH_CHECK(WQ[i].size(0) == N, "N must be constant across groups for stacked output."); + TORCH_CHECK( + WQ[i].size(0) == N, + "N must be constant across groups for stacked output."); } if (output.has_value()) { Y = output.value(); // Check that shape is expected. - TORCH_CHECK(Y.size(0) == total_M && Y.size(1) == N, "Preallocated output should have size [total_M, N]."); - } - else { + TORCH_CHECK( + Y.size(0) == total_M && Y.size(1) == N, + "Preallocated output should have size [total_M, N]."); + } else { Y = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16)); } } @@ -383,7 +401,8 @@ OutputType _f8f8bf16_rowwise_grouped( MaxK = max(MaxK, XQ[i].size(1)); } RowwiseGroupedKernel selected_kernel = - rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + rowwise_grouped_heuristic_dispatch( + MaxM, MaxN, MaxK); return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); } @@ -394,7 +413,8 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList x_scale, at::TensorList w_scale, std::optional> output = std::nullopt) { - return _f8f8bf16_rowwise_grouped>(XQ, WQ, x_scale, w_scale, output); + return _f8f8bf16_rowwise_grouped>( + XQ, WQ, x_scale, w_scale, output); } // Wrapper function for list input single tensor output. @@ -404,7 +424,8 @@ at::Tensor f8f8bf16_rowwise_grouped_stacked( at::TensorList x_scale, at::TensorList w_scale, std::optional output = std::nullopt) { - return _f8f8bf16_rowwise_grouped(XQ, WQ, x_scale, w_scale, output); + return _f8f8bf16_rowwise_grouped( + XQ, WQ, x_scale, w_scale, output); } at::Tensor f8f8bf16_rowwise_grouped_dynamic( @@ -452,13 +473,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic( {static_cast(group_count * sizeof(KernelArguments))}, XQ.options().dtype(at::kByte)); set_dynamic_kernel_args( - kernel_args, - XQ, - WQ, - x_scale, - w_scale, - Y, - zero_start_index_M); + kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M); RowwiseGroupedKernel selected_kernel = rowwise_grouped_heuristic_dispatch(M, N, K);