diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b0143e64..e0d418c57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,6 +212,7 @@ export(EXPORT portblas option(BLAS_ENABLE_TESTING "Whether to enable testing" ON) option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF) +option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING) message(STATUS "Tests are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_TESTING OFF) diff --git a/README.md b/README.md index 5186b98d0..fb210470d 100644 --- a/README.md +++ b/README.md @@ -458,6 +458,7 @@ Some of the supported options are: | `CMAKE_INSTALL_PREFIX` | path | Specify the install location, used when invoking `ninja install` | | `BUILD_SHARED_LIBS` | `ON`/`OFF` | Build as shared library (`ON` by default) | | `ENABLE_EXPRESSION_TESTS` | `ON`/`OFF` | Build additional tests that use the header-only framework (e.g to test expression trees); `OFF` by default | +| `ENABLE_JOINTMATRIX_TESTS` | `ON`/`OFF` | Build additional tests that use joint_matrix extension; `OFF` by default | | `BLAS_VERIFY_BENCHMARK` | `ON`/`OFF` | Verify the results of the benchmarks instead of only measuring the performance. See the documentation of the benchmarks for more details. `ON` by default | | `BLAS_MEMPOOL_BENCHMARK` | `ON`/`OFF` | Determines whether to enable the scratchpad memory pool for benchmark execution. `OFF` by default | | `BLAS_ENABLE_CONST_INPUT` | `ON`/`OFF` | Determines whether to enable kernel instantiation with const input buffer (`ON` by default) | diff --git a/benchmark/portblas/blas3/trsm.cpp b/benchmark/portblas/blas3/trsm.cpp index 0afd170ec..8d28cec4c 100644 --- a/benchmark/portblas/blas3/trsm.cpp +++ b/benchmark/portblas/blas3/trsm.cpp @@ -97,7 +97,13 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side, } std::ostringstream err_stream; - if (!utils::compare_vectors(b_temp, x_ref, err_stream, "")) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (!utils::compare_vectors(b_temp, x_ref, err_stream, "", + (en_joint_matrix != NULL) && + (std::is_same::value) && + (*en_joint_matrix == '1') + ? 2 + : 1)) { const std::string& err_str = err_stream.str(); state.SkipWithError(err_str.c_str()); *success = false; @@ -181,8 +187,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - side, uplo, trans, diag, m, n, - mem_type).c_str(), + side, uplo, trans, diag, m, n, mem_type) + .c_str(), BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success) ->UseRealTime(); } @@ -193,7 +199,8 @@ void register_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { auto trsm_params = blas_benchmark::utils::get_trsm_params(args); register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, trsm_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + trsm_params); #ifdef SB_ENABLE_USM register_benchmark( sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, trsm_params); @@ -201,8 +208,8 @@ void register_benchmark(blas_benchmark::Args& args, } namespace blas_benchmark { -void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, - bool* success) { +void create_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success); } } // namespace blas_benchmark diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 3cfc0885e..5c634279e 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -115,17 +115,20 @@ scalar_t clamp_to_limits(scalar_t v) { * Indicates the tolerated margin for relative differences */ template -inline scalar_t getRelativeErrorMargin() { +inline scalar_t getRelativeErrorMargin(const int32_t margin_multiplier = 1) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), * relative differences of up to 0.002 were observed for float */ - return static_cast(0.005); + scalar_t margin = 0.005; + // increase error margin for mixed precision calculation + // for trsm operator. + return margin * margin_multiplier; } template <> -inline double getRelativeErrorMargin() { +inline double getRelativeErrorMargin(const int32_t) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), @@ -135,7 +138,7 @@ inline double getRelativeErrorMargin() { } template <> -inline cl::sycl::half getRelativeErrorMargin() { +inline cl::sycl::half getRelativeErrorMargin(const int32_t) { // Measured empirically with gemm return 0.05f; } @@ -145,16 +148,19 @@ inline cl::sycl::half getRelativeErrorMargin() { * scalars are close to 0) */ template -inline scalar_t getAbsoluteErrorMargin() { +inline scalar_t getAbsoluteErrorMargin(const int32_t margin_multiplier = 1) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 0.0006 were observed for float */ - return 0.001f; + scalar_t margin = 0.001f; + // increase error margin for mixed precision calculation + // for trsm operator. + return margin * margin_multiplier; } template <> -inline double getAbsoluteErrorMargin() { +inline double getAbsoluteErrorMargin(const int32_t) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 10^-12 were observed for double @@ -163,7 +169,7 @@ inline double getAbsoluteErrorMargin() { } template <> -inline cl::sycl::half getAbsoluteErrorMargin() { +inline cl::sycl::half getAbsoluteErrorMargin(const int32_t) { // Measured empirically with gemm. return 1.0f; } @@ -172,7 +178,8 @@ inline cl::sycl::half getAbsoluteErrorMargin() { * Compare two scalars and returns false if the difference is not acceptable. */ template -inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { +inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2, + const int32_t margin_multiplier = 1) { // Shortcut, also handles case where both are zero if (scalar1 == scalar2) { return true; @@ -187,12 +194,14 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || - absolute_diff < getAbsoluteErrorMargin()) { - return (absolute_diff < getAbsoluteErrorMargin()); + absolute_diff < getAbsoluteErrorMargin(margin_multiplier)) { + return (absolute_diff < + getAbsoluteErrorMargin(margin_multiplier)); } // Use relative error const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2); - return (absolute_diff / absolute_sum) < getRelativeErrorMargin(); + return (absolute_diff / absolute_sum) < + getRelativeErrorMargin(margin_multiplier); } /** @@ -206,7 +215,8 @@ template inline bool compare_vectors(std::vector const& vec, std::vector const& ref, std::ostream& err_stream = std::cerr, - std::string end_line = "\n") { + std::string end_line = "\n", + const int32_t margin_multiplier = 1) { if (vec.size() != ref.size()) { err_stream << "Error: tried to compare vectors of different sizes" << std::endl; @@ -214,7 +224,7 @@ inline bool compare_vectors(std::vector const& vec, } for (int i = 0; i < vec.size(); ++i) { - if (!almost_equal(vec[i], ref[i])) { + if (!almost_equal(vec[i], ref[i], margin_multiplier)) { err_stream << "Value mismatch at index " << i << ": " << vec[i] << "; expected " << ref[i] << end_line; return false; diff --git a/src/operations/blas3/gemm_load_store_joint_matrix.hpp b/src/operations/blas3/gemm_load_store_joint_matrix.hpp index 81eb0625e..876817158 100644 --- a/src/operations/blas3/gemm_load_store_joint_matrix.hpp +++ b/src/operations/blas3/gemm_load_store_joint_matrix.hpp @@ -57,18 +57,16 @@ struct PacketizeJointMatrix { /*! @brief Performs a coalesced non-vectorized load when the current block is * not internal. - * @tparam trans Whether the source matrix is transposed or not. * @tparam internal True if the current block is internal and no bounds * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template + template static PORTBLAS_INLINE typename std::enable_if::type load( const bool in_range, SrcPointerType src, DestPointerType dest, EdgePredicate) { - value_t val = in_range ? *(src) : value_t{0}; + value_t val = in_range ? *src : value_t{0}; using address_t = cl::sycl::access::address_space; if constexpr (std::is_same, @@ -79,93 +77,96 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(val); + using namespace cl::sycl::ext::oneapi; + *dest = bfloat16(val); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; *dest = round_to_tf32(val); } } + /*! @brief Performs a vectorised load using sycl::vec::load when the current * block is internal. In the case where k < the * number of elements being loaded then edge loads will be element wise with * additional bounds checking. - * @tparam trans Whether the source matrix is transposed or not. * @tparam internal True if the current block is internal and no bounds * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template + */ + template static PORTBLAS_INLINE typename std::enable_if::type load( const bool in_range, SrcPointerType src, DestPointerType dest, EdgePredicate edge_in_range) { PacketType packet{}; + using address_t = cl::sycl::access::address_space; if (in_range) { - using address_t = cl::sycl::access::address_space; packet.template load( 0, cl::sycl::multi_ptr(src)); + store(packet, dest); } else { + // avoid writing to variable, instead directly write to + // shared local memory to avoid race condition experienced + // with release compiler. #pragma unroll - for (index_t i = 0; i < packet_size; i++) { - reinterpret_cast(&packet)[i] = - edge_in_range(i) ? *(src + i) : value_t{0}; - } - } - store(packet, dest); - } - /*! @brief Store a vector packet into local memory when the source is - * transposed. This will untranspose the elements individually when storing so - * the data in local memory is always consistent. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE typename std::enable_if::type store( - PacketType &packet, DestPointerType dest) { - using address_t = cl::sycl::access::address_space; -#pragma unroll - for (index_t i = 0; i < packet_size; i++) { - value_t val = reinterpret_cast(&packet)[i]; - if constexpr (std::is_same, - DestPointerType>::value) { - using dtype = cl::sycl::half; - *(dest + ld * i) = static_cast(val); - } else if constexpr (std::is_same, - DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *(dest + ld * i) = static_cast(val); - } else { - using namespace cl::sycl::ext::oneapi::experimental::matrix; - *(dest + ld * i) = round_to_tf32(val); + for (index_t i = 0; i < packet_size; i++, dest++, src++) { + if constexpr (std::is_same, + DestPointerType>::value) { + using dtype = cl::sycl::half; + *dest = static_cast(edge_in_range(i) ? *src : 0); + } else if constexpr (std::is_same, + DestPointerType>::value) { + using namespace cl::sycl::ext::oneapi; + *dest = bfloat16(edge_in_range(i) ? *src : 0.f); + } else { + using namespace cl::sycl::ext::oneapi::experimental::matrix; + *dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f; + } } } } - /*! @brief Store a vector packet into local memory when the source is not - * transposed. This will use sycl::vec::store function. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE typename std::enable_if::type store( - PacketType &packet, DestPointerType dest) { + /*! @brief Store a vector packet into local memory. This will use + * sycl::vec::store function. + */ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { using address_t = cl::sycl::access::address_space; if constexpr (std::is_same, DestPointerType>::value) { using dtype = cl::sycl::half; - *dest = static_cast(packet[0]); + cl::sycl::vec new_vec{}; + for (index_t i = 0; i < packet_size; i++) { + reinterpret_cast(&new_vec)[i] = + static_cast(reinterpret_cast(&packet)[i]); + } + new_vec.template store( + 0, cl::sycl::multi_ptr(dest)); } else if constexpr (std::is_same, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(packet[0]); + // sycl::vec doesn't accept bfloat16 as a valid input type + // so we need to write the packet elements individually to + // the shared memory. + using namespace cl::sycl::ext::oneapi; + for (index_t i = 0; i < packet_size; i++, dest++) { + *dest = bfloat16(reinterpret_cast(&packet)[i]); + } } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; - *dest = round_to_tf32(packet[0]); + using dtype = float; + cl::sycl::vec new_vec; + for (index_t i = 0; i < packet_size; i++) { + reinterpret_cast(&new_vec)[i] = + round_to_tf32(reinterpret_cast(&packet)[i]); + } + new_vec.template store( + 0, cl::sycl::multi_ptr(dest)); } } }; diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 664eed416..440229fef 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -83,7 +83,6 @@ class Gemm::type; using packetize_t = PacketizeJointMatrix; - using vector_t = typename packetize_t::PacketType; using address_t = cl::sycl::access::address_space; // enable easier access to tile dimensions @@ -156,20 +155,31 @@ class Gemm::value, "This code is only supported for float data type."); + static_assert(VectorSize == 1, + "Vectorization not supported for joint_matrix."); + + static_assert( + (frags_per_sg > 1 && jm_row_frags == num_sub_groups) || + (frags_per_sg == 1 && num_jm_frags == num_sub_groups), + "Joint Matrix Row Fragments needs to map 1:1 with total sub_groups."); + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = - tile_type::joint_matrix_M + - nbc_a * tile_type::joint_matrix_K / sizeof(float) * 2; + (trans_a ? cl_elems : block_rows) + + nbc_a * tile_type::joint_matrix_M / sizeof(float) * 2; //! @brief leading dimension of block of B in local static constexpr index_t ldsb = - tile_type::joint_matrix_K + + (trans_b ? block_cols : cl_elems) + nbc_b * tile_type::joint_matrix_K / sizeof(float) * 2; - //! @brief size (in elements) of local (local) memory required by each + //! @brief leading dimension of block of output C in local + static constexpr index_t ldsc = block_rows + (nbc_a | nbc_b) * + tile_type::joint_matrix_M / + sizeof(float) * 2; + //! @brief size (in elements) of local memory required by each // work group static constexpr index_t local_memory_size = - (double_buffer + 1) * - (ldsa * cl_elems * (block_rows / tile_type::joint_matrix_M) + - ldsb * block_cols * (cl_elems / tile_type::joint_matrix_K)); + (double_buffer + 1) * (ldsa * (trans_a ? block_rows : cl_elems) + + ldsb * (trans_b ? cl_elems : block_cols)); input_t a_; input_t b_; @@ -182,8 +192,8 @@ class Gemm PORTBLAS_INLINE void eval(local_memory_t scratch_acc, - const cl::sycl::nd_item<1> &id) noexcept { + const cl::sycl::nd_item<1> &id) noexcept { index_t m = a_.get_size_row(); index_t n = b_.get_size_col(); index_t k = a_.get_size_col(); @@ -270,10 +280,12 @@ class Gemm(id.get_group(0)); // The batch index that each workgroup should start working with const index_t x_groups = (get_wg_x_cluster() - 1) / jm_row_frags + 1; const index_t y_groups = (get_wg_y_cluster() - 1) / jm_col_frags + 1; - const index_t wg_batch_id = id.get_group(0) / (x_groups * y_groups); + const index_t wg_batch_id = wg_id / (x_groups * y_groups); // This will disable all workgroups that dont have any batch to work on if (wg_batch_id >= batch_size_) { return; @@ -283,20 +295,12 @@ class Gemm - (a_.get_pointer()) + (wg_batch_id * stridea_); - auto ptr_B = cl::sycl::multi_ptr - (b_.get_pointer()) + (wg_batch_id * strideb_); - auto ptr_C = cl::sycl::multi_ptr - (c_.get_pointer()) + (wg_batch_id * stridec_); + auto ptr_A = a_.get_pointer() + wg_batch_id * stridea_; + auto ptr_B = b_.get_pointer() + wg_batch_id * strideb_; + auto ptr_C = c_.get_pointer() + wg_batch_id * stridec_; auto sg = id.get_sub_group(); const index_t sg_id = sg.get_group_linear_id(); @@ -307,9 +311,7 @@ class Gemm= m || wg_col >= n); const bool internal = m - wg_row >= block_rows && n - wg_col >= block_cols; - ptr_C += - (wg_row + (sg_id % jm_row_frags) * tile_type::joint_matrix_M) + - (wg_col + (sg_id / jm_row_frags) * tile_type::joint_matrix_N) * ldc; + ptr_C += (wg_row + wg_col * ldc); const index_t mc = m - wg_row; const index_t nc = n - wg_col; @@ -323,8 +325,7 @@ class Gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } else { compute_panel_gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } } else { auto input_scratch = *reinterpret_cast( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } else { compute_panel_gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } } } @@ -430,8 +428,7 @@ class Gemm &id, const index_t &item_id, const index_t &m, const index_t &n, const index_t &orig_k, const index_t &mc, - const index_t &nc, const index_t &a_size, const index_t &b_size, - const index_t &c_size, InputPointerType orig_A, const index_t &lda, + const index_t &nc, InputPointerType orig_A, const index_t &lda, InputPointerType orig_B, const index_t &ldb, OutputPointerType orig_C, const index_t &ldc, OutputScratchPointerType s0, InputScratchPointerType s1, InputScratchPointerType s2, @@ -454,16 +451,15 @@ class Gemm( item_id, m, n, k, A, lda, B, ldb, s1, s3, out_of_range); id.barrier(cl::sycl::access::fence_space::local_space); - compute_block_gemm(id, s2, s4, reg_res); + compute_block_gemm(id, s2, s4, reg_res); A += cl_elems * (trans_a ? 1 : lda); B += cl_elems * (trans_b ? ldb : 1); - sync_smem( - id, ofs, s1, s2, s3, s4); + sync_smem(id, ofs, s1, s2, s3, + s4); k -= cl_elems; } @@ -471,14 +467,13 @@ class Gemm( item_id, m, n, k, A, lda, B, ldb, s1, s3, out_of_range); id.barrier(cl::sycl::access::fence_space::local_space); - compute_block_gemm(id, s2, s4, reg_res); - - sync_smem( - id, ofs, s1, s2, s3, s4); + compute_block_gemm(id, s2, s4, reg_res); + + sync_smem(id, ofs, s1, s2, s3, + s4); } // store the output @@ -506,8 +501,6 @@ class Gemm PORTBLAS_INLINE void store_output_block(cl::sycl::nd_item<1> id, index_t mc, - index_t nc, OutputPointerType C, - ScratchPointerType scratch, - index_t ldc, - CType (®_res)[frags_per_sg], - const bool out_of_range) noexcept { + index_t nc, OutputPointerType C, + ScratchPointerType scratch, + index_t ldc, + CType (®_res)[frags_per_sg], + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -533,126 +526,134 @@ class Gemm(id.get_local_linear_id()); const index_t sg_id = static_cast(sg.get_group_linear_id()); - const index_t sg_range = static_cast(sg.get_group_linear_range()); - const index_t sg_item_id = static_cast(sg.get_local_linear_id()); - - const index_t sg_mc = - mc - (sg_id % jm_row_frags) * tile_type::joint_matrix_M; - const index_t sg_nc = - nc - (sg_id / jm_row_frags) * tile_type::joint_matrix_N; - const bool jm_store_feasible = (mc < block_rows || nc < block_cols) - ? (sg_mc >= tile_type::joint_matrix_M && - sg_nc >= tile_type::joint_matrix_N) - ? true - : false - : true; - - if (jm_store_feasible) { - const index_t loop_limit = - (tile_type::joint_matrix_M * tile_type::joint_matrix_N) / sg_size; - if constexpr (is_beta_zero) { - for (index_t frag = 0; frag < frags_per_sg; frag++) { - for (index_t i = 0; i < loop_limit; i++) { - element_t data_left = - static_cast(get_wi_data(sg, reg_res[frag])[i]); - get_wi_data(sg, float_out)[i] = alpha_ * data_left; - } - - joint_matrix_store(sg, float_out, C, ldc, layout::col_major); - - C += (tile_type::joint_matrix_N * ldc); - } + const index_t output_local_store_offset = + (sg_id % jm_row_frags) * tile_type::joint_matrix_M + + (sg_id / jm_row_frags) * ldsc * tile_type::joint_matrix_N; + const index_t output_local_load_offset = + item_id % block_rows + (item_id / block_rows) * ldsc; - } else { - for (index_t frag = 0; frag < frags_per_sg; frag++) { - joint_matrix_load(sg, float_out, C, ldc, layout::col_major); - - for (index_t i = 0; i < loop_limit; i++) { - element_t data_left = - static_cast(get_wi_data(sg, reg_res[frag])[i]); - element_t data_right = get_wi_data(sg, float_out)[i]; - get_wi_data(sg, float_out)[i] = - beta_ * data_right + alpha_ * data_left; - } + const index_t it_mod_brows = item_id % block_rows; + const index_t it_div_brows = item_id / block_rows; - joint_matrix_store(sg, float_out, C, ldc, layout::col_major); + C += (it_mod_brows + it_div_brows * ldc); - C += (tile_type::joint_matrix_N * ldc); - } - } - return; - } else if (sg_mc <= 0 || sg_nc <= 0) { - return; - } - - id.barrier(cl::sycl::access::fence_space::local_space); + const index_t output_global_outer_offset = ldc * tile_type::joint_matrix_N; + constexpr index_t nc_conditional = + frags_per_sg > 1 ? tile_type::joint_matrix_N : block_cols; - scratch += sg_id * tile_type::joint_matrix_M; - const index_t sg_store_ld = sg_range * tile_type::joint_matrix_M; - const index_t loop_limit = - sg_nc >= tile_type::joint_matrix_N ? tile_type::joint_matrix_N : sg_nc; +#pragma unroll + for (index_t frag = 0; frag < frags_per_sg; frag++, + C += output_global_outer_offset, + nc -= tile_type::joint_matrix_N) { + const index_t rows_per_iter = + nc < nc_conditional ? 1 : wg_size / block_rows; + const index_t output_global_inner_offset = ldc * rows_per_iter; + const index_t output_local_inner_offset = ldsc * rows_per_iter; - for (index_t frag = 0; frag < frags_per_sg; frag++, C += ldc * loop_limit) { auto new_C = C; - auto new_scratch = scratch; + auto new_scratch = scratch + output_local_load_offset; - for (index_t i = 0; - i < - (tile_type::joint_matrix_M * tile_type::joint_matrix_N) / sg_size; - i++) { +#if __INTEL_LLVM_COMPILER && __INTEL_LLVM_COMPILER <= 20240002 + constexpr index_t conv_loop_limit = + (tile_type::joint_matrix_M * tile_type::joint_matrix_N) / sg_size; + for (index_t i = 0; i < conv_loop_limit; i++) { get_wi_data(sg, float_out)[i] = static_cast(get_wi_data(sg, reg_res[frag])[i]); } +#else + joint_matrix_copy(sg, reg_res[frag], float_out); +#endif + joint_matrix_apply(sg, float_out, [=](element_t &x) { x *= alpha_; }); + + id.barrier(cl::sycl::access::fence_space::local_space); - joint_matrix_store(sg, float_out, new_scratch, sg_store_ld, - layout::col_major); + joint_matrix_store(sg, float_out, scratch + output_local_store_offset, + ldsc, layout::col_major); id.barrier(cl::sycl::access::fence_space::local_space); - new_C += sg_item_id; - new_scratch += sg_item_id; - - if (sg_mc < tile_type::joint_matrix_M && - sg_nc < tile_type::joint_matrix_N) { - if (sg_item_id < sg_mc) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right; - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right + beta_ * data_left; + if constexpr (check_m_limit && check_n_limit) { + if (mc >= block_rows && nc >= nc_conditional) { + const index_t loop_limit = nc_conditional / rows_per_iter; +#pragma unroll + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; } } + continue; } - } else if (sg_mc < tile_type::joint_matrix_M) { - if (sg_item_id < sg_mc) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right; - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right + beta_ * data_left; + if (mc < block_rows && nc < nc_conditional) { + if (item_id < mc) { + const index_t loop_limit = nc; +#pragma unroll + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } } } + continue; } - } else { - if (sg_item_id < tile_type::joint_matrix_M) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right; - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right + beta_ * data_left; + if (mc < block_rows) { + if (it_mod_brows < mc) { + const index_t loop_limit = nc_conditional / rows_per_iter; +#pragma unroll + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } + } + } + continue; + } + if (nc < nc_conditional) { + if (item_id < block_rows) { + const index_t loop_limit = nc; +#pragma unroll + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } } } + continue; + } + } else { + const index_t loop_limit = nc_conditional / rows_per_iter; +#pragma unroll + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } } } } @@ -688,8 +689,7 @@ class Gemm( - item_id + i * (wg_size * multiplier) < bs)) + for (index_t i = 0; i < loop_iterations; ++i) { + if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs)) continue; - const index_t col_ofs = i * ((wg_size * multiplier) / rows); + const index_t col_ofs = i * (wg_size / rows); const bool in_range = - do_check( - in_row(((item_id * multiplier) % rows), multiplier - 1)) && - do_check( - in_col((item_id * multiplier / rows), col_ofs)); + do_check(in_row((item_id % rows), 0)) && + do_check(in_col((item_id / rows), col_ofs)); - packetize_t::template load( + packetize_t::template load( in_range, ptr + col_ofs * ld, scratch + col_ofs * lds, [&](const index_t &ofs) { - return in_row((item_id * multiplier) % rows, ofs) && - in_col((item_id * multiplier) / rows, col_ofs); + return in_row(item_id % rows, ofs) && + in_col(item_id / rows, col_ofs); }); } } @@ -756,29 +753,21 @@ class Gemm cols ? block_rows == 128 ? 2 : 4 : loop_iterations / 2; + constexpr index_t loop_iterations = (bs - 1) / wg_size + 1; #pragma unroll for (index_t i = 0; i < loop_iterations; ++i) { - if (!do_check<((bs % (wg_size * multiplier)) != 0)>( - item_id + i * (wg_size * multiplier) < bs)) + if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs)) continue; - const index_t local_row_ofs = - (i % divisor) * ((wg_size * multiplier) / cols) + - i / divisor * lds * cols; - const index_t row_ofs = i * ((wg_size * multiplier) / cols); - const bool in_range = do_check(in_row( - (item_id * multiplier) / cols, row_ofs)) && - do_check(in_col( - (item_id * multiplier) % cols, multiplier - 1)); - - packetize_t::template load( - in_range, ptr + row_ofs * ld, scratch + local_row_ofs, + const index_t row_ofs = i * (wg_size / cols); + const bool in_range = + do_check(in_row(item_id / cols, row_ofs)) && + do_check(in_col(item_id % cols, 0)); + + packetize_t::template load( + in_range, ptr + row_ofs * ld, scratch + row_ofs * lds, [&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE { - return in_col((item_id * multiplier) % cols, ofs) && - in_row((item_id * multiplier) / cols, row_ofs); + return in_col(item_id % cols, ofs) && + in_row(item_id / cols, row_ofs); }); } } @@ -787,49 +776,59 @@ class Gemm + template PORTBLAS_INLINE void compute_block_gemm( const cl::sycl::nd_item<1> &id, InputPointerType s2, InputPointerType s4, CType (®_res)[frags_per_sg]) noexcept { using namespace cl::sycl::ext::oneapi::experimental::matrix; + constexpr layout pattern_a = + trans_a ? layout::row_major : layout::col_major; + constexpr layout pattern_b = + trans_b ? layout::row_major : layout::col_major; using AType = joint_matrix; + pattern_a>; using BType = joint_matrix; - - AType inA; - BType inB; + pattern_b>; const index_t strideA = ldsa; const index_t strideB = ldsb; auto sg = id.get_sub_group(); + constexpr index_t loop_limit = cl_elems / tile_type::joint_matrix_K; #pragma unroll - for (index_t frag = 0; frag < frags_per_sg; frag++) { - auto new_B = s2 + frag * tile_type::joint_matrix_N * ldsb; - auto new_A = s4; + for (index_t i = 0; i < loop_limit; i++) { + auto new_B = s2; + AType inA; - for (index_t i = 0; i < cl_elems / tile_type::joint_matrix_K; i++) { - joint_matrix_load(sg, inA, new_A, strideA); // M - joint_matrix_load(sg, inB, new_B, strideB); // N + joint_matrix_load(sg, inA, s4, strideA); // M + for (index_t frag = 0; frag < frags_per_sg; frag++) { + BType inB; + joint_matrix_load(sg, inB, new_B, strideB); // N +#if __INTEL_LLVM_COMPILER && __INTEL_LLVM_COMPILER <= 20240002 reg_res[frag] = joint_matrix_mad(sg, inA, inB, reg_res[frag]); - - new_A += ldsa * tile_type::joint_matrix_K; - new_B += ldsb * block_cols; +#else + joint_matrix_mad(sg, reg_res[frag], inA, inB, reg_res[frag]); +#endif + new_B += (trans_b ? tile_type::joint_matrix_N + : tile_type::joint_matrix_N * ldsb); } + s4 += (trans_a ? tile_type::joint_matrix_K + : tile_type::joint_matrix_K * strideA); + s2 += (trans_b ? tile_type::joint_matrix_K * strideB + : tile_type::joint_matrix_K); } } @@ -852,7 +851,7 @@ class Gemm static PORTBLAS_INLINE typename std::enable_if::type sync_smem( const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s, - Ps &... ss) noexcept { + Ps &...ss) noexcept { s += ofs_sign * o; sync_smem(id, ofs_sign, ss...); } diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 5d0193518..6c31a4db0 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -226,6 +226,35 @@ static inline void fill_trsm_matrix(std::vector &A, size_t k, } } +/** + * @brief Set to zero the last n bits of a float. + * @tparam T value type. + * @param val input/output float value. + * @param nbits number of last bit set to zero. It is set by default to 13 since + * this is the difference of the number of bits of the mantissa between floats + * (23) and FP16 / NVIDIA TF32 (10). For bfloat16, this value needs to be set to + * 16 to get correct result. + */ +template +void set_to_zero_last_nbits(T &val, int32_t nbits = 13) { + static_assert(sizeof(T) <= 64); + using integer_t = + std::conditional_t; + integer_t *int_pntr = reinterpret_cast(&val); +} + +/** + * @brief Set to zero the last n bits of floats contained in a vector. + * @tparam T value type. + * @param val input/output float vector. + * @param nbits number of last bit set to zero. + */ +template +void set_to_zero_last_nbits(std::vector &vec, int32_t nbits = 13) { + for (T &val : vec) set_to_zero_last_nbits(val, nbits); +} + /** * @brief Helper class for dumping arguments to a stream, in a format compatible * with google test test names. diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index a871040d4..7b398a2a7 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -152,3 +152,21 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS}) COMPONENT tests ) endforeach() + +if(${ENABLE_JOINTMATRIX_TESTS}) + if (${DPCPP_SYCL_TARGET} STREQUAL "nvptx64-nvidia-cuda") + string(FIND ${DPCPP_SYCL_ARCH} "_" start_idx) + if(start_idx) + MATH(EXPR start_idx "${start_idx} + 1") + string(SUBSTRING ${DPCPP_SYCL_ARCH} ${start_idx} "2" sm_val) + endif() + + if (${start_idx} AND ${sm_val} GREATER_EQUAL "80") + add_subdirectory(joint_matrix) + else() + message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs with sm_80 arch and above.") + endif() + else() + message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs.") + endif() +endif() diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 5308d70e5..3f06d58ec 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -135,6 +135,17 @@ inline void verify_gemm(const gemm_arguments_t arguments) { fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); + + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + set_to_zero_last_nbits(a_m); + set_to_zero_last_nbits(b_m); + set_to_zero_last_nbits(c_m_gpu); + set_to_zero_last_nbits(alpha); + set_to_zero_last_nbits(beta); + } + std::vector c_m_cpu = c_m_gpu; // Use system blas to create a reference output @@ -302,6 +313,17 @@ inline void verify_gemm( fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); + + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + set_to_zero_last_nbits(a_m); + set_to_zero_last_nbits(b_m); + set_to_zero_last_nbits(c_m_gpu); + set_to_zero_last_nbits(alpha); + set_to_zero_last_nbits(beta); + } + std::vector c_m_cpu = c_m_gpu; // Use system blas to create a reference output diff --git a/test/unittest/blas3/blas3_trsm_test.cpp b/test/unittest/blas3/blas3_trsm_test.cpp index 793bd7ee5..6f44dab93 100644 --- a/test/unittest/blas3/blas3_trsm_test.cpp +++ b/test/unittest/blas3/blas3_trsm_test.cpp @@ -62,6 +62,14 @@ void run_test(const combination_t combi) { static_cast(unusedValue)); fill_random(B); + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + set_to_zero_last_nbits(A); + set_to_zero_last_nbits(B); + set_to_zero_last_nbits(alpha); + } + // Create a copy of B to calculate the reference outputs cpu_B = B; reference_blas::trsm(&side, &uplo, &trans, &diag, m, n, @@ -84,7 +92,12 @@ void run_test(const combination_t combi) { blas::helper::copy_to_host(q, b_gpu, B.data(), B.size()); sb_handle.wait(event); - bool isAlmostEqual = utils::compare_vectors(cpu_B, B); + bool isAlmostEqual = utils::compare_vectors( + cpu_B, B, std::cerr, "\n", + (en_joint_matrix != NULL) && (std::is_same::value) && + (*en_joint_matrix == '1') + ? 3 + : 1); ASSERT_TRUE(isAlmostEqual); diff --git a/test/unittest/joint_matrix/CMakeLists.txt b/test/unittest/joint_matrix/CMakeLists.txt new file mode 100644 index 000000000..efba7944c --- /dev/null +++ b/test/unittest/joint_matrix/CMakeLists.txt @@ -0,0 +1,74 @@ +#/*************************************************************************** +# * +# * @license +# * Copyright (C) Codeplay Software Limited +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * For your convenience, a copy of the License has been included in this +# * repository. +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# * +# * portBLAS: BLAS implementation using SYCL +# * +# * @filename CMakeLists.txt +# * +# **************************************************************************/ + +set(PORTBLAS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../include) +set(PORTBLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../src) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +include(ConfigurePORTBLAS) +include(SYCL) +find_package(PORTBLAS REQUIRED) + +set(PORTBLAS_JOINTMATRIX_TEST ${CMAKE_CURRENT_SOURCE_DIR}) + +include_directories(${PORTBLAS_TEST} ${BLAS_INCLUDE_DIRS}) + +# compiling tests +set(SYCL_UNITTEST_SRCS + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/tf32_float_16_16_8.cpp +) + +foreach(blas_test ${SYCL_UNITTEST_SRCS}) + get_filename_component(test_exec ${blas_test} NAME_WE) + add_executable(joint_matrix_${test_exec}_test ../main.cpp ${blas_test}) + target_compile_definitions(joint_matrix_${test_exec}_test PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE}) + target_link_libraries(joint_matrix_${test_exec}_test PRIVATE gtest_main Clara::Clara blas::blas PORTBLAS::PORTBLAS) + target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${SYCL_INCLUDE_DIRS}) + target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR}) + target_compile_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS}) + target_link_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS}) + + if(TEST_DEVICE) + add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --device ${TEST_DEVICE} --gtest_output=xml:output/) + else() + add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --gtest_output=xml:output/) + endif() + message(STATUS "Created google test joint_matrix_${test_exec}_test") + install(TARGETS joint_matrix_${test_exec}_test + RUNTIME + DESTINATION ${CMAKE_INSTALL_BINDIR} + COMPONENT tests + ) +endforeach() diff --git a/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp new file mode 100644 index 000000000..ba4d35f7b --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm16n16k16); + +template +const auto MediumMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm16n16k16); + +template +const auto LargeMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm16n16k16); diff --git a/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp new file mode 100644 index 000000000..49cc187ff --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm32n8k16); + +template +const auto MediumMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm32n8k16); + +template +const auto LargeMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm32n8k16); diff --git a/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp new file mode 100644 index 000000000..6583fd450 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm8n32k16); + +template +const auto MediumMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm8n32k16); + +template +const auto LargeMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm8n32k16); diff --git a/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake b/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake new file mode 100644 index 000000000..f102b0247 --- /dev/null +++ b/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake @@ -0,0 +1,65 @@ +#/*************************************************************************** +# * +# * @license +# * Copyright (C) Codeplay Software Limited +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * For your convenience, a copy of the License has been included in this +# * repository. +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# * +# * portBLAS: BLAS implementation using SYCL +# * +# * @filename FindPORTBLAS.cmake +# * +# **************************************************************************/ + +find_path(PORTBLAS_INCLUDE_DIR + NAMES portblas.h + PATH_SUFFIXES include + HINTS ${PORTBLAS_DIR} + DOC "The PORTBLAS include directory" +) + +find_path(PORTBLAS_SRC_DIR + NAMES portblas.hpp + PATH_SUFFIXES src + HINTS ${PORTBLAS_DIR} + DOC "The PORTBLAS source directory" +) + + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(PORTBLAS + FOUND_VAR PORTBLAS_FOUND + REQUIRED_VARS PORTBLAS_INCLUDE_DIR + PORTBLAS_SRC_DIR +) + +mark_as_advanced(PORTBLAS_FOUND + PORTBLAS_SRC_DIR + PORTBLAS_INCLUDE_DIR +) + +if(PORTBLAS_FOUND) + set(PORTBLAS_INCLUDE_DIRS + ${PORTBLAS_INCLUDE_DIR} + ${PORTBLAS_SRC_DIR} + ) +endif() + +if(PORTBLAS_FOUND AND NOT TARGET PORTBLAS::PORTBLAS) + add_library(PORTBLAS::PORTBLAS INTERFACE IMPORTED) + set_target_properties(PORTBLAS::PORTBLAS PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PORTBLAS_INCLUDE_DIRS}" + ) +endif() diff --git a/test/unittest/joint_matrix/half_float_16_16_16.cpp b/test/unittest/joint_matrix/half_float_16_16_16.cpp new file mode 100644 index 000000000..88b3bac8b --- /dev/null +++ b/test/unittest/joint_matrix/half_float_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloat161616); + +template +const auto MediumMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloat161616); + +template +const auto LargeMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloat161616); diff --git a/test/unittest/joint_matrix/half_float_32_8_16.cpp b/test/unittest/joint_matrix/half_float_32_8_16.cpp new file mode 100644 index 000000000..b370e67f2 --- /dev/null +++ b/test/unittest/joint_matrix/half_float_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloatm32n8k16); + +template +const auto MediumMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloatm32n8k16); + +template +const auto LargeMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloatm32n8k16); diff --git a/test/unittest/joint_matrix/half_float_8_32_16.cpp b/test/unittest/joint_matrix/half_float_8_32_16.cpp new file mode 100644 index 000000000..9853332ae --- /dev/null +++ b/test/unittest/joint_matrix/half_float_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloatm8n32k16); + +template +const auto MediumMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloatm8n32k16); + +template +const auto LargeMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloatm8n32k16); diff --git a/test/unittest/joint_matrix/half_half_16_16_16.cpp b/test/unittest/joint_matrix/half_half_16_16_16.cpp new file mode 100644 index 000000000..58241fca0 --- /dev/null +++ b/test/unittest/joint_matrix/half_half_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm16n16k16); + +template +const auto MediumMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm16n16k16); + +template +const auto LargeMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm16n16k16); diff --git a/test/unittest/joint_matrix/half_half_32_8_16.cpp b/test/unittest/joint_matrix/half_half_32_8_16.cpp new file mode 100644 index 000000000..e220b884b --- /dev/null +++ b/test/unittest/joint_matrix/half_half_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm32n8k16); + +template +const auto MediumMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm32n8k16); + +template +const auto LargeMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm32n8k16); diff --git a/test/unittest/joint_matrix/half_half_8_32_16.cpp b/test/unittest/joint_matrix/half_half_8_32_16.cpp new file mode 100644 index 000000000..94a4ebeaf --- /dev/null +++ b/test/unittest/joint_matrix/half_half_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm8n32k16); + +template +const auto MediumMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm8n32k16); + +template +const auto LargeMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm8n32k16); diff --git a/test/unittest/joint_matrix/joint_matrix_common.hpp b/test/unittest/joint_matrix/joint_matrix_common.hpp new file mode 100644 index 000000000..c18366a83 --- /dev/null +++ b/test/unittest/joint_matrix/joint_matrix_common.hpp @@ -0,0 +1,268 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename joint_matrix_common.hpp + * + **************************************************************************/ + +#include "launch_gemm.hpp" + +template +using joint_matrix_arguments_t = + std::tuple; + +template +inline void verify_gemm(const joint_matrix_arguments_t arguments) { + std::string jm_inType, jm_outType; + index_t jm_m, jm_n, jm_k; + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + scalar_t alpha; + scalar_t beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(jm_inType, jm_outType, jm_m, jm_n, jm_k, alloc, offset, batch, m, n, + k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, + batch_type) = arguments; + + assert(batch_type == gemm_batch_type_t::strided); + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t buffer_size_a = batch * size_a + offset; + const index_t buffer_size_b = batch * size_b + offset; + const index_t buffer_size_c = batch * size_c + offset; + + std::vector a_m(buffer_size_a); + std::vector b_m(buffer_size_b); + std::vector c_m_gpu(buffer_size_c); + + if (jm_outType == "half") { + // initialize the vectors with positive values + // to avoid test failures for half precision + // accumulation + fill_random_with_range(a_m, scalar_t{1}, scalar_t{2}); + fill_random_with_range(b_m, scalar_t{1}, scalar_t{2}); + fill_random_with_range(c_m_gpu, scalar_t{1}, scalar_t{2}); + } else { + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + } + + index_t nbits = 13; + if (jm_inType == "bfloat16") { + nbits = 16; + } + set_to_zero_last_nbits(a_m, nbits); + set_to_zero_last_nbits(b_m, nbits); + set_to_zero_last_nbits(c_m_gpu, nbits); + set_to_zero_last_nbits(alpha, nbits); + set_to_zero_last_nbits(beta, nbits); + + std::vector c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::gemm(ta_str, tb_str, m, n, k, alpha, + a_m.data() + i * size_a + offset, lda, + b_m.data() + i * size_b + offset, ldb, beta, + c_m_cpu.data() + i * size_c + offset, ldc); + } + + auto m_a_gpu = blas::helper::allocate(buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate(buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate(buffer_size_c, q); + + auto copy_a = + blas::helper::copy_to_device(q, a_m.data(), m_a_gpu, buffer_size_a); + auto copy_b = + blas::helper::copy_to_device(q, b_m.data(), m_b_gpu, buffer_size_b); + auto copy_c = + blas::helper::copy_to_device(q, c_m_gpu.data(), m_c_gpu, buffer_size_c); + + // portBLAS GEMM implementation + typename blas::SB_Handle::event_t gemm_event; + if (jm_inType == "half" && jm_outType == "float") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = launch_gemm_with_beta<16, 16, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = launch_gemm_with_beta<32, 8, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = launch_gemm_with_beta<8, 32, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "half" && jm_outType == "half") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = + launch_gemm_with_beta<16, 16, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = + launch_gemm_with_beta<32, 8, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = + launch_gemm_with_beta<8, 32, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "bfloat16" && jm_outType == "float") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = + launch_gemm_with_beta<16, 16, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = + launch_gemm_with_beta<32, 8, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = + launch_gemm_with_beta<8, 32, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "tf32" && jm_outType == "float") { + using namespace sycl::ext::oneapi::experimental::matrix; + gemm_event = launch_gemm_with_beta<16, 16, 8, precision::tf32, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + sb_handle.wait(gemm_event); + + auto event = + blas::helper::copy_to_host(q, m_c_gpu, c_m_gpu.data(), buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = utils::compare_vectors( + c_m_gpu, c_m_cpu, std::cerr, "\n", jm_outType == "half" ? 3 : 1); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm(const joint_matrix_arguments_t arguments) { + std::string jm_inType, jm_OutType; + index_t jm_m, jm_n, jm_k; + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + scalar_t alpha; + scalar_t beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(jm_inType, jm_OutType, jm_m, jm_n, jm_k, alloc, offset, batch, m, n, + k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, + batch_type) = arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#else + GTEST_SKIP(); +#endif + } else { + verify_gemm(arguments); + } +} + +template <> +inline void dump_arg(std::ostream& ss, + gemm_batch_type_t batch_type) { + ss << (int)batch_type; +} + +template +static std::string generate_name( + const ::testing::TestParamInfo>& info) { + std::string jm_inType, jm_OutType; + int jm_m, jm_n, jm_k; + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; + char transa, transb; + T alpha, beta; + gemm_batch_type_t batchType; + BLAS_GENERATE_NAME(info.param, jm_inType, jm_OutType, jm_m, jm_n, jm_k, alloc, + offset, batch, m, n, k, transa, transb, alpha, beta, + ldaMul, ldbMul, ldcMul, batchType); +} + +/** Registers Joint Matrix test for all supported data types (only float for + * now) + * @param test_suite Name of the test suite + * @param combination Combinations object + * @see BLAS_REGISTER_TEST_CUSTOM_NAME + */ +#define GENERATE_JOINTMATRIX_TEST(test_suite, combination) \ + BLAS_REGISTER_TEST_FLOAT_CUSTOM_NAME(test_suite, test_suite##combination, \ + verify_gemm, joint_matrix_arguments_t, \ + combination, generate_name); diff --git a/test/unittest/joint_matrix/launch_gemm.hpp b/test/unittest/joint_matrix/launch_gemm.hpp new file mode 100644 index 000000000..afab6a0b3 --- /dev/null +++ b/test/unittest/joint_matrix/launch_gemm.hpp @@ -0,0 +1,247 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename launch_gemm.hpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "interface/gemm_launcher.hpp" +#include "portblas.hpp" +#include + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<8, 8, 16, 16, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<4, 8, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<8, 16, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, true, true, 128, + Tile<8, 8, 8, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<4, 4, 16, 16, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE typename sb_handle_t::event_t launch_gemm_with_transpose( + sb_handle_t& sb_handle, char _trans_a, char _trans_b, index_t _M, + index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, + element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, + index_t batch_size, gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + typename sb_handle_t::event_t gemm_event; + if (_trans_a == 't' && _trans_b == 't') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 'n' && _trans_b == 'n') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 't' && _trans_b == 'n') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 'n' && _trans_b == 't') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } + return gemm_event; +} + +template +PORTBLAS_ALWAYS_INLINE typename sb_handle_t::event_t launch_gemm_with_beta( + sb_handle_t& sb_handle, char _trans_a, char _trans_b, index_t _M, + index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, + element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, + index_t batch_size, gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + typename sb_handle_t::event_t gemm_event; + if (_beta == (element_t)0) { + gemm_event = + launch_gemm_with_transpose( + sb_handle, _trans_a, _trans_b, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, + batch_type, _dependencies); + } else { + gemm_event = + launch_gemm_with_transpose( + sb_handle, _trans_a, _trans_b, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, + batch_type, _dependencies); + } + return gemm_event; +} diff --git a/test/unittest/joint_matrix/tf32_float_16_16_8.cpp b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp new file mode 100644 index 000000000..f7facadc3 --- /dev/null +++ b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename tf32_float_16_16_8.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesTF32Floatm16n16k8); + +template +const auto MediumMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesTF32Floatm16n16k8); + +template +const auto LargeMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesTF32Floatm16n16k8);