Skip to content

Commit

Permalink
Merge pull request #31 from ROCm/bwd_hd96_perf
Browse files Browse the repository at this point in the history
Bwd hd96 performance improvement
  • Loading branch information
qianfengz authored Oct 16, 2024
2 parents 93524db + 2773383 commit d4437ad
Show file tree
Hide file tree
Showing 60 changed files with 1,483 additions and 118 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
[submodule "third_party/composable_kernel_tiled"]
path = third_party/composable_kernel_tiled
url = https://github.com/ROCm/composable_kernel.git
branch = develop
branch = develop
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
Submodule composable_kernel_tiled updated 64 files
+7 −3 CMakeLists.txt
+43 −16 Jenkinsfile
+1 −0 README.md
+29 −31 codegen/CMakeLists.txt
+20 −18 codegen/test/CMakeLists.txt
+0 −0 codegen/test/include/common.hpp
+2 −0 codegen/test/rtc/CMakeLists.txt
+2 −2 codegen/test/rtc/include/rtc/compile_kernel.hpp
+60 −0 codegen/test/rtc/include/rtc/filesystem.hpp
+2 −2 codegen/test/rtc/include/rtc/tmp_dir.hpp
+5 −5 codegen/test/rtc/src/compile_kernel.cpp
+3 −3 codegen/test/rtc/src/tmp_dir.cpp
+0 −6 docs/reference/API_Reference_Guide.rst
+17 −16 example/01_gemm/common.hpp
+12 −1 example/01_gemm/gemm_dl_fp16.cpp
+12 −1 example/01_gemm/gemm_dl_fp32.cpp
+12 −1 example/01_gemm/gemm_dl_int8.cpp
+4 −1 example/01_gemm/gemm_dpp_fp16.cpp
+12 −1 example/01_gemm/gemm_wmma_fp16.cpp
+15 −1 example/01_gemm/gemm_xdl_bf16.cpp
+15 −1 example/01_gemm/gemm_xdl_bf16_rtn.cpp
+12 −1 example/01_gemm/gemm_xdl_fp16.cpp
+12 −1 example/01_gemm/gemm_xdl_fp16_fp8.cpp
+12 −1 example/01_gemm/gemm_xdl_fp16_v2.cpp
+12 −1 example/01_gemm/gemm_xdl_fp64.cpp
+14 −0 example/01_gemm/gemm_xdl_fp8.cpp
+12 −1 example/01_gemm/gemm_xdl_fp8_bf8.cpp
+12 −1 example/01_gemm/gemm_xdl_int8.cpp
+12 −1 example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp
+12 −1 example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
+12 −1 example/01_gemm/gemm_xdl_streamk.cpp
+12 −1 example/01_gemm/gemm_xdl_wavelet_fp16.cpp
+40 −6 example/01_gemm/run_gemm_example.inc
+2 −2 example/01_gemm/run_gemm_example_streamk_v2.inc
+1 −0 example/44_elementwise_permute/CMakeLists.txt
+247 −0 example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp
+44 −16 example/ck_tile/03_gemm/gemm_basic.cpp
+6 −0 include/ck/host_utility/kernel_launch.hpp
+3 −3 include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+17 −1 include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+6 −0 include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
+625 −30 include/ck/utility/data_type.hpp
+4 −0 include/ck/utility/math_v2.hpp
+4 −1 include/ck_tile/core/config.hpp
+1 −1 include/ck_tile/core/container/thread_buffer.hpp
+37 −10 include/ck_tile/host/reference/reference_gemm.hpp
+1 −0 include/ck_tile/ops/epilogue.hpp
+171 −0 include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+12 −15 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+12 −15 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+83 −143 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+35 −31 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+7 −5 include/ck_tile/ops/gemm.hpp
+6 −9 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+7 −3 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+2 −2 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+3 −3 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+4 −5 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+10 −7 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+424 −0 include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+27 −0 include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+245 −0 library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
+5 −0 test/data_type/CMakeLists.txt
+874 −0 test/data_type/test_custom_type.cpp
102 changes: 48 additions & 54 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch {
constexpr ck_tile::index_t kBlockSize = 64;

const bool pad_seqlen_q = !(param.M % kBlockSize == 0);
const bool pad_headdim_v =
!(param.Kv % FmhaBwdShape<MaxK>::kVHeaddim == 0);
const bool pad_headdim_v = !(param.Kv % MaxK == 0);

BOOL_SWITCH_2(
pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] {
Expand All @@ -78,7 +77,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch {
typename FmhaBwdTypeConfig<ScalarType>::OGradDataType,
typename FmhaBwdTypeConfig<ScalarType>::DDataType,
kBlockSize,
FmhaBwdShape<MaxK>::kVHeaddim,
MaxK, // kVHeaddim
false, // kIsGroupMode
FmhaOGradDotOTraits_>;

Expand Down Expand Up @@ -114,63 +113,58 @@ struct batched_backward_causalmask_bias_dropout_dispatch {
const bool pad_headdim_v =
!(param.Kv % FmhaBwdShape<MaxK>::kVHeaddim == 0);

// usually headdim_q and headdim_v are same, consider them together
// to determine whether to do padding saving some compiling time
const bool pad_headdim = (pad_headdim_q || pad_headdim_v);

BOOL_SWITCH(pad_headdim, kPadHeadDim, [&] {
using FmhaBwdTraits_ = ck_tile::TileFmhaTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDim, // kPadHeadDimQ,
kPadHeadDim, // kPadHeadDimV,
kBiasEnum,
kHasBiasGrad,
false, // kStoreLSE
false, // place-holder for kHasDropout, not used actually
false, // kDoFp8StaticQuant place-holder
occupancy>;

using FmhaBwdPipelineProblem =
FmhaBwdPipelineProblemTemp<FmhaBwdTraits_, FmhaMask>;

constexpr auto FmhaBwdPipelineEnum_ =
FmhaBwdPipelineEnumSelector<MaxK, kPadHeadDim, kPadHeadDim>::
value;

using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker<
FmhaBwdPipelineEnum_,
FmhaBwdPipelineProblem>::pipeline;

using FmhaBwdKGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::KGradDataType,
BOOL_SWITCH_2(
pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] {
using FmhaBwdTraits_ = ck_tile::TileFmhaTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDim>>;

using FmhaBwdVGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::VGradDataType,
kPadSeqLenK,
kPadHeadDim>>;

using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel<
FmhaBwdPipeline_,
FmhaBwdKGradEpilogue_,
FmhaBwdVGradEpilogue_>;

RunWithBwdDQDKDVKernel<FmhaBwdDQDKDVKernel_>(param, stream);
});
kPadHeadDimQ,
kPadHeadDimV,
kBiasEnum,
kHasBiasGrad,
false, // kStoreLSE
false, // place-holder for kHasDropout, not used actually
false, // kDoFp8StaticQuant place-holder
occupancy>;

using FmhaBwdPipelineProblem =
FmhaBwdPipelineProblemTemp<FmhaBwdTraits_, FmhaMask>;

constexpr auto FmhaBwdPipelineEnum_ =
FmhaBwdPipelineEnumSelector<MaxK>::value;

using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker<
FmhaBwdPipelineEnum_,
FmhaBwdPipelineProblem>::pipeline;

using FmhaBwdKGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::KGradDataType,
kPadSeqLenK,
kPadHeadDimQ>>;

using FmhaBwdVGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::VGradDataType,
kPadSeqLenK,
kPadHeadDimV>>;

using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel<
FmhaBwdPipeline_,
FmhaBwdKGradEpilogue_,
FmhaBwdVGradEpilogue_>;

RunWithBwdDQDKDVKernel<FmhaBwdDQDKDVKernel_>(param, stream);
});
});
};
if constexpr (NeedConvertGradQ) {
constexpr ck_tile::index_t kBlockSize = 256;

const bool pad_seqlen_q = !(param.M % kBlockSize == 0);
const bool pad_headdim_q =
!(param.K % FmhaBwdShape<MaxK>::kQKHeaddim == 0);
const bool pad_headdim_q = !(param.K % MaxK == 0);

BOOL_SWITCH_2(
pad_seqlen_q, kPadSeqLenQ, pad_headdim_q, kPadHeadDimQ, [&] {
Expand All @@ -189,7 +183,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch {
kBlockSize,
FmhaBwdShape<MaxK>::kM0,
FmhaBwdShape<MaxK>::kN0,
FmhaBwdShape<MaxK>::kQKHeaddim,
MaxK, // kQKHeaddim
false, // kIsGroupMode
false, // kIsDeterministic
FmhaBwdConvertQGradTraits_>;
Expand Down
24 changes: 23 additions & 1 deletion xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ struct FmhaBwdBlockTile<64> {
using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4
};

template <>
struct FmhaBwdBlockTile<96> {
using tile_lengths = ck_tile::sequence<16, 128, 96, 16, 96, 16, 32, 128, 128>;
using gemm02_warps = ck_tile::sequence<1, 4, 1>; // default for gemm0/gemm2
using gemm13_warps = ck_tile::sequence<4, 1, 1>; // default for gemm1/gemm3
using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4
};

template <>
struct FmhaBwdBlockTile<128> {
using tile_lengths =
Expand Down Expand Up @@ -123,6 +131,20 @@ struct FmhaBwdShape<64> : ck_tile::TileFmhaBwdShape<
typename FmhaBwdBlockTile<64>::gemm4_warps,
FmhaBwdWarpTile2> {};

template <>
struct FmhaBwdShape<96> : ck_tile::TileFmhaBwdShape<
typename FmhaBwdBlockTile<96>::tile_lengths,
typename FmhaBwdBlockTile<96>::gemm02_warps,
FmhaBwdWarpTile2,
typename FmhaBwdBlockTile<96>::gemm13_warps,
FmhaBwdWarpTile3,
typename FmhaBwdBlockTile<96>::gemm02_warps,
FmhaBwdWarpTile2,
typename FmhaBwdBlockTile<96>::gemm13_warps,
FmhaBwdWarpTile3,
typename FmhaBwdBlockTile<96>::gemm4_warps,
FmhaBwdWarpTile2> {};

template <>
struct FmhaBwdShape<128> : ck_tile::TileFmhaBwdShape<
typename FmhaBwdBlockTile<128>::tile_lengths,
Expand Down Expand Up @@ -151,7 +173,7 @@ struct FmhaBwdShape<256> : ck_tile::TileFmhaBwdShape<
typename FmhaBwdBlockTile<256>::gemm4_warps,
FmhaBwdWarpTile2> {};

template <ck_tile::index_t MaxK, bool kPadHeadDimQK, bool kPadHeadDimV>
template <ck_tile::index_t MaxK>
struct FmhaBwdPipelineEnumSelector {
static constexpr ck_tile::BlockFmhaBwdPipelineEnum value =
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP;
Expand Down
101 changes: 48 additions & 53 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch {
static void Run(GroupedBackwardParams& param, hipStream_t stream) {
{
constexpr ck_tile::index_t kBlockSize = 64;
bool pad_headdim_v = !(param.Kv % FmhaBwdShape<MaxK>::kVHeaddim == 0);
bool pad_headdim_v = !(param.Kv % MaxK == 0);

constexpr bool kPadSeqLenQ = true;

Expand All @@ -74,7 +74,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch {
typename FmhaBwdTypeConfig<ScalarType>::OGradDataType,
typename FmhaBwdTypeConfig<ScalarType>::DDataType,
kBlockSize,
FmhaBwdShape<MaxK>::kVHeaddim,
MaxK, // kVHeaddim
true, // kIsGroupMode
FmhaOGradDotOTraits_>;

Expand Down Expand Up @@ -111,64 +111,59 @@ struct grouped_backward_causalmask_bias_dropout_dispatch {
const bool pad_headdim_v =
!(param.Kv % FmhaBwdShape<MaxK>::kVHeaddim == 0);

// usually headdim_q and headdim_v are same, consider them together
// to determine whether to do padding saving some compiling time
const bool pad_headdim = (pad_headdim_q || pad_headdim_v);

BOOL_SWITCH(pad_headdim, kPadHeadDim, [&] {
using FmhaBwdTraits_ = ck_tile::TileFmhaTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDim, // kPadHeadDimQ,
kPadHeadDim, // kPadHeadDimV,
kBiasEnum,
kHasBiasGrad,
false, // kStoreLSE
false, // place-holder for kHasDropout, not used actually
false, // kDoFp8StaticQuant place-holder
occupancy>;

using FmhaBwdPipelineProblem =
FmhaBwdPipelineProblemTemp<FmhaBwdTraits_, FmhaMask>;

constexpr auto FmhaBwdPipelineEnum_ =
FmhaBwdPipelineEnumSelector<MaxK, kPadHeadDim, kPadHeadDim>::
value;

using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker<
FmhaBwdPipelineEnum_,
FmhaBwdPipelineProblem>::pipeline;

using FmhaBwdKGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::KGradDataType,
BOOL_SWITCH_2(
pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] {
using FmhaBwdTraits_ = ck_tile::TileFmhaTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDim>>;

using FmhaBwdVGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::VGradDataType,
kPadSeqLenK,
kPadHeadDim>>;

using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel<
FmhaBwdPipeline_,
FmhaBwdKGradEpilogue_,
FmhaBwdVGradEpilogue_>;

RunWithBwdDQDKDVKernel<FmhaBwdDQDKDVKernel_>(param, stream);
});
kPadHeadDimQ,
kPadHeadDimV,
kBiasEnum,
kHasBiasGrad,
false, // kStoreLSE
false, // place-holder for kHasDropout, not used actually
false, // kDoFp8StaticQuant place-holder
occupancy>;

using FmhaBwdPipelineProblem =
FmhaBwdPipelineProblemTemp<FmhaBwdTraits_, FmhaMask>;

constexpr auto FmhaBwdPipelineEnum_ =
FmhaBwdPipelineEnumSelector<MaxK>::value;

using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker<
FmhaBwdPipelineEnum_,
FmhaBwdPipelineProblem>::pipeline;

using FmhaBwdKGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::KGradDataType,
kPadSeqLenK,
kPadHeadDimQ>>;

using FmhaBwdVGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::VGradDataType,
kPadSeqLenK,
kPadHeadDimV>>;

using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel<
FmhaBwdPipeline_,
FmhaBwdKGradEpilogue_,
FmhaBwdVGradEpilogue_>;

RunWithBwdDQDKDVKernel<FmhaBwdDQDKDVKernel_>(param, stream);
});
});
};

if constexpr (NeedConvertGradQ) {
constexpr ck_tile::index_t kBlockSize = 128;

const bool pad_seqlen_q = true;
const bool pad_headdim_q =
!(param.K % FmhaBwdShape<MaxK>::kQKHeaddim == 0);
const bool pad_headdim_q = !(param.K % MaxK == 0);

BOOL_SWITCH_2(
pad_seqlen_q, kPadSeqLenQ, pad_headdim_q, kPadHeadDimQ, [&] {
Expand All @@ -187,7 +182,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch {
kBlockSize,
64, // kM0
1, // kN0, no use
FmhaBwdShape<MaxK>::kQKHeaddim,
MaxK, // kQKHeaddim
true, // kIsGroupMode
false, // kIsDeterministic
FmhaBwdConvertQGradTraits_>;
Expand Down
6 changes: 6 additions & 0 deletions xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
} else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \
constexpr ck_tile::index_t CONST_NAME = 64; \
__VA_ARGS__(); \
} else if (HEAD_DIM1 <= 96 && HEAD_DIM2 <= 96) { \
constexpr ck_tile::index_t CONST_NAME = 96; \
__VA_ARGS__(); \
} else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \
constexpr ck_tile::index_t CONST_NAME = 128; \
__VA_ARGS__(); \
Expand Down Expand Up @@ -76,6 +79,9 @@
} else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \
constexpr ck_tile::index_t CONST_NAME = 64; \
__VA_ARGS__(); \
} else if (HEAD_DIM1 <= 96 && HEAD_DIM2 <= 96) { \
constexpr ck_tile::index_t CONST_NAME = 96; \
__VA_ARGS__(); \
} else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \
constexpr ck_tile::index_t CONST_NAME = 128; \
__VA_ARGS__(); \
Expand Down
19 changes: 11 additions & 8 deletions xformers/csrc/attention/hip_fmha/generate_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
INT_MAP_MAX_K = {
32: "maxk_32",
64: "maxk_64",
96: "maxk_96",
128: "maxk_128",
256: "maxk_256",
}
Expand Down Expand Up @@ -368,9 +369,11 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None:
disable_hd256 = True

if disable_hd256:
headdims = [32, 64, 128]
headdims_fwd = [32, 64, 128]
headdims_bwd = [32, 64, 96, 128]
else:
headdims = [32, 64, 128, 256]
headdims_fwd = [32, 64, 128, 256]
headdims_bwd = [32, 64, 96, 128, 256]

this_dir = os.path.dirname(__file__)
output_dir = Path(this_dir) / "instances"
Expand All @@ -382,9 +385,9 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None:
file_path = os.path.join(output_dir, ff)
os.remove(file_path)

create_infer_instances(output_dir, headdims)
create_infer_instances_ref(output_dir, headdims)
create_forward_instances(output_dir, headdims)
create_forward_instances_ref(output_dir, headdims)
create_backward_instances(output_dir, headdims)
create_backward_instances_ref(output_dir, headdims)
create_infer_instances(output_dir, headdims_fwd)
create_infer_instances_ref(output_dir, headdims_fwd)
create_forward_instances(output_dir, headdims_fwd)
create_forward_instances_ref(output_dir, headdims_fwd)
create_backward_instances(output_dir, headdims_bwd)
create_backward_instances_ref(output_dir, headdims_bwd)
Loading

0 comments on commit d4437ad

Please sign in to comment.