From aff966c4a1a0e645784c9d27a384322093a3b71c Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Thu, 28 Mar 2024 11:16:28 +0000 Subject: [PATCH] removed non-necessary gemm template parameter element in/out --- CMakeLists.txt | 5 + include/operations/blas3_trees.h | 35 ++-- include/sb_handle/portblas_handle.h | 55 +++--- src/interface/blas3/backend/amd_gpu.hpp | 24 +-- src/interface/blas3/backend/default_cpu.hpp | 38 ++--- src/interface/blas3/backend/intel_gpu.hpp | 41 +++-- src/interface/blas3/backend/nvidia_gpu.hpp | 38 ++--- src/interface/blas3/backend/power_vr.hpp | 9 +- src/interface/gemm_interface.hpp | 6 +- src/interface/gemm_launcher.hpp | 17 +- src/operations/blas3/gemm_interleaved.hpp | 50 +++--- src/operations/blas3/gemm_local.hpp | 80 +++++---- .../blas3/gemm_local_joint_matrix.hpp | 16 +- .../blas3/gemm_no_local_full_vec.hpp | 146 ++++++++-------- .../blas3/gemm_no_local_partial_vec.hpp | 83 +++++---- src/operations/blas3/gemm_partial_local.hpp | 26 ++- src/operations/blas3/gemm_ref.hpp | 157 ++++++++---------- src/sb_handle/portblas_handle.hpp | 54 +++--- tools/auto_tuner/include/tune_impl.hpp | 2 +- 19 files changed, 407 insertions(+), 475 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15aa67f69..954f32cac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,6 +113,11 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for GEMM" OFF) option(BLAS_ENABLE_HALF "Whether to enable sycl::half data type for supported operators" OFF) +if(((NOT INSTALL_HEADER_ONLY) AND (TUNING_TARGET STREQUAL "DEFAULT_CPU")) + OR (INSTALL_HEADER_ONLY AND (NOT TUNING_TARGET))) + set(BLAS_ENABLE_HALF OFF) + message(STATUS "FP16 operations are not supported for CPU targets. BLAS_ENABLE_HALF is disabled") +endif() if (SYCL_COMPILER MATCHES "adaptivecpp") if(BLAS_ENABLE_COMPLEX) diff --git a/include/operations/blas3_trees.h b/include/operations/blas3_trees.h index 19a6772eb..d8ca1dc9f 100644 --- a/include/operations/blas3_trees.h +++ b/include/operations/blas3_trees.h @@ -169,8 +169,7 @@ struct Tile { * @tparam TransB iff true, matrix B will be transposed on the fly * @tparam SymmA whether the matrix A is a symmetric triangular matrix * @tparam SymmB whether the matrix B is a symmetric triangular matrix - * @tparam element_in_t type of input matrix elements (A & B) - * @tparam element_t type of output matrix elements (C) and scaling parameters + * @tparam element_t type of scalar alpha & beta * @tparam UseJointMatrix boolean parameter to decide whether to use * joint_matrix or not * @param a_ the lhs_t matrix @@ -191,13 +190,12 @@ struct Tile { */ template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix = false> class Gemm { public: - using value_t = element_t; + using value_t = typename input_t::value_t; using index_t = typename std::make_signed::type; static constexpr int wg_size = tile_type::wg_rows * tile_type::wg_cols; static constexpr bool trans_a = TransA; @@ -251,8 +249,7 @@ class Gemm { */ template + bool IsFinal, bool IsBetaZero, typename element_t, int GemmMemoryType> class GemmPartial {}; /* @@ -263,21 +260,21 @@ template + typename input_t, typename output_t, typename element_t, + typename index_t> inline Gemm + TileType, TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix> make_gemm(input_t buffer_a, input_t buffer_b, output_t buffer_c, element_t alpha, element_t beta, index_t batch_size, index_t _stridea, index_t _strideb, index_t _stridec) { return Gemm( - buffer_a, buffer_b, buffer_c, alpha, beta, batch_size, _stridea, _strideb, - _stridec); + TileType, TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix>(buffer_a, buffer_b, buffer_c, alpha, + beta, batch_size, _stridea, _strideb, + _stridec); } /** diff --git a/include/sb_handle/portblas_handle.h b/include/sb_handle/portblas_handle.h index a21177a08..836b37c61 100644 --- a/include/sb_handle/portblas_handle.h +++ b/include/sb_handle/portblas_handle.h @@ -123,43 +123,40 @@ class SB_Handle { template - event_t execute( - Gemm - gemm_tree, - const event_t& dependencies = {}); + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> + event_t execute(Gemm + gemm_tree, + const event_t& dependencies = {}); // Tall and skinny Gemm specialization template - event_t execute(Gemm(gemm_algorithm_t::tall_skinny), - GemmVectorization, VectorSize, BatchType> - gemm_wrapper, - const event_t& dependencies = {}); + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmVectorization, int VectorSize, + int BatchType> + event_t execute( + Gemm(gemm_algorithm_t::tall_skinny), + GemmVectorization, VectorSize, BatchType> + gemm_wrapper, + const event_t& dependencies = {}); // GemmPartial specialization template - event_t execute( - GemmPartial - gemm_partial, - const event_t& dependencies = {}); + bool IsFinal, bool IsBetaZero, typename element_t, + int GemmMemoryType> + event_t execute(GemmPartial + gemm_partial, + const event_t& dependencies = {}); // Reduction specialization (inner or outer dimension) template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { + using element_in_t = typename ValueType::type; // Unused configuration cases if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { return _dependencies; @@ -233,17 +233,17 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Complex Configurations #ifdef BLAS_ENABLE_COMPLEX template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { + using element_in_t = typename ValueType::type; static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_in_t); /* Tall & Skinny matrices. */ diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 3bbc4f6b9..7942efb1c 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -31,15 +31,15 @@ namespace gemm { namespace backend { template -typename std::enable_if::value && - !is_half::value, - typename sb_handle_t::event_t>::type + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if< + is_sycl_scalar::value && + !is_half::type>::value, + typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { @@ -123,14 +123,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Half Configurations template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::type>::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { @@ -167,14 +166,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Complex Configurations #ifdef BLAS_ENABLE_COMPLEX template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index 3b891fef3..f23a8a220 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -30,18 +30,19 @@ namespace blas { namespace gemm { namespace backend { template -typename std::enable_if::value && - !is_half::value, - typename sb_handle_t::event_t>::type + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if< + is_sycl_scalar::value && + !is_half::type>::value, + typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { + using element_in_t = typename ValueType::type; // Unused configuration cases if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { return _dependencies; @@ -213,14 +214,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Half Configurations template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::type>::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { @@ -275,20 +275,19 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Complex Configurations #ifdef BLAS_ENABLE_COMPLEX template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { #ifdef GEMM_TALL_SKINNY_SUPPORT if (batch_size == 1) { - constexpr int wg_size = sizeof(element_in_t) == 16 ? 4 : 8; + constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false, diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 3650cccd2..cc42bfbb3 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -31,15 +31,15 @@ namespace gemm { namespace backend { template -typename std::enable_if::value && - !is_half::value, - typename sb_handle_t::event_t>::type + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if< + is_sycl_scalar::value && + !is_half::type>::value, + typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { @@ -179,14 +179,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Half Configurations template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::type>::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { @@ -238,14 +237,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, // Complex Configurations #ifdef BLAS_ENABLE_COMPLEX template -typename std::enable_if::value, + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> +typename std::enable_if::value, typename sb_handle_t::event_t>::type _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { diff --git a/src/interface/blas3/backend/power_vr.hpp b/src/interface/blas3/backend/power_vr.hpp index 3ec335ea2..7b7b2473f 100644 --- a/src/interface/blas3/backend/power_vr.hpp +++ b/src/interface/blas3/backend/power_vr.hpp @@ -291,13 +291,12 @@ struct Gemm_Launcher { #endif template + typename sb_handle_t, typename container_0_t, typename container_1_t, + typename container_2_t, typename element_t, typename index_t> typename sb_handle_t::event_t _gemm( sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type) { #ifdef IMGDNN_LIBRARY diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index 659f04084..b777d1215 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -74,11 +74,7 @@ typename sb_handle_t::event_t _gemm_platform_specific( container_2_t _C, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - // element_in_t is passed here explicitly to differentiate mixed-precision - // cases (element_in_t != element_t) from the default ones. - using element_in_t = typename ValueType::type; - return blas::gemm::backend::_gemm<_t_a, _t_b, s_a, s_b, is_beta_zero, - element_in_t, element_t>( + return blas::gemm::backend::_gemm<_t_a, _t_b, s_a, s_b, is_beta_zero>( sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, _dependencies); } diff --git a/src/interface/gemm_launcher.hpp b/src/interface/gemm_launcher.hpp index 00b97cb39..b067e0e74 100644 --- a/src/interface/gemm_launcher.hpp +++ b/src/interface/gemm_launcher.hpp @@ -59,17 +59,12 @@ typename sb_handle_t::event_t Gemm_Launcher< auto b_view = make_matrix_view(b_, _K, _N, _ldb); auto c_view = make_matrix_view(_C, _M, _N, _ldc); - // element_in_t refers here to input matrices (A & B) underlying types, and - // having it separated from element_t helps distinguish Gemm mixed-precision - // cases. - using element_in_t = typename ValueType::type; - auto gemm = - make_gemm( - a_view, b_view, c_view, element_t(_alpha), element_t(_beta), - batch_size, index_t(_stridea), index_t(_strideb), index_t(_stridec)); + auto gemm = make_gemm( + a_view, b_view, c_view, element_t(_alpha), element_t(_beta), batch_size, + index_t(_stridea), index_t(_strideb), index_t(_stridec)); return sb_handle.execute(gemm, _dependencies); } diff --git a/src/operations/blas3/gemm_interleaved.hpp b/src/operations/blas3/gemm_interleaved.hpp index e8bb0e038..ae7987265 100644 --- a/src/operations/blas3/gemm_interleaved.hpp +++ b/src/operations/blas3/gemm_interleaved.hpp @@ -96,25 +96,22 @@ PORTBLAS_INLINE void store(const cl::sycl::vec &packet, PtrT ptr) { * level tiles to use, see Tile * @tparam TransA if true, matrix A will be transposed on the fly * @tparam TransB if true, matrix B will be transposed on the fly - * @tparam element_in_t type of input matrices elements (A, B) - * @tparam element_out_t type of output matrix elements (C) and scalars + * @tparam element_t type of scalar alpha & beta * @tparam is_beta_zero whether to optimize away the beta * C addition * @tparam UseJointMatrix boolean parameter to decide whether to use * joint_matrix or not */ template + bool TransA, bool TransB, bool SymmA, bool SymmB, typename element_t, + bool is_beta_zero, int VectorSize> class Gemm(gemm_memory_t::no_local), + element_t, is_beta_zero, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), VectorSize, static_cast(gemm_batch_type_t::interleaved), false> { public: - using value_t = element_in_t; + using value_t = typename input_t::value_t; using index_t = typename std::make_signed::type; using address_t = cl::sycl::access::address_space; static constexpr int local_memory_size = 0; @@ -147,21 +144,21 @@ class Gemm::type; using packet_out_type = - typename internal::packet::type; + typename internal::packet::type; static_assert(item_batchs % VectorSize == 0, "Item batch must be divisible by vector size"); #ifdef BLAS_ENABLE_COMPLEX - static_assert(!is_complex_sycl::value, + static_assert(!is_complex_sycl::value, "Interleaved GEMM is not supported for Complex Data types"); #endif input_t a_; input_t b_; output_t c_; - const element_out_t alpha_; - const element_out_t beta_; + const element_t alpha_; + const element_t beta_; const index_t m_; const index_t n_; index_t k_; @@ -169,8 +166,8 @@ class Gemm::get_value() << "_" - << type_string::get_value() << "gemm_memory:no_local, " - << "gemm_algorithm:standard, " - << "gemm_vectorization:full, " + << type_string::get_value() << "_" + << type_string::get_value() << "gemm_memory:no_local, " + << "gemm_algorithm:standard, " << "gemm_vectorization:full, " << "vector size" << VectorSize << ", batch_type:interleaved>"; return str.str(); } @@ -357,8 +353,8 @@ class Gemm(reg_res)[p] = - is_in ? input[(j * wg_batchs) + p] : element_in_t(0); + reinterpret_cast(reg_res)[p] = + is_in ? input[(j * wg_batchs) + p] : value_t(0); } ++reg_res; continue; @@ -398,7 +394,7 @@ class Gemm(reg_res)[p]; + reinterpret_cast(reg_res)[p]; } } ++reg_res; @@ -457,9 +453,9 @@ class Gemm( boundary_check(mb_start + (b * wg_batchs) + p, batch_size_)); - reinterpret_cast(reg_res)[p] = - is_in ? element_out_t{output[b * wg_batchs + p] * beta_} - : element_out_t{0}; + reinterpret_cast(reg_res)[p] = + is_in ? element_t{output[b * wg_batchs + p] * beta_} + : element_t{0}; } ++reg_res; continue; @@ -501,8 +497,8 @@ class Gemm::value || - !std::is_same_v) { + if constexpr (is_half::value || + !std::is_same_v) { #pragma unroll for (int v = 0; v < VectorSize; ++v) { (*reg_res)[v] = reg_a[j * (item_batchs / VectorSize) + b][v] * @@ -515,7 +511,7 @@ class Gemm) { + if constexpr (std::is_same_v) { *reg_res = cl::sycl::mad(reg_a[j * (item_batchs / VectorSize) + b], reg_b[i * (item_batchs / VectorSize) + b], *reg_res); diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 899869dda..0624f7f03 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -64,8 +64,7 @@ namespace blas { * @tparam TransB iff true, matrix B will be transposed on the fly * @tparam SymmA whether the matrix A is a symmetric triangular matrix * @tparam SymmB whether the matrix B is a symmetric triangular matrix - * @tparam element_in_t type of input matrices elements (A, B) - * @tparam element_out_t type of output matrix elements (C) and scalars + * @tparam element_t type of scalar alpha & beta * @tparam is_beta_zero True if beta == 0. * @tparam VectorSize The packet size to be used for vectorization. * @tparam batch_type the type of batch strideded /interleaved @@ -74,20 +73,20 @@ namespace blas { */ template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int VectorSize> class Gemm(gemm_memory_t::local), + TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), VectorSize, static_cast(gemm_batch_type_t::strided), false> { public: using tile_type = TileType; - using value_t = element_in_t; + using value_t = typename input_t::value_t; using index_t = typename std::make_signed::type; - using packetize_t = Packetize; - using packetize_out_t = Packetize; + using packetize_t = Packetize; + using packetize_out_t = Packetize; using vector_t = typename packetize_t::PacketType; using vector_out_t = typename packetize_out_t::PacketType; using address_t = cl::sycl::access::address_space; @@ -113,8 +112,8 @@ class Gemm::value) || - !is_complex_sycl::value, + static_assert((VectorSize == 1 && is_complex_sycl::value) || + !is_complex_sycl::value, "Vector size should be equal to 1 for Complex Data types"); #endif @@ -166,15 +165,15 @@ class Gemm::get_value() << "_" - << type_string::get_value() << "gemm_memory:local, " - << "gemm_algorithm:standard, " - << "gemm_vectorization:full, " + << cl_elems * sizeof(value_t) << ", " << tile_type::get_type_string() + << ", " << type_string::get_value() << "_" + << type_string::get_value() << "gemm_memory:local, " + << "gemm_algorithm:standard, " << "gemm_vectorization:full, " << "vector size" << VectorSize << ", batch_type:strided>"; return str.str(); } @@ -309,8 +306,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type scaling_c( - element_out_t *reg_res, InputPointerType C, const index_t &mc, + element_t *reg_res, InputPointerType C, const index_t &mc, const index_t &nc, const index_t &ldc, const bool out_of_range) { if (out_of_range) { return; @@ -404,8 +401,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type scaling_c( - element_out_t *reg_res, InputPointerType, const index_t &, - const index_t &, const index_t &, const bool) { + element_t *reg_res, InputPointerType, const index_t &, const index_t &, + const index_t &, const bool) { for (index_t i = 0; i < item_cols * item_rows; ++i) { reg_res[i] = 0; } @@ -436,7 +433,7 @@ class Gemm(reg_res, C, mc, nc, ldc, out_of_range); while (k >= cl_elems) { @@ -523,24 +520,22 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type store_packet( - element_out_t *reg, OutputPointerType out_ptr) { + element_t *reg, OutputPointerType out_ptr) { *out_ptr = alpha_ * (*reg); } template PORTBLAS_INLINE typename std::enable_if::type store_packet( - element_out_t *reg, OutputPointerType out_ptr) { + element_t *reg, OutputPointerType out_ptr) { vector_out_t out_vec{}; out_vec.template load( - 0, cl::sycl::multi_ptr( - reg)); + 0, cl::sycl::multi_ptr(reg)); out_vec *= alpha_; out_vec.template store( - 0, - cl::sycl::multi_ptr(out_ptr)); + 0, cl::sycl::multi_ptr(out_ptr)); } /*! * @brief Store the computed gemm result to the C matrix @@ -563,7 +558,7 @@ class Gemm PORTBLAS_INLINE void store_output_block(index_t, index_t mc, index_t nc, OutputPointerType C, index_t ldc, - element_out_t *reg_res, + element_t *reg_res, const bool out_of_range) noexcept { if (out_of_range) { return; @@ -742,10 +737,9 @@ class Gemm PORTBLAS_INLINE void compute_block_gemm(index_t, InputPointerType B, - InputPointerType A, - element_in_t *reg_a, - element_in_t ®_b, - element_out_t *reg_res) noexcept { + InputPointerType A, value_t *reg_a, + value_t ®_b, + element_t *reg_res) noexcept { // NOTE: Adding "#pragma unroll" here reduces performance on AMD R9 // Nano. // Seems that the small reduction of arithmetic operations does diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 6953838f9..2473ae562 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -61,8 +61,7 @@ namespace blas { * level tiles to use, see Tile * @tparam TransA iff true, matrix A will be transposed on the fly * @tparam TransB iff true, matrix B will be transposed on the fly - * @tparam element_t type of matrix elements - * @tparam element_out_t type of output matrix elements (C) and scalars + * @tparam element_t type of scalar alpha & beta * @tparam is_beta_zero True if beta == 0. * @tparam VectorSize The packet size to be used for vectorization. * @tparam batch_type the type of batch strideded /interleaved @@ -71,17 +70,17 @@ namespace blas { */ template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int VectorSize> class Gemm(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::none), VectorSize, static_cast(gemm_batch_type_t::strided), true> { public: using tile_type = TileType; - using value_t = element_t; + using value_t = typename input_t::value_t; using index_t = typename std::make_signed::type; using packetize_t = PacketizeJointMatrix; using address_t = cl::sycl::access::address_space; @@ -213,9 +212,8 @@ class Gemm::get_value() << "_" - << type_string::get_value() << "gemm_memory:local, " - << "gemm_algorithm:standard, " - << "gemm_vectorization:none, " + << type_string::get_value() << "gemm_memory:local, " + << "gemm_algorithm:standard, " << "gemm_vectorization:none, " << "vector size" << VectorSize << ", batch_type:strided> " << "with joint_matrix extension"; return str.str(); diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index 92fb4317c..7ff2b2215 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -53,23 +53,22 @@ namespace blas { * level tiles to use, see Tile * @tparam TransA iff true, matrix A will be transposed on the fly * @tparam TransB iff true, matrix B will be transposed on the fly - * @tparam element_in_t type of input matrices elements (A, B) - * @tparam element_out_t type of output matrix elements (C) and scalars + * @tparam element_t type of scalar alpha & beta * @tparam UseJointMatrix boolean parameter to decide whether to use * joint_matrix or not */ template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int VectorSize> class Gemm(gemm_memory_t::no_local), + TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), VectorSize, static_cast(gemm_batch_type_t::strided), false> { public: - using value_t = element_in_t; + using value_t = typename input_t::value_t; using index_t = typename std::make_signed::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; @@ -108,23 +107,23 @@ class Gemm::value) || - !is_complex_sycl::value, + static_assert((VectorSize == 1 && is_complex_sycl::value) || + !is_complex_sycl::value, "Vector size should be equal to 1 for Complex Data types"); #endif input_t a_; input_t b_; output_t c_; - const element_out_t alpha_; - const element_out_t beta_; + const element_t alpha_; + const element_t beta_; index_t batch_size_; index_t stridea_; index_t strideb_; index_t stridec_; - PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_out_t alpha, - element_out_t beta, index_t batch_size, index_t stride_a, + PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_t alpha, + element_t beta, index_t batch_size, index_t stride_a, index_t stride_b, index_t stride_c) : a_(A), b_(B), @@ -143,10 +142,9 @@ class Gemm::get_value() << "_" - << type_string::get_value() << "gemm_memory:no_local, " - << "gemm_algorithm:standard, " - << "gemm_vectorization:full, " + << type_string::get_value() << "_" + << type_string::get_value() << "gemm_memory:no_local, " + << "gemm_algorithm:standard, " << "gemm_vectorization:full, " << "vector size" << VectorSize << ", batch_type:strided>"; return str.str(); } @@ -317,7 +315,7 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type scaling_c( - element_out_t *reg_res, InputPointerType C, const index_t &ldc, + element_t *reg_res, InputPointerType C, const index_t &ldc, const index_t &dim_m_c_start, const index_t &dim_n_c_start, CheckBoundaryType check_boundary, bool out_of_range) { if (out_of_range) { @@ -329,18 +327,17 @@ class Gemm(check_boundary( dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - using l_vector_t = typename Packetize::PacketType; + using l_vector_t = + typename Packetize::PacketType; l_vector_t out_vec{}; out_vec.template load( - 0, - cl::sycl::multi_ptr( - C + j * wg_rows * packet_size)); + 0, cl::sycl::multi_ptr( + C + j * wg_rows * packet_size)); out_vec *= beta_; out_vec.template store( - 0, cl::sycl::multi_ptr( + 0, cl::sycl::multi_ptr( reg_res + i * item_rows + j * packet_size)); } } @@ -352,8 +349,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type scaling_c( - element_out_t *reg_res, InputPointerType, const index_t &, - const index_t &, const index_t &, CheckBoundaryType, bool) { + element_t *reg_res, InputPointerType, const index_t &, const index_t &, + const index_t &, CheckBoundaryType, bool) { #pragma unroll for (index_t i = 0; i < item_cols * item_rows; ++i) { reg_res[i] = 0; @@ -384,7 +381,7 @@ class Gemm( reg_res, C, ldc, dim_m_a_start, dim_n_b_start, boundary_check_c, out_of_range); @@ -470,8 +467,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type load_block_a( - PointerType ptr, element_in_t *reg, const index_t &ptr_next, - const index_t &ld, const RowCheckType &is_valid_row, - const ColCheckType &is_valid_col, const bool out_of_range) noexcept { + PointerType ptr, value_t *reg, const index_t &ptr_next, const index_t &ld, + const RowCheckType &is_valid_row, const ColCheckType &is_valid_col, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -566,22 +563,21 @@ class Gemm(is_valid_row(j * ptr_next + work_per_load - 1)); - using l_vector_t = typename Packetize::PacketType; + using l_vector_t = + typename Packetize::PacketType; l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( - 0, - cl::sycl::multi_ptr( - ptr + i * ld + j * ptr_next)); + 0, cl::sycl::multi_ptr( + ptr + i * ld + j * ptr_next)); } else { // if not in range perform element-wise load checking boundaries at // each load. #pragma unroll for (int l = 0; l < work_per_load; l++) { if (do_check(is_valid_row(j * ptr_next + l))) { - reinterpret_cast(&in_vec)[l] = + reinterpret_cast(&in_vec)[l] = *(ptr + i * ld + j * ptr_next + l); } else { break; @@ -590,8 +586,7 @@ class Gemm( - 0, cl::sycl::multi_ptr( - out_reg)); + 0, cl::sycl::multi_ptr(out_reg)); } } } @@ -632,7 +627,7 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type load_block_a( - PointerType ptr, element_in_t *reg, const index_t &, const index_t &ld, + PointerType ptr, value_t *reg, const index_t &, const index_t &ld, const RowCheckType &is_valid_row, const ColCheckType &is_valid_col, const bool out_of_range) noexcept { if (out_of_range) { @@ -647,15 +642,14 @@ class Gemm(is_valid_col(work_per_load - 1)); - using l_vector_t = typename Packetize::PacketType; + using l_vector_t = + typename Packetize::PacketType; l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( - 0, - cl::sycl::multi_ptr( - ptr + (i * next_element + j) * ld)); + 0, cl::sycl::multi_ptr( + ptr + (i * next_element + j) * ld)); } else { // if not in range perform element-wise load checking boundaries at @@ -663,7 +657,7 @@ class Gemm(is_valid_col(l))) { - reinterpret_cast(&in_vec)[l] = + reinterpret_cast(&in_vec)[l] = *(ptr + (i * next_element + j) * ld + l); } else { break; @@ -674,7 +668,7 @@ class Gemm(&in_vec)[k]; + reinterpret_cast(&in_vec)[k]; } } } @@ -714,9 +708,9 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type load_single_b( - PointerType ptr, element_in_t *reg, const index_t &, - const index_t &col_ofs, const RowCheckType &is_valid_row, - const ColCheckType &is_valid_col, const bool out_of_range) noexcept { + PointerType ptr, value_t *reg, const index_t &, const index_t &col_ofs, + const RowCheckType &is_valid_row, const ColCheckType &is_valid_col, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -726,25 +720,24 @@ class Gemm(is_valid_col(col_ofs)); using l_vector_t = - typename Packetize::PacketType; + typename Packetize::PacketType; l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( - 0, cl::sycl::multi_ptr( - ptr)); + 0, cl::sycl::multi_ptr(ptr)); } else { // Otherwise perform an element-wise load, checking boundaries each load. #pragma unroll for (int k = 0; k < work_per_load; k++) { if (do_check(is_valid_row(k)) && do_check(is_valid_col(col_ofs))) { - reinterpret_cast(&in_vec)[k] = *(ptr + k); + reinterpret_cast(&in_vec)[k] = *(ptr + k); } } } in_vec.template store( - 0, cl::sycl::multi_ptr(reg)); + 0, cl::sycl::multi_ptr(reg)); } /*! @@ -779,9 +772,9 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type load_single_b( - PointerType ptr, element_in_t *reg, const index_t &row_ofs, - const index_t &, const RowCheckType &is_valid_row, - const ColCheckType &is_valid_col, const bool out_of_range) noexcept { + PointerType ptr, value_t *reg, const index_t &row_ofs, const index_t &, + const RowCheckType &is_valid_row, const ColCheckType &is_valid_col, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -791,25 +784,24 @@ class Gemm(is_valid_col(work_per_load - 1)); using l_vector_t = - typename Packetize::PacketType; + typename Packetize::PacketType; l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( - 0, cl::sycl::multi_ptr( - ptr)); + 0, cl::sycl::multi_ptr(ptr)); } else { // Otherwise perform an element-wise load, checking boundaries each load. #pragma unroll for (int k = 0; k < work_per_load; k++) { if (do_check(is_valid_row(row_ofs)) && do_check(is_valid_col(k))) { - reinterpret_cast(&in_vec)[k] = *(ptr + k); + reinterpret_cast(&in_vec)[k] = *(ptr + k); } } } in_vec.template store( - 0, cl::sycl::multi_ptr(reg)); + 0, cl::sycl::multi_ptr(reg)); } /*! * @brief The following function computes the partial GEMM for the input @@ -825,9 +817,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type - compute_block_gemm_no_shared(index_t iteration, element_in_t *reg_a, - element_in_t *reg_b, - element_out_t *reg_res) noexcept { + compute_block_gemm_no_shared(index_t iteration, value_t *reg_a, + value_t *reg_b, element_t *reg_res) noexcept { reg_res += iteration * item_rows; #pragma unroll for (int k = 0; k < packet_size; k++) { @@ -854,9 +845,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if<(packet_size != 1 && trans)>::type - compute_block_gemm_no_shared(index_t iteration, element_in_t *reg_a, - element_in_t *reg_b, - element_out_t *reg_res) noexcept { + compute_block_gemm_no_shared(index_t iteration, value_t *reg_a, + value_t *reg_b, element_t *reg_res) noexcept { reg_a += iteration * item_rows; #pragma unroll for (int i = 0; i < packet_size; i++) { @@ -881,9 +871,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if<(packet_size == 1 && trans)>::type - compute_block_gemm_no_shared(index_t iteration, element_in_t *reg_a, - element_in_t *reg_b, - element_out_t *reg_res) noexcept { + compute_block_gemm_no_shared(index_t iteration, value_t *reg_a, + value_t *reg_b, element_t *reg_res) noexcept { reg_res += iteration * item_rows; #pragma unroll for (int j = 0; j < item_rows; j++) { @@ -913,7 +902,7 @@ class Gemm - PORTBLAS_INLINE void store(PointerType C, element_out_t *reg_res, + PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, const index_t &dim_m_c_start, const index_t &dim_n_c_start, const check_boundary &chk_boundary, @@ -928,18 +917,17 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - using l_vector_t = typename Packetize::PacketType; + using l_vector_t = + typename Packetize::PacketType; l_vector_t out_vec{}; out_vec.template load( - 0, cl::sycl::multi_ptr( + 0, cl::sycl::multi_ptr( reg_res + i * item_rows + j * packet_size)); out_vec *= alpha_; out_vec.template store( - 0, cl::sycl::multi_ptr( + 0, cl::sycl::multi_ptr( C + j * wg_rows * packet_size)); } } diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index ee91398bc..591ee604f 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -53,23 +53,22 @@ namespace blas { * kernel. * @tparam TransA if true, matrix A will be transposed on the fly * @tparam TransB if true, matrix B will be transposed on the fly - * @tparam element_in_t type of input matrices elements (A, B) - * @tparam element_out_t type of output matrix elements (C) and scalars + * @tparam element_t type of scalar alpha & beta * @tparam UseJointMatrix boolean parameter to decide whether to use * joint_matrix or not */ template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int VectorSize> class Gemm(gemm_memory_t::no_local), + TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), VectorSize, static_cast(gemm_batch_type_t::strided), false> { public: - using value_t = element_in_t; + using value_t = typename input_t::value_t; using index_t = typename std::make_signed::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; @@ -105,23 +104,23 @@ class Gemm::value) || - !is_complex_sycl::value, + static_assert((VectorSize == 1 && is_complex_sycl::value) || + !is_complex_sycl::value, "Vector size should be equal to 1 for Complex Data types"); #endif input_t a_; input_t b_; output_t c_; - const element_out_t alpha_; - const element_out_t beta_; + const element_t alpha_; + const element_t beta_; index_t batch_size_; index_t stridea_; index_t strideb_; index_t stridec_; - PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_out_t alpha, - element_out_t beta, index_t batch_size, index_t stride_a, + PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_t alpha, + element_t beta, index_t batch_size, index_t stride_a, index_t stride_b, index_t stride_c) : a_(A), b_(B), @@ -140,10 +139,9 @@ class Gemm::get_value() << "_" - << type_string::get_value() << "gemm_memory:no_local, " - << "gemm_algorithm:standard, " - << "gemm_vectorization:partial, " + << type_string::get_value() << "_" + << type_string::get_value() << "gemm_memory:no_local, " + << "gemm_algorithm:standard, " << "gemm_vectorization:partial, " << "vector size" << VectorSize << ", batch_type:strided>"; return str.str(); } @@ -313,7 +311,7 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type scaling_c( - element_out_t *reg_res, InputPointerType C, const index_t &ldc, + element_t *reg_res, InputPointerType C, const index_t &ldc, const index_t &dim_m_c_start, const index_t &dim_n_c_start, CheckBoundaryType check_boundary, bool out_of_range) { if (out_of_range) { @@ -340,8 +338,8 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type scaling_c( - element_out_t *reg_res, InputPointerType, const index_t &, - const index_t &, const index_t &, CheckBoundaryType, bool) { + element_t *reg_res, InputPointerType, const index_t &, const index_t &, + const index_t &, CheckBoundaryType, bool) { #pragma unroll for (index_t i = 0; i < item_cols * item_rows; ++i) { reg_res[i] = 0; @@ -359,8 +357,8 @@ class Gemm( reg_res, C, ldc, dim_m_a_start, dim_n_b_start, boundary_check_c, out_of_range); @@ -456,9 +454,8 @@ class Gemm - PORTBLAS_INLINE void load(PointerType ptr, element_in_t *reg, - const index_t &ld, index_t index, - const check_boundary &chk_boundary, + PORTBLAS_INLINE void load(PointerType ptr, value_t *reg, const index_t &ld, + index_t index, const check_boundary &chk_boundary, const bool out_of_range) noexcept { if (out_of_range) { return; @@ -472,15 +469,15 @@ class Gemm(chk_boundary(index + (work_per_load - 1))); using l_vector_t = - typename Packetize::PacketType; + typename Packetize::PacketType; l_vector_t in_vec{0}; if (in_range) { in_vec.template load( - 0, cl::sycl::multi_ptr( - ptr)); + 0, + cl::sycl::multi_ptr(ptr)); } in_vec.template store( - 0, cl::sycl::multi_ptr(reg)); + 0, cl::sycl::multi_ptr(reg)); // Move pointers and update index for next load ptr += ld; @@ -497,8 +494,7 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type store_packet( - element_out_t *reg, OutputPointerType out_ptr) { + element_t *reg, OutputPointerType out_ptr) { *out_ptr = alpha_ * (*reg); } template PORTBLAS_INLINE typename std::enable_if::type store_packet( - element_out_t *reg, OutputPointerType out_ptr) { + element_t *reg, OutputPointerType out_ptr) { using l_vector_t = - typename Packetize::PacketType; + typename Packetize::PacketType; l_vector_t out_vec{0}; out_vec.template load( - 0, cl::sycl::multi_ptr( - reg)); + 0, cl::sycl::multi_ptr(reg)); out_vec *= alpha_; out_vec.template store( - 0, - cl::sycl::multi_ptr(out_ptr)); + 0, cl::sycl::multi_ptr(out_ptr)); } /*! @@ -550,7 +544,7 @@ class Gemm - PORTBLAS_INLINE void store(PointerType C, element_out_t *reg_res, + PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, const index_t &dim_m_c_start, const index_t &dim_n_c_start, const check_boundary &chk_boundary, @@ -565,18 +559,17 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - using l_vector_t = typename Packetize::PacketType; + using l_vector_t = + typename Packetize::PacketType; l_vector_t out_vec{0}; out_vec.template load( - 0, cl::sycl::multi_ptr( + 0, cl::sycl::multi_ptr( reg_res + i * item_rows + j * a_packet_size)); out_vec *= alpha_; out_vec.template store( - 0, cl::sycl::multi_ptr( + 0, cl::sycl::multi_ptr( C + j * wg_rows * a_packet_size)); } } diff --git a/src/operations/blas3/gemm_partial_local.hpp b/src/operations/blas3/gemm_partial_local.hpp index 8d255f22c..4730eda4b 100644 --- a/src/operations/blas3/gemm_partial_local.hpp +++ b/src/operations/blas3/gemm_partial_local.hpp @@ -31,14 +31,13 @@ namespace blas { template + bool IsFinal, bool IsBetaZero, typename element_t> class GemmPartial(gemm_memory_t::local)> { + tile_type, TransA, TransB, IsFinal, IsBetaZero, element_t, + static_cast(gemm_memory_t::local)> { public: using index_t = typename std::make_signed::type; - using value_t = element_in_t; + using value_t = typename input_t::value_t; private: /* This structure holds information about the block loading pattern */ @@ -60,8 +59,8 @@ class GemmPartial(global_col_index + lpt * BlockPropertiesType::col_stride < global_cols); - element_in_t val = in_range - ? in_view.template eval(global_mem_index) - : element_in_t(0); + value_t val = + in_range ? in_view.template eval(global_mem_index) : value_t(0); local_ptr[local_mem_index + lpt * BlockPropertiesType::local_mem_increment] = val; diff --git a/src/operations/blas3/gemm_ref.hpp b/src/operations/blas3/gemm_ref.hpp index 992c3d943..8a7ad4f40 100644 --- a/src/operations/blas3/gemm_ref.hpp +++ b/src/operations/blas3/gemm_ref.hpp @@ -42,21 +42,18 @@ namespace blas { * @tparam WgSize the number of items in a work group * @tparam TransA iff true, A will be transposed on the fly * @tparam TransB iff true, B will be transposed on the fly - * @tparam element_in_t the type of input matrix elements (A, B) - * @tparam element_out_t the type of output matrix elements (C) and scalars + * @tparam element_t type of scalar alpha & beta */ template -PORTBLAS_INLINE Gemm:: - Gemm(input_t A, input_t B, output_t C, element_out_t alpha, - element_out_t beta, + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> +PORTBLAS_INLINE +Gemm:: + Gemm(input_t A, input_t B, output_t C, element_t alpha, element_t beta, typename std::make_signed::type batch_size, index_t stride_a, index_t stride_b, index_t stride_c) : a_(A), @@ -76,14 +73,13 @@ PORTBLAS_INLINE Gemm + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE std::string Gemm::get_type_string() noexcept { std::ostringstream str{}; str << "ReferenceGemmFactory<" << wg_size << ", " @@ -97,18 +93,17 @@ Gemm + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE typename Gemm::index_t + element_t, is_beta_zero, GemmMemoryType, + GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix>::index_t Gemm::get_workgroup_cluster() const noexcept { return ((m_ * n_ - 1) / wg_size + 1); } @@ -121,97 +116,91 @@ Gemm + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE typename Gemm::index_t + element_t, is_beta_zero, GemmMemoryType, + GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix>::index_t Gemm::get_num_workgroup_cluster(index_t compute_units) const noexcept { constexpr index_t num_gemm_per_compute_units = 4; return ( (num_gemm_per_compute_units * compute_units - 1) / Gemm::get_workgroup_cluster() + + TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix>::get_workgroup_cluster() + 1); } template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE cl::sycl::nd_range<1> Gemm::get_nd_range(index_t compute_units) const noexcept { const cl::sycl::range<1> nwg( Gemm::get_workgroup_cluster() * + TransA, TransB, SymmA, SymmB, element_t, is_beta_zero, + GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix>::get_workgroup_cluster() * Gemm::get_num_workgroup_cluster(compute_units)); const cl::sycl::range<1> wgs(wg_size); return cl::sycl::nd_range<1>(nwg * wgs, wgs); } - template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE typename Gemm::index_t + element_t, is_beta_zero, GemmMemoryType, + GemmAlgorithm, GemmVectorization, VectorSize, + BatchType, UseJointMatrix>::index_t Gemm::get_size() const { return m_ * n_; } template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE bool Gemm::valid_thread(const cl::sycl::nd_item<1>& ndItem) const { return true; } template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE void Gemm::eval(cl::sycl::nd_item<1> id) noexcept { const index_t wg_batch_id = id.get_group(0) / get_workgroup_cluster(); // This will disable all workgroups that dont have any batch to work on @@ -272,14 +261,13 @@ Gemm + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE void Gemm::bind(cl::sycl::handler& h) { a_.bind(h); b_.bind(h); @@ -288,14 +276,13 @@ Gemm + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> PORTBLAS_INLINE void Gemm::adjust_access_displacement() { a_.adjust_access_displacement(); b_.adjust_access_displacement(); diff --git a/src/sb_handle/portblas_handle.hpp b/src/sb_handle/portblas_handle.hpp index 30f586555..f158de242 100644 --- a/src/sb_handle/portblas_handle.hpp +++ b/src/sb_handle/portblas_handle.hpp @@ -271,22 +271,20 @@ inline typename SB_Handle::event_t SB_Handle::execute( template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmAlgorithm, int GemmVectorization, + int VectorSize, int BatchType, bool UseJointMatrix> inline typename SB_Handle::event_t SB_Handle::execute( Gemm + TransB, SymmA, SymmB, element_t, is_beta_zero, GemmMemoryType, + GemmAlgorithm, GemmVectorization, VectorSize, BatchType, + UseJointMatrix> gemm_tree, const typename SB_Handle::event_t& dependencies) { - using gemm_t = - Gemm; + using gemm_t = Gemm; auto rng = gemm_tree.get_nd_range(SB_Handle::get_num_compute_units()); return {execute_tree< Choose(gemm_memory_t::local), int, @@ -298,14 +296,14 @@ inline typename SB_Handle::event_t SB_Handle::execute( /* Tall and skinny Gemm */ template + bool SymmA, bool SymmB, typename element_t, bool is_beta_zero, + int GemmMemoryType, int GemmVectorization, int VectorSize, + int BatchType> inline typename SB_Handle::event_t SB_Handle::execute( Gemm(gemm_algorithm_t::tall_skinny), - GemmVectorization, VectorSize, BatchType> + TransB, SymmA, SymmB, element_t, is_beta_zero, GemmMemoryType, + static_cast(gemm_algorithm_t::tall_skinny), GemmVectorization, + VectorSize, BatchType> gemm_wrapper, const typename SB_Handle::event_t& dependencies) { using index_t = typename std::make_signed::type; @@ -317,15 +315,14 @@ inline typename SB_Handle::event_t SB_Handle::execute( /* Depth of the cube buffer */ const index_t depth = GemmPartial< input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type, TransA, - TransB, false, is_beta_zero, element_in_t, element_out_t, + TransB, false, is_beta_zero, element_t, GemmMemoryType>::get_ideal_cube_depth(SB_Handle::get_num_compute_units(), rows, cols, gemm_wrapper.k_); /* In some cases, use the tsgemm kernel as a normal gemm operation */ if (depth == 1 || gemm_wrapper.k_ <= 2048) { GemmPartial + TransA, TransB, true, is_beta_zero, element_t, GemmMemoryType> gemm_partial(gemm_wrapper.a_, gemm_wrapper.b_, gemm_wrapper.c_, gemm_wrapper.alpha_, gemm_wrapper.beta_, 1); auto events = execute(gemm_partial, dependencies); @@ -339,7 +336,7 @@ inline typename SB_Handle::event_t SB_Handle::execute( /* Create the cube buffer that will hold the output of the partial gemm */ auto cube_buffer = acquire_temp_mem < is_usm ? helper::AllocType::usm : helper::AllocType::buffer, - element_out_t > (rows * cols * depth); + element_t > (rows * cols * depth); /* Create a first matrix view used for the partial gemm */ auto cube_gemm = @@ -348,8 +345,7 @@ inline typename SB_Handle::event_t SB_Handle::execute( /* Note: we set is_beta_zero to true regardless of the value of beta * because this option is meant for use with a simple Gemm only */ GemmPartial + TransA, TransB, false, true, element_t, GemmMemoryType> gemm_partial(gemm_wrapper.a_, gemm_wrapper.b_, cube_gemm, gemm_wrapper.alpha_, gemm_wrapper.beta_, depth); auto events = execute(gemm_partial, dependencies); @@ -361,7 +357,7 @@ inline typename SB_Handle::event_t SB_Handle::execute( constexpr auto reductions_per_thread = 64; constexpr int work_group_size = tile_type::wg_rows * tile_type::wg_cols; using params_t = - blas::ReductionParams(reduction_dim_t::outer)>; /* Second step: reduction */ @@ -376,7 +372,7 @@ inline typename SB_Handle::event_t SB_Handle::execute( /* Create a temporary buffer to hold alpha * A * B */ auto temp_buffer = acquire_temp_mem < is_usm ? helper::AllocType::usm : helper::AllocType::buffer, - element_out_t > (rows * cols); + element_t > (rows * cols); auto temp = make_matrix_view(temp_buffer, rows, cols, rows); /* Execute the reduction */ @@ -409,12 +405,10 @@ inline typename SB_Handle::event_t SB_Handle::execute( /* GemmPartial */ template + bool IsFinal, bool IsBetaZero, typename element_t, int GemmMemoryType> inline typename SB_Handle::event_t SB_Handle::execute( GemmPartial + TransA, TransB, IsFinal, IsBetaZero, element_t, GemmMemoryType> gemm_partial, const typename SB_Handle::event_t& dependencies) { auto gemm_partial_range = diff --git a/tools/auto_tuner/include/tune_impl.hpp b/tools/auto_tuner/include/tune_impl.hpp index 74b740f80..f456ff9df 100644 --- a/tools/auto_tuner/include/tune_impl.hpp +++ b/tools/auto_tuner/include/tune_impl.hpp @@ -36,7 +36,7 @@ template a) { using Gemm = ::blas::Gemm< MatrixContainer, MatrixContainer, DoubleBuffer, Nbca, Nbcb, Cls, - Tile, Config::TransA, Config::TransB, Config::SymmA, Config::SymmB, T, T, + Tile, Config::TransA, Config::TransB, Config::SymmA, Config::SymmB, T, false, static_cast(Config::MemoryMode), static_cast(Config::ShapeMode), static_cast(Config::VecType), VecSize, static_cast(Config::BatchType)>;