Skip to content

Commit

Permalink
[SYCLomatic] Support more gemm_batch data type combinations when usin…
Browse files Browse the repository at this point in the history
…g opensource oneMKL and fix a type casting bug

Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 committed Jul 3, 2024
1 parent b68d623 commit 8e87927
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,7 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
ldc, batch_size, cm);
break;
}
#endif
case dpct::detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32): {
Expand Down Expand Up @@ -1873,7 +1874,6 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
batch_size, cm);
break;
}
#endif
case dpct::detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float): {
Expand Down Expand Up @@ -2022,13 +2022,18 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
beta, c, ldc, stride_c, batch_size, cm);
break;
}
#endif
case dpct::detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32): {
float alpha_float =
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
float beta_float =
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
dpct::detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
std::int32_t>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
beta, c, ldc, stride_c, batch_size, cm);
float>(
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, stride_a, b, ldb,
stride_b, &beta_float, c, ldc, stride_c, batch_size, cm);
break;
}
case dpct::detail::get_type_combination_id(
Expand All @@ -2047,7 +2052,6 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
beta, c, ldc, stride_c, batch_size, cm);
break;
}
#endif
case dpct::detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float): {
Expand Down

0 comments on commit 8e87927

Please sign in to comment.