Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Enable migration of CUBLASLT: EPILOGUE_BIAS, EPILOGUE_GELU, EPILOGUE_GELU_AUX, EPILOGUE_GELU_AUX_BIAS #2460

Open
wants to merge 13 commits into
base: SYCLomatic
Choose a base branch
from
12 changes: 12 additions & 0 deletions clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,18 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_GELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu"},
{"CUBLASLT_EPILOGUE_GELU_AUX",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu_aux"},
{"CUBLASLT_EPILOGUE_BIAS",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bias"},
{"CUBLASLT_EPILOGUE_GELU_AUX_BIAS",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu_aux_bias"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
32 changes: 27 additions & 5 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "compat_service.hpp"
#include "dnnl_utils.hpp"
#include "blas_utils.hpp"

namespace dpct {
namespace blas_gemm {
Expand All @@ -33,7 +34,7 @@ enum class pointer_mode_t {
alpha_device_vector_beta_zero,
alpha_device_vector_beta_host
};
enum class epilogue_t { nop = 1, relu };
enum class epilogue_t { nop = 1, relu, bias, gelu, gelu_aux, gelu_aux_bias };

class descriptor;
using descriptor_ptr = descriptor *;
Expand Down Expand Up @@ -780,9 +781,13 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
}

if (compute_desc->_epilogue != epilogue_t::nop &&
compute_desc->_epilogue != epilogue_t::relu) {
compute_desc->_epilogue != epilogue_t::relu &&
compute_desc->_epilogue != epilogue_t::bias &&
compute_desc->_epilogue != epilogue_t::gelu &&
compute_desc->_epilogue != epilogue_t::gelu_aux &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu epilogue currently.");
"supports relu, gelu, gelu_aux, gelu with bias epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -1024,10 +1029,27 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
matmul_args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *scales_alpha});
}

sycl::queue &queue = ::dpct::cs::get_default_queue()
if (compute_desc->_epilogue != epilogue_t::nop) {
::dnnl::post_ops matmul_ops;
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
if (compute_desc->_epilogue == epilogue_t::relu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::gelu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
sizeof(size_t) , dpct::device_to_device, queue);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. CUBLASLT_EPILOGUE_GELU_AUX_BIAS needs to output the bias result to CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. But here, your implementation does not.
  2. Please write an E2E test to validate your implementation. I tried something similar before but encountered a runtime error or an incorrect result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Please help to review.

matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
sizeof(size_t) , dpct::device_to_device, queue);
}
matmul_attr.set_post_ops(matmul_ops);
}

Expand Down
8 changes: 8 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,17 @@ void foo3() {
// CHECK: dpct::blas_gemm::experimental::epilogue_t e;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::bias;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux_bias;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux;
cublasLtEpilogue_t e;
e = CUBLASLT_EPILOGUE_DEFAULT;
e = CUBLASLT_EPILOGUE_RELU;
e = CUBLASLT_EPILOGUE_GELU;
e = CUBLASLT_EPILOGUE_BIAS;
e = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
e = CUBLASLT_EPILOGUE_GELU_AUX;
}

void foo4() {
Expand Down
Loading