diff --git a/src/sparse_blas/backends/mkl_common/mkl_handles.cxx b/src/sparse_blas/backends/mkl_common/mkl_handles.cxx index f3ff5afa2..0a80130f2 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_handles.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_handles.cxx @@ -34,7 +34,7 @@ void init_dense_vector(sycl::queue & /*queue*/, template void check_can_reset_value_handle(const std::string &function_name, - InternalHandleT *internal_handle) { + InternalHandleT *internal_handle, bool expect_buffer) { if (internal_handle->get_value_type() != detail::get_data_type()) { throw oneapi::mkl::invalid_argument( "sparse_blas", function_name, @@ -42,13 +42,17 @@ void check_can_reset_value_handle(const std::string &function_name, data_type_to_str(internal_handle->get_value_type()) + " but got " + data_type_to_str(detail::get_data_type())); } + if (internal_handle->all_use_buffer() != expect_buffer) { + throw oneapi::mkl::invalid_argument( + "sparse_blas", function_name, "Cannot change the container type between buffer or USM"); + } } template void set_dense_vector_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size, sycl::buffer val) { - check_can_reset_value_handle(__FUNCTION__, dvhandle); + check_can_reset_value_handle(__FUNCTION__, dvhandle, true); dvhandle->size = size; dvhandle->set_buffer(val); } @@ -57,7 +61,7 @@ template void set_dense_vector_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size, fpType *val) { - check_can_reset_value_handle(__FUNCTION__, dvhandle); + check_can_reset_value_handle(__FUNCTION__, dvhandle, false); dvhandle->size = size; dvhandle->set_usm_ptr(val); } @@ -108,7 +112,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, oneapi::mkl::layout dense_layout, sycl::buffer val) { - check_can_reset_value_handle(__FUNCTION__, dmhandle); + check_can_reset_value_handle(__FUNCTION__, dmhandle, true); dmhandle->num_rows = num_rows; dmhandle->num_cols = num_cols; dmhandle->ld = ld; @@ -121,7 +125,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, oneapi::mkl::layout dense_layout, fpType *val) { - check_can_reset_value_handle(__FUNCTION__, dmhandle); + check_can_reset_value_handle(__FUNCTION__, dmhandle, false); dmhandle->num_rows = num_rows; dmhandle->num_cols = num_cols; dmhandle->ld = ld; @@ -190,8 +194,9 @@ void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p template void check_can_reset_sparse_handle(const std::string &function_name, - detail::sparse_matrix_handle *internal_smhandle) { - check_can_reset_value_handle(function_name, internal_smhandle); + detail::sparse_matrix_handle *internal_smhandle, + bool expect_buffer) { + check_can_reset_value_handle(function_name, internal_smhandle, expect_buffer); if (internal_smhandle->get_int_type() != detail::get_data_type()) { throw oneapi::mkl::invalid_argument( "sparse_blas", function_name, @@ -212,7 +217,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, sycl::buffer row_ind, sycl::buffer col_ind, sycl::buffer val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle, true); internal_smhandle->row_container.set_buffer(row_ind); internal_smhandle->col_container.set_buffer(col_ind); internal_smhandle->value_container.set_buffer(val); @@ -231,7 +236,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, intType *row_ind, intType *col_ind, fpType *val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle, false); internal_smhandle->row_container.set_usm_ptr(row_ind); internal_smhandle->col_container.set_usm_ptr(col_ind); internal_smhandle->value_container.set_usm_ptr(val); @@ -308,7 +313,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, sycl::buffer row_ptr, sycl::buffer col_ind, sycl::buffer val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle, true); internal_smhandle->row_container.set_buffer(row_ptr); internal_smhandle->col_container.set_buffer(col_ind); internal_smhandle->value_container.set_buffer(val); @@ -328,7 +333,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, intType *row_ptr, intType *col_ind, fpType *val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle, false); internal_smhandle->row_container.set_usm_ptr(row_ptr); internal_smhandle->col_container.set_usm_ptr(col_ind); internal_smhandle->value_container.set_usm_ptr(val);