diff --git a/include/interface/blas1_interface.h b/include/interface/blas1_interface.h index 80d104e01..f344e55be 100644 --- a/include/interface/blas1_interface.h +++ b/include/interface/blas1_interface.h @@ -257,9 +257,9 @@ typename sb_handle_t::event_t _nrm2( * \brief Prototype for the internal implementation of the NRM2 operator. See * documentation in the blas1_interface.hpp file for details. */ -template +template typename sb_handle_t::event_t _nrm2_impl( sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx, container_1_t _rs, const index_t number_WG, @@ -269,8 +269,8 @@ typename sb_handle_t::event_t _nrm2_impl( * \brief Prototype for the internal implementation of the Dot operator. See * documentation in the blas1_interface.hpp file for details. */ -template typename sb_handle_t::event_t _dot_impl( sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx, diff --git a/src/interface/blas1/backend/amd_gpu.hpp b/src/interface/blas1/backend/amd_gpu.hpp index 741e7f730..7ec252995 100644 --- a/src/interface/blas1/backend/amd_gpu.hpp +++ b/src/interface/blas1/backend/amd_gpu.hpp @@ -130,16 +130,39 @@ template (localSize), 32>( - sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies); + /** + * Read comment in _asum above. + **/ + bool managed_mem{false}; + if constexpr (std::is_pointer_v) { + managed_mem = + sycl::usm::alloc::shared == + sycl::get_pointer_type(_rs, sb_handle.get_queue().get_context()); + } + if (managed_mem) { + if (_N < (1 << 18)) { + constexpr index_t localSize = 1024; + const index_t number_WG = (_N + localSize - 1) / localSize; + return blas::internal::_nrm2_impl(localSize), 32, true>( + sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies); + } else { + constexpr int localSize = 512; + constexpr index_t number_WG = 512; + return blas::internal::_nrm2_impl( + sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies); + } } else { - constexpr int localSize = 512; - constexpr index_t number_WG = 512; - return blas::internal::_nrm2_impl( - sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies); + if (_N < (1 << 18)) { + constexpr index_t localSize = 1024; + const index_t number_WG = (_N + localSize - 1) / localSize; + return blas::internal::_nrm2_impl(localSize), 32, false>( + sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies); + } else { + constexpr int localSize = 512; + constexpr index_t number_WG = 512; + return blas::internal::_nrm2_impl( + sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies); + } } } } // namespace backend @@ -153,16 +176,39 @@ typename sb_handle_t::event_t _dot( sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy, container_2_t _rs, const typename sb_handle_t::event_t& _dependencies) { - if (_N < (1 << 18)) { - constexpr index_t localSize = 1024; - const index_t number_WG = (_N + localSize - 1) / localSize; - return blas::internal::_dot_impl(localSize), 32>( - sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies); + /** + * Read comment in _asum above. + **/ + bool managed_mem{false}; + if constexpr (std::is_pointer_v) { + managed_mem = + sycl::usm::alloc::shared == + sycl::get_pointer_type(_rs, sb_handle.get_queue().get_context()); + } + if (managed_mem) { + if (_N < (1 << 18)) { + constexpr index_t localSize = 1024; + const index_t number_WG = (_N + localSize - 1) / localSize; + return blas::internal::_dot_impl(localSize), 32, true>( + sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies); + } else { + constexpr int localSize = 512; + constexpr index_t number_WG = 512; + return blas::internal::_dot_impl( + sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies); + } } else { - constexpr int localSize = 512; - constexpr index_t number_WG = 512; - return blas::internal::_dot_impl( - sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies); + if (_N < (1 << 18)) { + constexpr index_t localSize = 1024; + const index_t number_WG = (_N + localSize - 1) / localSize; + return blas::internal::_dot_impl(localSize), 32, false>( + sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies); + } else { + constexpr int localSize = 512; + constexpr index_t number_WG = 512; + return blas::internal::_dot_impl( + sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies); + } } } } // namespace backend diff --git a/src/interface/blas1_interface.hpp b/src/interface/blas1_interface.hpp index 7f6ee962e..bc458d42e 100644 --- a/src/interface/blas1_interface.hpp +++ b/src/interface/blas1_interface.hpp @@ -548,9 +548,9 @@ typename sb_handle_t::event_t _nrm2( * implementation use a kernel implementation which doesn't * require local memory. */ -template +template typename sb_handle_t::event_t _nrm2_impl( sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx, container_1_t _rs, const index_t number_WG, @@ -561,7 +561,8 @@ typename sb_handle_t::event_t _nrm2_impl( static_cast(1)); auto prdOp = make_op(vx); - auto assignOp = make_wg_atomic_reduction(rs, prdOp); + auto assignOp = + make_wg_atomic_reduction(rs, prdOp); typename sb_handle_t::event_t ret0; if constexpr (localMemSize != 0) { ret0 = sb_handle.execute(assignOp, static_cast(localSize), @@ -596,8 +597,8 @@ typename sb_handle_t::event_t _nrm2_impl( * implementation use a kernel implementation which doesn't * require local memory. */ -template typename sb_handle_t::event_t _dot_impl( sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx, @@ -613,7 +614,8 @@ typename sb_handle_t::event_t _dot_impl( static_cast(1)); auto prdOp = make_op(vx, vy); - auto wgReductionOp = make_wg_atomic_reduction(rs, prdOp); + auto wgReductionOp = + make_wg_atomic_reduction(rs, prdOp); if constexpr (localMemSize) { ret_event =