From 8b796d120ae581459ddd2b6f9304dae2e131e987 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 4 Oct 2024 09:00:31 -0400 Subject: [PATCH] [rocsolver] Use enqueue_native_command ext This makes use of the enqueue_native_command dpc++ extension if it is available. This improves performance and integrates correctly with the dpc++ scheduler. Signed-off-by: JackAKirk --- .../backends/rocsolver/rocsolver_batch.cpp | 4 +- .../backends/rocsolver/rocsolver_helper.hpp | 11 ++ .../backends/rocsolver/rocsolver_lapack.cpp | 100 +++++++++--------- .../backends/rocsolver/rocsolver_task.hpp | 4 + 4 files changed, 67 insertions(+), 52 deletions(-) diff --git a/src/lapack/backends/rocsolver/rocsolver_batch.cpp b/src/lapack/backends/rocsolver/rocsolver_batch.cpp index 0b4b877e8..4363be306 100644 --- a/src/lapack/backends/rocsolver/rocsolver_batch.cpp +++ b/src/lapack/backends/rocsolver/rocsolver_batch.cpp @@ -527,7 +527,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a_dev); auto *info_ = reinterpret_cast(info); - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]), (int)n[i], a_ + offset, (int)lda[i], info_ + offset, (int)group_sizes[i]); offset += group_sizes[i]; @@ -627,7 +627,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a_dev); auto **b_ = reinterpret_cast(b_dev); - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]), (int)n[i], (int)nrhs[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], (int)group_sizes[i]); offset += group_sizes[i]; diff --git a/src/lapack/backends/rocsolver/rocsolver_helper.hpp b/src/lapack/backends/rocsolver/rocsolver_helper.hpp index dade1df64..34064fb09 100644 --- a/src/lapack/backends/rocsolver/rocsolver_helper.hpp +++ b/src/lapack/backends/rocsolver/rocsolver_helper.hpp @@ -166,6 +166,17 @@ class hip_error : virtual public std::runtime_error { hipError_t hip_err; \ HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId); +template +inline void rocsolver_native_named_func(const char *func_name, Func func, + rocsolver_status err, + rocsolver_handle handle, Types... args){ +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, args...) +#else + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...) +#endif +}; + inline rocblas_eform get_rocsolver_itype(std::int64_t itype) { switch (itype) { case 1: return rocblas_eform_ax; diff --git a/src/lapack/backends/rocsolver/rocsolver_lapack.cpp b/src/lapack/backends/rocsolver/rocsolver_lapack.cpp index e5e634ad0..d3e1b9e26 100644 --- a/src/lapack/backends/rocsolver/rocsolver_lapack.cpp +++ b/src/lapack/backends/rocsolver/rocsolver_lapack.cpp @@ -54,7 +54,7 @@ inline void gebrd(const char *func_name, Func func, sycl::queue &queue, std::int auto tauq_ = sc.get_mem(tauq_acc); auto taup_ = sc.get_mem(taup_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + rocsolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_); }); }); @@ -112,7 +112,7 @@ inline void geqrf(const char *func_name, Func func, sycl::queue &queue, std::int auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_); + rocsolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, tau_); }); }); } @@ -156,7 +156,7 @@ void getrf(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, auto ipiv32_ = sc.get_mem(ipiv32_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, ipiv32_, + rocsolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, ipiv32_, devInfo_); }); }); @@ -242,7 +242,7 @@ inline void getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = sc.get_mem(ipiv_acc); auto b_ = sc.get_mem(b_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_operation(trans), n, nrhs, a_, lda, ipiv_, b_, ldb); }); }); @@ -290,7 +290,7 @@ inline void gesvd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), get_rocsolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, scratch_, rocblas_workmode::rocblas_outofplace, devInfo_); @@ -337,7 +337,7 @@ inline void heevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); }); @@ -382,7 +382,7 @@ inline void hegvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_itype(itype), get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, devInfo_); }); @@ -424,7 +424,7 @@ inline void hetrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto e_ = sc.get_mem(e_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, d_, e_, tau_); }); }); @@ -471,7 +471,7 @@ inline void orgbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, a_, lda, tau_); }); }); @@ -504,7 +504,7 @@ inline void orgqr(const char *func_name, Func func, sycl::queue &queue, std::int auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); + rocsolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); } @@ -536,7 +536,7 @@ inline void orgtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, tau_); }); }); @@ -573,7 +573,7 @@ inline void ormtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, a_, lda, tau_, c_, ldc); }); @@ -625,7 +625,7 @@ inline void ormqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); }); @@ -661,7 +661,7 @@ inline void potrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, devInfo_); }); }); @@ -697,7 +697,7 @@ inline void potri(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, devInfo_); }); }); @@ -733,7 +733,7 @@ inline void potrs(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, nrhs, a_, lda, b_, ldb); }); }); @@ -773,7 +773,7 @@ inline void syevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); }); @@ -816,7 +816,7 @@ inline void sygvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_itype(itype), get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, devInfo_); }); @@ -857,7 +857,7 @@ inline void sytrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto e_ = sc.get_mem(e_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, d_, e_, tau_); }); }); @@ -902,7 +902,7 @@ inline void sytrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto ipiv32_ = sc.get_mem(ipiv32_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, ipiv32_, devInfo_); }); }); @@ -976,7 +976,7 @@ inline void ungbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, a_, lda, tau_); }); }); @@ -1009,7 +1009,7 @@ inline void ungqr(const char *func_name, Func func, sycl::queue &queue, std::int auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); + rocsolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); } @@ -1041,7 +1041,7 @@ inline void ungtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, tau_); }); }); @@ -1092,7 +1092,7 @@ inline void unmqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); }); @@ -1131,7 +1131,7 @@ inline void unmtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, a_, lda, tau_, c_, ldc); }); @@ -1177,7 +1177,7 @@ inline sycl::event gebrd(const char *func_name, Func func, sycl::queue &queue, s auto tauq_ = reinterpret_cast(tauq); auto taup_ = reinterpret_cast(taup); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + rocsolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_); }); }); @@ -1238,7 +1238,7 @@ inline sycl::event geqrf(const char *func_name, Func func, sycl::queue &queue, s auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_); + rocsolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, tau_); }); }); return done; @@ -1285,7 +1285,7 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto ipiv_ = reinterpret_cast(ipiv32); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, ipiv_, + rocsolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, ipiv_, devInfo_); }); }); @@ -1374,7 +1374,7 @@ inline sycl::event getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto b_ = reinterpret_cast(b); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_operation(trans), n, nrhs, a_, lda, ipiv_, b_, ldb); }); }); @@ -1426,7 +1426,7 @@ inline sycl::event gesvd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), get_rocsolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, scratch_, rocblas_workmode::rocblas_outofplace, devInfo_); @@ -1475,7 +1475,7 @@ inline sycl::event heevd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); }); @@ -1522,7 +1522,7 @@ inline sycl::event hegvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_itype(itype), get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, devInfo_); }); @@ -1567,7 +1567,7 @@ inline sycl::event hetrd(const char *func_name, Func func, sycl::queue &queue, auto e_ = reinterpret_cast(e); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, d_, e_, tau_); }); }); @@ -1619,7 +1619,7 @@ inline sycl::event orgbr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, a_, lda, tau_); }); }); @@ -1657,7 +1657,7 @@ inline sycl::event orgqr(const char *func_name, Func func, sycl::queue &queue, s auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); + rocsolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); return done; @@ -1693,7 +1693,7 @@ inline sycl::event orgtr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, tau_); }); }); @@ -1733,7 +1733,7 @@ inline sycl::event ormtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, a_, lda, tau_, c_, ldc); }); @@ -1788,7 +1788,7 @@ inline sycl::event ormqr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); }); @@ -1829,7 +1829,7 @@ inline sycl::event potrf(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, devInfo_); }); }); @@ -1872,7 +1872,7 @@ inline sycl::event potri(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, devInfo_); }); }); @@ -1914,7 +1914,7 @@ inline sycl::event potrs(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, nrhs, a_, lda, b_, ldb); }); }); @@ -1957,7 +1957,7 @@ inline sycl::event syevd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); }); @@ -2003,7 +2003,7 @@ inline sycl::event sygvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + rocsolver_native_named_func(func_name, func, err, handle, get_rocsolver_itype(itype), get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, devInfo_); }); @@ -2046,7 +2046,7 @@ inline sycl::event sytrd(const char *func_name, Func func, sycl::queue &queue, auto e_ = reinterpret_cast(e); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, d_, e_, tau_); }); }); @@ -2093,7 +2093,7 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, ipiv_, devInfo_); }); }); @@ -2173,7 +2173,7 @@ inline sycl::event ungbr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, a_, lda, tau_); }); }); @@ -2211,7 +2211,7 @@ inline sycl::event ungqr(const char *func_name, Func func, sycl::queue &queue, s auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); + rocsolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); return done; @@ -2247,7 +2247,7 @@ inline sycl::event ungtr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, lda, tau_); }); }); @@ -2301,7 +2301,7 @@ inline sycl::event unmqr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); }); @@ -2344,7 +2344,7 @@ inline sycl::event unmtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_side_mode(side), get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, a_, lda, tau_, c_, ldc); }); diff --git a/src/lapack/backends/rocsolver/rocsolver_task.hpp b/src/lapack/backends/rocsolver/rocsolver_task.hpp index 902b2f080..43be4b543 100644 --- a/src/lapack/backends/rocsolver/rocsolver_task.hpp +++ b/src/lapack/backends/rocsolver/rocsolver_task.hpp @@ -52,7 +52,11 @@ namespace rocsolver { template static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([f, queue](sycl::interop_handle ih){ +#else cgh.host_task([f, queue](cl::sycl::interop_handle ih) { +#endif auto sc = RocsolverScopedContextHandler(queue, ih); f(sc); });