Skip to content

Commit

Permalink
Update configuration for GEMM on AMD GPUs (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick authored Feb 6, 2024
1 parent 9d515e3 commit cb69d68
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 6 deletions.
33 changes: 30 additions & 3 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,12 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
set(twr "${workgroup_${data}}")
set(twc "${workgroup_${data}}")

add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# General configuration
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false")

# configuration for tall_skinny
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false")
Expand All @@ -518,9 +517,37 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
"${data}" 256 "true" "true" "true"
64 4 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false")

# configuration for batch
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")

# Configurations for gemm

# low arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
128 1 1 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 4 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# highest arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
32 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# high arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# mid high 162 < a < 240
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# mid low 100 < a < 162
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")

endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
Expand Down
86 changes: 83 additions & 3 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
} else {
// computing arithmetic ratio with combination of input to use it as
// heuristic numerator is the number of fma, denominator is the number of
// bytes access.
const auto n_fma = (static_cast<int64_t>(_M) * static_cast<int64_t>(_K) *
static_cast<int64_t>(_N));
const auto n_elem_access = (_M * _K + _K * _N + _M * _N);
const auto arith_ratio = n_fma / n_elem_access;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
if (batch_type == gemm_batch_type_t::interleaved) {
Expand All @@ -59,6 +66,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}

/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 &&
Expand Down Expand Up @@ -123,18 +131,90 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}
} else
#endif // GEMM_TALL_SKINNY_SUPPORT
if (_M * _N <= 65536) {
// Following configurations are taken using the auto tuner on amd-mi210
// and divided following their arith_ratio or another ratio between _N
// and _K input size
if ((_N >> 4) > _K) {
if (arith_ratio <= 100) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
}
} else if (arith_ratio >= 360) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
s_b, static_cast<int>(gemm_memory_t::local),
true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio >= 240) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 64, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio > 162) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 128, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio >= 100 && arith_ratio <= 162) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, true, true,
128, Tile<2, 2, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio <= 100) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
// this branch is a safe net just in case no other branch is taken
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
Expand Down

0 comments on commit cb69d68

Please sign in to comment.