Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Address PR comments and fix small bug on nvidia gpu configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick committed Nov 20, 2023
1 parent 7dd74d1 commit 0ccfe3e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 15 deletions.
27 changes: 26 additions & 1 deletion include/operations/extension/axpy_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@

namespace blas {

/*!
* This class holds the kernel implementation to perform axpy_batch
* operator.
*
* It has three additional template parameters to keep the operation simple and
* to avoid some computation or code divergence inside the kernel code.
*
* If sameSign is false the kernel always assumes that inc_r is negative. This
* is true by construction. When the increases are of different sizes the result
* positions are swapped and indexes must be computed accordingly. Keeping
* always inc_r negative and inc_l positive reduces keep index
* computation consistent, obtaining the correct result.
*
* sameSign indicate if inc_r and inc_l are of the sameSign. The code
* implementation need to follow different index computation. This template
* allow the condition at compile time, avoiding code divergency.
*
* localSize local size of group, allow some device tailoring at compile
* time.
*
* maxBlockPerBatch set the number of device group to use for each
* batch. If possible multiple batches are computed concurrently.
*/

template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
struct Axpy_batch {
Expand All @@ -36,7 +60,8 @@ struct Axpy_batch {
lhs_t lhs_;
rhs_t rhs_;
value_t alpha_;
index_t n_, inc_r, inc_l, lhs_stride_, rhs_stride_, batch_size_;
index_t n_, inc_r, inc_l, lhs_stride_, rhs_stride_, batch_size_,
n_block_per_loop;

Axpy_batch(lhs_t _lhs, rhs_t _rhs_1, value_t _alpha, index_t _N,
index_t _inc_l, index_t _lhs_stride, index_t _inc_r,
Expand Down
2 changes: 1 addition & 1 deletion src/interface/extension/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ typename sb_handle_t::event_t _axpy_batch(
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on A100
if (_N * _batch_size <= 81920) {
if (_N * _batch_size <= 81920 || _N <= 16384) {
const index_t global_size = local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
Expand Down
7 changes: 4 additions & 3 deletions src/interface/extension_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,10 +624,11 @@ typename sb_handle_t::event_t _axpy_batch_impl(
_incx = -_incx;
_incy = -_incy;
}
// if _stride_x is zero use _N as vx size
const index_t overall_vx_size = (_stride_x) ? _stride_x * _batch_size : _N;
typename VectorViewType<container_0_t, index_t, index_t>::type vx =
make_vector_view(_vx, static_cast<index_t>(_incx),
static_cast<index_t>(_N * _batch_size));
auto vy = make_vector_view(_vy, _incy, _N * _batch_size);
make_vector_view(_vx, static_cast<index_t>(_incx), overall_vx_size);
auto vy = make_vector_view(_vy, _incy, _stride_y * _batch_size);
// If both vectors are read from the same side it doesn't matter the sign of
// the increment
if (_incx * _incy > 0) {
Expand Down
14 changes: 6 additions & 8 deletions src/operations/extension/axpy_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::Axpy_batch(
lhs_stride_(_lhs_stride),
inc_r(_inc_r),
rhs_stride_(_rhs_stride),
batch_size_(_batch_size){};
batch_size_(_batch_size),
n_block_per_loop(std::min((n_ + localSize - 1) / localSize,
static_cast<index_t>(maxBlockPerBatch))){};

template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
Expand All @@ -62,15 +64,12 @@ Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::eval(
const value_t alpha{alpha_};
const auto vx = rhs_.get_data();
const auto vy = lhs_.get_data();
const auto nbl = sycl::min((n + localSize - 1) / localSize,
static_cast<index_t>(maxBlockPerBatch));
const auto nbl{n_block_per_loop};

const index_t block_id = ndItem.get_group(0) % nbl;
const index_t l_id =
static_cast<index_t>(ndItem.get_local_range(0)) * block_id +
ndItem.get_local_id(0);
// const index_t group_id =
// static_cast<index_t>(ndItem.get_global_linear_id() / n);
const index_t group_id = static_cast<index_t>(ndItem.get_group(0) / nbl);

const index_t size_compute_rateo =
Expand All @@ -83,7 +82,7 @@ Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::eval(
const index_t stride_y = ndItem.get_local_range(0) * nbl * inc_l;
index_t x_index{};
index_t y_index{};
int j{0};
int j{};

if constexpr (sameSign) {
for (auto out_loop = group_id; out_loop < batch_size_;
Expand All @@ -100,8 +99,7 @@ Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::eval(
} else {
for (auto out_loop = group_id; out_loop < batch_size_;
out_loop += jump_value) {
x_index =
out_loop * rhs_stride_ + inc_r + n * sycl::abs(inc_r) + l_id * inc_r;
x_index = out_loop * rhs_stride_ + inc_r + n * (-inc_r) + l_id * inc_r;
y_index = out_loop * lhs_stride_ + l_id * inc_l;
j = y_index;
for (auto i = x_index; i >= (out_loop * rhs_stride_);
Expand Down
4 changes: 2 additions & 2 deletions test/unittest/extension/axpy_batch_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void run_test(const combination_t<scalar_t> combi) {
const index_t stride_x{size * std::abs(incX) * stride_mul_x};
const index_t stride_y{size * std::abs(incY) * stride_mul_y};

auto x_size = stride_x * batch_size;
auto x_size = (stride_x) ? stride_x * batch_size : size * std::abs(incX);
auto y_size = stride_y * batch_size;
// Input vector
std::vector<scalar_t> x_v(x_size);
Expand Down Expand Up @@ -132,7 +132,7 @@ const auto combi =
::testing::Values<scalar_t>(0.0, 1.3), // alpha
::testing::Values(1, -1, 2, -4), // incX
::testing::Values(1, -1, 3, -5), // incY
::testing::Values(1, 2, 3), // stride_mul_x
::testing::Values(0, 1, 2, 3), // stride_mul_x
::testing::Values(1, 2, 3), // stride_mul_y
::testing::Values(5) // batch_size
);
Expand Down

0 comments on commit 0ccfe3e

Please sign in to comment.