From 3f68197fa011d3088560d23769d906aecb339913 Mon Sep 17 00:00:00 2001 From: nscipione Date: Tue, 23 Jan 2024 11:26:23 +0100 Subject: [PATCH 1/4] Update configuration for gemm on AMD GPUs Following preliminary investigation and tuning with the auto tuner, these are the new configurations for gemm that provide the best perfomance. The selection of the configuration is now based on the arithmetic intensity and not only on _N and _M dimension. A part from general gemm no other implementation are affected. --- cmake/CmakeFunctionHelper.cmake | 33 ++++++++++- src/interface/blas3/backend/amd_gpu.hpp | 75 ++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 6 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index cd84d38cd..beae299e1 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -495,13 +495,12 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation set(twr "${workgroup_${data}}") set(twc "${workgroup_${data}}") - add_gemm_configuration( - "${data}" 256 "false" "false" "false" - 64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # General configuration add_gemm_configuration( "${data}" 256 "false" "false" "false" 64 4 4 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false") + # configuration for tall_skinny add_gemm_configuration( "${data}" 256 "true" "true" "true" 64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") @@ -518,9 +517,37 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation "${data}" 256 "true" "true" "true" 64 4 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") + # configuration for batch add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") + + # Configurations for gemm + + # low arithmetic intensity + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 128 1 1 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 64 4 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # highest arithmetic intensity + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 32 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # high arithmetic intensity + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 64 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # mid high 162 < a < 240 + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # mid low 100 < a < 162 + add_gemm_configuration( + "${data}" 256 "false" "true" "true" + 128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() if(BLAS_ENABLE_COMPLEX) # Extract list of complex for each data in supported_types diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 9d6ffa424..52527d145 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -45,6 +45,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { return _dependencies; } else { + // computing arithmetic intensity with combination of input to use it as + // heuristic numerator is the number of fma, denominator is the number of + // bytes access. + const auto n_fma = (static_cast(_M) * static_cast(_K) * + static_cast(_N)); + const auto n_bytes_access = (_M * _K + _K * _N + _M * _N); + const auto arith_intensity = n_fma / n_bytes_access; static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_t); if (batch_type == gemm_batch_type_t::interleaved) { @@ -59,6 +66,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } + /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT if (batch_size == 1 && @@ -123,18 +131,79 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } else #endif // GEMM_TALL_SKINNY_SUPPORT - if (_M * _N <= 65536) { + // Folling configuration are taken using the auto tuner on amd-mi210 + // and divided following their arith_intensity or another ratio between _N + // and _K input size + if (arith_intensity >= 360 || (_N >> 4) > _K) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, - false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, - s_b, static_cast(gemm_memory_t::local), + true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_intensity >= 240) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 64, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_intensity > 162) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 128, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_intensity >= 100 && arith_intensity <= 162) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, Tile<2, 2, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided)>:: template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); + } else if (arith_intensity <= 100) { + if ((_N >> 4) > _K) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } } else { + // this branch is a safe net just in case no other branch is taken return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, From 94d570479732c1c9ab67aab60b4a9567bd65c554 Mon Sep 17 00:00:00 2001 From: nscipione Date: Tue, 30 Jan 2024 10:21:20 +0100 Subject: [PATCH 2/4] Update gemm configuration selection following PR comments --- src/interface/blas3/backend/amd_gpu.hpp | 59 +++++++++++++++---------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 52527d145..42666438a 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -134,7 +134,31 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Folling configuration are taken using the auto tuner on amd-mi210 // and divided following their arith_intensity or another ratio between _N // and _K input size - if (arith_intensity >= 360 || (_N >> 4) > _K) { + if ((_N >> 4) > _K) { + if (arith_intensity <= 100) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } + } else if (arith_intensity >= 360) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b, @@ -179,29 +203,16 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } else if (arith_intensity <= 100) { - if ((_N >> 4) > _K) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, - true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, - true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, _dependencies); - } + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); } else { // this branch is a safe net just in case no other branch is taken return blas::Gemm_Launcher< From 68c608f47c5238ffaa580802ce1674c9a5b577da Mon Sep 17 00:00:00 2001 From: nscipione Date: Tue, 30 Jan 2024 10:24:07 +0100 Subject: [PATCH 3/4] Fix typo --- src/interface/blas3/backend/amd_gpu.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 42666438a..c8748829d 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -131,7 +131,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } else #endif // GEMM_TALL_SKINNY_SUPPORT - // Folling configuration are taken using the auto tuner on amd-mi210 + // Following configurations are taken using the auto tuner on amd-mi210 // and divided following their arith_intensity or another ratio between _N // and _K input size if ((_N >> 4) > _K) { From a9c85d45a8b3417fad6fb6831e6087b1fae8a90d Mon Sep 17 00:00:00 2001 From: nscipione Date: Mon, 5 Feb 2024 14:51:06 +0100 Subject: [PATCH 4/4] Change variable name for better naming conventions --- src/interface/blas3/backend/amd_gpu.hpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index c8748829d..8257c76ba 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -45,13 +45,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { return _dependencies; } else { - // computing arithmetic intensity with combination of input to use it as + // computing arithmetic ratio with combination of input to use it as // heuristic numerator is the number of fma, denominator is the number of // bytes access. const auto n_fma = (static_cast(_M) * static_cast(_K) * static_cast(_N)); - const auto n_bytes_access = (_M * _K + _K * _N + _M * _N); - const auto arith_intensity = n_fma / n_bytes_access; + const auto n_elem_access = (_M * _K + _K * _N + _M * _N); + const auto arith_ratio = n_fma / n_elem_access; static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_t); if (batch_type == gemm_batch_type_t::interleaved) { @@ -132,10 +132,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else #endif // GEMM_TALL_SKINNY_SUPPORT // Following configurations are taken using the auto tuner on amd-mi210 - // and divided following their arith_intensity or another ratio between _N + // and divided following their arith_ratio or another ratio between _N // and _K input size if ((_N >> 4) > _K) { - if (arith_intensity <= 100) { + if (arith_ratio <= 100) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b, @@ -158,7 +158,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } - } else if (arith_intensity >= 360) { + } else if (arith_ratio >= 360) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b, @@ -169,7 +169,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (arith_intensity >= 240) { + } else if (arith_ratio >= 240) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, true, 64, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b, @@ -180,7 +180,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (arith_intensity > 162) { + } else if (arith_ratio > 162) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, true, 128, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b, @@ -191,7 +191,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (arith_intensity >= 100 && arith_intensity <= 162) { + } else if (arith_ratio >= 100 && arith_ratio <= 162) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, true, true, 128, Tile<2, 2, 16, 8>, _t_a, _t_b, s_a, s_b, @@ -202,7 +202,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (arith_intensity <= 100) { + } else if (arith_ratio <= 100) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b,