Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Change axpy_batch implementation to improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick committed Nov 9, 2023
1 parent 19c9f0e commit daca9e8
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 58 deletions.
9 changes: 9 additions & 0 deletions include/interface/extension_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ typename sb_handle_t::event_t _axpy_batch(
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies);

template <int localSize, int maxBlockPerBatch, typename sb_handle_t,
typename container_0_t, typename container_1_t, typename element_t,
typename index_t>
typename sb_handle_t::event_t _axpy_batch_impl(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies, index_t global_size);

} // namespace internal

/**
Expand Down
14 changes: 8 additions & 6 deletions include/operations/extension/axpy_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

namespace blas {

template <bool same_sign, typename lhs_t, typename rhs_t>
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
struct Axpy_batch {
using value_t = typename lhs_t::value_t;
using index_t = typename rhs_t::index_t;
Expand All @@ -50,15 +51,16 @@ struct Axpy_batch {
void adjust_access_displacement();
};

template <bool same_sign, typename lhs_t, typename rhs_t>
Axpy_batch<same_sign, lhs_t, rhs_t> make_axpy_batch(
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t> make_axpy_batch(
lhs_t _lhs, rhs_t _rhs_1, typename rhs_t::value_t _alpha,
typename rhs_t::index_t _N, typename rhs_t::index_t _inc_l,
typename rhs_t::index_t _lhs_stride, typename rhs_t::index_t _inc_r,
typename rhs_t::index_t _rhs_stride, typename rhs_t::index_t _batch_size) {
return Axpy_batch<same_sign, lhs_t, rhs_t>(_lhs, _rhs_1, _alpha, _N, _inc_l,
_lhs_stride, _inc_r, _rhs_stride,
_batch_size);
return Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>(
_lhs, _rhs_1, _alpha, _N, _inc_l, _lhs_stride, _inc_r, _rhs_stride,
_batch_size);
}

} // namespace blas
Expand Down
25 changes: 25 additions & 0 deletions src/interface/extension/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,31 @@ typename sb_handle_t::event_t _omatadd_batch(
}
} // namespace backend
} // namespace omatadd_batch

namespace axpy_batch {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename sb_handle_t::event_t _axpy_batch(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies) {
// local_size taken empirically
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on AMDW6800
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
}
} // namespace backend
} // namespace axpy_batch

} // namespace blas

#endif
24 changes: 24 additions & 0 deletions src/interface/extension/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ typename sb_handle_t::event_t _omatadd_batch(
}
} // namespace backend
} // namespace omatadd_batch

namespace axpy_batch {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename sb_handle_t::event_t _axpy_batch(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies) {
// local_size taken empirically
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on AMDW6800
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
}
} // namespace backend
} // namespace axpy_batch
} // namespace blas

#endif
24 changes: 24 additions & 0 deletions src/interface/extension/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ typename sb_handle_t::event_t _omatadd_batch(
}
} // namespace backend
} // namespace omatadd_batch

namespace axpy_batch {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename sb_handle_t::event_t _axpy_batch(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies) {
// local_size taken empirically
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on intelGPU
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
}
} // namespace backend
} // namespace axpy_batch
} // namespace blas

#endif
24 changes: 24 additions & 0 deletions src/interface/extension/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,30 @@ typename sb_handle_t::event_t _omatadd_batch(
}
} // namespace backend
} // namespace omatadd_batch

namespace axpy_batch {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename sb_handle_t::event_t _axpy_batch(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies) {
// local_size taken empirically
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on AMDW6800
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
}
} // namespace backend
} // namespace axpy_batch
} // namespace blas

#endif
35 changes: 22 additions & 13 deletions src/interface/extension_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,14 +597,26 @@ typename sb_handle_t::event_t _reduction(
sb_handle, buffer_in, ld, buffer_out, rows, cols, dependencies);
}
}

template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename sb_handle_t::event_t _axpy_batch(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies) {
return blas::axpy_batch::backend::_axpy_batch(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies);
}

template <int localSize, int maxBlockPerBatch, typename sb_handle_t,
typename container_0_t, typename container_1_t, typename element_t,
typename index_t>
typename sb_handle_t::event_t _axpy_batch_impl(
sb_handle_t& sb_handle, index_t _N, element_t _alpha, container_0_t _vx,
index_t _incx, index_t _stride_x, container_1_t _vy, index_t _incy,
index_t _stride_y, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies, index_t global_size) {
// if inc are of opposite sign the values are exchanged. It doesn't matter
// which one is positive or negative, so to simplify index computation in
// kernel we always set incx to be negative and incy to be positive.
Expand All @@ -616,23 +628,20 @@ typename sb_handle_t::event_t _axpy_batch(
make_vector_view(_vx, static_cast<index_t>(_incx),
static_cast<index_t>(_N * _batch_size));
auto vy = make_vector_view(_vy, _incy, _N * _batch_size);
const auto local_size = sb_handle.get_work_group_size();
const auto nWG = (_N + local_size - 1) / local_size;
const auto global_size = local_size * nWG * _batch_size;
// If both vectors are read from the same side it doesn't matter the sign of
// the increment
if (_incx * _incy > 0) {
auto op =
make_axpy_batch<true>(vy, vx, _alpha, _N, std::abs(_incy), _stride_y,
std::abs(_incx), _stride_x, _batch_size);
typename sb_handle_t::event_t ret =
sb_handle.execute(op, local_size, global_size, _dependencies);
auto op = make_axpy_batch<true, localSize, maxBlockPerBatch>(
vy, vx, _alpha, _N, std::abs(_incy), _stride_y, std::abs(_incx),
_stride_x, _batch_size);
typename sb_handle_t::event_t ret = sb_handle.execute(
op, static_cast<index_t>(localSize), global_size, _dependencies);
return ret;
} else {
auto op = make_axpy_batch<false>(vy, vx, _alpha, _N, _incy, _stride_y,
_incx, _stride_x, _batch_size);
typename sb_handle_t::event_t ret =
sb_handle.execute(op, local_size, global_size, _dependencies);
auto op = make_axpy_batch<false, localSize, maxBlockPerBatch>(
vy, vx, _alpha, _N, _incy, _stride_y, _incx, _stride_x, _batch_size);
typename sb_handle_t::event_t ret = sb_handle.execute(
op, static_cast<index_t>(localSize), global_size, _dependencies);
return ret;
}
}
Expand Down
117 changes: 78 additions & 39 deletions src/operations/extension/axpy_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@

namespace blas {

template <bool same_sign, typename lhs_t, typename rhs_t>
Axpy_batch<same_sign, lhs_t, rhs_t>::Axpy_batch(
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::Axpy_batch(
lhs_t _lhs, rhs_t _rhs, typename lhs_t::value_t _alpha,
typename rhs_t::index_t _N, typename rhs_t::index_t _inc_l,
typename rhs_t::index_t _lhs_stride, typename rhs_t::index_t _inc_r,
Expand All @@ -46,69 +47,107 @@ Axpy_batch<same_sign, lhs_t, rhs_t>::Axpy_batch(
rhs_stride_(_rhs_stride),
batch_size_(_batch_size){};

template <bool same_sign, typename lhs_t, typename rhs_t>
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
PORTBLAS_INLINE typename lhs_t::value_t
Axpy_batch<same_sign, lhs_t, rhs_t>::eval(index_t i) {}
Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::eval(
index_t i) {}

template <bool same_sign, typename lhs_t, typename rhs_t>
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
PORTBLAS_INLINE typename lhs_t::value_t
Axpy_batch<same_sign, lhs_t, rhs_t>::eval(cl::sycl::nd_item<1> ndItem) {
Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::eval(
cl::sycl::nd_item<1> ndItem) {
const index_t n{n_};
const value_t alpha{alpha_};

const index_t l_id = static_cast<index_t>(ndItem.get_global_linear_id() % n);
const index_t group_id =
static_cast<index_t>(ndItem.get_global_linear_id() / n);

if (group_id >= batch_size_) return {};

if constexpr (same_sign) {
const index_t x_index = group_id * rhs_stride_ + l_id * inc_r;
const index_t y_index = group_id * lhs_stride_ + l_id * inc_l;

const value_t ax = alpha * rhs_.get_data()[x_index];
lhs_.get_data()[y_index] += ax;
const auto vx = rhs_.get_data();
const auto vy = lhs_.get_data();
const auto nbl = sycl::min((n + localSize - 1) / localSize,
static_cast<index_t>(maxBlockPerBatch));

const index_t block_id = ndItem.get_group(0) % nbl;
const index_t l_id =
static_cast<index_t>(ndItem.get_local_range(0)) * block_id +
ndItem.get_local_id(0);
// const index_t group_id =
// static_cast<index_t>(ndItem.get_global_linear_id() / n);
const index_t group_id = static_cast<index_t>(ndItem.get_group(0) / nbl);

const index_t size_compute_rateo =
(n > nbl * localSize) ? n / (nbl * localSize) : batch_size_;
const index_t jump_value{sycl::min(batch_size_, size_compute_rateo)};

if (group_id >= jump_value || l_id > n) return {};

const index_t stride_x = ndItem.get_local_range(0) * nbl * inc_r;
const index_t stride_y = ndItem.get_local_range(0) * nbl * inc_l;
index_t x_index{};
index_t y_index{};
int j{0};

if constexpr (sameSign) {
for (auto out_loop = group_id; out_loop < batch_size_;
out_loop += jump_value) {
x_index = out_loop * rhs_stride_ + l_id * inc_r;
y_index = out_loop * lhs_stride_ + l_id * inc_l;
j = y_index;
for (auto i = x_index; i < (out_loop * rhs_stride_) + n * inc_r;
i += stride_x, j += stride_y) {
vy[j] += alpha * vx[i];
}
}

} else {
const index_t x_index =
group_id * rhs_stride_ + inc_r + n * sycl::abs(inc_r) + l_id * inc_r;
const index_t y_index = group_id * lhs_stride_ + l_id * inc_l;

const value_t ax = alpha * rhs_.get_data()[x_index];
lhs_.get_data()[y_index] += ax;
for (auto out_loop = group_id; out_loop < batch_size_;
out_loop += jump_value) {
x_index =
out_loop * rhs_stride_ + inc_r + n * sycl::abs(inc_r) + l_id * inc_r;
y_index = out_loop * lhs_stride_ + l_id * inc_l;
j = y_index;
for (auto i = x_index; i >= (out_loop * rhs_stride_);
i += stride_x, j += stride_y) {
vy[j] += alpha * vx[i];
}
}
}

return {};
}

template <bool same_sign, typename lhs_t, typename rhs_t>
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
template <typename sharedT>
PORTBLAS_INLINE typename lhs_t::value_t
Axpy_batch<same_sign, lhs_t, rhs_t>::eval(sharedT shMem,
sycl::nd_item<1> ndItem){};
Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::eval(
sharedT shMem, sycl::nd_item<1> ndItem){};

template <bool same_sign, typename lhs_t, typename rhs_t>
PORTBLAS_INLINE void Axpy_batch<same_sign, lhs_t, rhs_t>::bind(
cl::sycl::handler& h) {
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
PORTBLAS_INLINE void Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t,
rhs_t>::bind(cl::sycl::handler& h) {
lhs_.bind(h);
rhs_.bind(h);
}

template <bool same_sign, typename lhs_t, typename rhs_t>
PORTBLAS_INLINE void
Axpy_batch<same_sign, lhs_t, rhs_t>::adjust_access_displacement() {
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
PORTBLAS_INLINE void Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t,
rhs_t>::adjust_access_displacement() {
lhs_.adjust_access_displacement();
rhs_.adjust_access_displacement();
}

template <bool same_sign, typename lhs_t, typename rhs_t>
PORTBLAS_INLINE typename rhs_t::index_t
Axpy_batch<same_sign, lhs_t, rhs_t>::get_size() const {
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
PORTBLAS_INLINE typename rhs_t::index_t Axpy_batch<
sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::get_size() const {
return n_ * batch_size_;
}

template <bool same_sign, typename lhs_t, typename rhs_t>
PORTBLAS_INLINE bool Axpy_batch<same_sign, lhs_t, rhs_t>::valid_thread(
template <bool sameSign, int localSize, int maxBlockPerBatch, typename lhs_t,
typename rhs_t>
PORTBLAS_INLINE bool
Axpy_batch<sameSign, localSize, maxBlockPerBatch, lhs_t, rhs_t>::valid_thread(
cl::sycl::nd_item<1> ndItem) const {
return true;
}
Expand Down

0 comments on commit daca9e8

Please sign in to comment.