Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Enable migration for CUBLASLT_EPILOGUE_DGELU & EPILOGUE_BGRADB #2449

Open
wants to merge 1 commit into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,12 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_DGELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::dgelu_epilogue"},
{"CUBLASLT_EPILOGUE_bgradb",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bgradb_epilogue"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
32 changes: 32 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class matmul_desc_t {
b_scale_pointer,
d_scale_pointer,
absmax_d_pointer,
dgelu_epilogue,
bgradb_epilogue,
unsupport
};

Expand Down Expand Up @@ -188,6 +190,8 @@ class matmul_desc_t {
CASE(epilogue_aux_ld)
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
CASE(epilogue_aux_pointer)
CASE(epilogue_aux_data_type)
CASE(dgelu_epilogue)
CASE(bgradb_epilogue)
default:
break;
}
Expand All @@ -210,6 +214,9 @@ class matmul_desc_t {
void *_absmax_d_pointer = nullptr;
void *_bias_pointer = nullptr;
void *_epilogue_aux_pointer = nullptr;
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());
auto *_bgradb_epilogue = detail::sync_gelu_backward<::dnnl::reduction>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());


friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand All @@ -222,6 +229,31 @@ class matmul_desc_t {

namespace detail {
/// Sacling each row of matrix D with the corresponding element of vector alpha.

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc sync_gelu_backward(
float alpha, float beta, const memory_desc_ext &src_desc,
const memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::eltwise_gelu_erf;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc bias_backward(
float alpha, float beta, const memory_desc_ext &src_desc,
const memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::reduction_sum;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}

template <class T, class Talpha>
sycl::event scale_d_with_vector_alpha_impl(::dpct::cs::queue_ptr q_ptr,
int rows, int cols, T *d,
Expand Down
4 changes: 4 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void foo3() {
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_data_type;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::dgelu_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::bgradb_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer;
Expand All @@ -220,6 +222,8 @@ void foo3() {
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE;
d = CUBLASLT_EPILOGUE_DGELU;
d = CUBLASLT_EPILOGUE_BGRADB;
d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET;
d = CUBLASLT_MATMUL_DESC_FAST_ACCUM;
d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
Expand Down
Loading