diff --git a/CMakeLists.txt b/CMakeLists.txt index e0c2de1ff..184314e7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ option(ENABLE_PORTFFT_BACKEND "Enable the portFFT DFT backend for the DFT interf # sparse option(ENABLE_CUSPARSE_BACKEND "Enable the cuSPARSE backend for the SPARSE_BLAS interface" OFF) +option(ENABLE_ROCSPARSE_BACKEND "Enable the rocSPARSE backend for the SPARSE_BLAS interface" OFF) set(ONEMATH_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler") set(HIP_TARGETS "" CACHE STRING "Target HIP architectures") @@ -106,7 +107,8 @@ if(ENABLE_MKLGPU_BACKEND endif() if(ENABLE_MKLCPU_BACKEND OR ENABLE_MKLGPU_BACKEND - OR ENABLE_CUSPARSE_BACKEND) + OR ENABLE_CUSPARSE_BACKEND + OR ENABLE_ROCSPARSE_BACKEND) list(APPEND DOMAINS_LIST "sparse_blas") endif() @@ -134,7 +136,7 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMATH_SYCL_IMPLEMENTATION STREQUAL "dpc++") endif() else() if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_CUSPARSE_BACKEND - OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND) + OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND OR ENABLE_ROCSPARSE_BACKEND) set(CMAKE_CXX_COMPILER "clang++") elseif(ENABLE_MKLGPU_BACKEND) if(UNIX) diff --git a/README.md b/README.md index 8fb74f316..f74d6600a 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ oneMath is part of the [UXL Foundation](http://www.uxlfoundation.org). - oneMath - oneMath selector + oneMath + oneMath selector Intel(R) oneAPI Math Kernel Library (oneMKL) x86 CPU, Intel GPU @@ -61,7 +61,11 @@ oneMath is part of the [UXL Foundation](http://www.uxlfoundation.org). AMD GPU - AMD rocFFT + AMD rocFFT + AMD GPU + + + AMD rocSPARSE AMD GPU @@ -333,7 +337,7 @@ Supported compilers include: Dynamic, Static - SPARSE_BLAS + SPARSE_BLAS x86 CPU Intel(R) oneMKL Intel DPC++ @@ -351,6 +355,12 @@ Supported compilers include: Open DPC++ Dynamic, Static + + AMD GPU + AMD rocSPARSE + Open DPC++ + Dynamic, Static + @@ -537,6 +547,7 @@ Product | Supported Version | License [AMD rocRAND](https://github.com/ROCm/rocRAND) | 5.1.0 | [AMD License](https://github.com/ROCm/rocRAND/blob/develop/LICENSE.txt) [AMD rocSOLVER](https://github.com/ROCm/rocSOLVER) | 5.0.0 | [AMD License](https://github.com/ROCm/rocSOLVER/blob/develop/LICENSE.md) [AMD rocFFT](https://github.com/ROCm/rocFFT) | rocm-5.4.3 | [AMD License](https://github.com/ROCm/rocFFT/blob/rocm-5.4.3/LICENSE.md) +[AMD rocSPARSE](https://github.com/ROCm/rocSPARSE) | 3.1.2 | [AMD License](https://github.com/ROCm/rocSPARSE/blob/develop/LICENSE.md) [NETLIB LAPACK](https://www.netlib.org/) | [5d4180c](https://github.com/Reference-LAPACK/lapack/commit/5d4180cf8288ae6ad9a771d18793d15bd0c5643c) | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt) [portBLAS](https://github.com/codeplaysoftware/portBLAS) | 0.1 | [Apache License v2.0](https://github.com/codeplaysoftware/portBLAS/blob/main/LICENSE) [portFFT](https://github.com/codeplaysoftware/portFFT) | 0.1 | [Apache License v2.0](https://github.com/codeplaysoftware/portFFT/blob/main/LICENSE) diff --git a/cmake/FindCompiler.cmake b/cmake/FindCompiler.cmake index 3467f23d9..b4f185c1d 100644 --- a/cmake/FindCompiler.cmake +++ b/cmake/FindCompiler.cmake @@ -43,7 +43,7 @@ if(is_dpcpp) list(APPEND UNIX_INTERFACE_LINK_OPTIONS -fsycl-targets=nvptx64-nvidia-cuda) elseif(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND - OR ENABLE_ROCSOLVER_BACKEND) + OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND) list(APPEND UNIX_INTERFACE_COMPILE_OPTIONS -fsycl-targets=amdgcn-amd-amdhsa -fsycl-unnamed-lambda -Xsycl-target-backend --offload-arch=${HIP_TARGETS}) @@ -52,7 +52,7 @@ if(is_dpcpp) --offload-arch=${HIP_TARGETS}) endif() if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUSPARSE_BACKEND OR ENABLE_ROCBLAS_BACKEND - OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) + OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND) set_target_properties(ONEMATH::SYCL::SYCL PROPERTIES INTERFACE_COMPILE_OPTIONS "${UNIX_INTERFACE_COMPILE_OPTIONS}" INTERFACE_LINK_OPTIONS "${UNIX_INTERFACE_LINK_OPTIONS}" @@ -69,7 +69,7 @@ if(is_dpcpp) INTERFACE_LINK_LIBRARIES ${SYCL_LIBRARY}) endif() - if(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) + if(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND) # Allow find_package(HIP) to find the correct path to libclang_rt.builtins.a # HIP's CMake uses the command `${HIP_CXX_COMPILER} -print-libgcc-file-name --rtlib=compiler-rt` to find this path. # This can print a non-existing file if the compiler used is icpx. diff --git a/docs/building_the_project_with_dpcpp.rst b/docs/building_the_project_with_dpcpp.rst index 49fd51abb..4f1c076ef 100644 --- a/docs/building_the_project_with_dpcpp.rst +++ b/docs/building_the_project_with_dpcpp.rst @@ -121,6 +121,9 @@ The most important supported build options are: * - ENABLE_ROCRAND_BACKEND - True, False - False + * - ENABLE_ROCSPARSE_BACKEND + - True, False + - False * - ENABLE_MKLCPU_THREAD_TBB - True, False - True @@ -197,14 +200,14 @@ Building for ROCm ^^^^^^^^^^^^^^^^^ The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND``, -``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND`` and -``ENABLE_ROCRAND_BACKEND``. +``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND``, +``ENABLE_ROCRAND_BACKEND``, and ``ENABLE_ROCSPARSE_BACKEND``. -For *RocBLAS*, *RocSOLVER* and *RocRAND*, the target device architecture must be -set. This can be set with using the ``HIP_TARGETS`` parameter. For example, to -enable a build for MI200 series GPUs, ``-DHIP_TARGETS=gfx90a`` should be set. -Currently, DPC++ can only build for a single HIP target at a time. This may -change in future versions. +For *RocBLAS*, *RocSOLVER*, *RocRAND*, and *RocSPARSE*, the target device +architecture must be set. This can be set with using the ``HIP_TARGETS`` +parameter. For example, to enable a build for MI200 series GPUs, +``-DHIP_TARGETS=gfx90a`` should be set. Currently, DPC++ can only build for a +single HIP target at a time. This may change in future versions. A few often-used architectures are listed below: @@ -393,7 +396,8 @@ disabled: -DENABLE_MKLGPU_BACKEND=False \ -DENABLE_ROCFFT_BACKEND=True \ -DENABLE_ROCBLAS_BACKEND=True \ - -DENABLE_ROCSOLVER_BACKEND=True \ + -DENABLE_ROCSOLVER_BACKEND=True \ + -DENABLE_ROCSPARSE_BACKEND=True \ -DHIP_TARGETS=gfx90a \ -DBUILD_FUNCTIONAL_TESTS=False diff --git a/docs/domains/sparse_linear_algebra.rst b/docs/domains/sparse_linear_algebra.rst index 915151a36..0af792a3e 100644 --- a/docs/domains/sparse_linear_algebra.rst +++ b/docs/domains/sparse_linear_algebra.rst @@ -68,6 +68,31 @@ Currently known limitations: ``cusparseSpMV_preprocess``. Feel free to create an issue if this is needed. +rocSPARSE backend +---------------- + +Currently known limitations: + +- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will + throw a ``oneapi::math::unimplemented`` exception. +- The COO format requires the indices to be sorted by row then by column. See + the `rocSPARSE COO documentation + `_. + Sparse operations using matrices with the COO format without the property + ``matrix_property::sorted`` will throw a ``oneapi::math::unimplemented`` + exception. +- The CSR format requires the column indices to be sorted within each row. See + the `rocSPARSE CSR documentation + `_. + Sparse operations using matrices with the CSR format without the property + ``matrix_property::sorted`` will throw a ``oneapi::math::unimplemented`` + exception. +- The same sparse matrix handle cannot be reused for multiple operations + ``spmm``, ``spmv``, or ``spsv``. Doing so will throw a + ``oneapi::math::unimplemented`` exception. See `#332 + `_. + + Operation algorithms mapping ---------------------------- @@ -89,33 +114,43 @@ spmm * - ``spmm_alg`` value - MKLCPU/MKLGPU - cuSPARSE + - rocSPARSE * - ``default_alg`` - none - ``CUSPARSE_SPMM_ALG_DEFAULT`` + - ``rocsparse_spmm_alg_default`` * - ``no_optimize_alg`` - none - ``CUSPARSE_SPMM_ALG_DEFAULT`` + - ``rocsparse_spmm_alg_default`` * - ``coo_alg1`` - none - ``CUSPARSE_SPMM_COO_ALG1`` + - ``rocsparse_spmm_alg_coo_segmented`` * - ``coo_alg2`` - none - ``CUSPARSE_SPMM_COO_ALG2`` + - ``rocsparse_spmm_alg_coo_atomic`` * - ``coo_alg3`` - none - ``CUSPARSE_SPMM_COO_ALG3`` + - ``rocsparse_spmm_alg_coo_segmented_atomic`` * - ``coo_alg4`` - none - ``CUSPARSE_SPMM_COO_ALG4`` + - ``rocsparse_spmm_alg_default`` * - ``csr_alg1`` - none - ``CUSPARSE_SPMM_CSR_ALG1`` + - ``rocsparse_spmm_alg_csr`` * - ``csr_alg2`` - none - ``CUSPARSE_SPMM_CSR_ALG2`` + - ``rocsparse_spmm_alg_csr_row_split`` * - ``csr_alg3`` - none - ``CUSPARSE_SPMM_CSR_ALG3`` + - ``rocsparse_spmm_alg_csr_merge`` spmv @@ -128,27 +163,35 @@ spmv * - ``spmv_alg`` value - MKLCPU/MKLGPU - cuSPARSE + - rocSPARSE * - ``default_alg`` - none - ``CUSPARSE_SPMV_ALG_DEFAULT`` + - ``rocsparse_spmv_alg_default`` * - ``no_optimize_alg`` - none - ``CUSPARSE_SPMV_ALG_DEFAULT`` + - ``rocsparse_spmv_alg_default`` * - ``coo_alg1`` - none - ``CUSPARSE_SPMV_COO_ALG1`` + - ``rocsparse_spmv_alg_coo`` * - ``coo_alg2`` - none - ``CUSPARSE_SPMV_COO_ALG2`` + - ``rocsparse_spmv_alg_coo_atomic`` * - ``csr_alg1`` - none - ``CUSPARSE_SPMV_CSR_ALG1`` + - ``rocsparse_spmv_alg_csr_adaptive`` * - ``csr_alg2`` - none - ``CUSPARSE_SPMV_CSR_ALG2`` + - ``rocsparse_spmv_alg_csr_stream`` * - ``csr_alg3`` - none - ``CUSPARSE_SPMV_ALG_DEFAULT`` + - ``rocsparse_spmv_alg_csr_lrb`` spsv @@ -161,9 +204,12 @@ spsv * - ``spsv_alg`` value - MKLCPU/MKLGPU - cuSPARSE + - rocSPARSE * - ``default_alg`` - none - ``CUSPARSE_SPSV_ALG_DEFAULT`` + - ``rocsparse_spsv_alg_default`` * - ``no_optimize_alg`` - none - ``CUSPARSE_SPSV_ALG_DEFAULT`` + - ``rocsparse_spsv_alg_default`` diff --git a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt index 97bd407fb..1274b8c6b 100644 --- a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt +++ b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt @@ -36,6 +36,9 @@ endif() if(ENABLE_CUSPARSE_BACKEND) list(APPEND DEVICE_FILTERS "cuda:gpu") endif() +if(ENABLE_ROCSPARSE_BACKEND) + list(APPEND DEVICE_FILTERS "hip:gpu") +endif() message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") diff --git a/examples/sparse_blas/run_time_dispatching/sparse_blas_spmv_usm.cpp b/examples/sparse_blas/run_time_dispatching/sparse_blas_spmv_usm.cpp index a6ff30354..5dbee489e 100644 --- a/examples/sparse_blas/run_time_dispatching/sparse_blas_spmv_usm.cpp +++ b/examples/sparse_blas/run_time_dispatching/sparse_blas_spmv_usm.cpp @@ -146,6 +146,11 @@ int run_sparse_matrix_vector_multiply_example(const sycl::device& dev) { oneapi::math::sparse::init_csr_matrix(main_queue, &A_handle, nrows, nrows, nnz, oneapi::math::index_base::zero, ia, ja, a); + // rocSPARSE backend requires that the property sorted is set when using matrices in CSR format. + // Setting this property is also the best practice to get best performance. + oneapi::math::sparse::set_matrix_property(main_queue, A_handle, + oneapi::math::sparse::matrix_property::sorted); + // Create and initialize dense vector handles oneapi::math::sparse::dense_vector_handle_t x_handle = nullptr; oneapi::math::sparse::dense_vector_handle_t y_handle = nullptr; diff --git a/include/oneapi/math/detail/backends.hpp b/include/oneapi/math/detail/backends.hpp index e8f66e021..58f61e4c1 100644 --- a/include/oneapi/math/detail/backends.hpp +++ b/include/oneapi/math/detail/backends.hpp @@ -41,6 +41,7 @@ enum class backend { rocfft, portfft, cusparse, + rocsparse, unsupported }; @@ -63,6 +64,7 @@ static backendmap backend_map = { { backend::mklcpu, "mklcpu" }, { backend::rocfft, "rocfft" }, { backend::portfft, "portfft" }, { backend::cusparse, "cusparse" }, + { backend::rocsparse, "rocsparse" }, { backend::unsupported, "unsupported" } }; // clang-format on diff --git a/include/oneapi/math/detail/backends_table.hpp b/include/oneapi/math/detail/backends_table.hpp index 1b7e1d723..3e83b070f 100644 --- a/include/oneapi/math/detail/backends_table.hpp +++ b/include/oneapi/math/detail/backends_table.hpp @@ -204,6 +204,12 @@ static std::map>> libraries = { #ifdef ONEMATH_ENABLE_CUSPARSE_BACKEND LIB_NAME("sparse_blas_cusparse") +#endif + } }, + { device::amdgpu, + { +#ifdef ONEMATH_ENABLE_ROCSPARSE_BACKEND + LIB_NAME("sparse_blas_rocsparse") #endif } } } }, }; diff --git a/include/oneapi/math/sparse_blas.hpp b/include/oneapi/math/sparse_blas.hpp index ee9735374..59a203070 100644 --- a/include/oneapi/math/sparse_blas.hpp +++ b/include/oneapi/math/sparse_blas.hpp @@ -37,6 +37,9 @@ #ifdef ONEMATH_ENABLE_CUSPARSE_BACKEND #include "sparse_blas/detail/cusparse/sparse_blas_ct.hpp" #endif +#ifdef ONEMATH_ENABLE_ROCSPARSE_BACKEND +#include "sparse_blas/detail/rocsparse/sparse_blas_ct.hpp" +#endif #include "sparse_blas/detail/sparse_blas_rt.hpp" diff --git a/include/oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp b/include/oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp new file mode 100644 index 000000000..2e727996a --- /dev/null +++ b/include/oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp @@ -0,0 +1,33 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMATH_SPARSE_BLAS_ROCSPARSE_HPP_ +#define _ONEMATH_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMATH_SPARSE_BLAS_ROCSPARSE_HPP_ + +#include "oneapi/math/detail/export.hpp" +#include "oneapi/math/sparse_blas/detail/helper_types.hpp" +#include "oneapi/math/sparse_blas/types.hpp" + +namespace oneapi::math::sparse::rocsparse { + +#include "oneapi/math/sparse_blas/detail/onemath_sparse_blas_backends.hxx" + +} // namespace oneapi::math::sparse::rocsparse + +#endif // _ONEMATH_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMATH_SPARSE_BLAS_ROCSPARSE_HPP_ diff --git a/include/oneapi/math/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp b/include/oneapi/math/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp new file mode 100644 index 000000000..fbf7d46aa --- /dev/null +++ b/include/oneapi/math/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp @@ -0,0 +1,40 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_ +#define _ONEMATH_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_ + +#include "oneapi/math/detail/backends.hpp" +#include "oneapi/math/detail/backend_selector.hpp" + +#include "onemath_sparse_blas_rocsparse.hpp" + +namespace oneapi { +namespace math { +namespace sparse { + +#define BACKEND rocsparse +#include "oneapi/math/sparse_blas/detail/sparse_blas_ct.hxx" +#undef BACKEND + +} //namespace sparse +} //namespace math +} //namespace oneapi + +#endif // _ONEMATH_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9ff24a721..ba0374e3f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -95,6 +95,7 @@ function(generate_header_file) set(ONEMATH_ENABLE_ROCFFT_BACKEND ${ENABLE_ROCFFT_BACKEND}) set(ONEMATH_ENABLE_PORTFFT_BACKEND ${ENABLE_PORTFFT_BACKEND}) set(ONEMATH_ENABLE_CUSPARSE_BACKEND ${ENABLE_CUSPARSE_BACKEND}) + set(ONEMATH_ENABLE_ROCSPARSE_BACKEND ${ENABLE_ROCSPARSE_BACKEND}) configure_file(config.hpp.in "${CMAKE_CURRENT_BINARY_DIR}/oneapi/math/config.hpp.configured") file(GENERATE diff --git a/src/config.hpp.in b/src/config.hpp.in index 63393e2d8..3fde31c79 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -38,6 +38,7 @@ #cmakedefine ONEMATH_ENABLE_ROCFFT_BACKEND #cmakedefine ONEMATH_ENABLE_ROCRAND_BACKEND #cmakedefine ONEMATH_ENABLE_ROCSOLVER_BACKEND +#cmakedefine ONEMATH_ENABLE_ROCSPARSE_BACKEND #cmakedefine ONEMATH_BUILD_SHARED_LIBS #endif diff --git a/src/sparse_blas/backends/CMakeLists.txt b/src/sparse_blas/backends/CMakeLists.txt index 405c79ce7..6663c1365 100644 --- a/src/sparse_blas/backends/CMakeLists.txt +++ b/src/sparse_blas/backends/CMakeLists.txt @@ -31,3 +31,7 @@ endif() if(ENABLE_CUSPARSE_BACKEND) add_subdirectory(cusparse) endif() + +if(ENABLE_ROCSPARSE_BACKEND) + add_subdirectory(rocsparse) +endif() diff --git a/src/sparse_blas/backends/common_launch_task.hpp b/src/sparse_blas/backends/common_launch_task.hpp new file mode 100644 index 000000000..fbd5fa200 --- /dev/null +++ b/src/sparse_blas/backends/common_launch_task.hpp @@ -0,0 +1,413 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_COMMON_LAUNCH_TASK_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_COMMON_LAUNCH_TASK_HPP_ + +/// This file provide a helper function to submit host_task using buffers or USM seamlessly + +namespace oneapi::math::sparse::detail { + +template +auto get_value_accessor(sycl::handler& cgh, Container container) { + auto buffer_ptr = + reinterpret_cast*>(container->value_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_fp_accessors(sycl::handler& cgh, Ts... containers) { + return std::array, sizeof...(containers)>{ get_value_accessor( + cgh, containers)... }; +} + +template +auto get_row_accessor(sycl::handler& cgh, matrix_handle_t smhandle) { + auto buffer_ptr = + reinterpret_cast*>(smhandle->row_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_col_accessor(sycl::handler& cgh, matrix_handle_t smhandle) { + auto buffer_ptr = + reinterpret_cast*>(smhandle->col_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_int_accessors(sycl::handler& cgh, matrix_handle_t smhandle) { + return std::array, 2>{ get_row_accessor(cgh, smhandle), + get_col_accessor(cgh, smhandle) }; +} + +template +void submit_host_task(sycl::handler& cgh, sycl::queue& queue, Functor functor, + CaptureOnlyAcc... capture_only_accessors) { + // Only capture the accessors to ensure the dependencies are properly + // handled. The accessors's pointer have already been set to the native + // container types in previous functions. This assumes the underlying + // pointer of the buffer does not change. This is not guaranteed by the SYCL + // specification but should be true for all the implementations. This + // assumption avoids the overhead of resetting the pointer of all data + // handles for each enqueued command. + cgh.host_task([functor, queue, capture_only_accessors...](sycl::interop_handle ih) { + auto unused = std::make_tuple(capture_only_accessors...); + (void)unused; + functor(ih); + }); +} + +template +void submit_host_task_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor, + sycl::accessor workspace_acc, + CaptureOnlyAcc... capture_only_accessors) { + // Only capture the accessors to ensure the dependencies are properly + // handled. The accessors's pointer have already been set to the native + // container types in previous functions. This assumes the underlying + // pointer of the buffer does not change. This is not guaranteed by the SYCL + // specification but should be true for all the implementations. This + // assumption avoids the overhead of resetting the pointer of all data + // handles for each enqueued command. + cgh.host_task( + [functor, queue, workspace_acc, capture_only_accessors...](sycl::interop_handle ih) { + auto unused = std::make_tuple(capture_only_accessors...); + (void)unused; + functor(ih, workspace_acc); + }); +} + +template +void submit_native_command_ext(sycl::handler& cgh, sycl::queue& queue, Functor functor, + const std::vector& dependencies, + CaptureOnlyAcc... capture_only_accessors) { + // Only capture the accessors to ensure the dependencies are properly + // handled. The accessors's pointer have already been set to the native + // container types in previous functions. This assumes the underlying + // pointer of the buffer does not change. This is not guaranteed by the SYCL + // specification but should be true for all the implementations. This + // assumption avoids the overhead of resetting the pointer of all data + // handles for each enqueued command. +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command( + [functor, queue, dependencies, capture_only_accessors...](sycl::interop_handle ih) { + auto unused = std::make_tuple(capture_only_accessors...); + (void)unused; + // The functor using ext_codeplay_enqueue_native_command need to + // explicitly wait on the events for the SPARSE domain. The + // extension ext_codeplay_enqueue_native_command is used to launch + // the compute operation which depends on the previous optimize + // step. In cuSPARSE the optimize step is synchronous but it is + // asynchronous in oneMath. The optimize step may not use the CUDA + // stream which would make it impossible for + // ext_codeplay_enqueue_native_command to automatically ensure it + // has completed before the compute function starts. These waits are + // used to ensure the optimize step has completed before starting + // the computation. + for (auto event : dependencies) { + event.wait(); + } + functor(ih); + }); +#else + (void)dependencies; + submit_host_task(cgh, queue, functor, capture_only_accessors...); +#endif +} + +template +void submit_native_command_ext_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor, + const std::vector& dependencies, + sycl::accessor workspace_acc, + CaptureOnlyAcc... capture_only_accessors) { + // Only capture the accessors to ensure the dependencies are properly + // handled. The accessors's pointer have already been set to the native + // container types in previous functions. This assumes the underlying + // pointer of the buffer does not change. This is not guaranteed by the SYCL + // specification but should be true for all the implementations. This + // assumption avoids the overhead of resetting the pointer of all data + // handles for each enqueued command. +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([functor, queue, dependencies, workspace_acc, + capture_only_accessors...](sycl::interop_handle ih) { + auto unused = std::make_tuple(capture_only_accessors...); + (void)unused; + // The functor using ext_codeplay_enqueue_native_command need to + // explicitly wait on the events for the SPARSE domain. The extension + // ext_codeplay_enqueue_native_command is used to launch the compute + // operation which depends on the previous optimize step. In cuSPARSE + // the optimize step is synchronous but it is asynchronous in oneMath. + // The optimize step may not use the CUDA stream which would make it + // impossible for ext_codeplay_enqueue_native_command to automatically + // ensure it has completed before the compute function starts. These + // waits are used to ensure the optimize step has completed before + // starting the computation. + for (auto event : dependencies) { + event.wait(); + } + functor(ih, workspace_acc); + }); +#else + (void)dependencies; + submit_host_task_with_acc(cgh, queue, functor, workspace_acc, capture_only_accessors...); +#endif +} + +/// Helper submit functions to capture all accessors from the generic containers +/// \p other_containers and ensure the dependencies of buffers are respected. +/// The accessors are not directly used as the underlying data pointer has +/// already been captured in previous functions. +/// \p workspace_buffer is an optional buffer. Its accessor will be given to the +/// functor as a last argument if \p UseWorkspace is true. +/// \p UseWorkspace must be true to use the given \p workspace_buffer. +/// \p UseEnqueueNativeCommandExt controls whether host_task are used or the +/// extension ext_codeplay_enqueue_native_command is used to launch tasks. The +/// extension should only be used for asynchronous functions using native +/// backend's functions. The extension can only be used for in-order queues as +/// the same cuStream needs to be used for the 3 steps to run an operation: +/// querying the buffer size, optimizing and running the computation. This means +/// a different cuStream can be used inside the native_command than the native +/// cuStream used by the extension. +template +sycl::event dispatch_submit_impl_fp_int(const std::string& function_name, sycl::queue queue, + const std::vector& dependencies, + Functor functor, matrix_handle_t sm_handle, + sycl::buffer workspace_buffer, + Ts... other_containers) { + bool is_in_order_queue = queue.is_in_order(); + if (sm_handle->all_use_buffer()) { + data_type value_type = sm_handle->get_value_type(); + data_type int_type = sm_handle->get_int_type(); + +#define ONEMATH_SUBMIT(FP_TYPE, INT_TYPE) \ + return queue.submit([&](sycl::handler& cgh) { \ + cgh.depends_on(dependencies); \ + auto fp_accs = get_fp_accessors(cgh, sm_handle, other_containers...); \ + auto int_accs = get_int_accessors(cgh, sm_handle); \ + auto workspace_acc = workspace_buffer.get_access(cgh); \ + if constexpr (UseWorkspace) { \ + if constexpr (UseEnqueueNativeCommandExt) { \ + if (is_in_order_queue) { \ + submit_native_command_ext_with_acc(cgh, queue, functor, dependencies, \ + workspace_acc, fp_accs, int_accs); \ + } \ + else { \ + submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, \ + int_accs); \ + } \ + } \ + else { \ + submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, int_accs); \ + } \ + } \ + else { \ + (void)workspace_buffer; \ + if constexpr (UseEnqueueNativeCommandExt) { \ + if (is_in_order_queue) { \ + submit_native_command_ext(cgh, queue, functor, dependencies, fp_accs, \ + int_accs); \ + } \ + else { \ + submit_host_task(cgh, queue, functor, fp_accs, int_accs); \ + } \ + } \ + else { \ + submit_host_task(cgh, queue, functor, fp_accs, int_accs); \ + } \ + } \ + }) +#define ONEMATH_SUBMIT_INT(FP_TYPE) \ + if (int_type == data_type::int32) { \ + ONEMATH_SUBMIT(FP_TYPE, std::int32_t); \ + } \ + else if (int_type == data_type::int64) { \ + ONEMATH_SUBMIT(FP_TYPE, std::int64_t); \ + } + + if (value_type == data_type::real_fp32) { + ONEMATH_SUBMIT_INT(float) + } + else if (value_type == data_type::real_fp64) { + ONEMATH_SUBMIT_INT(double) + } + else if (value_type == data_type::complex_fp32) { + ONEMATH_SUBMIT_INT(std::complex) + } + else if (value_type == data_type::complex_fp64) { + ONEMATH_SUBMIT_INT(std::complex) + } + +#undef ONEMATH_SUBMIT_INT +#undef ONEMATH_SUBMIT + + throw oneapi::math::exception("sparse_blas", function_name, + "Could not dispatch buffer kernel to a supported type"); + } + else { + // USM submit does not need to capture accessors + if constexpr (!UseWorkspace) { + return queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + if constexpr (UseEnqueueNativeCommandExt) { + if (is_in_order_queue) { + submit_native_command_ext(cgh, queue, functor, dependencies); + } + else { + submit_host_task(cgh, queue, functor); + } + } + else { + submit_host_task(cgh, queue, functor); + } + }); + } + else { + throw oneapi::math::exception("sparse_blas", function_name, + "Internal error: Cannot use accessor workspace with USM"); + } + } +} + +/// Similar to dispatch_submit_impl_fp_int but only dispatches the host_task based on the floating point value type. +template +sycl::event dispatch_submit_impl_fp(const std::string& function_name, sycl::queue queue, + const std::vector& dependencies, Functor functor, + ContainerT container_handle) { + if (container_handle->all_use_buffer()) { + data_type value_type = container_handle->get_value_type(); + +#define ONEMATH_SUBMIT(FP_TYPE) \ + return queue.submit([&](sycl::handler& cgh) { \ + cgh.depends_on(dependencies); \ + auto fp_accs = get_fp_accessors(cgh, container_handle); \ + submit_host_task(cgh, queue, functor, fp_accs); \ + }) + + if (value_type == data_type::real_fp32) { + ONEMATH_SUBMIT(float); + } + else if (value_type == data_type::real_fp64) { + ONEMATH_SUBMIT(double); + } + else if (value_type == data_type::complex_fp32) { + ONEMATH_SUBMIT(std::complex); + } + else if (value_type == data_type::complex_fp64) { + ONEMATH_SUBMIT(std::complex); + } + +#undef ONEMATH_SUBMIT + + throw oneapi::math::exception("sparse_blas", function_name, + "Could not dispatch buffer kernel to a supported type"); + } + else { + return queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + submit_host_task(cgh, queue, functor); + }); + } +} + +/// Helper function for dispatch_submit_impl_fp_int +template +sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor, + matrix_handle_t sm_handle, sycl::buffer workspace_buffer, + Ts... other_containers) { + constexpr bool UseWorkspace = true; + constexpr bool UseEnqueueNativeCommandExt = false; + return dispatch_submit_impl_fp_int( + function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...); +} + +/// Helper function for dispatch_submit_impl_fp_int +template +sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, + const std::vector& dependencies, Functor functor, + matrix_handle_t sm_handle, Ts... other_containers) { + constexpr bool UseWorkspace = false; + constexpr bool UseEnqueueNativeCommandExt = false; + sycl::buffer no_workspace(sycl::range<1>(0)); + return dispatch_submit_impl_fp_int( + function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...); +} + +/// Helper function for dispatch_submit_impl_fp_int +template +sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor, + matrix_handle_t sm_handle, Ts... other_containers) { + constexpr bool UseWorkspace = false; + constexpr bool UseEnqueueNativeCommandExt = false; + sycl::buffer no_workspace(sycl::range<1>(0)); + return dispatch_submit_impl_fp_int( + function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...); +} + +/// Helper function for dispatch_submit_impl_fp_int +template +sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue, + Functor functor, matrix_handle_t sm_handle, + sycl::buffer workspace_buffer, + Ts... other_containers) { + constexpr bool UseWorkspace = true; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + constexpr bool UseEnqueueNativeCommandExt = true; +#else + constexpr bool UseEnqueueNativeCommandExt = false; +#endif + return dispatch_submit_impl_fp_int( + function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...); +} + +/// Helper function for dispatch_submit_impl_fp_int +template +sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue, + const std::vector& dependencies, + Functor functor, matrix_handle_t sm_handle, + Ts... other_containers) { + constexpr bool UseWorkspace = false; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + constexpr bool UseEnqueueNativeCommandExt = true; +#else + constexpr bool UseEnqueueNativeCommandExt = false; +#endif + sycl::buffer no_workspace(sycl::range<1>(0)); + return dispatch_submit_impl_fp_int( + function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...); +} + +/// Helper function for dispatch_submit_impl_fp_int +template +sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue, + Functor functor, matrix_handle_t sm_handle, + Ts... other_containers) { + constexpr bool UseWorkspace = false; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + constexpr bool UseEnqueueNativeCommandExt = true; +#else + constexpr bool UseEnqueueNativeCommandExt = false; +#endif + sycl::buffer no_workspace(sycl::range<1>(0)); + return dispatch_submit_impl_fp_int( + function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...); +} + +} // namespace oneapi::math::sparse::detail + +#endif // _ONEMATH_SPARSE_BLAS_BACKENDS_COMMON_LAUNCH_TASK_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_handles.cpp b/src/sparse_blas/backends/cusparse/cusparse_handles.cpp index 95a615b05..153de54de 100644 --- a/src/sparse_blas/backends/cusparse/cusparse_handles.cpp +++ b/src/sparse_blas/backends/cusparse/cusparse_handles.cpp @@ -22,6 +22,7 @@ #include "cusparse_error.hpp" #include "cusparse_helper.hpp" #include "cusparse_handles.hpp" +#include "cusparse_scope_handle.hpp" #include "cusparse_task.hpp" #include "sparse_blas/macros.hpp" diff --git a/src/sparse_blas/backends/cusparse/cusparse_task.hpp b/src/sparse_blas/backends/cusparse/cusparse_task.hpp index 043cfaaf8..fd5414a11 100644 --- a/src/sparse_blas/backends/cusparse/cusparse_task.hpp +++ b/src/sparse_blas/backends/cusparse/cusparse_task.hpp @@ -17,400 +17,14 @@ * **************************************************************************/ -#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_ -#define _ONEMATH_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_ +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_CUSPARSE_TASK_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_CUSPARSE_TASK_HPP_ -#include "cusparse_handles.hpp" -#include "cusparse_scope_handle.hpp" - -/// This file provide a helper function to submit host_task using buffers or USM seamlessly +#include "cusparse_error.hpp" +#include "sparse_blas/backends/common_launch_task.hpp" namespace oneapi::math::sparse::cusparse::detail { -template -auto get_value_accessor(sycl::handler& cgh, Container container) { - auto buffer_ptr = - reinterpret_cast*>(container->value_container.buffer_ptr.get()); - return buffer_ptr->template get_access(cgh); -} - -template -auto get_fp_accessors(sycl::handler& cgh, Ts... containers) { - return std::array, sizeof...(containers)>{ get_value_accessor( - cgh, containers)... }; -} - -template -auto get_row_accessor(sycl::handler& cgh, matrix_handle_t smhandle) { - auto buffer_ptr = - reinterpret_cast*>(smhandle->row_container.buffer_ptr.get()); - return buffer_ptr->template get_access(cgh); -} - -template -auto get_col_accessor(sycl::handler& cgh, matrix_handle_t smhandle) { - auto buffer_ptr = - reinterpret_cast*>(smhandle->col_container.buffer_ptr.get()); - return buffer_ptr->template get_access(cgh); -} - -template -auto get_int_accessors(sycl::handler& cgh, matrix_handle_t smhandle) { - return std::array, 2>{ get_row_accessor(cgh, smhandle), - get_col_accessor(cgh, smhandle) }; -} - -template -void submit_host_task(sycl::handler& cgh, sycl::queue& queue, Functor functor, - CaptureOnlyAcc... capture_only_accessors) { - // Only capture the accessors to ensure the dependencies are properly - // handled. The accessors's pointer have already been set to the native - // container types in previous functions. This assumes the underlying - // pointer of the buffer does not change. This is not guaranteed by the SYCL - // specification but should be true for all the implementations. This - // assumption avoids the overhead of resetting the pointer of all data - // handles for each enqueued command. - cgh.host_task([functor, queue, capture_only_accessors...](sycl::interop_handle ih) { - auto unused = std::make_tuple(capture_only_accessors...); - (void)unused; - functor(ih); - }); -} - -template -void submit_host_task_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor, - sycl::accessor workspace_acc, - CaptureOnlyAcc... capture_only_accessors) { - // Only capture the accessors to ensure the dependencies are properly - // handled. The accessors's pointer have already been set to the native - // container types in previous functions. This assumes the underlying - // pointer of the buffer does not change. This is not guaranteed by the SYCL - // specification but should be true for all the implementations. This - // assumption avoids the overhead of resetting the pointer of all data - // handles for each enqueued command. - cgh.host_task( - [functor, queue, workspace_acc, capture_only_accessors...](sycl::interop_handle ih) { - auto unused = std::make_tuple(capture_only_accessors...); - (void)unused; - functor(ih, workspace_acc); - }); -} - -template -void submit_native_command_ext(sycl::handler& cgh, sycl::queue& queue, Functor functor, - const std::vector& dependencies, - CaptureOnlyAcc... capture_only_accessors) { - // Only capture the accessors to ensure the dependencies are properly - // handled. The accessors's pointer have already been set to the native - // container types in previous functions. This assumes the underlying - // pointer of the buffer does not change. This is not guaranteed by the SYCL - // specification but should be true for all the implementations. This - // assumption avoids the overhead of resetting the pointer of all data - // handles for each enqueued command. -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - cgh.ext_codeplay_enqueue_native_command( - [functor, queue, dependencies, capture_only_accessors...](sycl::interop_handle ih) { - auto unused = std::make_tuple(capture_only_accessors...); - (void)unused; - // The functor using ext_codeplay_enqueue_native_command need to - // explicitly wait on the events for the SPARSE domain. The - // extension ext_codeplay_enqueue_native_command is used to launch - // the compute operation which depends on the previous optimize - // step. In cuSPARSE the optimize step is synchronous but it is - // asynchronous in oneMath. The optimize step may not use the CUDA - // stream which would make it impossible for - // ext_codeplay_enqueue_native_command to automatically ensure it - // has completed before the compute function starts. These waits are - // used to ensure the optimize step has completed before starting - // the computation. - for (auto event : dependencies) { - event.wait(); - } - functor(ih); - }); -#else - (void)dependencies; - submit_host_task(cgh, queue, functor, capture_only_accessors...); -#endif -} - -template -void submit_native_command_ext_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor, - const std::vector& dependencies, - sycl::accessor workspace_acc, - CaptureOnlyAcc... capture_only_accessors) { - // Only capture the accessors to ensure the dependencies are properly - // handled. The accessors's pointer have already been set to the native - // container types in previous functions. This assumes the underlying - // pointer of the buffer does not change. This is not guaranteed by the SYCL - // specification but should be true for all the implementations. This - // assumption avoids the overhead of resetting the pointer of all data - // handles for each enqueued command. -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - cgh.ext_codeplay_enqueue_native_command([functor, queue, dependencies, workspace_acc, - capture_only_accessors...](sycl::interop_handle ih) { - auto unused = std::make_tuple(capture_only_accessors...); - (void)unused; - // The functor using ext_codeplay_enqueue_native_command need to - // explicitly wait on the events for the SPARSE domain. The extension - // ext_codeplay_enqueue_native_command is used to launch the compute - // operation which depends on the previous optimize step. In cuSPARSE - // the optimize step is synchronous but it is asynchronous in oneMath. - // The optimize step may not use the CUDA stream which would make it - // impossible for ext_codeplay_enqueue_native_command to automatically - // ensure it has completed before the compute function starts. These - // waits are used to ensure the optimize step has completed before - // starting the computation. - for (auto event : dependencies) { - event.wait(); - } - functor(ih, workspace_acc); - }); -#else - (void)dependencies; - submit_host_task_with_acc(cgh, queue, functor, workspace_acc, capture_only_accessors...); -#endif -} - -/// Helper submit functions to capture all accessors from the generic containers -/// \p other_containers and ensure the dependencies of buffers are respected. -/// The accessors are not directly used as the underlying data pointer has -/// already been captured in previous functions. -/// \p workspace_buffer is an optional buffer. Its accessor will be given to the -/// functor as a last argument if \p UseWorkspace is true. -/// \p UseWorkspace must be true to use the given \p workspace_buffer. -/// \p UseEnqueueNativeCommandExt controls whether host_task are used or the -/// extension ext_codeplay_enqueue_native_command is used to launch tasks. The -/// extension should only be used for asynchronous functions using native -/// backend's functions. The extension can only be used for in-order queues as -/// the same cuStream needs to be used for the 3 steps to run an operation: -/// querying the buffer size, optimizing and running the computation. This means -/// a different cuStream can be used inside the native_command than the native -/// cuStream used by the extension. -template -sycl::event dispatch_submit_impl_fp_int(const std::string& function_name, sycl::queue queue, - const std::vector& dependencies, - Functor functor, matrix_handle_t sm_handle, - sycl::buffer workspace_buffer, - Ts... other_containers) { - bool is_in_order_queue = queue.is_in_order(); - if (sm_handle->all_use_buffer()) { - data_type value_type = sm_handle->get_value_type(); - data_type int_type = sm_handle->get_int_type(); - -#define ONEMATH_CUSPARSE_SUBMIT(FP_TYPE, INT_TYPE) \ - return queue.submit([&](sycl::handler& cgh) { \ - cgh.depends_on(dependencies); \ - auto fp_accs = get_fp_accessors(cgh, sm_handle, other_containers...); \ - auto int_accs = get_int_accessors(cgh, sm_handle); \ - auto workspace_acc = workspace_buffer.get_access(cgh); \ - if constexpr (UseWorkspace) { \ - if constexpr (UseEnqueueNativeCommandExt) { \ - if (is_in_order_queue) { \ - submit_native_command_ext_with_acc(cgh, queue, functor, dependencies, \ - workspace_acc, fp_accs, int_accs); \ - } \ - else { \ - submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, \ - int_accs); \ - } \ - } \ - else { \ - submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, int_accs); \ - } \ - } \ - else { \ - (void)workspace_buffer; \ - if constexpr (UseEnqueueNativeCommandExt) { \ - if (is_in_order_queue) { \ - submit_native_command_ext(cgh, queue, functor, dependencies, fp_accs, \ - int_accs); \ - } \ - else { \ - submit_host_task(cgh, queue, functor, fp_accs, int_accs); \ - } \ - } \ - else { \ - submit_host_task(cgh, queue, functor, fp_accs, int_accs); \ - } \ - } \ - }) -#define ONEMATH_CUSPARSE_SUBMIT_INT(FP_TYPE) \ - if (int_type == data_type::int32) { \ - ONEMATH_CUSPARSE_SUBMIT(FP_TYPE, std::int32_t); \ - } \ - else if (int_type == data_type::int64) { \ - ONEMATH_CUSPARSE_SUBMIT(FP_TYPE, std::int64_t); \ - } - - if (value_type == data_type::real_fp32) { - ONEMATH_CUSPARSE_SUBMIT_INT(float) - } - else if (value_type == data_type::real_fp64) { - ONEMATH_CUSPARSE_SUBMIT_INT(double) - } - else if (value_type == data_type::complex_fp32) { - ONEMATH_CUSPARSE_SUBMIT_INT(std::complex) - } - else if (value_type == data_type::complex_fp64) { - ONEMATH_CUSPARSE_SUBMIT_INT(std::complex) - } - -#undef ONEMATH_CUSPARSE_SUBMIT_INT -#undef ONEMATH_CUSPARSE_SUBMIT - - throw oneapi::math::exception("sparse_blas", function_name, - "Could not dispatch buffer kernel to a supported type"); - } - else { - // USM submit does not need to capture accessors - if constexpr (!UseWorkspace) { - return queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(dependencies); - if constexpr (UseEnqueueNativeCommandExt) { - if (is_in_order_queue) { - submit_native_command_ext(cgh, queue, functor, dependencies); - } - else { - submit_host_task(cgh, queue, functor); - } - } - else { - submit_host_task(cgh, queue, functor); - } - }); - } - else { - throw oneapi::math::exception("sparse_blas", function_name, - "Internal error: Cannot use accessor workspace with USM"); - } - } -} - -/// Similar to dispatch_submit_impl_fp_int but only dispatches the host_task based on the floating point value type. -template -sycl::event dispatch_submit_impl_fp(const std::string& function_name, sycl::queue queue, - const std::vector& dependencies, Functor functor, - ContainerT container_handle) { - if (container_handle->all_use_buffer()) { - data_type value_type = container_handle->get_value_type(); - -#define ONEMATH_CUSPARSE_SUBMIT(FP_TYPE) \ - return queue.submit([&](sycl::handler& cgh) { \ - cgh.depends_on(dependencies); \ - auto fp_accs = get_fp_accessors(cgh, container_handle); \ - submit_host_task(cgh, queue, functor, fp_accs); \ - }) - - if (value_type == data_type::real_fp32) { - ONEMATH_CUSPARSE_SUBMIT(float); - } - else if (value_type == data_type::real_fp64) { - ONEMATH_CUSPARSE_SUBMIT(double); - } - else if (value_type == data_type::complex_fp32) { - ONEMATH_CUSPARSE_SUBMIT(std::complex); - } - else if (value_type == data_type::complex_fp64) { - ONEMATH_CUSPARSE_SUBMIT(std::complex); - } - -#undef ONEMATH_CUSPARSE_SUBMIT - - throw oneapi::math::exception("sparse_blas", function_name, - "Could not dispatch buffer kernel to a supported type"); - } - else { - return queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(dependencies); - submit_host_task(cgh, queue, functor); - }); - } -} - -/// Helper function for dispatch_submit_impl_fp_int -template -sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor, - matrix_handle_t sm_handle, sycl::buffer workspace_buffer, - Ts... other_containers) { - constexpr bool UseWorkspace = true; - constexpr bool UseEnqueueNativeCommandExt = false; - return dispatch_submit_impl_fp_int( - function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...); -} - -/// Helper function for dispatch_submit_impl_fp_int -template -sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, - const std::vector& dependencies, Functor functor, - matrix_handle_t sm_handle, Ts... other_containers) { - constexpr bool UseWorkspace = false; - constexpr bool UseEnqueueNativeCommandExt = false; - sycl::buffer no_workspace(sycl::range<1>(0)); - return dispatch_submit_impl_fp_int( - function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...); -} - -/// Helper function for dispatch_submit_impl_fp_int -template -sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor, - matrix_handle_t sm_handle, Ts... other_containers) { - constexpr bool UseWorkspace = false; - constexpr bool UseEnqueueNativeCommandExt = false; - sycl::buffer no_workspace(sycl::range<1>(0)); - return dispatch_submit_impl_fp_int( - function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...); -} - -/// Helper function for dispatch_submit_impl_fp_int -template -sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue, - Functor functor, matrix_handle_t sm_handle, - sycl::buffer workspace_buffer, - Ts... other_containers) { - constexpr bool UseWorkspace = true; -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - constexpr bool UseEnqueueNativeCommandExt = true; -#else - constexpr bool UseEnqueueNativeCommandExt = false; -#endif - return dispatch_submit_impl_fp_int( - function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...); -} - -/// Helper function for dispatch_submit_impl_fp_int -template -sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue, - const std::vector& dependencies, - Functor functor, matrix_handle_t sm_handle, - Ts... other_containers) { - constexpr bool UseWorkspace = false; -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - constexpr bool UseEnqueueNativeCommandExt = true; -#else - constexpr bool UseEnqueueNativeCommandExt = false; -#endif - sycl::buffer no_workspace(sycl::range<1>(0)); - return dispatch_submit_impl_fp_int( - function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...); -} - -/// Helper function for dispatch_submit_impl_fp_int -template -sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue, - Functor functor, matrix_handle_t sm_handle, - Ts... other_containers) { - constexpr bool UseWorkspace = false; -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - constexpr bool UseEnqueueNativeCommandExt = true; -#else - constexpr bool UseEnqueueNativeCommandExt = false; -#endif - sycl::buffer no_workspace(sycl::range<1>(0)); - return dispatch_submit_impl_fp_int( - function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...); -} - // Helper function for functors submitted to host_task or native_command. // When the extension is disabled, host_task are used and the synchronization is needed to ensure the sycl::event corresponds to the end of the whole functor. // When the extension is enabled, host_task are still used for out-of-order queues, see description of dispatch_submit_impl_fp_int. @@ -427,4 +41,4 @@ inline void synchronize_if_needed(bool is_in_order_queue, CUstream cu_stream) { } // namespace oneapi::math::sparse::cusparse::detail -#endif // _ONEMATH_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_ +#endif // _ONEMATH_SPARSE_BLAS_BACKENDS_CUSPARSE_TASK_HPP_ diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp index e4131bec6..e517d66e0 100644 --- a/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp +++ b/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp @@ -20,9 +20,10 @@ #include "oneapi/math/sparse_blas/detail/cusparse/onemath_sparse_blas_cusparse.hpp" #include "sparse_blas/backends/cusparse/cusparse_error.hpp" +#include "sparse_blas/backends/cusparse/cusparse_handles.hpp" #include "sparse_blas/backends/cusparse/cusparse_helper.hpp" #include "sparse_blas/backends/cusparse/cusparse_task.hpp" -#include "sparse_blas/backends/cusparse/cusparse_handles.hpp" +#include "sparse_blas/backends/cusparse/cusparse_scope_handle.hpp" #include "sparse_blas/common_op_verification.hpp" #include "sparse_blas/macros.hpp" #include "sparse_blas/matrix_view_comparison.hpp" diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp index 2af4a4e98..63d8339ab 100644 --- a/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp +++ b/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp @@ -20,9 +20,10 @@ #include "oneapi/math/sparse_blas/detail/cusparse/onemath_sparse_blas_cusparse.hpp" #include "sparse_blas/backends/cusparse/cusparse_error.hpp" +#include "sparse_blas/backends/cusparse/cusparse_handles.hpp" #include "sparse_blas/backends/cusparse/cusparse_helper.hpp" #include "sparse_blas/backends/cusparse/cusparse_task.hpp" -#include "sparse_blas/backends/cusparse/cusparse_handles.hpp" +#include "sparse_blas/backends/cusparse/cusparse_scope_handle.hpp" #include "sparse_blas/common_op_verification.hpp" #include "sparse_blas/macros.hpp" #include "sparse_blas/matrix_view_comparison.hpp" diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp index affad658b..22fa772ba 100644 --- a/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp +++ b/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp @@ -20,9 +20,10 @@ #include "oneapi/math/sparse_blas/detail/cusparse/onemath_sparse_blas_cusparse.hpp" #include "sparse_blas/backends/cusparse/cusparse_error.hpp" +#include "sparse_blas/backends/cusparse/cusparse_handles.hpp" #include "sparse_blas/backends/cusparse/cusparse_helper.hpp" #include "sparse_blas/backends/cusparse/cusparse_task.hpp" -#include "sparse_blas/backends/cusparse/cusparse_handles.hpp" +#include "sparse_blas/backends/cusparse/cusparse_scope_handle.hpp" #include "sparse_blas/common_op_verification.hpp" #include "sparse_blas/macros.hpp" #include "sparse_blas/matrix_view_comparison.hpp" diff --git a/src/sparse_blas/backends/rocsparse/CMakeLists.txt b/src/sparse_blas/backends/rocsparse/CMakeLists.txt new file mode 100644 index 000000000..c6b78aa47 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/CMakeLists.txt @@ -0,0 +1,81 @@ +#=============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemath_sparse_blas_rocsparse) +set(LIB_OBJ ${LIB_NAME}_obj) + +include(WarningsUtils) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + rocsparse_handles.cpp + rocsparse_scope_handle.cpp + operations/rocsparse_spmm.cpp + operations/rocsparse_spmv.cpp + operations/rocsparse_spsv.cpp + $<$: rocsparse_wrappers.cpp> +) +add_dependencies(onemath_backend_libs_sparse_blas ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMATH_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMATH_BUILD_COPT}) + +find_package(HIP REQUIRED) +find_package(rocsparse 3.1.2 REQUIRED) # ROCm 6.1.0 or above + +target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocsparse) + +target_link_libraries(${LIB_OBJ} + PUBLIC ONEMATH::SYCL::SYCL + PRIVATE onemath_warnings +) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMATH libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMATH::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMATHTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMATHTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp new file mode 100644 index 000000000..c810e207c --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp @@ -0,0 +1,352 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp" +#include "sparse_blas/common_op_verification.hpp" +#include "sparse_blas/macros.hpp" +#include "sparse_blas/matrix_view_comparison.hpp" +#include "sparse_blas/sycl_helper.hpp" + +namespace oneapi::math::sparse { + +// Complete the definition of the incomplete type +struct spmm_descr { + // Cache the hipStream_t and global handle to avoid relying on RocsparseScopedContextHandler to retrieve them. + hipStream_t hip_stream; + rocsparse_handle roc_handle; + + detail::generic_container workspace; + std::size_t temp_buffer_size = 0; + bool buffer_size_called = false; + bool optimized_called = false; + oneapi::math::transpose last_optimized_opA; + oneapi::math::transpose last_optimized_opB; + oneapi::math::sparse::matrix_view last_optimized_A_view; + oneapi::math::sparse::matrix_handle_t last_optimized_A_handle; + oneapi::math::sparse::dense_matrix_handle_t last_optimized_B_handle; + oneapi::math::sparse::dense_matrix_handle_t last_optimized_C_handle; + oneapi::math::sparse::spmm_alg last_optimized_alg; +}; + +} // namespace oneapi::math::sparse + +namespace oneapi::math::sparse::rocsparse { + +namespace detail { + +inline auto get_roc_spmm_alg(spmm_alg alg) { + switch (alg) { + case spmm_alg::coo_alg1: return rocsparse_spmm_alg_coo_segmented; + case spmm_alg::coo_alg2: return rocsparse_spmm_alg_coo_atomic; + case spmm_alg::coo_alg3: return rocsparse_spmm_alg_coo_segmented_atomic; + case spmm_alg::csr_alg1: return rocsparse_spmm_alg_csr; + case spmm_alg::csr_alg2: return rocsparse_spmm_alg_csr_row_split; + case spmm_alg::csr_alg3: return rocsparse_spmm_alg_csr_merge; + default: return rocsparse_spmm_alg_default; + } +} + +void check_valid_spmm(const std::string& function_name, matrix_view A_view, + matrix_handle_t A_handle, dense_matrix_handle_t B_handle, + dense_matrix_handle_t C_handle, bool is_alpha_host_accessible, + bool is_beta_host_accessible) { + check_valid_spmm_common(function_name, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->check_valid_handle(function_name); +} + +inline void common_spmm_optimize( + oneapi::math::transpose opA, oneapi::math::transpose opB, bool is_alpha_host_accessible, + oneapi::math::sparse::matrix_view A_view, oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_matrix_handle_t B_handle, bool is_beta_host_accessible, + oneapi::math::sparse::dense_matrix_handle_t C_handle, oneapi::math::sparse::spmm_alg alg, + oneapi::math::sparse::spmm_descr_t spmm_descr) { + check_valid_spmm("spmm_optimize", A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + if (!spmm_descr->buffer_size_called) { + throw math::uninitialized("sparse_blas", "spmm_optimize", + "spmm_buffer_size must be called before spmm_optimize."); + } + spmm_descr->optimized_called = true; + spmm_descr->last_optimized_opA = opA; + spmm_descr->last_optimized_opB = opB; + spmm_descr->last_optimized_A_view = A_view; + spmm_descr->last_optimized_A_handle = A_handle; + spmm_descr->last_optimized_B_handle = B_handle; + spmm_descr->last_optimized_C_handle = C_handle; + spmm_descr->last_optimized_alg = alg; +} + +void spmm_optimize_impl(rocsparse_handle roc_handle, oneapi::math::transpose opA, + oneapi::math::transpose opB, const void* alpha, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::math::sparse::dense_matrix_handle_t C_handle, + oneapi::math::sparse::spmm_alg alg, std::size_t buffer_size, + void* workspace_ptr, bool is_alpha_host_accessible) { + auto roc_a = A_handle->backend_handle; + auto roc_b = B_handle->backend_handle; + auto roc_c = C_handle->backend_handle; + auto roc_op_a = get_roc_operation(opA); + auto roc_op_b = get_roc_operation(opB); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmm_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + // rocsparse_spmm_stage_preprocess stage is blocking + auto status = + rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, roc_c, roc_type, + roc_alg, rocsparse_spmm_stage_preprocess, &buffer_size, workspace_ptr); + check_status(status, "spmm_optimize"); +} + +} // namespace detail + +void init_spmm_descr(sycl::queue& /*queue*/, spmm_descr_t* p_spmm_descr) { + *p_spmm_descr = new spmm_descr(); +} + +sycl::event release_spmm_descr(sycl::queue& queue, spmm_descr_t spmm_descr, + const std::vector& dependencies) { + if (!spmm_descr) { + return detail::collapse_dependencies(queue, dependencies); + } + + auto release_functor = [=]() { + spmm_descr->roc_handle = nullptr; + spmm_descr->last_optimized_A_handle = nullptr; + spmm_descr->last_optimized_B_handle = nullptr; + spmm_descr->last_optimized_C_handle = nullptr; + delete spmm_descr; + }; + + // Use dispatch_submit to ensure the descriptor is kept alive as long as the buffers are used + // dispatch_submit can only be used if the descriptor's handles are valid + if (spmm_descr->last_optimized_A_handle && + spmm_descr->last_optimized_A_handle->all_use_buffer() && + spmm_descr->last_optimized_B_handle && spmm_descr->last_optimized_C_handle && + spmm_descr->workspace.use_buffer()) { + auto dispatch_functor = [=](sycl::interop_handle, sycl::accessor) { + release_functor(); + }; + return detail::dispatch_submit( + __func__, queue, dispatch_functor, spmm_descr->last_optimized_A_handle, + spmm_descr->workspace.get_buffer(), spmm_descr->last_optimized_B_handle, + spmm_descr->last_optimized_C_handle); + } + + // Release used if USM is used or if the descriptor has been released before spmm_optimize has succeeded + sycl::event event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + cgh.host_task(release_functor); + }); + return event; +} + +void spmm_buffer_size(sycl::queue& queue, oneapi::math::transpose opA, oneapi::math::transpose opB, + const void* alpha, oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::math::sparse::dense_matrix_handle_t C_handle, + oneapi::math::sparse::spmm_alg alg, + oneapi::math::sparse::spmm_descr_t spmm_descr, + std::size_t& temp_buffer_size) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + detail::check_valid_spmm(__func__, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + bool is_in_order_queue = queue.is_in_order(); + auto functor = [=, &temp_buffer_size](sycl::interop_handle ih) { + detail::RocsparseScopedContextHandler sc(queue, ih); + auto [roc_handle, hip_stream] = sc.get_handle_and_stream(queue); + spmm_descr->roc_handle = roc_handle; + spmm_descr->hip_stream = hip_stream; + auto roc_a = A_handle->backend_handle; + auto roc_b = B_handle->backend_handle; + auto roc_c = C_handle->backend_handle; + auto roc_op_a = detail::get_roc_operation(opA); + auto roc_op_b = detail::get_roc_operation(opB); + auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = detail::get_roc_spmm_alg(alg); + detail::set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, + roc_c, roc_type, roc_alg, rocsparse_spmm_stage_buffer_size, + &temp_buffer_size, nullptr); + detail::check_status(status, __func__); + detail::synchronize_if_needed(is_in_order_queue, hip_stream); + }; + auto event = detail::dispatch_submit(__func__, queue, functor, A_handle, B_handle, C_handle); + event.wait_and_throw(); + spmm_descr->temp_buffer_size = temp_buffer_size; + spmm_descr->buffer_size_called = true; +} + +void spmm_optimize(sycl::queue& queue, oneapi::math::transpose opA, oneapi::math::transpose opB, + const void* alpha, oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::math::sparse::dense_matrix_handle_t C_handle, + oneapi::math::sparse::spmm_alg alg, + oneapi::math::sparse::spmm_descr_t spmm_descr, + sycl::buffer workspace) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + if (!A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + detail::common_spmm_optimize(opA, opB, is_alpha_host_accessible, A_view, A_handle, B_handle, + is_beta_host_accessible, C_handle, alg, spmm_descr); + // Copy the buffer to extend its lifetime until the descriptor is free'd. + spmm_descr->workspace.set_buffer_untyped(workspace); + if (alg == oneapi::math::sparse::spmm_alg::no_optimize_alg) { + return; + } + std::size_t buffer_size = spmm_descr->temp_buffer_size; + + // The accessor can only be created if the buffer size is greater than 0 + if (buffer_size > 0) { + auto functor = [=](sycl::interop_handle ih, sycl::accessor workspace_acc) { + auto roc_handle = spmm_descr->roc_handle; + auto workspace_ptr = detail::get_mem(ih, workspace_acc); + detail::spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, + C_handle, alg, buffer_size, workspace_ptr, + is_alpha_host_accessible); + }; + + detail::dispatch_submit(__func__, queue, functor, A_handle, workspace, B_handle, C_handle); + } + else { + auto functor = [=](sycl::interop_handle) { + auto roc_handle = spmm_descr->roc_handle; + detail::spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, + C_handle, alg, buffer_size, nullptr, + is_alpha_host_accessible); + }; + + detail::dispatch_submit(__func__, queue, functor, A_handle, B_handle, C_handle); + } +} + +sycl::event spmm_optimize(sycl::queue& queue, oneapi::math::transpose opA, + oneapi::math::transpose opB, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::math::sparse::dense_matrix_handle_t C_handle, + oneapi::math::sparse::spmm_alg alg, + oneapi::math::sparse::spmm_descr_t spmm_descr, void* workspace, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + if (A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + detail::common_spmm_optimize(opA, opB, is_alpha_host_accessible, A_view, A_handle, B_handle, + is_beta_host_accessible, C_handle, alg, spmm_descr); + spmm_descr->workspace.usm_ptr = workspace; + if (alg == oneapi::math::sparse::spmm_alg::no_optimize_alg) { + return detail::collapse_dependencies(queue, dependencies); + } + std::size_t buffer_size = spmm_descr->temp_buffer_size; + auto functor = [=](sycl::interop_handle) { + auto roc_handle = spmm_descr->roc_handle; + detail::spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, C_handle, + alg, buffer_size, workspace, is_alpha_host_accessible); + }; + + return detail::dispatch_submit(__func__, queue, dependencies, functor, A_handle, B_handle, + C_handle); +} + +sycl::event spmm(sycl::queue& queue, oneapi::math::transpose opA, oneapi::math::transpose opB, + const void* alpha, oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::math::sparse::dense_matrix_handle_t C_handle, + oneapi::math::sparse::spmm_alg alg, oneapi::math::sparse::spmm_descr_t spmm_descr, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + if (A_handle->all_use_buffer() != spmm_descr->workspace.use_buffer()) { + detail::throw_incompatible_container(__func__); + } + if (!spmm_descr->optimized_called) { + throw math::uninitialized( + "sparse_blas", __func__, + "spmm_optimize must be called with the same arguments before spmm."); + } + CHECK_DESCR_MATCH(spmm_descr, opA, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, opB, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, A_view, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, A_handle, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, B_handle, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, C_handle, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, alg, "spmm_optimize"); + detail::check_valid_spmm(__func__, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->mark_used(); + bool is_in_order_queue = queue.is_in_order(); + auto compute_functor = [=](void* workspace_ptr) { + auto roc_handle = spmm_descr->roc_handle; + auto hip_stream = spmm_descr->hip_stream; + auto buffer_size = spmm_descr->temp_buffer_size; + auto roc_a = A_handle->backend_handle; + auto roc_b = B_handle->backend_handle; + auto roc_c = C_handle->backend_handle; + auto roc_op_a = detail::get_roc_operation(opA); + auto roc_op_b = detail::get_roc_operation(opB); + auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = detail::get_roc_spmm_alg(alg); + detail::set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, + roc_c, roc_type, roc_alg, rocsparse_spmm_stage_compute, + &buffer_size, workspace_ptr); + detail::check_status(status, __func__); + detail::synchronize_if_needed(is_in_order_queue, hip_stream); + }; + // The accessor can only be created if the buffer size is greater than 0 + if (A_handle->all_use_buffer() && spmm_descr->temp_buffer_size > 0) { + auto functor_buffer = [=](sycl::interop_handle ih, + sycl::accessor workspace_acc) { + auto workspace_ptr = detail::get_mem(ih, workspace_acc); + compute_functor(workspace_ptr); + }; + return detail::dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle, + spmm_descr->workspace.get_buffer(), + B_handle, C_handle); + } + else { + // The same dispatch_submit can be used for USM or buffers if no + // workspace accessor is needed. + // workspace_ptr will be a nullptr in the latter case. + auto workspace_ptr = spmm_descr->workspace.usm_ptr; + auto functor_usm = [=](sycl::interop_handle) { + compute_functor(workspace_ptr); + }; + return detail::dispatch_submit_native_ext(__func__, queue, dependencies, functor_usm, + A_handle, B_handle, C_handle); + } +} + +} // namespace oneapi::math::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp new file mode 100644 index 000000000..ead19ec40 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp @@ -0,0 +1,352 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp" +#include "sparse_blas/common_op_verification.hpp" +#include "sparse_blas/macros.hpp" +#include "sparse_blas/matrix_view_comparison.hpp" +#include "sparse_blas/sycl_helper.hpp" + +namespace oneapi::math::sparse { + +// Complete the definition of the incomplete type +struct spmv_descr { + // Cache the hipStream_t and global handle to avoid relying on RocsparseScopedContextHandler to retrieve them. + hipStream_t hip_stream; + rocsparse_handle roc_handle; + + detail::generic_container workspace; + std::size_t temp_buffer_size = 0; + bool buffer_size_called = false; + bool optimized_called = false; + oneapi::math::transpose last_optimized_opA; + oneapi::math::sparse::matrix_view last_optimized_A_view; + oneapi::math::sparse::matrix_handle_t last_optimized_A_handle; + oneapi::math::sparse::dense_vector_handle_t last_optimized_x_handle; + oneapi::math::sparse::dense_vector_handle_t last_optimized_y_handle; + oneapi::math::sparse::spmv_alg last_optimized_alg; +}; + +} // namespace oneapi::math::sparse + +namespace oneapi::math::sparse::rocsparse { + +namespace detail { + +inline auto get_roc_spmv_alg(spmv_alg alg) { + switch (alg) { + case spmv_alg::coo_alg1: return rocsparse_spmv_alg_coo; + case spmv_alg::coo_alg2: return rocsparse_spmv_alg_coo_atomic; + case spmv_alg::csr_alg1: return rocsparse_spmv_alg_csr_adaptive; + case spmv_alg::csr_alg2: return rocsparse_spmv_alg_csr_stream; + case spmv_alg::csr_alg3: return rocsparse_spmv_alg_csr_lrb; + default: return rocsparse_spmv_alg_default; + } +} + +void check_valid_spmv(const std::string& function_name, oneapi::math::transpose opA, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + bool is_alpha_host_accessible, bool is_beta_host_accessible) { + check_valid_spmv_common(function_name, opA, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->check_valid_handle(__func__); + if (A_view.type_view != oneapi::math::sparse::matrix_descr::general) { + throw math::unimplemented( + "sparse_blas", function_name, + "The backend does not support spmv with a `type_view` other than `matrix_descr::general`."); + } +} + +inline void common_spmv_optimize(oneapi::math::transpose opA, bool is_alpha_host_accessible, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + bool is_beta_host_accessible, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spmv_alg alg, + oneapi::math::sparse::spmv_descr_t spmv_descr) { + check_valid_spmv("spmv_optimize", opA, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible, is_beta_host_accessible); + if (!spmv_descr->buffer_size_called) { + throw math::uninitialized( + "sparse_blas", "spmv_optimize", + "spmv_buffer_size must be called with the same arguments before spmv_optimize."); + } + spmv_descr->optimized_called = true; + spmv_descr->last_optimized_opA = opA; + spmv_descr->last_optimized_A_view = A_view; + spmv_descr->last_optimized_A_handle = A_handle; + spmv_descr->last_optimized_x_handle = x_handle; + spmv_descr->last_optimized_y_handle = y_handle; + spmv_descr->last_optimized_alg = alg; +} + +void spmv_optimize_impl(rocsparse_handle roc_handle, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, const void* beta, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spmv_alg alg, std::size_t buffer_size, + void* workspace_ptr, bool is_alpha_host_accessible) { + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + // rocsparse_spmv_stage_preprocess stage is blocking + auto status = + rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg, + rocsparse_spmv_stage_preprocess, &buffer_size, workspace_ptr); + check_status(status, "spmv_optimize"); +} + +} // namespace detail + +void init_spmv_descr(sycl::queue& /*queue*/, spmv_descr_t* p_spmv_descr) { + *p_spmv_descr = new spmv_descr(); +} + +sycl::event release_spmv_descr(sycl::queue& queue, spmv_descr_t spmv_descr, + const std::vector& dependencies) { + if (!spmv_descr) { + return detail::collapse_dependencies(queue, dependencies); + } + + auto release_functor = [=]() { + spmv_descr->roc_handle = nullptr; + spmv_descr->last_optimized_A_handle = nullptr; + spmv_descr->last_optimized_x_handle = nullptr; + spmv_descr->last_optimized_y_handle = nullptr; + delete spmv_descr; + }; + + // Use dispatch_submit to ensure the descriptor is kept alive as long as the buffers are used + // dispatch_submit can only be used if the descriptor's handles are valid + if (spmv_descr->last_optimized_A_handle && + spmv_descr->last_optimized_A_handle->all_use_buffer() && + spmv_descr->last_optimized_x_handle && spmv_descr->last_optimized_y_handle && + spmv_descr->workspace.use_buffer()) { + auto dispatch_functor = [=](sycl::interop_handle, sycl::accessor) { + release_functor(); + }; + return detail::dispatch_submit( + __func__, queue, dispatch_functor, spmv_descr->last_optimized_A_handle, + spmv_descr->workspace.get_buffer(), spmv_descr->last_optimized_x_handle, + spmv_descr->last_optimized_y_handle); + } + + // Release used if USM is used or if the descriptor has been released before spmv_optimize has succeeded + sycl::event event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + cgh.host_task(release_functor); + }); + return event; +} + +void spmv_buffer_size(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, const void* beta, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spmv_alg alg, + oneapi::math::sparse::spmv_descr_t spmv_descr, + std::size_t& temp_buffer_size) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + detail::check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible, is_beta_host_accessible); + bool is_in_order_queue = queue.is_in_order(); + auto functor = [=, &temp_buffer_size](sycl::interop_handle ih) { + detail::RocsparseScopedContextHandler sc(queue, ih); + auto [roc_handle, hip_stream] = sc.get_handle_and_stream(queue); + spmv_descr->roc_handle = roc_handle; + spmv_descr->hip_stream = hip_stream; + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + auto roc_op = detail::get_roc_operation(opA); + auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = detail::get_roc_spmv_alg(alg); + detail::set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg, + rocsparse_spmv_stage_buffer_size, &temp_buffer_size, nullptr); + detail::check_status(status, __func__); + detail::synchronize_if_needed(is_in_order_queue, hip_stream); + }; + auto event = detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + event.wait_and_throw(); + spmv_descr->temp_buffer_size = temp_buffer_size; + spmv_descr->buffer_size_called = true; +} + +void spmv_optimize(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, const void* beta, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spmv_alg alg, + oneapi::math::sparse::spmv_descr_t spmv_descr, + sycl::buffer workspace) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + if (!A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + detail::common_spmv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, + is_beta_host_accessible, y_handle, alg, spmv_descr); + // Copy the buffer to extend its lifetime until the descriptor is free'd. + spmv_descr->workspace.set_buffer_untyped(workspace); + if (alg == oneapi::math::sparse::spmv_alg::no_optimize_alg) { + return; + } + std::size_t buffer_size = spmv_descr->temp_buffer_size; + // The accessor can only be created if the buffer size is greater than 0 + if (buffer_size > 0) { + auto functor = [=](sycl::interop_handle ih, sycl::accessor workspace_acc) { + auto roc_handle = spmv_descr->roc_handle; + auto workspace_ptr = detail::get_mem(ih, workspace_acc); + detail::spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, + alg, buffer_size, workspace_ptr, is_alpha_host_accessible); + }; + + detail::dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle); + } + else { + auto functor = [=](sycl::interop_handle) { + auto roc_handle = spmv_descr->roc_handle; + detail::spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, + alg, buffer_size, nullptr, is_alpha_host_accessible); + }; + + detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + } +} + +sycl::event spmv_optimize(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, const void* beta, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spmv_alg alg, + oneapi::math::sparse::spmv_descr_t spmv_descr, void* workspace, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + if (A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + detail::common_spmv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, + is_beta_host_accessible, y_handle, alg, spmv_descr); + spmv_descr->workspace.usm_ptr = workspace; + if (alg == oneapi::math::sparse::spmv_alg::no_optimize_alg) { + return detail::collapse_dependencies(queue, dependencies); + } + std::size_t buffer_size = spmv_descr->temp_buffer_size; + auto functor = [=](sycl::interop_handle) { + auto roc_handle = spmv_descr->roc_handle; + detail::spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, alg, + buffer_size, workspace, is_alpha_host_accessible); + }; + + return detail::dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle, + y_handle); +} + +sycl::event spmv(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, const void* beta, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spmv_alg alg, oneapi::math::sparse::spmv_descr_t spmv_descr, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + if (A_handle->all_use_buffer() != spmv_descr->workspace.use_buffer()) { + detail::throw_incompatible_container(__func__); + } + detail::check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible, is_beta_host_accessible); + + if (!spmv_descr->optimized_called) { + throw math::uninitialized( + "sparse_blas", __func__, + "spmv_optimize must be called with the same arguments before spmv."); + } + CHECK_DESCR_MATCH(spmv_descr, opA, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, A_view, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, A_handle, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, x_handle, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, y_handle, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, alg, "spmv_optimize"); + + A_handle->mark_used(); + bool is_in_order_queue = queue.is_in_order(); + auto compute_functor = [=](void* workspace_ptr) { + auto roc_handle = spmv_descr->roc_handle; + auto hip_stream = spmv_descr->hip_stream; + auto buffer_size = spmv_descr->temp_buffer_size; + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + auto roc_op = detail::get_roc_operation(opA); + auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = detail::get_roc_spmv_alg(alg); + detail::set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg, + rocsparse_spmv_stage_compute, &buffer_size, workspace_ptr); + detail::check_status(status, __func__); + detail::synchronize_if_needed(is_in_order_queue, hip_stream); + }; + // The accessor can only be created if the buffer size is greater than 0 + if (A_handle->all_use_buffer() && spmv_descr->temp_buffer_size > 0) { + auto functor_buffer = [=](sycl::interop_handle ih, + sycl::accessor workspace_acc) { + auto workspace_ptr = detail::get_mem(ih, workspace_acc); + compute_functor(workspace_ptr); + }; + return detail::dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle, + spmv_descr->workspace.get_buffer(), + x_handle, y_handle); + } + else { + // The same dispatch_submit can be used for USM or buffers if no + // workspace accessor is needed. + // workspace_ptr will be a nullptr in the latter case. + auto workspace_ptr = spmv_descr->workspace.usm_ptr; + auto functor_usm = [=](sycl::interop_handle) { + compute_functor(workspace_ptr); + }; + return detail::dispatch_submit_native_ext(__func__, queue, dependencies, functor_usm, + A_handle, x_handle, y_handle); + } +} + +} // namespace oneapi::math::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp new file mode 100644 index 000000000..07a6c72d9 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp @@ -0,0 +1,333 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp" +#include "sparse_blas/common_op_verification.hpp" +#include "sparse_blas/macros.hpp" +#include "sparse_blas/matrix_view_comparison.hpp" +#include "sparse_blas/sycl_helper.hpp" + +namespace oneapi::math::sparse { + +// Complete the definition of the incomplete type +struct spsv_descr { + // Cache the hipStream_t and global handle to avoid relying on RocsparseScopedContextHandler to retrieve them. + hipStream_t hip_stream; + rocsparse_handle roc_handle; + + detail::generic_container workspace; + std::size_t temp_buffer_size = 0; + bool buffer_size_called = false; + bool optimized_called = false; + oneapi::math::transpose last_optimized_opA; + oneapi::math::sparse::matrix_view last_optimized_A_view; + oneapi::math::sparse::matrix_handle_t last_optimized_A_handle; + oneapi::math::sparse::dense_vector_handle_t last_optimized_x_handle; + oneapi::math::sparse::dense_vector_handle_t last_optimized_y_handle; + oneapi::math::sparse::spsv_alg last_optimized_alg; +}; + +} // namespace oneapi::math::sparse + +namespace oneapi::math::sparse::rocsparse { + +namespace detail { + +inline auto get_roc_spsv_alg(spsv_alg /*alg*/) { + return rocsparse_spsv_alg_default; +} + +void check_valid_spsv(const std::string& function_name, matrix_view A_view, + matrix_handle_t A_handle, dense_vector_handle_t x_handle, + dense_vector_handle_t y_handle, bool is_alpha_host_accessible) { + check_valid_spsv_common(function_name, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + A_handle->check_valid_handle(function_name); +} + +inline void common_spsv_optimize(oneapi::math::transpose opA, bool is_alpha_host_accessible, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spsv_alg alg, + oneapi::math::sparse::spsv_descr_t spsv_descr) { + check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + if (!spsv_descr->buffer_size_called) { + throw math::uninitialized( + "sparse_blas", "spsv_optimize", + "spsv_buffer_size must be called with the same arguments before spsv_optimize."); + } + spsv_descr->optimized_called = true; + spsv_descr->last_optimized_opA = opA; + spsv_descr->last_optimized_A_view = A_view; + spsv_descr->last_optimized_A_handle = A_handle; + spsv_descr->last_optimized_x_handle = x_handle; + spsv_descr->last_optimized_y_handle = y_handle; + spsv_descr->last_optimized_alg = alg; +} + +void spsv_optimize_impl(rocsparse_handle roc_handle, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spsv_alg alg, std::size_t buffer_size, + void* workspace_ptr, bool is_alpha_host_accessible) { + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + set_matrix_attributes("spsv_optimize", roc_a, A_view); + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spsv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + // rocsparse_spsv_stage_preprocess stage is blocking + auto status = rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg, + rocsparse_spsv_stage_preprocess, &buffer_size, workspace_ptr); + check_status(status, "spsv_optimize"); +} + +} // namespace detail + +void init_spsv_descr(sycl::queue& /*queue*/, spsv_descr_t* p_spsv_descr) { + *p_spsv_descr = new spsv_descr(); +} + +sycl::event release_spsv_descr(sycl::queue& queue, spsv_descr_t spsv_descr, + const std::vector& dependencies) { + if (!spsv_descr) { + return detail::collapse_dependencies(queue, dependencies); + } + + auto release_functor = [=]() { + spsv_descr->roc_handle = nullptr; + spsv_descr->last_optimized_A_handle = nullptr; + spsv_descr->last_optimized_x_handle = nullptr; + spsv_descr->last_optimized_y_handle = nullptr; + delete spsv_descr; + }; + + // Use dispatch_submit to ensure the descriptor is kept alive as long as the buffers are used + // dispatch_submit can only be used if the descriptor's handles are valid + if (spsv_descr->last_optimized_A_handle && + spsv_descr->last_optimized_A_handle->all_use_buffer() && + spsv_descr->last_optimized_x_handle && spsv_descr->last_optimized_y_handle && + spsv_descr->workspace.use_buffer()) { + auto dispatch_functor = [=](sycl::interop_handle, sycl::accessor) { + release_functor(); + }; + return detail::dispatch_submit( + __func__, queue, dispatch_functor, spsv_descr->last_optimized_A_handle, + spsv_descr->workspace.get_buffer(), spsv_descr->last_optimized_x_handle, + spsv_descr->last_optimized_y_handle); + } + + // Release used if USM is used or if the descriptor has been released before spmv_optimize has succeeded + sycl::event event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + cgh.host_task(release_functor); + }); + return event; +} + +void spsv_buffer_size(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spsv_alg alg, + oneapi::math::sparse::spsv_descr_t spsv_descr, + std::size_t& temp_buffer_size) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + detail::check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + bool is_in_order_queue = queue.is_in_order(); + auto functor = [=, &temp_buffer_size](sycl::interop_handle ih) { + detail::RocsparseScopedContextHandler sc(queue, ih); + auto [roc_handle, hip_stream] = sc.get_handle_and_stream(queue); + spsv_descr->roc_handle = roc_handle; + spsv_descr->hip_stream = hip_stream; + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + detail::set_matrix_attributes(__func__, roc_a, A_view); + auto roc_op = detail::get_roc_operation(opA); + auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = detail::get_roc_spsv_alg(alg); + detail::set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg, + rocsparse_spsv_stage_buffer_size, &temp_buffer_size, nullptr); + detail::check_status(status, __func__); + detail::synchronize_if_needed(is_in_order_queue, hip_stream); + }; + auto event = detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + event.wait_and_throw(); + spsv_descr->temp_buffer_size = temp_buffer_size; + spsv_descr->buffer_size_called = true; +} + +void spsv_optimize(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spsv_alg alg, + oneapi::math::sparse::spsv_descr_t spsv_descr, + sycl::buffer workspace) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + if (!A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->check_valid_handle(__func__); + detail::common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, + y_handle, alg, spsv_descr); + // Ignore spsv_alg::no_optimize_alg as this step is mandatory for rocSPARSE + // Copy the buffer to extend its lifetime until the descriptor is free'd. + spsv_descr->workspace.set_buffer_untyped(workspace); + std::size_t buffer_size = spsv_descr->temp_buffer_size; + // The accessor can only be created if the buffer size is greater than 0 + if (buffer_size > 0) { + auto functor = [=](sycl::interop_handle ih, sycl::accessor workspace_acc) { + auto roc_handle = spsv_descr->roc_handle; + auto workspace_ptr = detail::get_mem(ih, workspace_acc); + detail::spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, + alg, buffer_size, workspace_ptr, is_alpha_host_accessible); + }; + + detail::dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle); + } + else { + auto functor = [=](sycl::interop_handle) { + auto roc_handle = spsv_descr->roc_handle; + detail::spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, + alg, buffer_size, nullptr, is_alpha_host_accessible); + }; + + detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + } +} + +sycl::event spsv_optimize(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spsv_alg alg, + oneapi::math::sparse::spsv_descr_t spsv_descr, void* workspace, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + if (A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->check_valid_handle(__func__); + detail::common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, + y_handle, alg, spsv_descr); + spsv_descr->workspace.usm_ptr = workspace; + // Ignore spsv_alg::no_optimize_alg as this step is mandatory for rocSPARSE + std::size_t buffer_size = spsv_descr->temp_buffer_size; + auto functor = [=](sycl::interop_handle) { + auto roc_handle = spsv_descr->roc_handle; + detail::spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, + alg, buffer_size, workspace, is_alpha_host_accessible); + }; + + return detail::dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle, + y_handle); +} + +sycl::event spsv(sycl::queue& queue, oneapi::math::transpose opA, const void* alpha, + oneapi::math::sparse::matrix_view A_view, + oneapi::math::sparse::matrix_handle_t A_handle, + oneapi::math::sparse::dense_vector_handle_t x_handle, + oneapi::math::sparse::dense_vector_handle_t y_handle, + oneapi::math::sparse::spsv_alg alg, oneapi::math::sparse::spsv_descr_t spsv_descr, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + if (A_handle->all_use_buffer() != spsv_descr->workspace.use_buffer()) { + detail::throw_incompatible_container(__func__); + } + detail::check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + + if (!spsv_descr->optimized_called) { + throw math::uninitialized( + "sparse_blas", __func__, + "spsv_optimize must be called with the same arguments before spsv."); + } + CHECK_DESCR_MATCH(spsv_descr, opA, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, A_view, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, A_handle, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, x_handle, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, y_handle, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, alg, "spsv_optimize"); + + A_handle->mark_used(); + bool is_in_order_queue = queue.is_in_order(); + auto compute_functor = [=](void* workspace_ptr) { + auto roc_handle = spsv_descr->roc_handle; + auto hip_stream = spsv_descr->hip_stream; + auto buffer_size = spsv_descr->temp_buffer_size; + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + detail::set_matrix_attributes(__func__, roc_a, A_view); + auto roc_op = detail::get_roc_operation(opA); + auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = detail::get_roc_spsv_alg(alg); + detail::set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg, + rocsparse_spsv_stage_compute, &buffer_size, workspace_ptr); + detail::check_status(status, __func__); + detail::synchronize_if_needed(is_in_order_queue, hip_stream); + }; + // The accessor can only be created if the buffer size is greater than 0 + if (A_handle->all_use_buffer() && spsv_descr->temp_buffer_size > 0) { + auto functor_buffer = [=](sycl::interop_handle ih, + sycl::accessor workspace_acc) { + auto workspace_ptr = detail::get_mem(ih, workspace_acc); + compute_functor(workspace_ptr); + }; + return detail::dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle, + spsv_descr->workspace.get_buffer(), + x_handle, y_handle); + } + else { + // The same dispatch_submit can be used for USM or buffers if no + // workspace accessor is needed. + // workspace_ptr will be a nullptr in the latter case. + auto workspace_ptr = spsv_descr->workspace.usm_ptr; + auto functor_usm = [=](sycl::interop_handle) { + compute_functor(workspace_ptr); + }; + return detail::dispatch_submit_native_ext(__func__, queue, dependencies, functor_usm, + A_handle, x_handle, y_handle); + } +} + +} // namespace oneapi::math::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp new file mode 100644 index 000000000..59da045c9 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp @@ -0,0 +1,126 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_ + +#include + +#include +#include + +#include "oneapi/math/exceptions.hpp" + +namespace oneapi::math::sparse::rocsparse::detail { + +inline std::string hip_result_to_str(hipError_t result) { + switch (result) { +#define ONEMATH_ROCSPARSE_CASE(STATUS) \ + case STATUS: return #STATUS + ONEMATH_ROCSPARSE_CASE(hipSuccess); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidContext); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidKernelFile); + ONEMATH_ROCSPARSE_CASE(hipErrorMemoryAllocation); + ONEMATH_ROCSPARSE_CASE(hipErrorInitializationError); + ONEMATH_ROCSPARSE_CASE(hipErrorLaunchFailure); + ONEMATH_ROCSPARSE_CASE(hipErrorLaunchOutOfResources); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidDevice); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidValue); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidDevicePointer); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidMemcpyDirection); + ONEMATH_ROCSPARSE_CASE(hipErrorUnknown); + ONEMATH_ROCSPARSE_CASE(hipErrorInvalidResourceHandle); + ONEMATH_ROCSPARSE_CASE(hipErrorNotReady); + ONEMATH_ROCSPARSE_CASE(hipErrorNoDevice); + ONEMATH_ROCSPARSE_CASE(hipErrorPeerAccessAlreadyEnabled); + ONEMATH_ROCSPARSE_CASE(hipErrorPeerAccessNotEnabled); + ONEMATH_ROCSPARSE_CASE(hipErrorRuntimeMemory); + ONEMATH_ROCSPARSE_CASE(hipErrorRuntimeOther); + ONEMATH_ROCSPARSE_CASE(hipErrorHostMemoryAlreadyRegistered); + ONEMATH_ROCSPARSE_CASE(hipErrorHostMemoryNotRegistered); + ONEMATH_ROCSPARSE_CASE(hipErrorMapBufferObjectFailed); + ONEMATH_ROCSPARSE_CASE(hipErrorTbd); + default: return ""; + } +} + +#define HIP_ERROR_FUNC(func, ...) \ + do { \ + auto res = func(__VA_ARGS__); \ + if (res != hipSuccess) { \ + throw oneapi::math::exception("sparse_blas", #func, \ + "hip error: " + detail::hip_result_to_str(res)); \ + } \ + } while (0) + +inline std::string rocsparse_status_to_str(rocsparse_status status) { + switch (status) { +#define ONEMATH_ROCSPARSE_CASE(STATUS) \ + case STATUS: return #STATUS + ONEMATH_ROCSPARSE_CASE(rocsparse_status_success); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_invalid_handle); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_not_implemented); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_invalid_pointer); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_invalid_size); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_memory_error); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_internal_error); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_invalid_value); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_arch_mismatch); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_zero_pivot); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_not_initialized); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_type_mismatch); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_requires_sorted_storage); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_thrown_exception); + ONEMATH_ROCSPARSE_CASE(rocsparse_status_continue); +#undef ONEMATH_ROCSPARSE_CASE + default: return ""; + } +} + +inline void check_status(rocsparse_status status, const std::string& function, + std::string error_str = "") { + if (status != rocsparse_status_success) { + if (!error_str.empty()) { + error_str += "; "; + } + error_str += "rocSPARSE status: " + rocsparse_status_to_str(status); + switch (status) { + case rocsparse_status_not_implemented: + throw oneapi::math::unimplemented("sparse_blas", function, error_str); + case rocsparse_status_invalid_handle: + case rocsparse_status_invalid_pointer: + case rocsparse_status_invalid_size: + case rocsparse_status_invalid_value: + throw oneapi::math::invalid_argument("sparse_blas", function, error_str); + case rocsparse_status_not_initialized: + throw oneapi::math::uninitialized("sparse_blas", function, error_str); + default: throw oneapi::math::exception("sparse_blas", function, error_str); + } + } +} + +#define ROCSPARSE_ERR_FUNC(func, ...) \ + do { \ + auto status = func(__VA_ARGS__); \ + detail::check_status(status, #func); \ + } while (0) + +} // namespace oneapi::math::sparse::rocsparse::detail + +#endif // _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp new file mode 100644 index 000000000..78699bc8f --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp @@ -0,0 +1,63 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_ + +/** + * @file Similar to blas_handle.hpp + * Provides a map from a ur_context_handle_t (or equivalent) to a rocsparse_handle. + * @see rocsparse_scope_handle.hpp +*/ + +#include +#include + +namespace oneapi::math::sparse::rocsparse::detail { + +template +struct rocsparse_global_handle { + using handle_container_t = std::unordered_map*>; + handle_container_t rocsparse_global_handle_mapper_{}; + + ~rocsparse_global_handle() noexcept(false) { + for (auto& handle_pair : rocsparse_global_handle_mapper_) { + if (handle_pair.second != nullptr) { + auto handle = handle_pair.second->exchange(nullptr); + if (handle != nullptr) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_handle, handle); + handle = nullptr; + } + else { + // if the handle is nullptr it means the handle was already + // destroyed by the ContextCallback and we're free to delete the + // atomic object. + delete handle_pair.second; + } + + handle_pair.second = nullptr; + } + } + rocsparse_global_handle_mapper_.clear(); + } +}; + +} // namespace oneapi::math::sparse::rocsparse::detail + +#endif // _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp new file mode 100644 index 000000000..63cd41c7e --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp @@ -0,0 +1,491 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp" + +#include "rocsparse_error.hpp" +#include "rocsparse_helper.hpp" +#include "rocsparse_handles.hpp" +#include "rocsparse_scope_handle.hpp" +#include "rocsparse_task.hpp" +#include "sparse_blas/macros.hpp" + +namespace oneapi::math::sparse::rocsparse { + +/** + * In this file RocsparseScopedContextHandler are used to ensure that a rocsparse_handle is created before any other rocSPARSE call, as required by the specification. +*/ + +// Dense vector +template +void init_dense_vector(sycl::queue& queue, dense_vector_handle_t* p_dvhandle, std::int64_t size, + sycl::buffer val) { + auto event = queue.submit([&](sycl::handler& cgh) { + auto acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_value_type = detail::RocEnumType::value; + rocsparse_dnvec_descr roc_dvhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &roc_dvhandle, size, + detail::get_mem(ih, acc), roc_value_type); + *p_dvhandle = new dense_vector_handle(roc_dvhandle, val, size); + }); + }); + event.wait_and_throw(); +} + +template +void init_dense_vector(sycl::queue& queue, dense_vector_handle_t* p_dvhandle, std::int64_t size, + fpType* val) { + auto event = queue.submit([&](sycl::handler& cgh) { + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_value_type = detail::RocEnumType::value; + rocsparse_dnvec_descr roc_dvhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &roc_dvhandle, size, val, + roc_value_type); + *p_dvhandle = new dense_vector_handle(roc_dvhandle, val, size); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_vector_data(sycl::queue& queue, dense_vector_handle_t dvhandle, std::int64_t size, + sycl::buffer val) { + detail::check_can_reset_value_handle(__func__, dvhandle, true); + auto event = queue.submit([&](sycl::handler& cgh) { + auto acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + if (dvhandle->size != size) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle); + auto roc_value_type = detail::RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &dvhandle->backend_handle, size, + detail::get_mem(ih, acc), roc_value_type); + dvhandle->size = size; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnvec_set_values, dvhandle->backend_handle, + detail::get_mem(ih, acc)); + } + dvhandle->set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_vector_data(sycl::queue&, dense_vector_handle_t dvhandle, std::int64_t size, + fpType* val) { + detail::check_can_reset_value_handle(__func__, dvhandle, false); + if (dvhandle->size != size) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle); + auto roc_value_type = detail::RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &dvhandle->backend_handle, size, val, + roc_value_type); + dvhandle->size = size; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnvec_set_values, dvhandle->backend_handle, val); + } + dvhandle->set_usm_ptr(val); +} + +FOR_EACH_FP_TYPE(INSTANTIATE_DENSE_VECTOR_FUNCS); + +sycl::event release_dense_vector(sycl::queue& queue, dense_vector_handle_t dvhandle, + const std::vector& dependencies) { + // Use dispatch_submit_impl_fp to ensure the backend's handle is kept alive as long as the buffer is used + auto functor = [=](sycl::interop_handle) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle); + delete dvhandle; + }; + return detail::dispatch_submit_impl_fp(__func__, queue, dependencies, functor, dvhandle); +} + +// Dense matrix +template +void init_dense_matrix(sycl::queue& queue, dense_matrix_handle_t* p_dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout, + sycl::buffer val) { + auto event = queue.submit([&](sycl::handler& cgh) { + auto acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_value_type = detail::RocEnumType::value; + auto roc_order = detail::get_roc_order(dense_layout); + rocsparse_dnmat_descr roc_dmhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &roc_dmhandle, num_rows, num_cols, ld, + detail::get_mem(ih, acc), roc_value_type, roc_order); + *p_dmhandle = + new dense_matrix_handle(roc_dmhandle, val, num_rows, num_cols, ld, dense_layout); + }); + }); + event.wait_and_throw(); +} + +template +void init_dense_matrix(sycl::queue& queue, dense_matrix_handle_t* p_dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout, fpType* val) { + auto event = queue.submit([&](sycl::handler& cgh) { + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_value_type = detail::RocEnumType::value; + auto roc_order = detail::get_roc_order(dense_layout); + rocsparse_dnmat_descr roc_dmhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &roc_dmhandle, num_rows, num_cols, ld, + val, roc_value_type, roc_order); + *p_dmhandle = + new dense_matrix_handle(roc_dmhandle, val, num_rows, num_cols, ld, dense_layout); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_matrix_data(sycl::queue& queue, dense_matrix_handle_t dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + oneapi::math::layout dense_layout, sycl::buffer val) { + detail::check_can_reset_value_handle(__func__, dmhandle, true); + auto event = queue.submit([&](sycl::handler& cgh) { + auto acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + if (dmhandle->num_rows != num_rows || dmhandle->num_cols != num_cols || + dmhandle->ld != ld || dmhandle->dense_layout != dense_layout) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle); + auto roc_value_type = detail::RocEnumType::value; + auto roc_order = detail::get_roc_order(dense_layout); + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &dmhandle->backend_handle, + num_rows, num_cols, ld, detail::get_mem(ih, acc), roc_value_type, + roc_order); + dmhandle->num_rows = num_rows; + dmhandle->num_cols = num_cols; + dmhandle->ld = ld; + dmhandle->dense_layout = dense_layout; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnmat_set_values, dmhandle->backend_handle, + detail::get_mem(ih, acc)); + } + dmhandle->set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_matrix_data(sycl::queue&, dense_matrix_handle_t dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, + oneapi::math::layout dense_layout, fpType* val) { + detail::check_can_reset_value_handle(__func__, dmhandle, false); + if (dmhandle->num_rows != num_rows || dmhandle->num_cols != num_cols || dmhandle->ld != ld || + dmhandle->dense_layout != dense_layout) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle); + auto roc_value_type = detail::RocEnumType::value; + auto roc_order = detail::get_roc_order(dense_layout); + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &dmhandle->backend_handle, num_rows, + num_cols, ld, val, roc_value_type, roc_order); + dmhandle->num_rows = num_rows; + dmhandle->num_cols = num_cols; + dmhandle->ld = ld; + dmhandle->dense_layout = dense_layout; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnmat_set_values, dmhandle->backend_handle, val); + } + dmhandle->set_usm_ptr(val); +} + +FOR_EACH_FP_TYPE(INSTANTIATE_DENSE_MATRIX_FUNCS); + +sycl::event release_dense_matrix(sycl::queue& queue, dense_matrix_handle_t dmhandle, + const std::vector& dependencies) { + // Use dispatch_submit_impl_fp to ensure the backend's handle is kept alive as long as the buffer is used + auto functor = [=](sycl::interop_handle) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle); + delete dmhandle; + }; + return detail::dispatch_submit_impl_fp(__func__, queue, dependencies, functor, dmhandle); +} + +// COO matrix +template +void init_coo_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + sycl::buffer row_ind, sycl::buffer col_ind, + sycl::buffer val) { + auto event = queue.submit([&](sycl::handler& cgh) { + auto row_acc = row_ind.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &roc_smhandle, num_rows, num_cols, nnz, + detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc), + detail::get_mem(ih, val_acc), roc_index_type, roc_index_base, + roc_value_type); + *p_smhandle = + new matrix_handle(roc_smhandle, row_ind, col_ind, val, detail::sparse_format::COO, + num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void init_coo_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + intType* row_ind, intType* col_ind, fpType* val) { + auto event = queue.submit([&](sycl::handler& cgh) { + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &roc_smhandle, num_rows, num_cols, nnz, + row_ind, col_ind, val, roc_index_type, roc_index_base, + roc_value_type); + *p_smhandle = + new matrix_handle(roc_smhandle, row_ind, col_ind, val, detail::sparse_format::COO, + num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void set_coo_matrix_data(sycl::queue& queue, matrix_handle_t smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + sycl::buffer row_ind, sycl::buffer col_ind, + sycl::buffer val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, true); + auto event = queue.submit([&](sycl::handler& cgh) { + auto row_acc = row_ind.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || + smhandle->nnz != nnz || smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, detail::get_mem(ih, row_acc), + detail::get_mem(ih, col_acc), detail::get_mem(ih, val_acc), + roc_index_type, roc_index_base, roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_coo_set_pointers, smhandle->backend_handle, + detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc), + detail::get_mem(ih, val_acc)); + } + smhandle->row_container.set_buffer(row_ind); + smhandle->col_container.set_buffer(col_ind); + smhandle->value_container.set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_coo_matrix_data(sycl::queue&, matrix_handle_t smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + intType* row_ind, intType* col_ind, fpType* val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, false); + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || smhandle->nnz != nnz || + smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, row_ind, col_ind, val, roc_index_type, roc_index_base, + roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_coo_set_pointers, smhandle->backend_handle, row_ind, col_ind, + val); + } + smhandle->row_container.set_usm_ptr(row_ind); + smhandle->col_container.set_usm_ptr(col_ind); + smhandle->value_container.set_usm_ptr(val); +} + +FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_COO_MATRIX_FUNCS); + +// CSR matrix +template +void init_csr_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + sycl::buffer row_ptr, sycl::buffer col_ind, + sycl::buffer val) { + auto event = queue.submit([&](sycl::handler& cgh) { + auto row_acc = row_ptr.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &roc_smhandle, num_rows, num_cols, nnz, + detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc), + detail::get_mem(ih, val_acc), roc_index_type, roc_index_type, + roc_index_base, roc_value_type); + *p_smhandle = + new matrix_handle(roc_smhandle, row_ptr, col_ind, val, detail::sparse_format::CSR, + num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void init_csr_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + intType* row_ptr, intType* col_ind, fpType* val) { + auto event = queue.submit([&](sycl::handler& cgh) { + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &roc_smhandle, num_rows, num_cols, nnz, + row_ptr, col_ind, val, roc_index_type, roc_index_type, + roc_index_base, roc_value_type); + *p_smhandle = + new matrix_handle(roc_smhandle, row_ptr, col_ind, val, detail::sparse_format::CSR, + num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void set_csr_matrix_data(sycl::queue& queue, matrix_handle_t smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + sycl::buffer row_ptr, sycl::buffer col_ind, + sycl::buffer val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, true); + auto event = queue.submit([&](sycl::handler& cgh) { + auto row_acc = row_ptr.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) { + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || + smhandle->nnz != nnz || smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, detail::get_mem(ih, row_acc), + detail::get_mem(ih, col_acc), detail::get_mem(ih, val_acc), + roc_index_type, roc_index_type, roc_index_base, roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_csr_set_pointers, smhandle->backend_handle, + detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc), + detail::get_mem(ih, val_acc)); + } + smhandle->row_container.set_buffer(row_ptr); + smhandle->col_container.set_buffer(col_ind); + smhandle->value_container.set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_csr_matrix_data(sycl::queue&, matrix_handle_t smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index, + intType* row_ptr, intType* col_ind, fpType* val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, false); + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || smhandle->nnz != nnz || + smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = detail::RocIndexEnumType::value; + auto roc_index_base = detail::get_roc_index_base(index); + auto roc_value_type = detail::RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, row_ptr, col_ind, val, roc_index_type, roc_index_type, + roc_index_base, roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_csr_set_pointers, smhandle->backend_handle, row_ptr, col_ind, + val); + } + smhandle->row_container.set_usm_ptr(row_ptr); + smhandle->col_container.set_usm_ptr(col_ind); + smhandle->value_container.set_usm_ptr(val); +} + +FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_CSR_MATRIX_FUNCS); + +sycl::event release_sparse_matrix(sycl::queue& queue, matrix_handle_t smhandle, + const std::vector& dependencies) { + // Use dispatch_submit to ensure the backend's handle is kept alive as long as the buffers are used + auto functor = [=](sycl::interop_handle) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + delete smhandle; + }; + return detail::dispatch_submit(__func__, queue, dependencies, functor, smhandle); +} + +// Matrix property +bool set_matrix_property(sycl::queue&, matrix_handle_t smhandle, matrix_property property) { + // No equivalent in rocSPARSE + // Store the matrix property internally for future usages + smhandle->set_matrix_property(property); + return false; +} + +} // namespace oneapi::math::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp new file mode 100644 index 000000000..a7c429fad --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp @@ -0,0 +1,111 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_ +#define _ONEMATH_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_ + +#include + +#include "sparse_blas/generic_container.hpp" + +namespace oneapi::math::sparse { + +// Complete the definition of incomplete types dense_vector_handle, dense_matrix_handle and matrix_handle. + +struct dense_vector_handle : public detail::generic_dense_vector_handle { + template + dense_vector_handle(rocsparse_dnvec_descr roc_descr, T* value_ptr, std::int64_t size) + : detail::generic_dense_vector_handle(roc_descr, value_ptr, + size) {} + + template + dense_vector_handle(rocsparse_dnvec_descr roc_descr, const sycl::buffer value_buffer, + std::int64_t size) + : detail::generic_dense_vector_handle(roc_descr, value_buffer, + size) {} +}; + +struct dense_matrix_handle : public detail::generic_dense_matrix_handle { + template + dense_matrix_handle(rocsparse_dnmat_descr roc_descr, T* value_ptr, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout) + : detail::generic_dense_matrix_handle( + roc_descr, value_ptr, num_rows, num_cols, ld, dense_layout) {} + + template + dense_matrix_handle(rocsparse_dnmat_descr roc_descr, const sycl::buffer value_buffer, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + layout dense_layout) + : detail::generic_dense_matrix_handle( + roc_descr, value_buffer, num_rows, num_cols, ld, dense_layout) {} +}; + +struct matrix_handle : public detail::generic_sparse_handle { + // A matrix handle should only be used once per operation to be safe with the rocSPARSE backend. + // An operation can store information in the handle. See details in https://github.com/ROCm/rocSPARSE/issues/332. +private: + bool used = false; + +public: + template + matrix_handle(rocsparse_spmat_descr roc_descr, intType* row_ptr, intType* col_ptr, + fpType* value_ptr, detail::sparse_format format, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::math::index_base index) + : detail::generic_sparse_handle( + roc_descr, row_ptr, col_ptr, value_ptr, format, num_rows, num_cols, nnz, index) {} + + template + matrix_handle(rocsparse_spmat_descr roc_descr, const sycl::buffer row_buffer, + const sycl::buffer col_buffer, + const sycl::buffer value_buffer, detail::sparse_format format, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::math::index_base index) + : detail::generic_sparse_handle(roc_descr, row_buffer, + col_buffer, value_buffer, format, + num_rows, num_cols, nnz, index) { + } + + void check_valid_handle(const std::string& function_name) { + if (used) { + throw math::unimplemented( + "sparse_blas", function_name, + "The backend does not support re-using the same sparse matrix handle in multiple operations."); + } + if (this->format == detail::sparse_format::COO && + !this->has_matrix_property(matrix_property::sorted)) { + throw math::unimplemented( + "sparse_blas", function_name, + "The backend does not support unsorted COO format. Use `set_matrix_property` to set the property `matrix_property::sorted`"); + } + if (this->format == detail::sparse_format::CSR && + !this->has_matrix_property(matrix_property::sorted)) { + throw math::unimplemented( + "sparse_blas", function_name, + "The backend does not support unsorted CSR format. Use `set_matrix_property` to set the property `matrix_property::sorted`"); + } + } + + void mark_used() { + used = true; + } +}; + +} // namespace oneapi::math::sparse + +#endif // _ONEMATH_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp new file mode 100644 index 000000000..56e3e6b75 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp @@ -0,0 +1,162 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_HELPER_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_HELPER_HPP_ + +#include +#include +#include +#include + +#include + +#include "oneapi/math/sparse_blas/types.hpp" +#include "sparse_blas/enum_data_types.hpp" +#include "sparse_blas/sycl_helper.hpp" +#include "rocsparse_error.hpp" + +namespace oneapi::math::sparse::rocsparse::detail { + +using namespace oneapi::math::sparse::detail; + +template +struct RocEnumType; +template <> +struct RocEnumType { + static constexpr rocsparse_datatype value = rocsparse_datatype_f32_r; +}; +template <> +struct RocEnumType { + static constexpr rocsparse_datatype value = rocsparse_datatype_f64_r; +}; +template <> +struct RocEnumType> { + static constexpr rocsparse_datatype value = rocsparse_datatype_f32_c; +}; +template <> +struct RocEnumType> { + static constexpr rocsparse_datatype value = rocsparse_datatype_f64_c; +}; + +template +struct RocIndexEnumType; +template <> +struct RocIndexEnumType { + static constexpr rocsparse_indextype value = rocsparse_indextype_i32; +}; +template <> +struct RocIndexEnumType { + static constexpr rocsparse_indextype value = rocsparse_indextype_i64; +}; + +template +inline std::string cast_enum_to_str(E e) { + return std::to_string(static_cast(e)); +} + +inline auto get_roc_value_type(data_type onemath_data_type) { + switch (onemath_data_type) { + case data_type::real_fp32: return rocsparse_datatype_f32_r; + case data_type::real_fp64: return rocsparse_datatype_f64_r; + case data_type::complex_fp32: return rocsparse_datatype_f32_c; + case data_type::complex_fp64: return rocsparse_datatype_f64_c; + default: + throw oneapi::math::invalid_argument( + "sparse_blas", "get_roc_value_type", + "Invalid data type: " + cast_enum_to_str(onemath_data_type)); + } +} + +inline auto get_roc_order(layout l) { + switch (l) { + case layout::row_major: return rocsparse_order_row; + case layout::col_major: return rocsparse_order_column; + default: + throw oneapi::math::invalid_argument("sparse_blas", "get_roc_order", + "Unknown layout: " + cast_enum_to_str(l)); + } +} + +inline auto get_roc_index_base(index_base index) { + switch (index) { + case index_base::zero: return rocsparse_index_base_zero; + case index_base::one: return rocsparse_index_base_one; + default: + throw oneapi::math::invalid_argument("sparse_blas", "get_roc_index_base", + "Unknown index_base: " + cast_enum_to_str(index)); + } +} + +inline auto get_roc_operation(transpose op) { + switch (op) { + case transpose::nontrans: return rocsparse_operation_none; + case transpose::trans: return rocsparse_operation_transpose; + case transpose::conjtrans: return rocsparse_operation_conjugate_transpose; + default: + throw oneapi::math::invalid_argument( + "sparse_blas", "get_roc_operation", + "Unknown transpose operation: " + cast_enum_to_str(op)); + } +} + +inline auto get_roc_uplo(uplo uplo_val) { + switch (uplo_val) { + case uplo::upper: return rocsparse_fill_mode_upper; + case uplo::lower: return rocsparse_fill_mode_lower; + default: + throw oneapi::math::invalid_argument("sparse_blas", "get_roc_uplo", + "Unknown uplo: " + cast_enum_to_str(uplo_val)); + } +} + +inline auto get_roc_diag(diag diag_val) { + switch (diag_val) { + case diag::nonunit: return rocsparse_diag_type_non_unit; + case diag::unit: return rocsparse_diag_type_unit; + default: + throw oneapi::math::invalid_argument("sparse_blas", "get_roc_diag", + "Unknown diag: " + cast_enum_to_str(diag_val)); + } +} + +inline void set_matrix_attributes(const std::string& func_name, rocsparse_spmat_descr roc_a, + oneapi::math::sparse::matrix_view A_view) { + auto roc_fill_mode = get_roc_uplo(A_view.uplo_view); + auto status = rocsparse_spmat_set_attribute(roc_a, rocsparse_spmat_fill_mode, &roc_fill_mode, + sizeof(roc_fill_mode)); + check_status(status, func_name + "/set_uplo"); + + auto roc_diag_type = get_roc_diag(A_view.diag_view); + status = rocsparse_spmat_set_attribute(roc_a, rocsparse_spmat_diag_type, &roc_diag_type, + sizeof(roc_diag_type)); + check_status(status, func_name + "/set_diag"); +} + +/** + * rocSPARSE requires to set the pointer mode for scalars parameters (typically alpha and beta). + */ +inline void set_pointer_mode(rocsparse_handle roc_handle, bool is_ptr_accessible_on_host) { + rocsparse_set_pointer_mode(roc_handle, is_ptr_accessible_on_host + ? rocsparse_pointer_mode_host + : rocsparse_pointer_mode_device); +} + +} // namespace oneapi::math::sparse::rocsparse::detail + +#endif //_ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_HELPER_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp new file mode 100644 index 000000000..b933b4ba4 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp @@ -0,0 +1,135 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +/** + * @file Similar to rocblas_scope_handle.cpp +*/ + +#include "rocsparse_scope_handle.hpp" + +namespace oneapi::math::sparse::rocsparse::detail { + +/** + * Inserts a new element in the map if its key is unique. This new element + * is constructed in place using args as the arguments for the construction + * of a value_type (which is an object of a pair type). The insertion only + * takes place if no other element in the container has a key equivalent to + * the one being emplaced (keys in a map container are unique). + */ +#ifdef ONEAPI_ONEMATH_PI_INTERFACE_REMOVED +thread_local rocsparse_handle_container + RocsparseScopedContextHandler::handle_helper = + rocsparse_handle_container{}; +#else +thread_local rocsparse_handle_container RocsparseScopedContextHandler::handle_helper = + rocsparse_handle_container{}; +#endif + +// Disable warning for deprecated hipCtxGetCurrent and similar hip runtime functions +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +RocsparseScopedContextHandler::RocsparseScopedContextHandler(sycl::queue queue, + sycl::interop_handle& ih) + : ih(ih), + needToRecover_(false) { + placedContext_ = new sycl::context(queue.get_context()); + auto hipDevice = ih.get_native_device(); + hipCtx_t desired; + HIP_ERROR_FUNC(hipCtxGetCurrent, &original_); + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, &desired, hipDevice); + if (original_ != desired) { + // Sets the desired context as the active one for the thread + HIP_ERROR_FUNC(hipCtxSetCurrent, desired); + // No context is installed and the suggested context is primary + // This is the most common case. We can activate the context in the + // thread and leave it there until all the PI context referring to the + // same underlying rocsparse primary context are destroyed. This emulates + // the behaviour of the rocsparse runtime api, and avoids costly context + // switches. No action is required on this side of the if. + needToRecover_ = !(original_ == nullptr); + } +} + +RocsparseScopedContextHandler::~RocsparseScopedContextHandler() noexcept(false) { + if (needToRecover_) { + HIP_ERROR_FUNC(hipCtxSetCurrent, original_); + } + delete placedContext_; +} + +void ContextCallback(void* userData) { + auto* atomic_ptr = static_cast*>(userData); + if (!atomic_ptr) { + return; + } + auto handle = atomic_ptr->exchange(nullptr); + if (handle != nullptr) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_handle, handle); + } + delete atomic_ptr; +} + +std::pair RocsparseScopedContextHandler::get_handle_and_stream( + const sycl::queue& queue) { + auto hipDevice = ih.get_native_device(); + hipCtx_t desired; + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, &desired, hipDevice); +#ifdef ONEAPI_ONEMATH_PI_INTERFACE_REMOVED + auto piPlacedContext_ = reinterpret_cast(desired); +#else + auto piPlacedContext_ = reinterpret_cast(desired); +#endif + hipStream_t streamId = sycl::get_native(queue); + auto it = handle_helper.rocsparse_handle_container_mapper_.find(piPlacedContext_); + if (it != handle_helper.rocsparse_handle_container_mapper_.end()) { + if (it->second == nullptr) { + handle_helper.rocsparse_handle_container_mapper_.erase(it); + } + else { + auto handle = it->second->load(); + if (handle != nullptr) { + hipStream_t currentStreamId; + ROCSPARSE_ERR_FUNC(rocsparse_get_stream, handle, ¤tStreamId); + if (currentStreamId != streamId) { + ROCSPARSE_ERR_FUNC(rocsparse_set_stream, handle, streamId); + } + return { handle, streamId }; + } + else { + handle_helper.rocsparse_handle_container_mapper_.erase(it); + } + } + } + + rocsparse_handle handle; + ROCSPARSE_ERR_FUNC(rocsparse_create_handle, &handle); + ROCSPARSE_ERR_FUNC(rocsparse_set_stream, handle, streamId); + + auto atomic_ptr = new std::atomic(handle); + handle_helper.rocsparse_handle_container_mapper_.insert( + std::make_pair(piPlacedContext_, atomic_ptr)); + + sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, atomic_ptr); + return { handle, streamId }; +} + +#pragma clang diagnostic pop + +} // namespace oneapi::math::sparse::rocsparse::detail diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp new file mode 100644 index 000000000..3cb72f455 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp @@ -0,0 +1,87 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_SCOPE_HANDLE_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_SCOPE_HANDLE_HPP_ + +/** + * @file Similar to rocblas_scope_handle.hpp +*/ + +#if __has_include() +#include +#else +#include +#endif + +// After Plugin Interface removal in DPC++ ur.hpp is the new include +#if __has_include() && !defined(ONEAPI_ONEMATH_PI_INTERFACE_REMOVED) +#define ONEAPI_ONEMATH_PI_INTERFACE_REMOVED +#endif + +#include + +#include "rocsparse_error.hpp" +#include "rocsparse_global_handle.hpp" +#include "rocsparse_helper.hpp" + +namespace oneapi::math::sparse::rocsparse::detail { + +template +struct rocsparse_handle_container { + using handle_container_t = std::unordered_map*>; + handle_container_t rocsparse_handle_container_mapper_{}; + + // Do not free any pointer nor handle in this destructor. The resources are + // free'd via the PI ContextCallback to ensure the context is still alive. + ~rocsparse_handle_container() = default; +}; + +class RocsparseScopedContextHandler { + HIPcontext original_; + sycl::context* placedContext_; + sycl::interop_handle& ih; + bool needToRecover_; +#ifdef ONEAPI_ONEMATH_PI_INTERFACE_REMOVED + static thread_local rocsparse_handle_container handle_helper; +#else + static thread_local rocsparse_handle_container handle_helper; +#endif + +public: + RocsparseScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih); + ~RocsparseScopedContextHandler() noexcept(false); + + std::pair get_handle_and_stream(const sycl::queue& queue); + + rocsparse_handle get_handle(const sycl::queue& queue) { + return get_handle_and_stream(queue).first; + } +}; + +// Get the native pointer from an accessor. This is a different pointer than +// what can be retrieved with get_multi_ptr. +template +inline void* get_mem(sycl::interop_handle ih, AccT acc) { + auto hipPtr = ih.get_native_mem(acc); + return reinterpret_cast(hipPtr); +} + +} // namespace oneapi::math::sparse::rocsparse::detail + +#endif //_ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_SCOPE_HANDLE_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_task.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_task.hpp new file mode 100644 index 000000000..29a4ef901 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_task.hpp @@ -0,0 +1,44 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_TASK_HPP_ +#define _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_TASK_HPP_ + +#include "rocsparse_error.hpp" +#include "sparse_blas/backends/common_launch_task.hpp" + +namespace oneapi::math::sparse::rocsparse::detail { + +// Helper function for functors submitted to host_task or native_command. +// When the extension is disabled, host_task are used and the synchronization is needed to ensure the sycl::event corresponds to the end of the whole functor. +// When the extension is enabled, host_task are still used for out-of-order queues, see description of dispatch_submit_impl_fp_int. +inline void synchronize_if_needed(bool is_in_order_queue, hipStream_t hip_stream) { +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + (void)is_in_order_queue; + HIP_ERROR_FUNC(hipStreamSynchronize, hip_stream); +#else + if (!is_in_order_queue) { + HIP_ERROR_FUNC(hipStreamSynchronize, hip_stream); + } +#endif +} + +} // namespace oneapi::math::sparse::rocsparse::detail + +#endif // _ONEMATH_SPARSE_BLAS_BACKENDS_ROCSPARSE_TASK_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp new file mode 100644 index 000000000..0c23372d8 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp @@ -0,0 +1,32 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/math/sparse_blas/types.hpp" + +#include "oneapi/math/sparse_blas/detail/rocsparse/onemath_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND rocsparse + +extern "C" sparse_blas_function_table_t onemath_sparse_blas_table = { + WRAPPER_VERSION, +#include "sparse_blas/backends/backend_wrappers.cxx" +}; diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index 13108dd32..e335d4315 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -197,6 +197,11 @@ foreach(domain ${TEST_TARGET_DOMAINS}) list(APPEND ONEMATH_LIBRARIES_${domain} onemath_${domain}_cusparse) endif() + if(domain STREQUAL "sparse_blas" AND ENABLE_ROCSPARSE_BACKEND) + add_dependencies(test_main_${domain}_ct onemath_${domain}_rocsparse) + list(APPEND ONEMATH_LIBRARIES_${domain} onemath_${domain}_rocsparse) + endif() + target_link_libraries(test_main_${domain}_ct PUBLIC gtest gtest_main diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index 5d2996b57..f91d71d96 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -183,6 +183,13 @@ #define TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, ...) #endif +#ifdef ONEMATH_ENABLE_ROCSPARSE_BACKEND +#define TEST_RUN_AMDGPU_ROCSPARSE_SELECT(q, func, ...) \ + func(oneapi::math::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_AMDGPU_ROCSPARSE_SELECT(q, func, ...) +#endif + #ifndef __HIPSYCL__ #define CHECK_HOST_OR_CPU(q) q.get_device().is_cpu() #else @@ -278,6 +285,9 @@ else if (vendor_id == NVIDIA_ID) { \ TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, __VA_ARGS__); \ } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCSPARSE_SELECT(q, func, __VA_ARGS__); \ + } \ } \ } while (0); diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index a170c45da..bcc0bec38 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -135,7 +135,8 @@ int main(int argc, char** argv) { #if !defined(ONEMATH_ENABLE_ROCBLAS_BACKEND) && !defined(ONEMATH_ENABLE_ROCRAND_BACKEND) && \ !defined(ONEMATH_ENABLE_ROCSOLVER_BACKEND) && \ !defined(ONEMATH_ENABLE_PORTBLAS_BACKEND_AMD_GPU) && \ - !defined(ONEMATH_ENABLE_ROCFFT_BACKEND) && !defined(ONEMATH_ENABLE_PORTFFT_BACKEND) + !defined(ONEMATH_ENABLE_ROCFFT_BACKEND) && !defined(ONEMATH_ENABLE_PORTFFT_BACKEND) && \ + !defined(ONEMATH_ENABLE_ROCSPARSE_BACKEND) if (dev.is_gpu() && vendor_id == AMD_ID) continue; #endif diff --git a/tests/unit_tests/sparse_blas/include/test_common.hpp b/tests/unit_tests/sparse_blas/include/test_common.hpp index 6a577a389..da466e6cb 100644 --- a/tests/unit_tests/sparse_blas/include/test_common.hpp +++ b/tests/unit_tests/sparse_blas/include/test_common.hpp @@ -65,6 +65,10 @@ inline std::set get_default_matrix_proper if (vendor_id == oneapi::math::device::nvidiagpu && format == sparse_matrix_format_t::COO) { return { oneapi::math::sparse::matrix_property::sorted_by_rows }; } + if (vendor_id == oneapi::math::device::amdgpu && + (format == sparse_matrix_format_t::COO || format == sparse_matrix_format_t::CSR)) { + return { oneapi::math::sparse::matrix_property::sorted }; + } return {}; } @@ -72,15 +76,6 @@ inline std::set get_default_matrix_proper inline std::vector> get_all_matrix_properties_combinations(sycl::queue queue, sparse_matrix_format_t format) { auto vendor_id = oneapi::math::get_device_id(queue); - if (vendor_id == oneapi::math::device::nvidiagpu && format == sparse_matrix_format_t::COO) { - // Ensure all the sets have the sorted or sorted_by_rows properties - return { { oneapi::math::sparse::matrix_property::sorted }, - { oneapi::math::sparse::matrix_property::sorted_by_rows, - oneapi::math::sparse::matrix_property::symmetric }, - { oneapi::math::sparse::matrix_property::sorted, - oneapi::math::sparse::matrix_property::symmetric } }; - } - std::vector> properties_combinations{ { oneapi::math::sparse::matrix_property::sorted }, { oneapi::math::sparse::matrix_property::symmetric }, @@ -91,6 +86,10 @@ get_all_matrix_properties_combinations(sycl::queue queue, sparse_matrix_format_t properties_combinations.push_back( { oneapi::math::sparse::matrix_property::sorted_by_rows }); } + if (vendor_id == oneapi::math::device::nvidiagpu || vendor_id == oneapi::math::device::amdgpu) { + // Test without any properties set since for backends for which this is not the default behavior + properties_combinations.push_back({}); + } return properties_combinations; } @@ -230,7 +229,7 @@ void rand_matrix(std::vector& m, oneapi::math::layout layout_val, std::s } /// Generate random value in the range [-0.5, 0.5] -/// The amplitude is guaranteed to be >= 0.1 if is_diag is true +/// The amplitude is guaranteed to be >= 10 if is_diag is true template fpType generate_data(bool is_diag) { rand_scalar rand_data;