From 821f85f1153361fcdee328b702090a5282386afe Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Fri, 17 Nov 2023 15:33:46 +0000 Subject: [PATCH 1/2] Fixed is_device_copyable error --- include/operations/blas_constants.h | 12 +++++ src/interface/blas2_interface.hpp | 71 ++++++++++++++++------------- src/operations/blas2/trsv.hpp | 0 3 files changed, 52 insertions(+), 31 deletions(-) delete mode 100644 src/operations/blas2/trsv.hpp diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index 637f23f95..a5f77d1ba 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -265,4 +265,16 @@ struct constant_pair { } // namespace blas +#define PORTBLAS_DEVICE_COPY(T1, T2) \ + template <> \ + struct sycl::is_device_copyable> \ + : std::true_type {}; + +PORTBLAS_DEVICE_COPY(int32_t, float) +PORTBLAS_DEVICE_COPY(int32_t, double) +PORTBLAS_DEVICE_COPY(int64_t, float) +PORTBLAS_DEVICE_COPY(int64_t, double) + +#undef PORTBLAS_DEVICE_COPY + #endif // BLAS_CONSTANTS_H diff --git a/src/interface/blas2_interface.hpp b/src/interface/blas2_interface.hpp index ee9ea5fc0..3e6ee2e31 100644 --- a/src/interface/blas2_interface.hpp +++ b/src/interface/blas2_interface.hpp @@ -345,10 +345,10 @@ template -typename sb_handle_t::event_t _trsv_impl(sb_handle_t& sb_handle, index_t _N, - container_t0 _mA, index_t _lda, - container_t1 _vx, increment_t _incx, - const typename sb_handle_t::event_t& _dependencies) { +typename sb_handle_t::event_t _trsv_impl( + sb_handle_t& sb_handle, index_t _N, container_t0 _mA, index_t _lda, + container_t1 _vx, increment_t _incx, + const typename sb_handle_t::event_t& _dependencies) { #if (SYCL_LANGUAGE_VERSION < 202000) || (defined __HIPSYCL__) throw std::runtime_error("trsv requires SYCL 2020"); #else @@ -392,7 +392,8 @@ typename sb_handle_t::event_t _trsv_impl(sb_handle_t& sb_handle, index_t _N, auto ret = sb_handle.execute( trsv, static_cast(sub_num * subgroup_size), roundUp(sub_num * _N, sub_num * subgroup_size), - static_cast(subgroup_size * (subgroup_size + 2 + sub_num)), _dependencies); + static_cast(subgroup_size * (subgroup_size + 2 + sub_num)), + _dependencies); blas::helper::enqueue_deallocate(ret, sync_buffer, queue); @@ -727,18 +728,16 @@ template -typename sb_handle_t::event_t _tbsv_impl(sb_handle_t& sb_handle, index_t _N, - index_t _K, container_t0 _mA, - index_t _lda, container_t1 _vx, - increment_t _incx, - const typename sb_handle_t::event_t& _dependencies) { +typename sb_handle_t::event_t _tbsv_impl( + sb_handle_t& sb_handle, index_t _N, index_t _K, container_t0 _mA, + index_t _lda, container_t1 _vx, increment_t _incx, + const typename sb_handle_t::event_t& _dependencies) { #if (SYCL_LANGUAGE_VERSION < 202000) || (defined __HIPSYCL__) throw std::runtime_error("tbsv requires SYCL 2020"); #else static_assert(subgroup_size % subgroups == 0, "`subgroups` needs to be a multiple of `subgroup_size`."); - if (_K >= _N) throw std::invalid_argument("Erroneous parameter: _K >= _N"); using one = constant; @@ -780,7 +779,8 @@ typename sb_handle_t::event_t _tbsv_impl(sb_handle_t& sb_handle, index_t _N, auto ret = sb_handle.execute( tbsv, static_cast(sub_num * subgroup_size), roundUp(sub_num * _N, sub_num * subgroup_size), - static_cast(subgroup_size * (subgroup_size + 2 + sub_num)), _dependencies); + static_cast(subgroup_size * (subgroup_size + 2 + sub_num)), + _dependencies); blas::helper::enqueue_deallocate(ret, sync_buffer, queue); @@ -792,10 +792,9 @@ template -typename sb_handle_t::event_t _tpsv_impl(sb_handle_t& sb_handle, index_t _N, - container_t0 _mA, container_t1 _vx, - increment_t _incx, - const typename sb_handle_t::event_t& _dependencies) { +typename sb_handle_t::event_t _tpsv_impl( + sb_handle_t& sb_handle, index_t _N, container_t0 _mA, container_t1 _vx, + increment_t _incx, const typename sb_handle_t::event_t& _dependencies) { #if (SYCL_LANGUAGE_VERSION < 202000) || (defined __HIPSYCL__) throw std::runtime_error("tpsv requires SYCL 2020"); #else @@ -823,8 +822,18 @@ typename sb_handle_t::event_t _tpsv_impl(sb_handle_t& sb_handle, index_t _N, : ((roundUp(_N, subgroup_size) / subgroup_size) - 1); sync_vec[1] = sync_vec[0]; - auto sync_buffer = - blas::make_sycl_iterator_buffer(sync_vec, sync_vec.size()); + constexpr bool is_usm = std::is_pointer::value; + auto queue = sb_handle.get_queue(); + + auto sync_buffer = blas::helper::allocate < is_usm + ? blas::helper::AllocType::usm + : blas::helper::AllocType::buffer, + int32_t > (sync_vec.size(), queue); + + auto copy_sync = blas::helper::copy_to_device( + queue, sync_vec.data(), sync_buffer, sync_vec.size()); + sb_handle.wait(copy_sync); + auto sync = make_vector_view(sync_buffer, one_increment_t::value(), sync_vec.size()); @@ -833,11 +842,13 @@ typename sb_handle_t::event_t _tpsv_impl(sb_handle_t& sb_handle, index_t _N, vx, mA, sync); const index_t sub_num = subgroups; - return sb_handle.execute( + auto ret = sb_handle.execute( tpsv, static_cast(sub_num * subgroup_size), roundUp(sub_num * _N, sub_num * subgroup_size), static_cast(subgroup_size * (subgroup_size + 2 + sub_num)), _dependencies); + blas::helper::enqueue_deallocate(ret, sync_buffer, queue); + return ret; #endif } @@ -1329,11 +1340,10 @@ typename sb_handle_t::event_t inline _spr2( template -typename sb_handle_t::event_t inline _syr2(sb_handle_t& sb_handle, char _Uplo, - index_t _N, element_t _alpha, - container_t0 _vx, increment_t _incx, - container_t1 _vy, increment_t _incy, - container_t2 _mA, index_t _lda, +typename sb_handle_t::event_t inline _syr2( + sb_handle_t& sb_handle, char _Uplo, index_t _N, element_t _alpha, + container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy, + container_t2 _mA, index_t _lda, const typename sb_handle_t::event_t& _dependencies) { // TODO: Here we can use some heuristics to select localn global, local, and // scratch size per device @@ -1366,17 +1376,16 @@ typename sb_handle_t::event_t _tpmv( sb_handle_t& sb_handle, char _Uplo, char _trans, char _Diag, index_t _N, container_t0 _mA, container_t1 _vx, increment_t _incx, const typename sb_handle_t::event_t& _dependencies) { -INST_UPLO_TRANS_DIAG(blas::tpmv::backend::_tpmv, sb_handle, _N, _mA, _vx, - _incx, _dependencies) + INST_UPLO_TRANS_DIAG(blas::tpmv::backend::_tpmv, sb_handle, _N, _mA, _vx, + _incx, _dependencies) } template -typename sb_handle_t::event_t _tpsv(sb_handle_t& sb_handle, char _Uplo, - char _trans, char _Diag, index_t _N, - container_t0 _mA, container_t1 _vx, - increment_t _incx, -const typename sb_handle_t::event_t& _dependencies) { +typename sb_handle_t::event_t _tpsv( + sb_handle_t& sb_handle, char _Uplo, char _trans, char _Diag, index_t _N, + container_t0 _mA, container_t1 _vx, increment_t _incx, + const typename sb_handle_t::event_t& _dependencies) { INST_UPLO_TRANS_DIAG(blas::tpsv::backend::_tpsv, sb_handle, _N, _mA, _vx, _incx, _dependencies) } diff --git a/src/operations/blas2/trsv.hpp b/src/operations/blas2/trsv.hpp deleted file mode 100644 index e69de29bb..000000000 From 18c96e10c979efc991c367333876ab542349e234 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 22 Nov 2023 12:26:20 +0000 Subject: [PATCH 2/2] Addressed comment --- include/operations/blas_constants.h | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index a5f77d1ba..8214b2cf8 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -265,16 +265,8 @@ struct constant_pair { } // namespace blas -#define PORTBLAS_DEVICE_COPY(T1, T2) \ - template <> \ - struct sycl::is_device_copyable> \ - : std::true_type {}; - -PORTBLAS_DEVICE_COPY(int32_t, float) -PORTBLAS_DEVICE_COPY(int32_t, double) -PORTBLAS_DEVICE_COPY(int64_t, float) -PORTBLAS_DEVICE_COPY(int64_t, double) - -#undef PORTBLAS_DEVICE_COPY +template +struct sycl::is_device_copyable> + : std::true_type {}; #endif // BLAS_CONSTANTS_H