From 5a4a0eaa75c6b133aee060770b9ee8ee7db366ab Mon Sep 17 00:00:00 2001 From: nscipione Date: Tue, 5 Dec 2023 13:43:50 +0000 Subject: [PATCH] Update configuration for gemm operator on CPU --- cmake/CmakeFunctionHelper.cmake | 13 ++++++++---- src/interface/blas3/backend/default_cpu.hpp | 23 +++++++++++++++------ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index f2d3244dc..c7804d0aa 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -616,19 +616,24 @@ else() # default cpu backend 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "naive" "none" 1 "strided" "false" "false") else() add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 2 "strided" "false" "false") + "${data}" 128 "false" "false" "false" + 64 2 2 2 2 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 2 "strided" "false" "false") add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") + "${data}" 128 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 128 "false" "false" "false" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false" "false") + endif() add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false") + endforeach() if(BLAS_ENABLE_COMPLEX) # Extract list of complex for each data in supported_types diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 4fbf341e2..e0b519e61 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -69,10 +69,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); #else - if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + if (_M <= 128 && _N <= 128 && _K <= 256 && !s_a && !s_b) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + container_0_t, container_1_t, container_2_t, 128, false, false, false, + 64, Tile<2, 2, 2, 2>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 2, @@ -80,10 +80,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (!s_a && !s_b) { + } else if ((_M * _N) >= 524288 && !s_a && !s_b) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + container_0_t, container_1_t, container_2_t, 128, false, false, false, + 64, Tile<4, 4, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -91,6 +91,17 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); + } else if (!s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, false, false, + 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false,