Skip to content

Commit

Permalink
Add usm memory fix to other reduction operators
Browse files Browse the repository at this point in the history
Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
  • Loading branch information
s-Nick committed Mar 21, 2024
1 parent a532beb commit 4531e3d
Showing 3 changed files with 78 additions and 30 deletions.
10 changes: 5 additions & 5 deletions include/interface/blas1_interface.h
Original file line number Diff line number Diff line change
@@ -257,9 +257,9 @@ typename sb_handle_t::event_t _nrm2(
* \brief Prototype for the internal implementation of the NRM2 operator. See
* documentation in the blas1_interface.hpp file for details.
*/
template <int localSize, int localMemSize, typename sb_handle_t,
typename container_0_t, typename container_1_t, typename index_t,
typename increment_t>
template <int localSize, int localMemSize, bool usmManagedMem = false,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename index_t, typename increment_t>
typename sb_handle_t::event_t _nrm2_impl(
sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _rs, const index_t number_WG,
@@ -269,8 +269,8 @@ typename sb_handle_t::event_t _nrm2_impl(
* \brief Prototype for the internal implementation of the Dot operator. See
* documentation in the blas1_interface.hpp file for details.
*/
template <int localSize, int localMemSize, typename sb_handle_t,
typename container_0_t, typename container_1_t,
template <int localSize, int localMemSize, bool usmManagedMem = false,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot_impl(
sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
82 changes: 64 additions & 18 deletions src/interface/blas1/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
@@ -130,16 +130,39 @@ template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename sb_handle_t::event_t _nrm2(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _rs, const typename sb_handle_t::event_t& _dependencies) {
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_nrm2_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies);
/**
* Read comment in _asum above.
**/
bool managed_mem{false};
if constexpr (std::is_pointer_v<decltype(_rs)>) {
managed_mem =
sycl::usm::alloc::shared ==
sycl::get_pointer_type(_rs, sb_handle.get_queue().get_context());
}
if (managed_mem) {
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_nrm2_impl<static_cast<int>(localSize), 32, true>(
sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_nrm2_impl<localSize, 32, true>(
sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies);
}
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_nrm2_impl<localSize, 32>(
sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies);
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_nrm2_impl<static_cast<int>(localSize), 32, false>(
sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_nrm2_impl<localSize, 32, false>(
sb_handle, _N, _vx, _incx, _rs, number_WG, _dependencies);
}
}
}
} // namespace backend
@@ -153,16 +176,39 @@ typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
/**
* Read comment in _asum above.
**/
bool managed_mem{false};
if constexpr (std::is_pointer_v<decltype(_rs)>) {
managed_mem =
sycl::usm::alloc::shared ==
sycl::get_pointer_type(_rs, sb_handle.get_queue().get_context());
}
if (managed_mem) {
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_dot_impl<static_cast<int>(localSize), 32, true>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_dot_impl<localSize, 32, true>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_dot_impl<localSize, 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_dot_impl<static_cast<int>(localSize), 32, false>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_dot_impl<localSize, 32, false>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
}
}
} // namespace backend
16 changes: 9 additions & 7 deletions src/interface/blas1_interface.hpp
Original file line number Diff line number Diff line change
@@ -548,9 +548,9 @@ typename sb_handle_t::event_t _nrm2(
* implementation use a kernel implementation which doesn't
* require local memory.
*/
template <int localSize, int localMemSize, typename sb_handle_t,
typename container_0_t, typename container_1_t, typename index_t,
typename increment_t>
template <int localSize, int localMemSize, bool usmManagedMem,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename index_t, typename increment_t>
typename sb_handle_t::event_t _nrm2_impl(
sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _rs, const index_t number_WG,
@@ -561,7 +561,8 @@ typename sb_handle_t::event_t _nrm2_impl(
static_cast<index_t>(1));
auto prdOp = make_op<UnaryOp, SquareOperator>(vx);

auto assignOp = make_wg_atomic_reduction<AddOperator>(rs, prdOp);
auto assignOp =
make_wg_atomic_reduction<AddOperator, usmManagedMem>(rs, prdOp);
typename sb_handle_t::event_t ret0;
if constexpr (localMemSize != 0) {
ret0 = sb_handle.execute(assignOp, static_cast<index_t>(localSize),
@@ -596,8 +597,8 @@ typename sb_handle_t::event_t _nrm2_impl(
* implementation use a kernel implementation which doesn't
* require local memory.
*/
template <int localSize, int localMemSize, typename sb_handle_t,
typename container_0_t, typename container_1_t,
template <int localSize, int localMemSize, bool usmManagedMem,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot_impl(
sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
@@ -613,7 +614,8 @@ typename sb_handle_t::event_t _dot_impl(
static_cast<index_t>(1));

auto prdOp = make_op<BinaryOpConst, ProductOperator>(vx, vy);
auto wgReductionOp = make_wg_atomic_reduction<AddOperator>(rs, prdOp);
auto wgReductionOp =
make_wg_atomic_reduction<AddOperator, usmManagedMem>(rs, prdOp);

if constexpr (localMemSize) {
ret_event =

0 comments on commit 4531e3d

Please sign in to comment.