Skip to content

Commit

Permalink
Optimize zero fill (#3666)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3666

X-link: facebookresearch/FBGEMM#741

We were spending more time then necessary setting the output tensor to zero during kernel setup. Assuming that N is divisible by 8 and using float4 vectorized writes saves us a good bit of time.

This can yield as much as a 10% overall speedup for fp8 grouped gemm.

Reviewed By: jiawenliu64, mxz297

Differential Revision: D69267443

fbshipit-source-id: 527b81f69fc3792c2b41fad0ba8f123de5bafde6
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 7, 2025
1 parent d564c8c commit 2cef43a
Showing 1 changed file with 80 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include <ATen/ATen.h>
#include <c10/hip/HIPStream.h>
#include <hip_bf16.h>
#include <torch/torch.h>

#include "ck/ck.hpp"
Expand All @@ -27,7 +26,7 @@

namespace fbgemm_gpu {

template<typename InputType, typename OutputType>
template <typename InputType, typename OutputType>
using RowwiseGroupedKernel = std::function<OutputType(
InputType,
InputType,
Expand All @@ -46,49 +45,76 @@ using D1DataType = float;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = ck::bhalf_t;

template<typename InputType, typename OutputType>
RowwiseGroupedKernel<InputType, OutputType> rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
template <typename InputType, typename OutputType>
RowwiseGroupedKernel<InputType, OutputType>
rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
// We use shape heuristics to find the best kernel.
// To do this, we divide by the size of M and find the best
// option within that grouping.
if (M <= 16) {
if (N < 8192 && K <= 8192) {
return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1<InputType, OutputType>;
return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1<
InputType,
OutputType>;
}
if (K <= 8192) {
return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2<InputType, OutputType>;
return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2<
InputType,
OutputType>;
}
return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2<InputType, OutputType>;
return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2<
InputType,
OutputType>;
}
if (M <= 32) {
if (N < 8192 && K <= 8192) {
return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<InputType, OutputType>;
return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<
InputType,
OutputType>;
}
if (K <= 8192) {
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<InputType, OutputType>;
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<
InputType,
OutputType>;
}
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2<InputType, OutputType>;
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2<
InputType,
OutputType>;
}
if (M <= 64) {
return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
InputType,
OutputType>;
}
if (M <= 128) {
if (N < 8192 && K <= 8192) {
return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
InputType,
OutputType>;
}
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
InputType,
OutputType>;
}
if (M <= 256) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
InputType,
OutputType>;
}
if (M <= 512) {
if (K <= 8192) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<InputType, OutputType>;
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<
InputType,
OutputType>;
}
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
InputType,
OutputType>;
}
// Default kernel for all other shapes.
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<InputType, OutputType>;
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<
InputType,
OutputType>;
}

__global__ void set_kernel_args_kernel(
Expand Down Expand Up @@ -139,9 +165,10 @@ void set_static_kernel_args(
if constexpr (std::is_same_v<OutputType, std::vector<at::Tensor>>) {
// Output is a list of tensors and we can access each individually.
output_ptr = reinterpret_cast<EDataType*>(output[i].data_ptr());
} else{
} else {
// Output is a single contiguous tensor and must be accessed via offset.
output_ptr = reinterpret_cast<EDataType*>(output.data_ptr()) + output_offset;
output_ptr =
reinterpret_cast<EDataType*>(output.data_ptr()) + output_offset;
output_offset += M * N;
}

Expand All @@ -165,7 +192,6 @@ void set_static_kernel_args(
M,
N,
K);

}
}

Expand All @@ -180,8 +206,7 @@ __global__ void set_kernel_args_fixed_nk_kernel(
int M,
int N,
int K,
int group_count,
const int BLOCK_SIZE) {
int group_count) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each thread is responsible for setting up the arguments for one group.
if (thread_idx < group_count) {
Expand All @@ -203,33 +228,21 @@ __global__ void set_kernel_args_fixed_nk_kernel(
kernel_args[thread_idx] = kernel_group_args;
}

// We also fuse in initialization of the output tensor.
// We write in chunks of 2 bfloats at a time for efficiency.
for (int i = 0; i < BLOCK_SIZE / 2; i++) {
// Figure out where in memory we are.
int output_offset = (thread_idx * BLOCK_SIZE) + (i * 2);
int current_group = output_offset / (M * N);
// Skip if outside of valid groups.
if (current_group < group_count) {
int nonzeros = prepad_M[current_group];
int current_M = output_offset / N;
// Only write if this block needs initialization.
// Avoid writing to final element if number of elements is odd.
if (current_M >= nonzeros && output_offset < (M * N * group_count) - 1) {
__hip_bfloat162* output_block =
reinterpret_cast<__hip_bfloat162*>(output + output_offset);
*output_block = __hip_bfloat162(0, 0);
}
// Figure out where in memory we are.
// Each thread sets one float 4 which corresponds to 8 bf16 values.
int output_offset = (thread_idx * 8);
int current_group = output_offset / (M * N);
// Skip if outside of valid groups.
if (current_group < group_count) {
int nonzeros = prepad_M[current_group];
int current_M = (output_offset % (M * N)) / N;
// Only write zeros if we're currently in a sparse row.
if (current_M >= nonzeros) {
// Write out a block of 8 output values via vectorized float4.
float4* output_block = reinterpret_cast<float4*>(output + output_offset);
*output_block = {0, 0, 0, 0};
}
}
// Handle case where there are an odd number of total elements.
if (((M * N * group_count) % 2) != 0 &&
((M * N * group_count) - (thread_idx * BLOCK_SIZE) < BLOCK_SIZE)) {
// Write out the final element.
__hip_bfloat16* output_block =
reinterpret_cast<__hip_bfloat16*>(output + (M * N * group_count) - 1);
*output_block = __hip_bfloat16(0);
}
}

void set_dynamic_kernel_args(
Expand Down Expand Up @@ -261,9 +274,12 @@ void set_dynamic_kernel_args(
int N = WQ.size(1);

// Launch a kernel that sets kernel argument memory.
// Each thread sets one float4 which corresponds to 8 bf16 values.
const int BLOCK_SIZE = 8;
TORCH_CHECK(
N % BLOCK_SIZE == 0, "N must be divisible 8 for dynamic grouped gemm.");
int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE);
int blockSize = std::min(1024, block_factor);
int blockSize = std::min(512, block_factor);
int numBlocks = (block_factor + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
Expand All @@ -276,8 +292,7 @@ void set_dynamic_kernel_args(
M,
N,
K,
group_count,
BLOCK_SIZE);
group_count);
}

template <typename OutputType>
Expand Down Expand Up @@ -347,22 +362,25 @@ OutputType _f8f8bf16_rowwise_grouped(
Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16)));
}
}
// Now handle single tensor output.
// Now handle single tensor output.
} else {
// Compute total M across groups.
int total_M = 0;
int N = WQ[0].size(0);
for (int i = 0; i < group_count; i++) {
total_M += XQ[i].size(0);
// Also make sure N is constant across shapes.
TORCH_CHECK(WQ[i].size(0) == N, "N must be constant across groups for stacked output.");
TORCH_CHECK(
WQ[i].size(0) == N,
"N must be constant across groups for stacked output.");
}
if (output.has_value()) {
Y = output.value();
// Check that shape is expected.
TORCH_CHECK(Y.size(0) == total_M && Y.size(1) == N, "Preallocated output should have size [total_M, N].");
}
else {
TORCH_CHECK(
Y.size(0) == total_M && Y.size(1) == N,
"Preallocated output should have size [total_M, N].");
} else {
Y = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16));
}
}
Expand All @@ -383,7 +401,8 @@ OutputType _f8f8bf16_rowwise_grouped(
MaxK = max(MaxK, XQ[i].size(1));
}
RowwiseGroupedKernel<at::TensorList, OutputType> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(MaxM, MaxN, MaxK);
rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(
MaxM, MaxN, MaxK);
return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y);
}

Expand All @@ -394,7 +413,8 @@ std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(XQ, WQ, x_scale, w_scale, output);
return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(
XQ, WQ, x_scale, w_scale, output);
}

// Wrapper function for list input single tensor output.
Expand All @@ -404,7 +424,8 @@ at::Tensor f8f8bf16_rowwise_grouped_stacked(
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<at::Tensor> output = std::nullopt) {
return _f8f8bf16_rowwise_grouped<at::Tensor>(XQ, WQ, x_scale, w_scale, output);
return _f8f8bf16_rowwise_grouped<at::Tensor>(
XQ, WQ, x_scale, w_scale, output);
}

at::Tensor f8f8bf16_rowwise_grouped_dynamic(
Expand Down Expand Up @@ -452,13 +473,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
{static_cast<long>(group_count * sizeof(KernelArguments))},
XQ.options().dtype(at::kByte));
set_dynamic_kernel_args(
kernel_args,
XQ,
WQ,
x_scale,
w_scale,
Y,
zero_start_index_M);
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);

RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);
Expand Down

0 comments on commit 2cef43a

Please sign in to comment.