diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index e2c857643643..78e02a5c8657 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -2072,6 +2072,12 @@ void MapNames::setExplicitNamespaceMap( {"CUBLASLT_EPILOGUE_RELU", getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::relu"}, + {"CUBLASLT_EPILOGUE_GELU", + getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::gelu"}, + {"CUBLASLT_EPILOGUE_GELU_AUX", + getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::gelu"}, {"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE", getLibraryHelperNamespace() + "blas_gemm::experimental::transform_desc_t::attribute::scale_type"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 9d9d9718eacf..8b02c4aafb94 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -33,7 +33,7 @@ enum class pointer_mode_t { alpha_device_vector_beta_zero, alpha_device_vector_beta_host }; -enum class epilogue_t { nop = 1, relu }; +enum class epilogue_t { nop = 1, relu, gelu }; class descriptor; using descriptor_ptr = descriptor *; @@ -714,7 +714,7 @@ template struct absmax_impl { /// scale_type==float && a_type==int8 && b_type==int8 && c_type==int32; /// scale_type==float && a_type==float && b_type==float && c_type==float. /// Currently, this function only supports beta==0 or beta==1. -/// Currently, this function only supports the relu epilogue. +/// Currently, this function only supports the relu and gelu epilogue. /// NOTE: Non-col-major matrix will be converted to col-major matrix before. /// TODO: Impl row-major matmul without layout conversion. /// multiplication and converted back after multiplication. @@ -777,9 +777,10 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, } if (compute_desc->_epilogue != epilogue_t::nop && - compute_desc->_epilogue != epilogue_t::relu) { + compute_desc->_epilogue != epilogue_t::relu && + compute_desc->epilogue != epilogue_t::gelu) { throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only " - "supports relu epilogue currently."); + "supports relu and gelu epilogue currently."); } if (!(compute_desc->_scale_type == library_data_t::real_int32 && @@ -1024,7 +1025,11 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, if (compute_desc->_epilogue != epilogue_t::nop) { ::dnnl::post_ops matmul_ops; - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); + if (compute_desc->_epilogue == epilogue_t::relu) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); + } else if (compute_desc->_epilogue == epilogue_t::gelu) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + } matmul_attr.set_post_ops(matmul_ops); } diff --git a/clang/test/dpct/cublaslt.cu b/clang/test/dpct/cublaslt.cu index a753badb08a9..b5e2cdb83887 100644 --- a/clang/test/dpct/cublaslt.cu +++ b/clang/test/dpct/cublaslt.cu @@ -232,9 +232,13 @@ void foo3() { // CHECK: dpct::blas_gemm::experimental::epilogue_t e; // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop; // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu; cublasLtEpilogue_t e; e = CUBLASLT_EPILOGUE_DEFAULT; e = CUBLASLT_EPILOGUE_RELU; + e = CUBLASLT_EPILOGUE_GELU; + e = CUBLASLT_EPILOGUE_GELU_AUX; } void foo4() {