Skip to content

Commit

Permalink
Performance Optimization: Optimized TileShape Configuration for bf16 …
Browse files Browse the repository at this point in the history
…and Mixed Formats (#3710)

Summary:
Pull Request resolved: #3710

X-link: facebookresearch/FBGEMM#783

## Performance Issue with Current BF16 and mixed TileShape Configuration
The current FBGEMM bf16 kernel uses a TileShape configuration of 128x128x128,
while the optimal shape for dense bf16 tensor core on H100 is m64n256k16.
The current configuration leads to suboptimal performance for tensor cores and bandwidth usage,
as evidenced by PTX warnings about:
'wgmma.mma_async instruction serialization due to insufficient register resources'

## Optimized TileShape (128x256x64) Implementation
Modification of the TileShape configuration from 128x128x128 to 128x256x64 for large GEMM
operations using a cooperative kernel, enabling optimal bandwidth and tensor cores utilization.
This configuration is notably used in Flash Attention V3 and identified by Colfax-intl
as the optimal configuration after empirical study for bf16 kernels.

## Benchmark Results on H100 GPU
### Benchmark configuration:
PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
We used the same gemm sizes as in the Colfax benchmarks

### Benchmark
#### bf16bf16bf16_grouped (G = 4, M = 2,048, N = 8,192, K = 8,192)
| TileShape   | TFlops  |
|-------------|-------- |
| 128-128-128 | 606     |
| 128-256- 64 | 790     |

#### bf16i4bf16_rowwise_batched (B = 4, M = 2,048, N = 8,192, K = 8,192)
| TileShape   | TFlops bf16*| TFlops fp16*| TFlops float*|
|-------------|-------------|-------------|------------- |
| 128-128-128 |         354 |         341 |          383 |
| 128-256- 64 |         704 |         727 |          763 |

#### bf16i4bf16_rowwise (M=N=K = 8,192)
| TileShape   | TFlops bf16*| TFlops fp16*| TFlops float*|
|-------------|-------------|-------------|------------- |
| 128-128-128 |         349 |         351 |          381 |
| 128-256- 64 |         652 |         663 |          693 |

#### f8i4bf16_rowwise (M=N=K = 8,192)
| TileShape   | TFlops bf16*| TFlops fp16*| TFlops float*|
|-------------|-------------|-------------|------------- |
| 128-128-128 |         407 |         542 |          606 |
| 128-256- 64 |         921 |         942 |         1088 |

*WEIGHT_SCALE_DTYPE

## Technical Implementation
Modified TileShape from 128-128-128 to 128-256-64 for:
 - bf16bf16bf16_grouped
 - bf16i4bf16_rowwise_batched
 - bf16i4bf16_rowwise
 - f8i4bf16_rowwise

Added cooperative kernel by default for:
 - bf16i4bf16_rowwise_batched
 - bf16i4bf16_rowwise
 - f8i4bf16_rowwise

The modifications only affect large mode and Default kernels where N > 128.
These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.

## Test Coverage
### Unit Tests Results
The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.

jiawenliu64 jwfromm Hello! I wanted to share this contribution to FBGEMM.
While this is my first PR, I hope these changes could be useful for this great project.
I'd welcome any feedback if you have time to take a look. Thank you!

Pull Request resolved: #3591

Reviewed By: jianyuh

Differential Revision: D68609243

Pulled By: jiawenliu64

fbshipit-source-id: e6cc2a9e42f2fc7da76f5fa7189fe773a8c69e51
  • Loading branch information
MatrixAssembler authored and facebook-github-bot committed Feb 18, 2025
1 parent f8f6f43 commit 19f3713
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,20 @@ __global__ void set_dynamic_kernel_args_kernel(
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
Expand All @@ -167,15 +167,15 @@ __global__ void set_dynamic_kernel_args_kernel(
zero_start_index_M[group_index], N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
{zero_start_index_M[group_index], K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
{zero_start_index_M[group_index], N, 1});
}
}
Expand Down Expand Up @@ -212,20 +212,20 @@ __global__ void set_static_kernel_args_kernel(
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] = reinterpret_cast<int64_t>(output_data);
Expand All @@ -237,15 +237,15 @@ __global__ void set_static_kernel_args_kernel(
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape(M, N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
{M, K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
{M, N, 1});
}
}
Expand Down Expand Up @@ -470,10 +470,10 @@ std::vector<at::Tensor> dispatch_bf16_grouped_kernel(
return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>(
x_group, w_group, output_tensor, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>(
return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>(
x_group, w_group, output_tensor, zero_start_index_M);
} else {
return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>(
return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>(
x_group, w_group, output_tensor, zero_start_index_M);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,18 @@ at::Tensor bf16i4bf16_rowwise_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
Expand Down Expand Up @@ -231,18 +236,18 @@ at::Tensor dispatch_bf16i4bf16_rowwise_kernel(
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
true,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
Expand Down Expand Up @@ -235,17 +240,17 @@ at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel(
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
256,
64,
2,
1,
1,
true,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
256,
64,
2,
1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,18 @@ at::Tensor f8i4bf16_rowwise_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

// Implement rowwise scaling epilogue for x
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
Expand Down Expand Up @@ -254,19 +259,19 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
} else if (kernel == KernelMode::Large) {
return f8i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
true,
false,
InputDType,
WEIGHT_SCALE_DTYPE>(XQ, WQ, x_scale, w_scale, w_zp);
} else {
return f8i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
Expand Down

0 comments on commit 19f3713

Please sign in to comment.