diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx index e870341ff..6e3038122 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx @@ -94,10 +94,13 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl:: oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, sycl::buffer /*workspace*/) { check_valid_spmm(__FUNCTION__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta); + auto internal_A_handle = detail::get_internal_handle(A_handle); + if (!internal_A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__FUNCTION__); + } if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) { return; } - auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; // TODO: Add support for spmm_optimize once the close-source oneMKL backend supports it. } @@ -112,10 +115,13 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, void * /*workspace*/, const std::vector &dependencies) { check_valid_spmm(__FUNCTION__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta); + auto internal_A_handle = detail::get_internal_handle(A_handle); + if (internal_A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__FUNCTION__); + } if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) { return detail::collapse_dependencies(queue, dependencies); } - auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; // TODO: Add support for spmm_optimize once the close-source oneMKL backend supports it. return detail::collapse_dependencies(queue, dependencies); diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx index 73efe4e7d..6950dc700 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx @@ -80,11 +80,14 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, sycl::buffer /*workspace*/) { check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta); + auto internal_A_handle = detail::get_internal_handle(A_handle); + if (!internal_A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__FUNCTION__); + } if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) { return; } sycl::event event; - auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; if (A_view.type_view == matrix_descr::triangular) { event = oneapi::mkl::sparse::optimize_trmv(queue, A_view.uplo_view, opA, A_view.diag_view, @@ -111,10 +114,13 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/, const std::vector &dependencies) { check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta); + auto internal_A_handle = detail::get_internal_handle(A_handle); + if (internal_A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__FUNCTION__); + } if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) { return detail::collapse_dependencies(queue, dependencies); } - auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; if (A_view.type_view == matrix_descr::triangular) { return oneapi::mkl::sparse::optimize_trmv(queue, A_view.uplo_view, opA, A_view.diag_view, diff --git a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx index bd8094f90..8fef1339d 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx @@ -80,10 +80,13 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, sycl::buffer /*workspace*/) { check_valid_spsv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, alg); + auto internal_A_handle = detail::get_internal_handle(A_handle); + if (!internal_A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__FUNCTION__); + } if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) { return; } - auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; auto event = oneapi::mkl::sparse::optimize_trsv(queue, A_view.uplo_view, opA, A_view.diag_view, internal_A_handle->backend_handle); @@ -100,10 +103,13 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, void * /*workspace*/, const std::vector &dependencies) { check_valid_spsv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, alg); + auto internal_A_handle = detail::get_internal_handle(A_handle); + if (internal_A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__FUNCTION__); + } if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) { return detail::collapse_dependencies(queue, dependencies); } - auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; return oneapi::mkl::sparse::optimize_trsv(queue, A_view.uplo_view, opA, A_view.diag_view, internal_A_handle->backend_handle, dependencies); diff --git a/src/sparse_blas/generic_container.hpp b/src/sparse_blas/generic_container.hpp index 46732722d..53bd50837 100644 --- a/src/sparse_blas/generic_container.hpp +++ b/src/sparse_blas/generic_container.hpp @@ -269,6 +269,12 @@ struct generic_sparse_handle { } }; +inline void throw_incompatible_container(const std::string& function_name) { + throw oneapi::mkl::invalid_argument( + "sparse_blas", function_name, + "Incompatible container types. All inputs and outputs must use the same container: buffer or USM"); +} + /** * Check that all internal containers use the same container. */ @@ -279,9 +285,7 @@ void check_all_containers_use_buffers(const std::string& function_name, bool first_use_buffer = first_internal_container->all_use_buffer(); for (const auto internal_container : { internal_containers... }) { if (internal_container->all_use_buffer() != first_use_buffer) { - throw oneapi::mkl::invalid_argument( - "sparse_blas", function_name, - "Incompatible container types. All inputs and outputs must use the same container: buffer or USM"); + throw_incompatible_container(function_name); } } }