Skip to content

Commit

Permalink
Add check container type is not changed
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Jun 7, 2024
1 parent 637eba6 commit 7f8d181
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/sparse_blas/backends/mkl_common/mkl_handles.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,25 @@ void init_dense_vector(sycl::queue & /*queue*/,

template <typename fpType, typename InternalHandleT>
void check_can_reset_value_handle(const std::string &function_name,
InternalHandleT *internal_handle) {
InternalHandleT *internal_handle, bool expect_buffer) {
if (internal_handle->get_value_type() != detail::get_data_type<fpType>()) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
"Incompatible data types expected " +
data_type_to_str(internal_handle->get_value_type()) + " but got " +
data_type_to_str(detail::get_data_type<fpType>()));
}
if (internal_handle->all_use_buffer() != expect_buffer) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name, "Cannot change the container type between buffer or USM");
}
}

template <typename fpType>
void set_dense_vector_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size,
sycl::buffer<fpType, 1> val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dvhandle);
check_can_reset_value_handle<fpType>(__FUNCTION__, dvhandle, true);
dvhandle->size = size;
dvhandle->set_buffer(val);
}
Expand All @@ -57,7 +61,7 @@ template <typename fpType>
void set_dense_vector_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size,
fpType *val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dvhandle);
check_can_reset_value_handle<fpType>(__FUNCTION__, dvhandle, false);
dvhandle->size = size;
dvhandle->set_usm_ptr(val);
}
Expand Down Expand Up @@ -108,7 +112,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_matrix_handle_t dmhandle,
std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld,
oneapi::mkl::layout dense_layout, sycl::buffer<fpType, 1> val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dmhandle);
check_can_reset_value_handle<fpType>(__FUNCTION__, dmhandle, true);
dmhandle->num_rows = num_rows;
dmhandle->num_cols = num_cols;
dmhandle->ld = ld;
Expand All @@ -121,7 +125,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_matrix_handle_t dmhandle,
std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld,
oneapi::mkl::layout dense_layout, fpType *val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dmhandle);
check_can_reset_value_handle<fpType>(__FUNCTION__, dmhandle, false);
dmhandle->num_rows = num_rows;
dmhandle->num_cols = num_cols;
dmhandle->ld = ld;
Expand Down Expand Up @@ -190,8 +194,9 @@ void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p

template <typename fpType, typename intType>
void check_can_reset_sparse_handle(const std::string &function_name,
detail::sparse_matrix_handle *internal_smhandle) {
check_can_reset_value_handle<fpType>(function_name, internal_smhandle);
detail::sparse_matrix_handle *internal_smhandle,
bool expect_buffer) {
check_can_reset_value_handle<fpType>(function_name, internal_smhandle, expect_buffer);
if (internal_smhandle->get_int_type() != detail::get_data_type<intType>()) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
Expand All @@ -212,7 +217,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, sycl::buffer<intType, 1> row_ind,
sycl::buffer<intType, 1> col_ind, sycl::buffer<fpType, 1> val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle, true);
internal_smhandle->row_container.set_buffer(row_ind);
internal_smhandle->col_container.set_buffer(col_ind);
internal_smhandle->value_container.set_buffer(val);
Expand All @@ -231,7 +236,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, intType *row_ind, intType *col_ind,
fpType *val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle, false);
internal_smhandle->row_container.set_usm_ptr(row_ind);
internal_smhandle->col_container.set_usm_ptr(col_ind);
internal_smhandle->value_container.set_usm_ptr(val);
Expand Down Expand Up @@ -308,7 +313,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, sycl::buffer<intType, 1> row_ptr,
sycl::buffer<intType, 1> col_ind, sycl::buffer<fpType, 1> val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle, true);
internal_smhandle->row_container.set_buffer(row_ptr);
internal_smhandle->col_container.set_buffer(col_ind);
internal_smhandle->value_container.set_buffer(val);
Expand All @@ -328,7 +333,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, intType *row_ptr, intType *col_ind,
fpType *val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle, false);
internal_smhandle->row_container.set_usm_ptr(row_ptr);
internal_smhandle->col_container.set_usm_ptr(col_ind);
internal_smhandle->value_container.set_usm_ptr(val);
Expand Down

0 comments on commit 7f8d181

Please sign in to comment.