diff --git a/include/operations/extension/axpy_batch.h b/include/operations/extension/axpy_batch.h index ed96432d4..d25b0536f 100644 --- a/include/operations/extension/axpy_batch.h +++ b/include/operations/extension/axpy_batch.h @@ -27,6 +27,30 @@ namespace blas { +/*! + * This class holds the kernel implementation to perform axpy_batch + * operator. + * + * It has three additional template parameters to keep the operation simple and + * to avoid some computation or code divergence inside the kernel code. + * + * If sameSign is false the kernel always assumes that inc_r is negative. This + * is true by construction. When the increases are of different sizes the result + * positions are swapped and indexes must be computed accordingly. Keeping + * always inc_r negative and inc_l positive reduces keep index + * computation consistent, obtaining the correct result. + * + * sameSign indicate if inc_r and inc_l are of the sameSign. The code + * implementation need to follow different index computation. This template + * allow the condition at compile time, avoiding code divergency. + * + * localSize local size of group, allow some device tailoring at compile + * time. + * + * maxBlockPerBatch set the number of device group to use for each + * batch. If possible multiple batches are computed concurrently. + */ + template struct Axpy_batch { @@ -36,7 +60,8 @@ struct Axpy_batch { lhs_t lhs_; rhs_t rhs_; value_t alpha_; - index_t n_, inc_r, inc_l, lhs_stride_, rhs_stride_, batch_size_; + index_t n_, inc_r, inc_l, lhs_stride_, rhs_stride_, batch_size_, + n_block_per_loop; Axpy_batch(lhs_t _lhs, rhs_t _rhs_1, value_t _alpha, index_t _N, index_t _inc_l, index_t _lhs_stride, index_t _inc_r, diff --git a/src/interface/extension/backend/nvidia_gpu.hpp b/src/interface/extension/backend/nvidia_gpu.hpp index d378abccf..b21228f5d 100644 --- a/src/interface/extension/backend/nvidia_gpu.hpp +++ b/src/interface/extension/backend/nvidia_gpu.hpp @@ -152,7 +152,7 @@ typename sb_handle_t::event_t _axpy_batch( constexpr index_t local_size = static_cast(256); const auto nWG = (_N + local_size - 1) / local_size; // the limit for _N*batch_size is taken empirically from test on A100 - if (_N * _batch_size <= 81920) { + if (_N * _batch_size <= 81920 || _N <= 16384) { const index_t global_size = local_size * nWG * _batch_size; return blas::internal::_axpy_batch_impl<256, 32>( sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, diff --git a/src/interface/extension_interface.hpp b/src/interface/extension_interface.hpp index 7ef7859f6..9613a4aeb 100644 --- a/src/interface/extension_interface.hpp +++ b/src/interface/extension_interface.hpp @@ -624,10 +624,11 @@ typename sb_handle_t::event_t _axpy_batch_impl( _incx = -_incx; _incy = -_incy; } + // if _stride_x is zero use _N as vx size + const index_t overall_vx_size = (_stride_x) ? _stride_x * _batch_size : _N; typename VectorViewType::type vx = - make_vector_view(_vx, static_cast(_incx), - static_cast(_N * _batch_size)); - auto vy = make_vector_view(_vy, _incy, _N * _batch_size); + make_vector_view(_vx, static_cast(_incx), overall_vx_size); + auto vy = make_vector_view(_vy, _incy, _stride_y * _batch_size); // If both vectors are read from the same side it doesn't matter the sign of // the increment if (_incx * _incy > 0) { diff --git a/src/operations/extension/axpy_batch.hpp b/src/operations/extension/axpy_batch.hpp index ec48e09d0..bb797b34a 100644 --- a/src/operations/extension/axpy_batch.hpp +++ b/src/operations/extension/axpy_batch.hpp @@ -45,7 +45,9 @@ Axpy_batch::Axpy_batch( lhs_stride_(_lhs_stride), inc_r(_inc_r), rhs_stride_(_rhs_stride), - batch_size_(_batch_size){}; + batch_size_(_batch_size), + n_block_per_loop(std::min((n_ + localSize - 1) / localSize, + static_cast(maxBlockPerBatch))){}; template @@ -62,15 +64,12 @@ Axpy_batch::eval( const value_t alpha{alpha_}; const auto vx = rhs_.get_data(); const auto vy = lhs_.get_data(); - const auto nbl = sycl::min((n + localSize - 1) / localSize, - static_cast(maxBlockPerBatch)); + const auto nbl{n_block_per_loop}; const index_t block_id = ndItem.get_group(0) % nbl; const index_t l_id = static_cast(ndItem.get_local_range(0)) * block_id + ndItem.get_local_id(0); - // const index_t group_id = - // static_cast(ndItem.get_global_linear_id() / n); const index_t group_id = static_cast(ndItem.get_group(0) / nbl); const index_t size_compute_rateo = @@ -83,7 +82,7 @@ Axpy_batch::eval( const index_t stride_y = ndItem.get_local_range(0) * nbl * inc_l; index_t x_index{}; index_t y_index{}; - int j{0}; + int j{}; if constexpr (sameSign) { for (auto out_loop = group_id; out_loop < batch_size_; @@ -100,8 +99,7 @@ Axpy_batch::eval( } else { for (auto out_loop = group_id; out_loop < batch_size_; out_loop += jump_value) { - x_index = - out_loop * rhs_stride_ + inc_r + n * sycl::abs(inc_r) + l_id * inc_r; + x_index = out_loop * rhs_stride_ + inc_r + n * (-inc_r) + l_id * inc_r; y_index = out_loop * lhs_stride_ + l_id * inc_l; j = y_index; for (auto i = x_index; i >= (out_loop * rhs_stride_); diff --git a/test/unittest/extension/axpy_batch_test.cpp b/test/unittest/extension/axpy_batch_test.cpp index de2668c8e..7c7bce441 100644 --- a/test/unittest/extension/axpy_batch_test.cpp +++ b/test/unittest/extension/axpy_batch_test.cpp @@ -45,7 +45,7 @@ void run_test(const combination_t combi) { const index_t stride_x{size * std::abs(incX) * stride_mul_x}; const index_t stride_y{size * std::abs(incY) * stride_mul_y}; - auto x_size = stride_x * batch_size; + auto x_size = (stride_x) ? stride_x * batch_size : size * std::abs(incX); auto y_size = stride_y * batch_size; // Input vector std::vector x_v(x_size); @@ -132,7 +132,7 @@ const auto combi = ::testing::Values(0.0, 1.3), // alpha ::testing::Values(1, -1, 2, -4), // incX ::testing::Values(1, -1, 3, -5), // incY - ::testing::Values(1, 2, 3), // stride_mul_x + ::testing::Values(0, 1, 2, 3), // stride_mul_x ::testing::Values(1, 2, 3), // stride_mul_y ::testing::Values(5) // batch_size );