Skip to content

Commit

Permalink
Throw unimplemented for spsv using no_optimize_alg
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 30, 2024
1 parent 325f794 commit 9bc77d1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/domains/sparse_linear_algebra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Currently known limitations:
an ``oneapi::mkl::unimplemented`` exception.
- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will
throw an ``oneapi::mkl::unimplemented`` exception.
- Using ``spsv`` with the algorithm ``spsv_alg::no_optimize_alg`` will throw an
``oneapi::mkl::unimplemented`` exception.
- oneMKL Interface does not provide a way to use non-default algorithms without
calling preprocess functions such as ``cusparseSpMM_preprocess`` or
``cusparseSpMV_preprocess``. Feel free to create an issue if this is needed.
Expand Down
45 changes: 34 additions & 11 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,34 @@ void init_spsv_descr(sycl::queue & /*queue*/, spsv_descr_t *p_spsv_descr) {

sycl::event release_spsv_descr(sycl::queue &queue, spsv_descr_t spsv_descr,
const std::vector<sycl::event> &dependencies) {
// Use dispatch_submit to ensure the backend's descriptor is kept alive as long as the buffers are used
auto functor = [=](CusparseScopedContextHandler &) {
if (!spsv_descr) {
return {};
}

auto release_functor = [=]() {
CUSPARSE_ERR_FUNC(cusparseSpSV_destroyDescr, spsv_descr->cu_descr);
delete spsv_descr;
};
return dispatch_submit(__func__, queue, dependencies, functor,
spsv_descr->last_optimized_A_handle, spsv_descr->last_optimized_x_handle,
spsv_descr->last_optimized_y_handle);

// Use dispatch_submit to ensure the backend's descriptor is kept alive as long as the buffers are used
// dispatch_submit can only be used if the descriptor's handles are valid
if (spsv_descr->last_optimized_A_handle &&
spsv_descr->last_optimized_A_handle->all_use_buffer() &&
spsv_descr->last_optimized_x_handle && spsv_descr->last_optimized_y_handle) {
auto dispatch_functor = [=](CusparseScopedContextHandler &) {
release_functor();
};
return dispatch_submit(
__func__, queue, dependencies, dispatch_functor, spsv_descr->last_optimized_A_handle,
spsv_descr->last_optimized_x_handle, spsv_descr->last_optimized_y_handle);
}

// Release used if USM is used or the descriptor has been released before spsv_optimize has succeeded
sycl::event event = queue.submit([&](sycl::handler &cgh) {
cgh.depends_on(dependencies);
cgh.host_task(release_functor);
});
return event;
}

inline auto get_cuda_spsv_alg(spsv_alg /*alg*/) {
Expand All @@ -71,18 +91,23 @@ inline auto get_cuda_spsv_alg(spsv_alg /*alg*/) {

void check_valid_spsv(const std::string &function_name, matrix_view A_view,
matrix_handle_t A_handle, dense_vector_handle_t x_handle,
dense_vector_handle_t y_handle, bool is_alpha_host_accessible) {
dense_vector_handle_t y_handle, spsv_alg alg, bool is_alpha_host_accessible) {
detail::check_valid_spsv_common(function_name, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
check_valid_matrix_properties(function_name, A_handle);
if (alg == spsv_alg::no_optimize_alg) {
throw mkl::unimplemented(
"sparse_blas", function_name,
"The backend does not support the algorithm ``spsv_alg::no_optimize_alg``.");
}
}

void spsv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
matrix_view A_view, matrix_handle_t A_handle, dense_vector_handle_t x_handle,
dense_vector_handle_t y_handle, spsv_alg alg, spsv_descr_t spsv_descr,
std::size_t &temp_buffer_size) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, alg, is_alpha_host_accessible);
auto functor = [=, &temp_buffer_size](CusparseScopedContextHandler &sc) {
auto cu_handle = sc.get_handle(queue);
auto cu_a = A_handle->backend_handle;
Expand All @@ -108,7 +133,7 @@ inline void common_spsv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_
matrix_view A_view, matrix_handle_t A_handle,
dense_vector_handle_t x_handle, dense_vector_handle_t y_handle,
spsv_alg alg, spsv_descr_t spsv_descr) {
check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle,
check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle, alg,
is_alpha_host_accessible);
if (!spsv_descr->buffer_size_called) {
throw mkl::uninitialized("sparse_blas", "spsv_optimize",
Expand Down Expand Up @@ -153,7 +178,6 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
}
common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, y_handle, alg,
spsv_descr);
// Ignore spsv_alg::no_optimize_alg as this step is mandatory for cuSPARSE
// Copy the buffer to extend its lifetime until the descriptor is free'd.
spsv_descr->workspace.set_buffer_untyped(workspace);

Expand Down Expand Up @@ -191,7 +215,6 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
}
common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, y_handle, alg,
spsv_descr);
// Ignore spsv_alg::no_optimize_alg as this step is mandatory for cuSPARSE
auto functor = [=](CusparseScopedContextHandler &sc) {
auto cu_handle = sc.get_handle(queue);
spsv_optimize_impl(cu_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, alg,
Expand All @@ -206,7 +229,7 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
dense_vector_handle_t y_handle, spsv_alg alg, spsv_descr_t spsv_descr,
const std::vector<sycl::event> &dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, alg, is_alpha_host_accessible);
if (A_handle->all_use_buffer() != spsv_descr->workspace.use_buffer()) {
detail::throw_incompatible_container(__func__);
}
Expand Down

0 comments on commit 9bc77d1

Please sign in to comment.