Skip to content

Commit

Permalink
Update configuration for gemm on AMD GPUs
Browse files Browse the repository at this point in the history
Following preliminary investigation and tuning with the auto tuner, these are the
new configurations for gemm that provide the best perfomance.
The selection of the configuration is now based on the arithmetic
intensity and not only on _N and _M dimension. A part from general gemm
no other implementation are affected.
  • Loading branch information
s-Nick committed Jan 23, 2024
1 parent 89bc647 commit 1217acc
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 6 deletions.
33 changes: 30 additions & 3 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,12 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
set(twr "${workgroup_${data}}")
set(twc "${workgroup_${data}}")

add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# General configuration
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false")

# configuration for tall_skinny
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false")
Expand All @@ -518,9 +517,37 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
"${data}" 256 "true" "true" "true"
64 4 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false")

# configuration for batch
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")

# Configurations for gemm

# low arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
128 1 1 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 4 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# highest arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
32 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# high arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# mid high 162 < a < 240
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# mid low 100 < a < 162
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")

endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
Expand Down
75 changes: 72 additions & 3 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
} else {
// computing arithmetic intensity with combination of input to use it as
// heuristic numerator is the number of fma, denominator is the number of
// bytes access.
const auto n_fma = (static_cast<int64_t>(_M) * static_cast<int64_t>(_K) *
static_cast<int64_t>(_N));
const auto n_bytes_access = (_M * _K + _K * _N + _M * _N);
const auto arith_intensity = n_fma / n_bytes_access;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
if (batch_type == gemm_batch_type_t::interleaved) {
Expand All @@ -59,6 +66,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}

/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 &&
Expand Down Expand Up @@ -123,18 +131,79 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}
} else
#endif // GEMM_TALL_SKINNY_SUPPORT
if (_M * _N <= 65536) {
// Folling configuration are taken using the auto tuner on amd-mi210
// and divided following their arith_intensity or another ratio between _N
// and _K input size
if (arith_intensity >= 360 || (_N >> 4) > _K) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
s_b, static_cast<int>(gemm_memory_t::local),
true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_intensity >= 240) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 64, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_intensity > 162) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 128, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_intensity >= 100 && arith_intensity <= 162) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, true, true,
128, Tile<2, 2, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_intensity <= 100) {
if ((_N >> 4) > _K) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
}
} else {
// this branch is a safe net just in case no other branch is taken
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
Expand Down

0 comments on commit 1217acc

Please sign in to comment.