From 6019a73acf2f3a748f154022a070dd4206159c91 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Tue, 16 Jul 2024 10:20:16 +0800 Subject: [PATCH] [SYCLomatic] Support more gemm_batch data type combinations when using opensource oneMKL and fix a type casting bug (#2121) Signed-off-by: Jiang, Zhiwei --- .../dpct-rt/include/dpct/blas_utils.hpp | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index ef406a29a025..7ca0d9855115 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -1846,6 +1846,7 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, ldc, batch_size, cm); break; } +#endif case dpct::detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): { @@ -1854,9 +1855,9 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, float beta_float = dpct::get_value(reinterpret_cast(beta), q); dpct::detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, - &alpha_float, a, lda, b, ldb, - &beta_float, c, ldc, batch_size, cm); + float>( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, + c, ldc, batch_size DPCT_COMPUTE_MODE_ARG); break; } case dpct::detail::get_type_combination_id( @@ -1864,7 +1865,7 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, library_data_t::real_float, library_data_t::real_float): { dpct::detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, cm); + batch_size DPCT_COMPUTE_MODE_ARG); break; } case dpct::detail::get_type_combination_id( @@ -1872,10 +1873,9 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, library_data_t::real_float, library_data_t::real_float): { dpct::detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, cm); + batch_size DPCT_COMPUTE_MODE_ARG); break; } -#endif case dpct::detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): { @@ -2024,13 +2024,19 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, beta, c, ldc, stride_c, batch_size, cm); break; } +#endif case dpct::detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); dpct::detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size, cm); + float>( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, stride_a, b, ldb, + stride_b, &beta_float, c, ldc, stride_c, + batch_size DPCT_COMPUTE_MODE_ARG); break; } case dpct::detail::get_type_combination_id( @@ -2038,7 +2044,7 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, library_data_t::real_float, library_data_t::real_float): { dpct::detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size, cm); + beta, c, ldc, stride_c, batch_size DPCT_COMPUTE_MODE_ARG); break; } case dpct::detail::get_type_combination_id( @@ -2046,10 +2052,9 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, library_data_t::real_float, library_data_t::real_float): { dpct::detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size, cm); + beta, c, ldc, stride_c, batch_size DPCT_COMPUTE_MODE_ARG); break; } -#endif case dpct::detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): {