Skip to content

Commit

Permalink
bias and gel_aux_bias epilogues
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Nov 6, 2024
1 parent f38770f commit 0a657ab
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
6 changes: 6 additions & 0 deletions clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,12 @@ void MapNames::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_BIAS",
getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bias"},
{"CUBLASLT_EPILOGUE_GELU_AUX_BIAS",
getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu_aux_bias"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
19 changes: 14 additions & 5 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class pointer_mode_t {
alpha_device_vector_beta_zero,
alpha_device_vector_beta_host
};
enum class epilogue_t { nop = 1, relu };
enum class epilogue_t { nop = 1, relu, bias, gelu_aux_bias };

class descriptor;
using descriptor_ptr = descriptor *;
Expand Down Expand Up @@ -714,7 +714,7 @@ template <typename T> struct absmax_impl {
/// scale_type==float && a_type==int8 && b_type==int8 && c_type==int32;
/// scale_type==float && a_type==float && b_type==float && c_type==float.
/// Currently, this function only supports beta==0 or beta==1.
/// Currently, this function only supports the relu epilogue.
/// Currently, this function only supports the relu, bias, gelu with bias epilogue.
/// NOTE: Non-col-major matrix will be converted to col-major matrix before.
/// TODO: Impl row-major matmul without layout conversion.
/// multiplication and converted back after multiplication.
Expand Down Expand Up @@ -777,9 +777,11 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
}

if (compute_desc->_epilogue != epilogue_t::nop &&
compute_desc->_epilogue != epilogue_t::relu) {
compute_desc->_epilogue != epilogue_t::relu &&
compute_desc->_epilogue != epilogue_t::bias &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu epilogue currently.");
"supports relu, bias, gelu_aux_bias epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -1024,7 +1026,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,

if (compute_desc->_epilogue != epilogue_t::nop) {
::dnnl::post_ops matmul_ops;
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
if (compute_desc->_epilogue == epilogue_t::relu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
} else if (compute_desc-> epilogue == epilogue::gelu_aux_bias) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
}
matmul_attr.set_post_ops(matmul_ops);
}

Expand Down
4 changes: 4 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,13 @@ void foo3() {
// CHECK: dpct::blas_gemm::experimental::epilogue_t e;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::bias;
// CHECK-NEXT: e = dpct::bias_gemm::experimental::epilogue_t::gelu_aux_bias;
cublasLtEpilogue_t e;
e = CUBLASLT_EPILOGUE_DEFAULT;
e = CUBLASLT_EPILOGUE_RELU;
e = CUBLASLT_EPILOGUE_BIAS;
e = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
}

void foo4() {
Expand Down

0 comments on commit 0a657ab

Please sign in to comment.