From 8e87927ef839e7b0885c2e2d6cb7b1f94d7e769b Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Wed, 3 Jul 2024 16:27:31 +0800 Subject: [PATCH] [SYCLomatic] Support more gemm_batch data type combinations when using opensource oneMKL and fix a type casting bug Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_utils.hpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 5420fc3a4c28..1efd49aac46e 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -1844,6 +1844,7 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, ldc, batch_size, cm); break; } +#endif case dpct::detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): { @@ -1873,7 +1874,6 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, batch_size, cm); break; } -#endif case dpct::detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): { @@ -2022,13 +2022,18 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, beta, c, ldc, stride_c, batch_size, cm); break; } +#endif case dpct::detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); dpct::detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size, cm); + float>( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, stride_a, b, ldb, + stride_b, &beta_float, c, ldc, stride_c, batch_size, cm); break; } case dpct::detail::get_type_combination_id( @@ -2047,7 +2052,6 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, beta, c, ldc, stride_c, batch_size, cm); break; } -#endif case dpct::detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): {