diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index c7804d0aa..cd84d38cd 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -592,6 +592,9 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") add_gemm_configuration( "${data}" 256 "false" "true" "true" 128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "true" + 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() if(BLAS_ENABLE_COMPLEX) # Extract list of complex for each data in supported_types diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 72ef7160d..b2bca28ef 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -113,9 +113,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (batch_size > 1) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, true, true, - 128, Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, - _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 64, false, false, true, + 64, Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided),