From b3f7108002f91a2f11e4de02c0c8d28f79086e2c Mon Sep 17 00:00:00 2001 From: Abhilash Majumder Date: Wed, 30 Oct 2024 04:22:55 -0700 Subject: [PATCH 1/2] add gelu epilogue --- clang/lib/DPCT/MapNames.cpp | 4 ++++ .../dpct-rt/include/dpct/blas_gemm_utils.hpp | 15 +++++++++++---- clang/test/dpct/cublaslt.cu | 4 ++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index c4f89eba62a0..8a17df672847 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -1985,6 +1985,10 @@ void MapNames::setExplicitNamespaceMap( getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::nop"}, {"CUBLASLT_EPILOGUE_RELU", getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::relu"}, + {"CUBLASLT_EPILOGUE_GELU", + getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::gelu"}, + {"CUBLASLT_EPILOGUE_GELU_AUX", + getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::gelu"}, {"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE", getLibraryHelperNamespace() + "blas_gemm::experimental::transform_desc_t::attribute::scale_type"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 22cfda9dae8a..996729c3f98c 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -33,7 +33,7 @@ enum class pointer_mode_t { alpha_device_vector_beta_zero, alpha_device_vector_beta_host }; -enum class epilogue_t { nop = 1, relu }; +enum class epilogue_t { nop = 1, relu, gelu }; class descriptor; using descriptor_ptr = descriptor *; @@ -690,7 +690,7 @@ template struct absmax_impl { /// scale_type==float && a_type==int8 && b_type==int8 && c_type==int32; /// scale_type==float && a_type==float && b_type==float && c_type==float. /// Currently, this function only supports beta==0 or beta==1. -/// Currently, this function only supports the relu epilogue. +/// Currently, this function only supports the relu and gelu epilogue. /// NOTE: Non-col-major matrix will be converted to col-major matrix before. /// TODO: Impl row-major matmul without layout conversion. /// multiplication and converted back after multiplication. @@ -753,9 +753,10 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, } if (compute_desc->_epilogue != epilogue_t::nop && - compute_desc->_epilogue != epilogue_t::relu) { + compute_desc->_epilogue != epilogue_t::relu && + compute_desc->epilogue != epilogue_t::gelu) { throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only " - "supports relu epilogue currently."); + "supports relu and gelu epilogue currently."); } if (!(compute_desc->_scale_type == library_data_t::real_int32 && @@ -999,8 +1000,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, } if (compute_desc->_epilogue != epilogue_t::nop) { + ::dnnl::post_ops matmul_ops; + if(compute_desc->_epilogue == epilogue_t::relu){ matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); + } + else if(compute_desc->_epilogue == epilogue_t::gelu){ + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + } matmul_attr.set_post_ops(matmul_ops); } diff --git a/clang/test/dpct/cublaslt.cu b/clang/test/dpct/cublaslt.cu index 9efb693b56b6..333956f45578 100644 --- a/clang/test/dpct/cublaslt.cu +++ b/clang/test/dpct/cublaslt.cu @@ -220,9 +220,13 @@ void foo3() { // CHECK: dpct::blas_gemm::experimental::epilogue_t e; // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop; // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu; cublasLtEpilogue_t e; e = CUBLASLT_EPILOGUE_DEFAULT; e = CUBLASLT_EPILOGUE_RELU; + e = CUBLASLT_EPILOGUE_GELU; + e = CUBLASLT_EPILOGUE_GELU_AUX; } void foo4() { From 255bce7cca42a61bd57bec34bdc0e02b5362438d Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Fri, 1 Nov 2024 22:13:21 +0530 Subject: [PATCH 2/2] Update clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp Co-authored-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 996729c3f98c..421a29168744 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -1000,13 +1000,11 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, } if (compute_desc->_epilogue != epilogue_t::nop) { - ::dnnl::post_ops matmul_ops; - if(compute_desc->_epilogue == epilogue_t::relu){ - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); - } - else if(compute_desc->_epilogue == epilogue_t::gelu){ - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + if (compute_desc->_epilogue == epilogue_t::relu) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); + } else if (compute_desc->_epilogue == epilogue_t::gelu) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); } matmul_attr.set_post_ops(matmul_ops); }