From 40c6d79fb52f995f47507fedfeaae2ac05d9b35c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= Date: Wed, 4 Dec 2024 02:29:20 +0100 Subject: [PATCH] SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend (#10584) * [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend Move to compile time selection to backend to avoid latency at run time. Add it to all mkl gemm calls and only for NVIDIA backend. Signed-off-by: nscipione * Formatting * Address PR comments to increase readibility --------- Signed-off-by: nscipione --- ggml/src/ggml-sycl/CMakeLists.txt | 3 ++- ggml/src/ggml-sycl/dpct/helper.hpp | 43 +++++++++++++++++++++--------- ggml/src/ggml-sycl/ggml-sycl.cpp | 13 ++++++--- ggml/src/ggml-sycl/outprod.cpp | 16 +++++------ 4 files changed, 50 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 83f223fd7b6fc..3579a311aac07 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -68,7 +68,8 @@ else() target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") - target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl) + add_compile_definitions(GGML_SYCL_NVIDIA) + target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index c2f28bb49579e..d1b5dd87c6922 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -1689,9 +1689,14 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - data_b, ldb, beta_value, data_c, ldc); +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ q }, + a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, + beta_value, data_c, ldc); +#else + oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, + beta_value, data_c, ldc); +#endif } template @@ -1754,14 +1759,22 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; +#ifdef GGML_SYCL_NVIDIA + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, + matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, + matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), + matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, + matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, + &(matrix_info->groupsize_info)); +#else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); +#endif q.submit([&](sycl::handler &cgh) { @@ -1783,10 +1796,16 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); +#ifdef GGML_SYCL_NVIDIA oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, - data_c, ldc, stride_c, batch_size); + oneapi::mkl::backend_selector{ q }, a_trans, b_trans, m, n, k, + alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, + batch_size); +#else + oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, + stride_c, batch_size); +#endif } } // namespace detail diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 1310981e52f4c..135efb521a980 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2573,12 +2573,17 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; #if !GGML_SYCL_DNNL +# ifdef GGML_SYCL_NVIDIA SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, - dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, - src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), + oneapi::mkl::backend_selector{ *stream }, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, + ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +# else + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, + dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +# endif #else auto dnnl_stream = ctx.stream_dnnl(stream); DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index e61cdc2ca5d53..ef9af0b7633ab 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr try { // Perform matrix multiplication using oneMKL GEMM - oneapi::mkl::blas::column_major::gemm(*stream, - oneapi::mkl::transpose::nontrans, src1_op, - ne0, ne1, ne01, - alpha, - src0_d, ne00, - src1_d, ldb, - beta, - dst_d, ne0); +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ *stream }, + oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, + ne00, src1_d, ldb, beta, dst_d, ne0); +#else + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, + src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); +#endif } catch (sycl::exception const& exc) { std::cerr << exc.what() << std::endl;