diff --git a/include/interface/extension_interface.h b/include/interface/extension_interface.h index 94fc13d80..2e78d7935 100644 --- a/include/interface/extension_interface.h +++ b/include/interface/extension_interface.h @@ -154,6 +154,15 @@ typename sb_handle_t::event_t _axpy_batch( index_t _stride_y, index_t _batch_size, const typename sb_handle_t::event_t& _dependencies); +template +typename sb_handle_t::event_t _axpy_batch_impl( + sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx, + index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, + index_t _stride_y, index_t _batch_size, + const typename sb_handle_t::event_t& _dependencies, index_t global_size); + } // namespace internal /** diff --git a/include/operations/extension/axpy_batch.h b/include/operations/extension/axpy_batch.h index b11bee990..ed96432d4 100644 --- a/include/operations/extension/axpy_batch.h +++ b/include/operations/extension/axpy_batch.h @@ -27,7 +27,8 @@ namespace blas { -template +template struct Axpy_batch { using value_t = typename lhs_t::value_t; using index_t = typename rhs_t::index_t; @@ -50,15 +51,16 @@ struct Axpy_batch { void adjust_access_displacement(); }; -template -Axpy_batch make_axpy_batch( +template +Axpy_batch make_axpy_batch( lhs_t _lhs, rhs_t _rhs_1, typename rhs_t::value_t _alpha, typename rhs_t::index_t _N, typename rhs_t::index_t _inc_l, typename rhs_t::index_t _lhs_stride, typename rhs_t::index_t _inc_r, typename rhs_t::index_t _rhs_stride, typename rhs_t::index_t _batch_size) { - return Axpy_batch(_lhs, _rhs_1, _alpha, _N, _inc_l, - _lhs_stride, _inc_r, _rhs_stride, - _batch_size); + return Axpy_batch( + _lhs, _rhs_1, _alpha, _N, _inc_l, _lhs_stride, _inc_r, _rhs_stride, + _batch_size); } } // namespace blas diff --git a/src/interface/extension/backend/amd_gpu.hpp b/src/interface/extension/backend/amd_gpu.hpp index 3ee9db746..63ae8725d 100644 --- a/src/interface/extension/backend/amd_gpu.hpp +++ b/src/interface/extension/backend/amd_gpu.hpp @@ -138,6 +138,31 @@ typename sb_handle_t::event_t _omatadd_batch( } } // namespace backend } // namespace omatadd_batch + +namespace axpy_batch { +namespace backend { +template +typename sb_handle_t::event_t _axpy_batch( + sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx, + index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, + index_t _stride_y, index_t _batch_size, + const typename sb_handle_t::event_t& _dependencies) { + // local_size taken empirically + constexpr index_t local_size = static_cast(256); + const auto nWG = (_N + local_size - 1) / local_size; + // the limit for _N*batch_size is taken empirically from test on AMDW6800 + const index_t global_size = + (_N * _batch_size >= 327680) + ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG + : local_size * nWG * _batch_size; + return blas::internal::_axpy_batch_impl<256, 32>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); +} +} // namespace backend +} // namespace axpy_batch + } // namespace blas #endif diff --git a/src/interface/extension/backend/default_cpu.hpp b/src/interface/extension/backend/default_cpu.hpp index b168bb6c5..3e11f4587 100644 --- a/src/interface/extension/backend/default_cpu.hpp +++ b/src/interface/extension/backend/default_cpu.hpp @@ -111,6 +111,30 @@ typename sb_handle_t::event_t _omatadd_batch( } } // namespace backend } // namespace omatadd_batch + +namespace axpy_batch { +namespace backend { +template +typename sb_handle_t::event_t _axpy_batch( + sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx, + index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, + index_t _stride_y, index_t _batch_size, + const typename sb_handle_t::event_t& _dependencies) { + // local_size taken empirically + constexpr index_t local_size = static_cast(256); + const auto nWG = (_N + local_size - 1) / local_size; + // the limit for _N*batch_size is taken empirically from test on AMDW6800 + const index_t global_size = + (_N * _batch_size >= 327680) + ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG + : local_size * nWG * _batch_size; + return blas::internal::_axpy_batch_impl<256, 32>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); +} +} // namespace backend +} // namespace axpy_batch } // namespace blas #endif diff --git a/src/interface/extension/backend/intel_gpu.hpp b/src/interface/extension/backend/intel_gpu.hpp index 9e2566aa7..b2d9a7b8e 100644 --- a/src/interface/extension/backend/intel_gpu.hpp +++ b/src/interface/extension/backend/intel_gpu.hpp @@ -132,6 +132,30 @@ typename sb_handle_t::event_t _omatadd_batch( } } // namespace backend } // namespace omatadd_batch + +namespace axpy_batch { +namespace backend { +template +typename sb_handle_t::event_t _axpy_batch( + sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx, + index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, + index_t _stride_y, index_t _batch_size, + const typename sb_handle_t::event_t& _dependencies) { + // local_size taken empirically + constexpr index_t local_size = static_cast(256); + const auto nWG = (_N + local_size - 1) / local_size; + // the limit for _N*batch_size is taken empirically from test on intelGPU + const index_t global_size = + (_N * _batch_size >= 327680) + ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG + : local_size * nWG * _batch_size; + return blas::internal::_axpy_batch_impl<256, 32>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); +} +} // namespace backend +} // namespace axpy_batch } // namespace blas #endif diff --git a/src/interface/extension/backend/nvidia_gpu.hpp b/src/interface/extension/backend/nvidia_gpu.hpp index e3aac7028..2cc56941d 100644 --- a/src/interface/extension/backend/nvidia_gpu.hpp +++ b/src/interface/extension/backend/nvidia_gpu.hpp @@ -138,6 +138,30 @@ typename sb_handle_t::event_t _omatadd_batch( } } // namespace backend } // namespace omatadd_batch + +namespace axpy_batch { +namespace backend { +template +typename sb_handle_t::event_t _axpy_batch( + sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx, + index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, + index_t _stride_y, index_t _batch_size, + const typename sb_handle_t::event_t& _dependencies) { + // local_size taken empirically + constexpr index_t local_size = static_cast(256); + const auto nWG = (_N + local_size - 1) / local_size; + // the limit for _N*batch_size is taken empirically from test on AMDW6800 + const index_t global_size = + (_N * _batch_size >= 327680) + ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG + : local_size * nWG * _batch_size; + return blas::internal::_axpy_batch_impl<256, 32>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); +} +} // namespace backend +} // namespace axpy_batch } // namespace blas #endif diff --git a/src/interface/extension_interface.hpp b/src/interface/extension_interface.hpp index c7906973f..7ef7859f6 100644 --- a/src/interface/extension_interface.hpp +++ b/src/interface/extension_interface.hpp @@ -597,7 +597,6 @@ typename sb_handle_t::event_t _reduction( sb_handle, buffer_in, ld, buffer_out, rows, cols, dependencies); } } - template typename sb_handle_t::event_t _axpy_batch( @@ -605,6 +604,19 @@ typename sb_handle_t::event_t _axpy_batch( index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, index_t _stride_y, index_t _batch_size, const typename sb_handle_t::event_t& _dependencies) { + return blas::axpy_batch::backend::_axpy_batch( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies); +} + +template +typename sb_handle_t::event_t _axpy_batch_impl( + sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx, + index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy, + index_t _stride_y, index_t _batch_size, + const typename sb_handle_t::event_t& _dependencies, index_t global_size) { // if inc are of opposite sign the values are exchanged. It doesn't matter // which one is positive or negative, so to simplify index computation in // kernel we always set incx to be negative and incy to be positive. @@ -616,23 +628,20 @@ typename sb_handle_t::event_t _axpy_batch( make_vector_view(_vx, static_cast(_incx), static_cast(_N * _batch_size)); auto vy = make_vector_view(_vy, _incy, _N * _batch_size); - const auto local_size = sb_handle.get_work_group_size(); - const auto nWG = (_N + local_size - 1) / local_size; - const auto global_size = local_size * nWG * _batch_size; // If both vectors are read from the same side it doesn't matter the sign of // the increment if (_incx * _incy > 0) { - auto op = - make_axpy_batch(vy, vx, _alpha, _N, std::abs(_incy), _stride_y, - std::abs(_incx), _stride_x, _batch_size); - typename sb_handle_t::event_t ret = - sb_handle.execute(op, local_size, global_size, _dependencies); + auto op = make_axpy_batch( + vy, vx, _alpha, _N, std::abs(_incy), _stride_y, std::abs(_incx), + _stride_x, _batch_size); + typename sb_handle_t::event_t ret = sb_handle.execute( + op, static_cast(localSize), global_size, _dependencies); return ret; } else { - auto op = make_axpy_batch(vy, vx, _alpha, _N, _incy, _stride_y, - _incx, _stride_x, _batch_size); - typename sb_handle_t::event_t ret = - sb_handle.execute(op, local_size, global_size, _dependencies); + auto op = make_axpy_batch( + vy, vx, _alpha, _N, _incy, _stride_y, _incx, _stride_x, _batch_size); + typename sb_handle_t::event_t ret = sb_handle.execute( + op, static_cast(localSize), global_size, _dependencies); return ret; } } diff --git a/src/operations/extension/axpy_batch.hpp b/src/operations/extension/axpy_batch.hpp index 2012137f9..ec48e09d0 100644 --- a/src/operations/extension/axpy_batch.hpp +++ b/src/operations/extension/axpy_batch.hpp @@ -30,8 +30,9 @@ namespace blas { -template -Axpy_batch::Axpy_batch( +template +Axpy_batch::Axpy_batch( lhs_t _lhs, rhs_t _rhs, typename lhs_t::value_t _alpha, typename rhs_t::index_t _N, typename rhs_t::index_t _inc_l, typename rhs_t::index_t _lhs_stride, typename rhs_t::index_t _inc_r, @@ -46,69 +47,107 @@ Axpy_batch::Axpy_batch( rhs_stride_(_rhs_stride), batch_size_(_batch_size){}; -template +template PORTBLAS_INLINE typename lhs_t::value_t -Axpy_batch::eval(index_t i) {} +Axpy_batch::eval( + index_t i) {} -template +template PORTBLAS_INLINE typename lhs_t::value_t -Axpy_batch::eval(cl::sycl::nd_item<1> ndItem) { +Axpy_batch::eval( + cl::sycl::nd_item<1> ndItem) { const index_t n{n_}; const value_t alpha{alpha_}; - - const index_t l_id = static_cast(ndItem.get_global_linear_id() % n); - const index_t group_id = - static_cast(ndItem.get_global_linear_id() / n); - - if (group_id >= batch_size_) return {}; - - if constexpr (same_sign) { - const index_t x_index = group_id * rhs_stride_ + l_id * inc_r; - const index_t y_index = group_id * lhs_stride_ + l_id * inc_l; - - const value_t ax = alpha * rhs_.get_data()[x_index]; - lhs_.get_data()[y_index] += ax; + const auto vx = rhs_.get_data(); + const auto vy = lhs_.get_data(); + const auto nbl = sycl::min((n + localSize - 1) / localSize, + static_cast(maxBlockPerBatch)); + + const index_t block_id = ndItem.get_group(0) % nbl; + const index_t l_id = + static_cast(ndItem.get_local_range(0)) * block_id + + ndItem.get_local_id(0); + // const index_t group_id = + // static_cast(ndItem.get_global_linear_id() / n); + const index_t group_id = static_cast(ndItem.get_group(0) / nbl); + + const index_t size_compute_rateo = + (n > nbl * localSize) ? n / (nbl * localSize) : batch_size_; + const index_t jump_value{sycl::min(batch_size_, size_compute_rateo)}; + + if (group_id >= jump_value || l_id > n) return {}; + + const index_t stride_x = ndItem.get_local_range(0) * nbl * inc_r; + const index_t stride_y = ndItem.get_local_range(0) * nbl * inc_l; + index_t x_index{}; + index_t y_index{}; + int j{0}; + + if constexpr (sameSign) { + for (auto out_loop = group_id; out_loop < batch_size_; + out_loop += jump_value) { + x_index = out_loop * rhs_stride_ + l_id * inc_r; + y_index = out_loop * lhs_stride_ + l_id * inc_l; + j = y_index; + for (auto i = x_index; i < (out_loop * rhs_stride_) + n * inc_r; + i += stride_x, j += stride_y) { + vy[j] += alpha * vx[i]; + } + } } else { - const index_t x_index = - group_id * rhs_stride_ + inc_r + n * sycl::abs(inc_r) + l_id * inc_r; - const index_t y_index = group_id * lhs_stride_ + l_id * inc_l; - - const value_t ax = alpha * rhs_.get_data()[x_index]; - lhs_.get_data()[y_index] += ax; + for (auto out_loop = group_id; out_loop < batch_size_; + out_loop += jump_value) { + x_index = + out_loop * rhs_stride_ + inc_r + n * sycl::abs(inc_r) + l_id * inc_r; + y_index = out_loop * lhs_stride_ + l_id * inc_l; + j = y_index; + for (auto i = x_index; i >= (out_loop * rhs_stride_); + i += stride_x, j += stride_y) { + vy[j] += alpha * vx[i]; + } + } } return {}; } -template +template template PORTBLAS_INLINE typename lhs_t::value_t -Axpy_batch::eval(sharedT shMem, - sycl::nd_item<1> ndItem){}; +Axpy_batch::eval( + sharedT shMem, sycl::nd_item<1> ndItem){}; -template -PORTBLAS_INLINE void Axpy_batch::bind( - cl::sycl::handler& h) { +template +PORTBLAS_INLINE void Axpy_batch::bind(cl::sycl::handler& h) { lhs_.bind(h); rhs_.bind(h); } -template -PORTBLAS_INLINE void -Axpy_batch::adjust_access_displacement() { +template +PORTBLAS_INLINE void Axpy_batch::adjust_access_displacement() { lhs_.adjust_access_displacement(); rhs_.adjust_access_displacement(); } -template -PORTBLAS_INLINE typename rhs_t::index_t -Axpy_batch::get_size() const { +template +PORTBLAS_INLINE typename rhs_t::index_t Axpy_batch< + sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::get_size() const { return n_ * batch_size_; } -template -PORTBLAS_INLINE bool Axpy_batch::valid_thread( +template +PORTBLAS_INLINE bool +Axpy_batch::valid_thread( cl::sycl::nd_item<1> ndItem) const { return true; }