Skip to content

Commit

Permalink
Update ger implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pgorlani committed Apr 11, 2024
1 parent e5f9738 commit 6e18114
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 16 deletions.
58 changes: 58 additions & 0 deletions include/operations/blas2_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,64 @@ make_trsv(vector_t &lhs_, matrix_t &matrix_, sync_t &sync_) {
subgroups, is_upper, is_transposed, is_unit>(lhs_, matrix_, k_,
sync_);
}
/**
* @struct Ger
* @brief Tree node representing the sum of scalar-vector-vector product with a
* matrix, i.e., it computes lhs_ such that
*
* lhs_ = scalar_ * ( rhs_1_ * rhs_2_^t ) + lhs_
*
* @param lhs_ input/output matrix
* @param scalar_ value for scaling vector product
* @param rhs_1_ first input vector
* @param rhs_2_ second input vector
* @param nRowsWG_ rows of the workgroup tile
* @param nColsWG_ cols of the workgroup tile
* @param nWG_row_ number of tiles per global size row
* @param nWG_col_ number of tiles per global size column
*
*/
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
struct Ger {
using value_t = typename rhs_2_t::value_t;
using index_t = typename rhs_2_t::index_t;

lhs_t lhs_;
value_t scalar_;
rhs_1_t rhs_1_;
rhs_2_t rhs_2_;
index_t nRowsWG_;
index_t nColsWG_;
index_t nWG_row_;
index_t nWG_col_;

Ger(lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG,
index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col);

index_t get_size() const;
bool valid_thread(cl::sycl::nd_item<1> ndItem) const;
value_t eval(index_t i);
value_t eval(cl::sycl::nd_item<1> ndItem);
template <typename sharedT>
value_t eval(sharedT shrMem, cl::sycl::nd_item<1> ndItem);
void bind(cl::sycl::handler &h);
void adjust_access_displacement();
};

/*!
@brief Generator/factory for GER trees.
*/
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
Ger<lhs_t, rhs_1_t, rhs_2_t> make_ger(lhs_t &lhs_,
typename lhs_t::value_t scalar_,
rhs_1_t &rhs_1_, rhs_2_t &rhs_2_,
typename rhs_2_t::index_t nRowsWG_,
typename rhs_2_t::index_t nColsWG_,
typename rhs_2_t::index_t nWG_row_,
typename rhs_2_t::index_t nWG_col_) {
return Ger<lhs_t, rhs_1_t, rhs_2_t>(lhs_, scalar_, rhs_1_, rhs_2_, nRowsWG_,
nColsWG_, nWG_row_, nWG_col_);
}

/**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/
// template <typename lhs_t,typename rhs_1_t,typename rhs_2_t>
Expand Down
67 changes: 51 additions & 16 deletions src/interface/blas2_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ typename sb_handle_t::event_t _ger_impl(
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies, index_t _localSize = 0,
index_t _scratchPadSize = 0, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
bool _useLocalMem = true, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
index_t M = _M;
index_t N = _N;
auto mA = make_matrix_view<col_major>(_mA, M, N, _lda);
Expand All @@ -887,24 +887,39 @@ typename sb_handle_t::event_t _ger_impl(
typename VectorViewType<container_t1, index_t, increment_t>::type vy =
make_vector_view(_vy, _incy, N);

const index_t localSize =
(_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
const index_t nRowsWG = (_nRowsWG == 0) ? localSize : std::min(M, _nRowsWG);
_localSize = (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
_nRowsWG = (_nRowsWG == 0) ? _localSize : _nRowsWG;
_nColsWG = (_nColsWG == 0) ? _localSize : _nColsWG;

const index_t nColsWG = (_nColsWG == 0) ? localSize : std::min(N, _nColsWG);
assert(_localSize % _nRowsWG == 0);
assert((_nRowsWG * _nColsWG) % _localSize == 0);
assert(_nColsWG % (_localSize / _nRowsWG) == 0);

const index_t scratchPadSize =
(_localSize == 0) ? localSize : _scratchPadSize;
if (_useLocalMem) {
assert((_nRowsWG <= _localSize) && (_nColsWG <= _localSize));
} else {
std::vector<size_t> subgroup_sizes =
sb_handle.get_queue()
.get_device()
.template get_info<sycl::info::device::sub_group_sizes>();
size_t min_subgroup_size = *subgroup_sizes.begin();
size_t max_subgroup_size = *subgroup_sizes.rbegin();
assert(((_nRowsWG * _nColsWG) / _localSize) <= min_subgroup_size);
assert(_nRowsWG % max_subgroup_size == 0);
}

const index_t nWGPerCol = (N - 1) / nColsWG + 1;
const index_t nWGPerRow = (M - 1) / nRowsWG + 1;
const index_t globalSize = localSize * nWGPerRow * nWGPerCol;
const index_t nWGPerCol = (N - 1) / _nColsWG + 1;
const index_t nWGPerRow = (M - 1) / _nRowsWG + 1;
const index_t globalSize = _localSize * nWGPerRow * nWGPerCol;

typename sb_handle_t::event_t ret;
auto assignOp =
make_ger_col(mA, _alpha, vx, vy, nWGPerRow, nWGPerCol, scratchPadSize);
return sb_handle.execute(assignOp, localSize, globalSize, scratchPadSize,
_dependencies);
make_ger(mA, _alpha, vx, vy, _nRowsWG, _nColsWG, nWGPerRow, nWGPerCol);

return _useLocalMem ? sb_handle.execute(assignOp, _localSize, globalSize,
_nRowsWG + _nColsWG, _dependencies)
: sb_handle.execute(assignOp, _localSize, globalSize,
_dependencies);
}

/*! _SYR.
Expand Down Expand Up @@ -1280,10 +1295,30 @@ typename sb_handle_t::event_t inline _ger(
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies) {
// TODO: Here we can use some heuristics to select localn global, local, and
// scratch size per device
index_t localSize = 0;
bool useLocalMem = true;
index_t nRowsWG = 0;
index_t nColsWG = 0;

#if defined(INTEL_GPU)
localSize = 32;
useLocalMem = false;
nRowsWG = 32;
nColsWG = 8;
#elif defined(NVIDIA_GPU)
localSize = 256;
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
nRowsWG = 32;
nColsWG = 32;
#elif defined(AMD_GPU)
localSize = (_N < 8192 && _M < 8192) ? 512 : 256;
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
nRowsWG = (_N < 8192 && _M < 8192) ? 64 : 128;
nColsWG = (_N < 8192 && _M < 8192) ? 64 : 256;
#endif

return _ger_impl(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda,
_dependencies);
_dependencies, localSize, useLocalMem, nRowsWG, nColsWG);
}

template <typename sb_handle_t, typename index_t, typename element_t,
Expand Down
167 changes: 167 additions & 0 deletions src/operations/blas2/ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,173 @@

namespace blas {

template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
PORTBLAS_INLINE Ger<lhs_t, rhs_1_t, rhs_2_t>::Ger(
lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG,
index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col)
: lhs_(_l),
scalar_(_scl),
rhs_1_(_r1),
rhs_2_(_r2),
nRowsWG_(_nRowsWG),
nColsWG_(_nColsWG),
nWG_row_(_nWG_row),
nWG_col_(_nWG_col) {}

template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
PORTBLAS_INLINE typename Ger<lhs_t, rhs_1_t, rhs_2_t>::index_t
Ger<lhs_t, rhs_1_t, rhs_2_t>::get_size() const {
return rhs_1_.get_size();
}
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
PORTBLAS_INLINE bool Ger<lhs_t, rhs_1_t, rhs_2_t>::valid_thread(
cl::sycl::nd_item<1> ndItem) const {
return true;
}

template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
PORTBLAS_INLINE typename Ger<lhs_t, rhs_1_t, rhs_2_t>::value_t
Ger<lhs_t, rhs_1_t, rhs_2_t>::eval(cl::sycl::nd_item<1> ndItem) {
using index_t = typename Ger<lhs_t, rhs_1_t, rhs_2_t>::index_t;

const index_t subgroup_size = ndItem.get_sub_group().get_local_range().get(0);
const index_t subgroups_per_col = nRowsWG_ / subgroup_size;
const index_t subgroups_per_group =
ndItem.get_sub_group().get_group_range().get(0);

const index_t group_size = ndItem.get_local_range(0);

// col_per_workitem <= subgroup_size
const index_t col_per_workitem = nColsWG_ * nRowsWG_ / group_size;

const index_t group_id = ndItem.get_group(0);
const index_t idWFR = group_id % nWG_row_;
const index_t idWFC = group_id / nWG_row_;

const index_t subgroup_id = ndItem.get_sub_group().get_group_id().get(0);
const index_t subgroup_local_id =
ndItem.get_sub_group().get_local_id().get(0);

const index_t id_row0 = idWFR * nRowsWG_ +
subgroup_size * (subgroup_id % subgroups_per_col) +
subgroup_local_id;
const index_t id_col0 =
idWFC * nColsWG_ + col_per_workitem * (subgroup_id / subgroups_per_col);

const index_t dimR = lhs_.get_size_row();
const index_t dimC = lhs_.get_size_col();
const bool id_row_active = id_row0 < dimR;

#ifndef __ADAPTIVECPP__
const value_t rhs_2 = (subgroup_local_id < col_per_workitem &&
id_col0 + subgroup_local_id < dimC)
? rhs_2_.eval(id_col0 + subgroup_local_id)
: 0;
#endif

const value_t scal_rhs_1 = id_row_active ? scalar_ * rhs_1_.eval(id_row0) : 0;

value_t prefetch_lhs_ =
(id_row_active && id_col0 < dimC) ? lhs_.eval(id_row0, id_col0) : 0;

for (index_t sub_id_col = 0; sub_id_col < col_per_workitem; sub_id_col++) {
const value_t rhs_2_sub_id_col =
#ifndef __ADAPTIVECPP__
cl::sycl::group_broadcast(ndItem.get_sub_group(), rhs_2, sub_id_col);
#else
rhs_2_.eval(id_col0 + sub_id_col);
#endif
if (id_row_active && id_col0 + sub_id_col < dimC) {
lhs_.eval(id_row0, id_col0 + sub_id_col) =
prefetch_lhs_ + scal_rhs_1 * rhs_2_sub_id_col;
prefetch_lhs_ = (id_col0 + sub_id_col + 1 < dimC)
? lhs_.eval(id_row0, id_col0 + sub_id_col + 1)
: 0;
}
}

return 0;
}

template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
template <typename sharedT>
PORTBLAS_INLINE typename Ger<lhs_t, rhs_1_t, rhs_2_t>::value_t
Ger<lhs_t, rhs_1_t, rhs_2_t>::eval(sharedT shrMem,
cl::sycl::nd_item<1> ndItem) {
using index_t = typename Ger<lhs_t, rhs_1_t, rhs_2_t>::index_t;

const index_t group_id = ndItem.get_group(0);
const index_t idWFR = group_id % nWG_row_;
const index_t idWFC = group_id / nWG_row_;
const index_t frs_row = idWFR * nRowsWG_;
const index_t group_local_id = ndItem.get_local_id(0);

// group_size%nRowsWG_ == 0
const index_t id_row0 = group_local_id % nRowsWG_;
const index_t id_row1 = frs_row + id_row0;

index_t frs_col = idWFC * nColsWG_;

const index_t dimR = lhs_.get_size_row();
const index_t dimC = lhs_.get_size_col();

value_t *l_rhs_1 = shrMem.localAcc.get_pointer();
value_t *l_rhs_2 = shrMem.localAcc.get_pointer() + nRowsWG_;

// nRowsWG_ <= group_size
if (group_local_id < nRowsWG_)
l_rhs_1[group_local_id] =
(frs_row + group_local_id < dimR)
? scalar_ * rhs_1_.eval(frs_row + group_local_id)
: 0;

// nColsWG_ <= group_size
if (group_local_id < nColsWG_)
l_rhs_2[group_local_id] = (frs_col + group_local_id < dimC)
? rhs_2_.eval(frs_col + group_local_id)
: 0;

const index_t group_size = ndItem.get_local_range(0);

// nRowsWG_ * nColsWG_ % group_size == 0
const index_t col_per_workitem = nRowsWG_ * nColsWG_ / group_size;
const index_t subgroup_col_id = group_local_id / nRowsWG_;

const index_t id_col0 = subgroup_col_id * col_per_workitem;
const index_t id_col1 = frs_col + id_col0;

value_t prefetch_lhs_ =
(id_row1 < dimR && id_col1 < dimC) ? lhs_.eval(id_row1, id_col1) : 0;

ndItem.barrier(cl::sycl::access::fence_space::local_space);

for (index_t id_col = 0; id_col < col_per_workitem; id_col++) {
const value_t val = l_rhs_1[id_row0] * l_rhs_2[id_col0 + id_col];
if (id_row1 < dimR && id_col1 + id_col < dimC) {
lhs_.eval(id_row1, id_col1 + id_col) = prefetch_lhs_ + val;
prefetch_lhs_ = (id_col1 + id_col + 1 < dimC)
? lhs_.eval(id_row1, id_col1 + id_col + 1)
: 0;
}
}

return 0;
}

template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
PORTBLAS_INLINE void Ger<lhs_t, rhs_1_t, rhs_2_t>::bind(cl::sycl::handler &h) {
lhs_.bind(h);
rhs_1_.bind(h);
rhs_2_.bind(h);
}
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
PORTBLAS_INLINE void
Ger<lhs_t, rhs_1_t, rhs_2_t>::adjust_access_displacement() {
lhs_.adjust_access_displacement();
rhs_1_.adjust_access_displacement();
rhs_2_.adjust_access_displacement();
}

/**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/
// template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
template <bool Single, bool Lower, bool Diag, bool Upper, typename lhs_t,
Expand Down

0 comments on commit 6e18114

Please sign in to comment.