From cb69d68e81887f38461998d67102bd14a1ae179e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= <9421873+s-Nick@users.noreply.github.com> Date: Tue, 6 Feb 2024 12:21:23 +0100 Subject: [PATCH] Update configuration for GEMM on AMD GPUs (#494) --- cmake/CmakeFunctionHelper.cmake | 33 +++++++++- src/interface/blas3/backend/amd_gpu.hpp | 86 ++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 6 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index cd84d38cd..beae299e1 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -495,13 +495,12 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation set(twr "${workgroup_${data}}") set(twc "${workgroup_${data}}") - add_gemm_configuration( - "${data}" 256 "false" "false" "false" - 64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # General configuration add_gemm_configuration( "${data}" 256 "false" "false" "false" 64 4 4 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false") + # configuration for tall_skinny add_gemm_configuration( "${data}" 256 "true" "true" "true" 64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") @@ -518,9 +517,37 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation "${data}" 256 "true" "true" "true" 64 4 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") + # configuration for batch add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") + + # Configurations for gemm + + # low arithmetic intensity + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 128 1 1 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 64 4 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # highest arithmetic intensity + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 32 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # high arithmetic intensity + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 64 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # mid high 162 < a < 240 + add_gemm_configuration( + "${data}" 256 "false" "false" "true" + 128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + # mid low 100 < a < 162 + add_gemm_configuration( + "${data}" 256 "false" "true" "true" + 128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() if(BLAS_ENABLE_COMPLEX) # Extract list of complex for each data in supported_types diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 9d6ffa424..8257c76ba 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -45,6 +45,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { return _dependencies; } else { + // computing arithmetic ratio with combination of input to use it as + // heuristic numerator is the number of fma, denominator is the number of + // bytes access. + const auto n_fma = (static_cast(_M) * static_cast(_K) * + static_cast(_N)); + const auto n_elem_access = (_M * _K + _K * _N + _M * _N); + const auto arith_ratio = n_fma / n_elem_access; static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_t); if (batch_type == gemm_batch_type_t::interleaved) { @@ -59,6 +66,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } + /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT if (batch_size == 1 && @@ -123,11 +131,82 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } else #endif // GEMM_TALL_SKINNY_SUPPORT - if (_M * _N <= 65536) { + // Following configurations are taken using the auto tuner on amd-mi210 + // and divided following their arith_ratio or another ratio between _N + // and _K input size + if ((_N >> 4) > _K) { + if (arith_ratio <= 100) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } + } else if (arith_ratio >= 360) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, - false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, - s_b, static_cast(gemm_memory_t::local), + true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_ratio >= 240) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 64, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_ratio > 162) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 128, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_ratio >= 100 && arith_ratio <= 162) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, Tile<2, 2, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (arith_ratio <= 100) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided)>:: @@ -135,6 +214,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } else { + // this branch is a safe net just in case no other branch is taken return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,