diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 009bb9541..9f198b653 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -167,12 +167,21 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); +#else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, get_cublas_datatype(), cublas_gemm_algo); +#endif }); }); } @@ -608,12 +617,21 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cublasStatus_t err; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); +#else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, get_cublas_datatype(), cublas_gemm_algo); +#endif }); }); return done; @@ -687,6 +705,16 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr int64_t offset = 0; cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T( + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], + (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); +#else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], @@ -695,6 +723,7 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr get_cublas_datatype(), (int)ldb[i], &beta[i], (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); +#endif offset += group_size[i]; } }); @@ -792,12 +821,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); auto **b_ = reinterpret_cast(b); - CUBLAS_ERROR_FUNC_T_SYNC( + cublas_native_named_func( func_name, func, err, handle, get_cublas_side_mode(left_right[i]), get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]), get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], (int)group_size[i]); + offset += group_size[i]; } }); diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 0fe7e7c5a..0bd4e6274 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -190,6 +190,12 @@ class cuda_error : virtual public std::runtime_error { CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); \ cuStreamSynchronize(currentStreamId); +#define CUBLAS_ERROR_FUNC_T(name, func, err, handle, ...) \ + err = func(handle, __VA_ARGS__); \ + if (err != CUBLAS_STATUS_SUCCESS) { \ + throw cublas_error(std::string(name) + std::string(" : "), err); \ + } + #define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ err = func(handle, __VA_ARGS__); \ if (err != CUBLAS_STATUS_SUCCESS) { \ @@ -199,6 +205,27 @@ class cuda_error : virtual public std::runtime_error { CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); \ cuStreamSynchronize(currentStreamId); +template +inline void cublas_native_func(Func func, cublasStatus_t err, + cublasHandle_t handle, Types... args) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC(func, err, handle, args...) +#else + CUBLAS_ERROR_FUNC_SYNC(func, err, handle, args...) +#endif +}; + +template +inline void cublas_native_named_func(const char *func_name, Func func, + cublasStatus_t err, cublasHandle_t handle, + Types... args) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, args...) +#else + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...) +#endif +}; + inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) { switch (trn) { case oneapi::mkl::transpose::nontrans: return CUBLAS_OP_N; diff --git a/src/blas/backends/cublas/cublas_level1.cpp b/src/blas/backends/cublas/cublas_level1.cpp index 5f7087727..3b0699c87 100644 --- a/src/blas/backends/cublas/cublas_level1.cpp +++ b/src/blas/backends/cublas/cublas_level1.cpp @@ -53,7 +53,7 @@ inline void asum(const char *func_name, Func func, sycl::queue &queue, int64_t n auto res_ = sc.get_mem(res_acc); cublasStatus_t err; // ASUM does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -86,7 +86,7 @@ inline void scal(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); cublasStatus_t err; // SCAL does not support negative incx - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, std::abs(incx)); }); }); @@ -117,7 +117,7 @@ inline void axpy(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, incx, y_, incy); }); }); @@ -180,7 +180,7 @@ inline void rotg(const char *func_name, Func func, sycl::queue &queue, sycl::buf auto c_ = sc.get_mem(c_acc); auto s_ = sc.get_mem(s_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, a_, b_, c_, s_); + cublas_native_named_func(func_name, func, err, handle, a_, b_, c_, s_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -223,7 +223,7 @@ inline void rotm(const char *func_name, Func func, sycl::queue &queue, int64_t n auto y_ = sc.get_mem(y_acc); auto param_ = sc.get_mem(param_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, param_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, param_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -255,7 +255,7 @@ inline void copy(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); } @@ -294,7 +294,7 @@ inline void dot(const char *func_name, Func func, sycl::queue &queue, int64_t n, auto y_ = sc.get_mem(y_acc); auto res_ = sc.get_mem(res_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -338,7 +338,7 @@ inline void rot(const char *func_name, Func func, sycl::queue &queue, int64_t n, auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, (cuDataType2 *)&c, (cuDataType3 *)&s); }); }); @@ -376,7 +376,7 @@ void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, auto y_ = sc.get_mem(y_acc); auto res_ = sc.get_mem(res_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_SYNC(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_func(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -418,7 +418,7 @@ inline void rotmg(const char *func_name, Func func, sycl::queue &queue, sycl::bu auto y1_ = sc.get_mem(y1_acc); auto param_ = sc.get_mem(param_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); + cublas_native_named_func(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -466,7 +466,7 @@ inline void iamax(const char *func_name, Func func, sycl::queue &queue, int64_t cublasStatus_t err; // For negative incx, iamax returns 0. This behaviour is similar to that of // reference netlib BLAS. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -506,7 +506,7 @@ inline void swap(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); } @@ -552,7 +552,7 @@ inline void iamin(const char *func_name, Func func, sycl::queue &queue, int64_t cublasStatus_t err; // For negative incx, iamin returns 0. This behaviour is similar to that of // implemented as a reference IAMIN. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -601,7 +601,7 @@ inline void nrm2(const char *func_name, Func func, sycl::queue &queue, int64_t n auto res_ = sc.get_mem(res_acc); cublasStatus_t err; // NRM2 does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -648,7 +648,7 @@ inline sycl::event asum(const char *func_name, Func func, sycl::queue &queue, in } cublasStatus_t err; // ASUM does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -684,7 +684,7 @@ inline sycl::event scal(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); cublasStatus_t err; // SCAL does not support negative incx - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, std::abs(incx)); }); }); @@ -720,7 +720,7 @@ inline sycl::event axpy(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, incx, y_, incy); }); }); @@ -798,7 +798,7 @@ inline sycl::event rotg(const char *func_name, Func func, sycl::queue &queue, T1 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); } cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, a_, b_, c_, s_); + cublas_native_named_func(func_name, func, err, handle, a_, b_, c_, s_); if (results_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -836,7 +836,7 @@ inline sycl::event rotm(const char *func_name, Func func, sycl::queue &queue, in auto y_ = reinterpret_cast(y); auto param_ = reinterpret_cast(param); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, param_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, param_); }); }); return done; @@ -869,7 +869,7 @@ inline sycl::event copy(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); return done; @@ -909,7 +909,7 @@ inline sycl::event dot(const char *func_name, Func func, sycl::queue &queue, int cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); } cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -951,7 +951,7 @@ inline sycl::event rot(const char *func_name, Func func, sycl::queue &queue, int auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, (cuDataType2 *)&c, (cuDataType3 *)&s); }); }); @@ -993,7 +993,7 @@ sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int6 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); } cublasStatus_t err; - CUBLAS_ERROR_FUNC_SYNC(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_func(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -1058,12 +1058,12 @@ inline sycl::event rotmg(const char *func_name, Func func, sycl::queue &queue, T cublasStatus_t err; if (results_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); + cublas_native_named_func(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } else { auto y1_c = reinterpret_cast(&y1); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_c, param_); + cublas_native_named_func(func_name, func, err, handle, d1_, d2_, x1_, y1_c, param_); } }); }); @@ -1120,7 +1120,7 @@ inline sycl::event iamax(const char *func_name, Func func, sycl::queue &queue, i cublasStatus_t err; // For negative incx, iamax returns 0. This behaviour is similar to that of // reference iamax. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_p); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -1168,7 +1168,7 @@ inline sycl::event swap(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); return done; @@ -1221,7 +1221,7 @@ inline sycl::event iamin(const char *func_name, Func func, sycl::queue &queue, i cublasStatus_t err; // For negative incx, iamin returns 0. This behaviour is similar to that of // implemented iamin. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_p); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -1277,7 +1277,7 @@ inline sycl::event nrm2(const char *func_name, Func func, sycl::queue &queue, in } cublasStatus_t err; // NRM2 does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } diff --git a/src/blas/backends/cublas/cublas_level2.cpp b/src/blas/backends/cublas/cublas_level2.cpp index 8f711243b..5ce6e5eaf 100644 --- a/src/blas/backends/cublas/cublas_level2.cpp +++ b/src/blas/backends/cublas/cublas_level2.cpp @@ -46,7 +46,7 @@ inline void gemv(const char *func_name, Func func, sycl::queue &queue, transpose auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -83,7 +83,7 @@ inline void gbmv(const char *func_name, Func func, sycl::queue &queue, transpose auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, kl, ku, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -120,7 +120,7 @@ inline void ger(const char *func_name, Func func, sycl::queue &queue, int64_t m, auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); }); @@ -157,7 +157,7 @@ inline void hbmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -192,7 +192,7 @@ inline void hemv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -227,7 +227,7 @@ inline void her(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_, lda); }); @@ -262,7 +262,7 @@ inline void her2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -298,7 +298,7 @@ inline void hpmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -333,7 +333,7 @@ inline void hpr(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_); }); @@ -367,7 +367,7 @@ inline void hpr2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -402,7 +402,7 @@ inline void sbmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -438,7 +438,7 @@ inline void symv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -471,7 +471,7 @@ inline void syr(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_, lda); }); @@ -507,7 +507,7 @@ inline void syr2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -546,7 +546,7 @@ inline void spmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -579,7 +579,7 @@ inline void spr(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_); }); @@ -613,7 +613,7 @@ inline void spr2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -646,7 +646,7 @@ inline void tbmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -682,7 +682,7 @@ inline void tbsv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -718,7 +718,7 @@ inline void tpmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -753,7 +753,7 @@ inline void tpsv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -788,7 +788,7 @@ inline void trmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); @@ -823,7 +823,7 @@ inline void trsv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); @@ -864,7 +864,7 @@ inline sycl::event gemv(const char *func_name, Func func, sycl::queue &queue, tr auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -904,7 +904,7 @@ inline sycl::event gbmv(const char *func_name, Func func, sycl::queue &queue, tr auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, kl, ku, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -944,7 +944,7 @@ inline sycl::event ger(const char *func_name, Func func, sycl::queue &queue, int auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); }); @@ -985,7 +985,7 @@ inline sycl::event hbmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1022,7 +1022,7 @@ inline sycl::event hemv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1060,7 +1060,7 @@ inline sycl::event her(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_, lda); }); @@ -1098,7 +1098,7 @@ inline sycl::event her2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -1136,7 +1136,7 @@ inline sycl::event hpmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1174,7 +1174,7 @@ inline sycl::event hpr(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_); }); @@ -1212,7 +1212,7 @@ inline sycl::event hpr2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -1251,7 +1251,7 @@ inline sycl::event sbmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1289,7 +1289,7 @@ inline sycl::event symv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1326,7 +1326,7 @@ inline sycl::event syr(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_, lda); }); @@ -1366,7 +1366,7 @@ inline sycl::event syr2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -1407,7 +1407,7 @@ inline sycl::event spmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1444,7 +1444,7 @@ inline sycl::event spr(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_); }); @@ -1481,7 +1481,7 @@ inline sycl::event spr2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -1519,7 +1519,7 @@ inline sycl::event tbmv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -1559,7 +1559,7 @@ inline sycl::event tbsv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -1598,7 +1598,7 @@ inline sycl::event tpmv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -1637,7 +1637,7 @@ inline sycl::event tpsv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -1676,7 +1676,7 @@ inline sycl::event trmv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); @@ -1715,7 +1715,7 @@ inline sycl::event trsv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp index 5ea4e2152..be634a15c 100644 --- a/src/blas/backends/cublas/cublas_level3.cpp +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -47,7 +47,7 @@ inline void gemm(const char *func_name, Func func, sycl::queue &queue, transpose auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -94,10 +94,17 @@ inline void gemm_ex(DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, sycl::que auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUBLAS_ERROR_FUNC_SYNC(cublasGemmEx, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#else + CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, + a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, + ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#endif }); }); } @@ -139,7 +146,7 @@ inline void symm(const char *func_name, Func func, sycl::queue &queue, side left auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -178,7 +185,7 @@ inline void hemm(const char *func_name, Func func, sycl::queue &queue, side left auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -211,7 +218,7 @@ inline void syrk(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, (cuDataType *)&beta, c_, ldc); @@ -250,7 +257,7 @@ inline void herk(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuScalarType *)&alpha, a_, lda, (cuScalarType *)&beta, c_, ldc); @@ -288,7 +295,7 @@ inline void syr2k(const char *func_name, Func func, sycl::queue &queue, uplo upp auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); @@ -328,7 +335,7 @@ inline void her2k(const char *func_name, Func func, sycl::queue &queue, uplo upp auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuScalarType *)&beta, c_, ldc); @@ -368,7 +375,7 @@ inline void trmm(const char *func_name, Func func, sycl::queue &queue, side left auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, b_, ldb); @@ -404,7 +411,7 @@ inline void trsm(const char *func_name, Func func, sycl::queue &queue, side left auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb); @@ -446,7 +453,7 @@ inline sycl::event gemm(const char *func_name, Func func, sycl::queue &queue, tr auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -492,10 +499,17 @@ inline sycl::event gemm_ex_usm(DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUBLAS_ERROR_FUNC_SYNC(cublasGemmEx, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#else + CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, + a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, + ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#endif }); }); return done; @@ -541,7 +555,7 @@ inline sycl::event symm(const char *func_name, Func func, sycl::queue &queue, si auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -583,7 +597,7 @@ inline sycl::event hemm(const char *func_name, Func func, sycl::queue &queue, si auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -620,7 +634,7 @@ inline sycl::event syrk(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, (cuDataType *)&beta, c_, ldc); @@ -662,7 +676,7 @@ inline sycl::event herk(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuScalarType *)&alpha, a_, lda, (cuScalarType *)&beta, c_, ldc); @@ -703,7 +717,7 @@ inline sycl::event syr2k(const char *func_name, Func func, sycl::queue &queue, u auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); @@ -747,7 +761,7 @@ inline sycl::event her2k(const char *func_name, Func func, sycl::queue &queue, u auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuScalarType *)&beta, c_, ldc); @@ -791,7 +805,7 @@ inline sycl::event trmm(const char *func_name, Func func, sycl::queue &queue, si auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, b_, ldb); @@ -831,7 +845,7 @@ inline sycl::event trsm(const char *func_name, Func func, sycl::queue &queue, si auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb); diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index a486aafee..4fbdfdda2 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -67,7 +67,11 @@ static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { #else template static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([f, queue](sycl::interop_handle ih){ +#else cgh.host_task([f, queue](sycl::interop_handle ih) { +#endif auto sc = CublasScopedContextHandler(queue, ih); f(sc); }); diff --git a/src/lapack/backends/cusolver/cusolver_batch.cpp b/src/lapack/backends/cusolver/cusolver_batch.cpp index 59fa47f84..f4017f873 100644 --- a/src/lapack/backends/cusolver/cusolver_batch.cpp +++ b/src/lapack/backends/cusolver/cusolver_batch.cpp @@ -53,7 +53,7 @@ inline void geqrf_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -137,8 +137,8 @@ inline void getri_batch(const char *func_name, Func func, sycl::queue &queue, st sizeof(T *) * batch_size); auto **scratch_dev_ = reinterpret_cast(scratch_dev); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, - scratch_dev_, lda, info_, batch_size) + blas::cublas::cublas_native_named_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, + scratch_dev_, lda, info_, batch_size); free(a_batched); free(scratch_batched); @@ -227,7 +227,9 @@ inline void getrs_batch(const char *func_name, Func func, sycl::queue &queue, nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i, b_ + stride_b * i, ldb, nullptr); } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif }); }); } @@ -283,7 +285,7 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (std::int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, scratch_, ipiv_ + stride_ipiv * i, devInfo_ + i); } }); @@ -340,7 +342,7 @@ inline void orgqr_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -388,7 +390,7 @@ inline void potrf_batch(const char *func_name, Func func, sycl::queue &queue, auto **a_dev_ = reinterpret_cast(a_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, a_dev_, (int)lda, nullptr, (int)batch_size); free(a_batched); @@ -452,7 +454,7 @@ inline void potrs_batch(const char *func_name, Func func, sycl::queue &queue, auto **a_dev_ = reinterpret_cast(a_dev); auto **b_dev_ = reinterpret_cast(b_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, (int)nrhs, a_dev_, (int)lda, b_dev_, ldb, nullptr, (int)batch_size); @@ -506,7 +508,7 @@ inline void ungqr_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -551,7 +553,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -605,7 +607,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], a_[global_id], lda[group_id], tau_[global_id], scratch_, scratchpad_size, nullptr); } @@ -661,7 +663,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, scratchpad_, ipiv_ + stride_ipiv * i, devInfo_ + i); } }); @@ -744,7 +746,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], a_[global_id], lda[group_id], scratch_, ipiv32[global_id], devInfo + global_id); } @@ -857,8 +859,8 @@ sycl::event getri_batch(const char *func_name, Func func, sycl::queue &queue, st sizeof(T *) * batch_size); auto **scratch_dev_ = reinterpret_cast(scratch_dev); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, - scratch_dev_, lda, devInfo, batch_size) + blas::cublas::cublas_native_named_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, + scratch_dev_, lda, devInfo, batch_size); free(a_batched); free(scratch_batched); @@ -972,7 +974,9 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i, b_ + stride_b * i, ldb, nullptr); } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif sycl::free(ipiv32, queue); }); @@ -1062,7 +1066,9 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu ipiv32[global_id], b_[global_id], ldb[group_id], nullptr); } } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif for (int64_t i = 0; i < batch_size; ++i) sycl::free(ipiv32[i], queue); @@ -1112,7 +1118,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -1165,7 +1171,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], k[group_id], a_[global_id], lda[group_id], tau_[global_id], scratch_, scratchpad_size, nullptr); @@ -1219,7 +1225,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu auto **a_dev_ = reinterpret_cast(a_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, a_dev_, (int)lda, nullptr, (int)batch_size); free(a_batched); @@ -1281,7 +1287,9 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu (int)group_sizes[i]); offset += group_sizes[i]; } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif cuMemFree(a_dev); }); @@ -1342,7 +1350,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu auto **a_dev_ = reinterpret_cast(a_dev); auto **b_dev_ = reinterpret_cast(b_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, (int)nrhs, a_dev_, (int)lda, b_dev_, ldb, nullptr, (int)batch_size); @@ -1421,7 +1429,9 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu b_ + offset, (int)ldb[i], info_, (int)group_sizes[i]); offset += group_sizes[i]; } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif }); }); return done; @@ -1467,7 +1477,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -1520,7 +1530,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], k[group_id], a_[global_id], lda[group_id], tau_[global_id], scratch_, scratchpad_size, nullptr); diff --git a/src/lapack/backends/cusolver/cusolver_helper.hpp b/src/lapack/backends/cusolver/cusolver_helper.hpp index e10f56b36..954d41246 100644 --- a/src/lapack/backends/cusolver/cusolver_helper.hpp +++ b/src/lapack/backends/cusolver/cusolver_helper.hpp @@ -200,6 +200,17 @@ class cuda_error : virtual public std::runtime_error { } \ CUSOLVER_SYNC(err, handle) +template +inline void cusolver_native_named_func(const char *func_name, Func func, + cusolverStatus_t err, + cusolverDnHandle_t handle, Types... args){ +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, args...) +#else + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...) +#endif +}; + inline cusolverEigType_t get_cusolver_itype(std::int64_t itype) { switch (itype) { case 1: return CUSOLVER_EIG_TYPE_1; diff --git a/src/lapack/backends/cusolver/cusolver_lapack.cpp b/src/lapack/backends/cusolver/cusolver_lapack.cpp index 0c7aaefc8..c8190f50d 100644 --- a/src/lapack/backends/cusolver/cusolver_lapack.cpp +++ b/src/lapack/backends/cusolver/cusolver_lapack.cpp @@ -57,7 +57,7 @@ inline void gebrd(const char *func_name, Func func, sycl::queue &queue, std::int auto taup_ = sc.get_mem(taup_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_, scratch_, scratchpad_size, nullptr); }); }); @@ -117,7 +117,7 @@ inline void geqrf(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -164,7 +164,7 @@ void getrf(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, scratch_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv32_, devInfo_); }); }); @@ -243,7 +243,7 @@ inline void getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = sc.get_mem(ipiv_acc); auto b_ = sc.get_mem(b_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), n, nrhs, a_, lda, ipiv_, b_, ldb, nullptr); }); }); @@ -292,7 +292,7 @@ inline void gesvd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; // rwork is set to nullptr. If set it is filled with information from the superdiagonal. - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_jobsvd(jobu), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_jobsvd(jobu), get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, scratch_, scratchpad_size, nullptr, devInfo_); }); @@ -338,7 +338,7 @@ inline void heevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -383,7 +383,7 @@ inline void hegvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo_); }); @@ -430,7 +430,7 @@ inline void hetrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -480,7 +480,7 @@ inline void orgbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -515,7 +515,7 @@ inline void orgqr(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -550,7 +550,7 @@ inline void orgtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -589,7 +589,7 @@ inline void ormtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -644,7 +644,7 @@ inline void ormqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -682,7 +682,7 @@ inline void potrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -720,7 +720,7 @@ inline void potri(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -757,7 +757,7 @@ inline void potrs(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nrhs, a_, lda, b_, ldb, nullptr); }); }); @@ -797,7 +797,7 @@ inline void syevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -840,7 +840,7 @@ inline void sygvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo_); }); @@ -886,7 +886,7 @@ inline void sytrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -934,7 +934,7 @@ inline void sytrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, ipiv32_, scratch_, scratchpad_size, devInfo_); }); }); @@ -1009,7 +1009,7 @@ inline void ungbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1044,7 +1044,7 @@ inline void ungqr(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1079,7 +1079,7 @@ inline void ungtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1132,7 +1132,7 @@ inline void unmqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -1173,7 +1173,7 @@ inline void unmtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -1224,7 +1224,7 @@ inline sycl::event gebrd(const char *func_name, Func func, sycl::queue &queue, s auto taup_ = reinterpret_cast(taup); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_, scratch_, scratchpad_size, nullptr); }); }); @@ -1286,7 +1286,7 @@ inline sycl::event geqrf(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1335,7 +1335,7 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s auto scratch_ = reinterpret_cast(scratchpad); auto ipiv_ = reinterpret_cast(ipiv32); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv_, devInfo_); }); }); @@ -1422,7 +1422,7 @@ inline sycl::event getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto b_ = reinterpret_cast(b); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), n, nrhs, a_, lda, ipiv_, b_, ldb, nullptr); }); }); @@ -1475,7 +1475,7 @@ inline sycl::event gesvd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; // rwork is set to nullptr. If set it is filled with information from the superdiagonal. - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_jobsvd(jobu), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_jobsvd(jobu), get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, scratch_, scratchpad_size, nullptr, devInfo_); }); @@ -1523,7 +1523,7 @@ inline sycl::event heevd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -1570,7 +1570,7 @@ inline sycl::event hegvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo); }); @@ -1618,7 +1618,7 @@ inline sycl::event hetrd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -1673,7 +1673,7 @@ inline sycl::event orgbr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1712,7 +1712,7 @@ inline sycl::event orgqr(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1750,7 +1750,7 @@ inline sycl::event orgtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1791,7 +1791,7 @@ inline sycl::event ormtr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -1848,7 +1848,7 @@ inline sycl::event ormqr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -1890,7 +1890,7 @@ inline sycl::event potrf(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -1933,7 +1933,7 @@ inline sycl::event potri(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -1976,7 +1976,7 @@ inline sycl::event potrs(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nrhs, a_, lda, b_, ldb, nullptr); }); }); @@ -2019,7 +2019,7 @@ inline sycl::event syevd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -2065,7 +2065,7 @@ inline sycl::event sygvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo); }); @@ -2111,7 +2111,7 @@ inline sycl::event sytrd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -2161,7 +2161,7 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, ipiv_, scratch_, scratchpad_size, devInfo_); }); }); @@ -2245,7 +2245,7 @@ inline sycl::event ungbr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -2284,7 +2284,7 @@ inline sycl::event ungqr(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -2322,7 +2322,7 @@ inline sycl::event ungtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -2377,7 +2377,7 @@ inline sycl::event unmqr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -2421,7 +2421,7 @@ inline sycl::event unmtr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -2455,10 +2455,9 @@ inline void gebrd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GEBRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2507,10 +2506,9 @@ inline void geqrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GEQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2539,10 +2537,9 @@ inline void gesvd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GESVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2571,10 +2568,9 @@ inline void getrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GETRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2632,12 +2628,11 @@ inline void heevd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define HEEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2665,12 +2660,11 @@ inline void hegvd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, ldb, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define HEGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2697,11 +2691,10 @@ inline void hetrd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, nullptr, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define HETRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2738,11 +2731,10 @@ inline void orgbr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2769,11 +2761,10 @@ inline void orgtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2799,11 +2790,10 @@ inline void orgqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2845,12 +2835,11 @@ inline void ormqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORMQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2878,12 +2867,11 @@ inline void ormtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2911,11 +2899,10 @@ inline void potrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define POTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2959,11 +2946,10 @@ inline void potri_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define POTRI_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2991,10 +2977,9 @@ inline void sytrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, n, nullptr, lda, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3022,12 +3007,11 @@ inline void syevd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3055,12 +3039,11 @@ inline void sygvd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, ldb, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3087,11 +3070,10 @@ inline void sytrd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, nullptr, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYTRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3148,11 +3130,10 @@ inline void ungbr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3179,11 +3160,10 @@ inline void ungqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3209,11 +3189,10 @@ inline void ungtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3257,12 +3236,11 @@ inline void unmqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNMQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3290,12 +3268,11 @@ inline void unmtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ diff --git a/src/lapack/backends/cusolver/cusolver_task.hpp b/src/lapack/backends/cusolver/cusolver_task.hpp index 00e6e26be..6a35dea84 100644 --- a/src/lapack/backends/cusolver/cusolver_task.hpp +++ b/src/lapack/backends/cusolver/cusolver_task.hpp @@ -50,10 +50,13 @@ namespace cusolver { template static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([f, queue](sycl::interop_handle ih){ +#else cgh.host_task([f, queue](sycl::interop_handle ih) { +#endif auto sc = CusolverScopedContextHandler(queue, ih); f(sc); - sc.wait_stream(queue); }); } diff --git a/tests/unit_tests/lapack/common/dependency_check.cpp b/tests/unit_tests/lapack/common/dependency_check.cpp index 30d2d1d4a..86e313aa3 100644 --- a/tests/unit_tests/lapack/common/dependency_check.cpp +++ b/tests/unit_tests/lapack/common/dependency_check.cpp @@ -56,8 +56,7 @@ bool check_dependency(sycl::queue queue, sycl::event in_event, sycl::event func_ do { func_status = func_event.get_info(); - } while (func_status != sycl::info::event_command_status::running && - func_status != sycl::info::event_command_status::complete); + } while (func_status != sycl::info::event_command_status::complete); in_status = in_event.get_info(); /* Print results */